diff --git a/.gitignore b/.gitignore deleted file mode 100644 index 3423c416a7..0000000000 --- a/.gitignore +++ /dev/null @@ -1,11 +0,0 @@ -data/tokenizers -__pycache__/ -.DS_Store -modded-nanogpt/ -modded-nanogpt -data/datasets -data/manifest.json -data/docs_selected.jsonl -.mypy_cache/ -.venv -logs/ \ No newline at end of file diff --git a/LICENSE b/LICENSE deleted file mode 100644 index 4243cb4a9a..0000000000 --- a/LICENSE +++ /dev/null @@ -1,21 +0,0 @@ -MIT License - -Copyright (c) 2026 OpenAI - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. diff --git a/README.md b/README.md index 39012623ea..6754074834 100644 --- a/README.md +++ b/README.md @@ -1,235 +1,183 @@ -1920x640-discord +# Non-Record: CAT, Sparsity (Structured and Hessian-Guided), MoE, KAN Negative Results -
-
+**Tokenizer:** SP8192 +**Submission type:** Non-record (negative results & technique exploration) -**OpenAI Model Craft Challenge: Parameter Golf** is a challenge to train the best language model that fits in a 16MB artifact and trains in under 10 minutes on 8xH100s, evaluated by compression on the FineWeb validation set (tokenizer-agnostic, bits per byte). +--- -This challenge is heavily inspired by the [NanoGPT Speedrunning](https://github.com/KellerJordan/modded-nanogpt) challenge, where participants compete to train a model that reaches 3.28 FineWeb validation loss as quickly as possible. We're excited to see how optimizing for a parameter-constrained setting pushes people toward unique architectures (test-time compute, aggressive parameter tying, depth recurrence, low-rank training, ...), compression schemes (low precision, QAT, bitnets, novel tokenizers, ...), and other creative submissions (test-time training, long context, megakernels ...). +## Summary -If you're familiar with [neural scaling laws](https://arxiv.org/abs/2001.08361), you can consider this challenge a form of L(N) optimization, where the objective is to optimize the lowest loss given a fixed number of parameters (N) unconstrained by data, compute, steps, or architecture. Challenges like the [NanoGPT Speedrun](https://github.com/KellerJordan/modded-nanogpt), which optimizes for a form of L(T) (~lowest time given constrained loss) or the [NanoGPT Slowrun](https://github.com/qlabs-eng/slowrun), which optimizes for L(D) (lowest loss given constrained dataset size), can be thought of as equivalent challenges in this family. +This submission documents a systematic exploration of some novel techniques, built on top of the PR #1394. All techniques were implemented as toggleable features in a single training script and evaluated across 14 runs on 1xH100 SXM (RunPod) using a scaled-down "medium" configuration (2000 steps, 2048 seq_len, 5 train shards). -Ideally, we'd allow for submissions to use arbitrary computational resources. But in order to make the challenge not inaccessibly expensive, we're limiting *leaderboard submissions* to 10 minutes on 8xH100s. However, we'd still love to see submissions that don't meet the compute limitation requirements in our 'Non-record Submissions' section: We're excited to see people push the infinite frontier of parameter limited performance as well. +**Key finding:** None of the novel techniques (Sparsity, Hessian-Guided Sparsity, MoE, KAN) improved BPB over a well-tuned baseline that combines established techniques from the top leaderboard submissions (parallel residuals from PR #1412, TTT from PR #1413, QK gain tuning from PR #1493, CAT idea from PR #1385). The most effective strategy was simply combining some known techniques from these PRs. -We also know compute is expensive, so **OpenAI is sponsoring $1,000,000 in compute credits** to help people get started training their models. To request a compute grant, use this form: [Request a Compute Grant](https://openai.com/index/parameter-golf/#credit-form). -When requesting compute, please make sure you choose the appropriate level, write sufficient justification, and **submit with an email tied to a OpenAI / ChatGPT account**. +--- -## Participant Form +## Hardware & Training Setup -If you enjoy solving very difficult technical problems, please introduce yourself via the [Challenge Participant Form](https://jobs.ashbyhq.com/openai/form/open-ai-challenge-parameter-golf). It helps us attribute challenge submissions and reach out about opportunities with OpenAI. _Completing the form is not required to participate._ +| Parameter | Value | +|---|---| +| Hardware | 1xH100 SXM 80GB (RunPod) | +| Training mode | `TEST_MODE=medium` (scaled-down) | +| Training steps | 2000 | +| Sequence length | 2048 | +| Train shards | 5 (~500M tokens vs. 8B+ full) | +| GPUs | 1 (vs. 8 for competition) | -Many researchers at OpenAI first distinguished themselves through elite mathematics and programming competitions. The Model Craft Challenge is designed in that spirit: testing the ability to tackle unfamiliar problems with creativity and rigor, qualities we believe are essential for frontier AI research. +**Note:** All BPB values are from medium runs and are not directly comparable to full 8xH100 submissions. The relative comparisons between techniques are valid since all used identical medium configuration. -In June, we plan to hire a small cohort of early-career researchers, targeting current undergraduate students and recent graduates, including Olympiad medalists and elite competitors. For exceptional participants, the challenge may also serve as a way to stand out to OpenAI researchers and recruiters. +--- -The challenge runs from March 18th to April 30th. +## Results -Happy training! +### Round 1: Initial Novel Techniques on top of PR #1394 -## Leaderboard +Architecture: `loop_start=4, loop_end=5, qk_gain=4.0, warmdown=0.667, muon_wd=0.085` -| Run | Score | Author | Summary | Date | Info | -|-----|------:|--------|---------|------|------| -| 11L AR Self-Gen GPTQ + XSA | 1.1147 | abaybektursun | On PR #1019: Self-Generated GPTQ Calibration Data + all-layer XSA on the PR #549 stack | 2026-03-25 | [info](records/track_10min_16mb/2026-03-25_ValCalib_GPTQ_XSA_BigramHash3072/README.md) | -| LeakyReLU² + Legal Score-First TTT + Parallel Muon | 1.1194 | abaybektursun | On PR #549: LeakyReLU(0.5)^2 + TTT + Parallel Muon on the PR #414 stack | 2026-03-23 | [info](records/track_10min_16mb/2026-03-23_LeakyReLU_LegalTTT_ParallelMuon/README.md) | -| 11L EMA + GPTQ-lite + warmdown3500 | 1.1228 | signalrush | On PR #374: GPTQ-lite clip search + EMA, plus warmdown3500 and QAT@0.15 | 2026-03-22 | [info](records/track_10min_16mb/2026-03-22_11L_EMA_GPTQ-lite_warmdown3500_QAT015_1.1233/README.md) | -| 11L Partial RoPE + LN Scale + EMA + XSA4 | 1.1248 | jfprincz | On PR #287: Partial RoPE (16/64) + layerwise LN scale | 2026-03-21 | [info](records/track_10min_16mb/2026-03-21_11L_XSA4_EMA_PartialRoPE_LateQAT_1.1248/README.md) | -| 11L XSA4 + EMA + Int6 MLP3x | 1.1271 | jfprincz | On PR #198: XSA on the last 4 layers + EMA replacing SWA | 2026-03-20 | [info](records/track_10min_16mb/2026-03-20_11L_XSA4_EMA_Int6_MLP3x_WD04_1.1271/README.md) | -| 11L Efficient Partial XSA | 1.1307 | unnir | On PR #198: Efficient Partial XSA on the deepest 3 layers | 2026-03-20 | [info](records/track_10min_16mb/2026-03-20_11L_EfficientPartialXSA_FA3_SWA120/README.md) | -| 10L Int5-MLP + BigramHash(10240) | 1.1428 | thwu1 | 10 layers, mixed int5/int6 quantization, BigramHash(10240), SWA(0.4), WD=0.04 | 2026-03-20 | [info](records/track_10min_16mb/2026-03-20_10L_Int5MLP_MuonWD04_SWA50/README.md) | -| Int6 MLP3x + SmearGate + BigramHash | 1.1458 | Raahil Shah | 3x MLP + SmearGate + BigramHash + OrthoInit + Muon WD + SWA | 2026-03-20 | [info](records/track_10min_16mb/2026-03-20_Int6_MLP3x_SmearGate_BigramHash_MuonWD_SWA/README.md) | -| 11L MLP3x + Int6 QAT | 1.1502 | aruniyer | 11 layers, 3x MLP, int6 QAT, zstd-22, WD=0.04, sliding eval | 2026-03-20 | [info](records/track_10min_16mb/2026-03-19_MLP3x_QAT_Int6_SlidingWindow/README.md) | -| SmearGate + OrthoInit + Muon WD | 1.1556 | aquariouseworkman | SmearGate + BigramHash + 3x MLP + int6 STE QAT + sliding eval | 2026-03-19 | [info](records/track_10min_16mb/2026-03-19_smeargate_orthoinit_muonwd/README.md) | -| Ternary Quantization | 1.1570 | Ciprian-Florin Ifrim | 73.7M params quantized to 1 0 -1 + misc arch changes | 2026-03-24 | [info](records/track_10min_16mb/2026-03-24_74M_Ternary_UNet_FP8_10L_8192BPE_YaRN_NeoMuon/README.md) | -| 10L Int6 QAT + Zstd MLP2.6x | 1.1586 | yahya010 | 10 layers, int6 QAT + zstd-22, MLP 1344, Muon 0.99, sliding eval | 2026-03-19 | [info](records/track_10min_16mb/2026-03-19_Seq2048_FP16Emb_TunedLR/README.md) | -| Mixed Quant + Sliding Window Eval | 1.1630 | aquariouseworkman | Int6 block weights + int8 embeddings + 3x MLP + sliding eval | 2026-03-19 | [info](records/track_10min_16mb/2026-03-19_MixedQuant_Int6Int8_SlidingWindow/README.md) | -| Muon WD + 10 layer | 1.1748 | notapplica | Includes prev. wins + Spectral embed init + resid mix | 2026-03-19 | [info](records/track_10min_16mb/2026-03-19_SlidingWindow_FP16Emb_10L_MuonWD_OvertoneInit/README.md) | -| Sliding Window Eval | 1.1925 | Matthew Li | Sliding window evaluation at stride=64, increasing context for eval | 2026-03-19 | [info](records/track_10min_16mb/2026-03-19_SlidingWindowEval/README.md) | -| Lora TTT | 1.1928 | samacqua | Test-time training with LORAs | 2026-03-19 | [info](records/track_10min_16mb/2026-03-17_LoRA_TTT/README.md) | -| 4k seq length| 1.2014 | Spokane Way | 4k seq length + better hypers | 2026-03-19 | [info](records/track_10min_16mb/2026-03-19_TrainingOptSeq4096/README.md) | -| 2048 seq length | 1.206 | Spokane Way | 2048 seq length (train + val) | 2026-03-18 | [info](records/track_10min_16mb/2026-03-18_LongContextSeq2048/README.md) | -| int6 mixed precision | 1.2147 | Nan Liu | 10 layers, mixed int8/int6 | 2026-03-18 | [info](records/track_10min_16mb/2026-03-19_10L_MixedPrecision/README.md) | -| fp16 Embed | 1.2197 | Renier Velazco | FP16 Tied Embedding + LR/Warmdown Tuning | 2026-03-18 | [info](records/track_10min_16mb/2026-03-18_FP16Embed_WD3600/README.md) | -| Naive Baseline | 1.2244 | Baseline | 9layer 512dim 1024vocab TiedEmbeddings 4 KV heads | 2026-03-18 | [info](records/track_10min_16mb/2026-03-17_NaiveBaseline/README.md) | +| # | Experiment | Layers | Params | val_bpb | Pre-quant BPB | Size (bytes) | Under 16MB? | +|---|---|---|---|---|---|---|---| +| 1 | Baseline (PR #1394) | 11 | 35.9M | 1.6567 | 1.6508 | 16,053,780 | Borderline | +| 2 | 2:4 Sparsity | 13 | 41.7M | 1.5040 | 1.4616 | 16,624,092 | Over | +| 3 | MoE (4 experts, top-2) | 11 | 105.2M | 1.4367 | 1.4291 | 45,409,570 | **3x over** | +| 4 | KAN (grid=5, order=3) | 11 | 128.2M | 1.5322 | 1.5215 | 55,011,901 | **3.4x over** | +| 5 | CAT (every 50 steps) | 11 | 35.9M | 1.4759 | 1.4680 | 16,069,715 | Borderline | +| 6 | CAT + Sparsity | 13 | 41.7M | 1.4964 | 1.4551 | 16,620,887 | Over | -#### Unlimited Compute Leaderboard & Non-record Submissions +### Round 2: Combined with Top Submission Techniques -| Run | Score | Author | Summary | Date | Info | -|-----|------:|--------|---------|------|------| -| 1 Bit Quantization | 1.1239 | Ciprian-Florin Ifrim | 106M params quantized to 1 bit + misc arch changes + 2hr training | 2026-03-24 | [info](records/track_non_record_16mb/2026-03-24_106M_Binary_Asymmetric_UNet_FP8_15L_8192BPE_YaRN_NeoMuon_Smear/README.md) | -| 4-Hour Baseline | 1.2074 | Will DePue | Testing unlimited compute, 4 hours on 8xH100 | 2026-03-18 | [info](records/track_non_record_16mb/2026-03-18_Quasi10Bfrom50B_SP1024_9x512_KV4_4h_pgut3/README.md) | +Architecture: parallel residuals from layer 7, `loop_start=3, loop_end=5, qk_gain=5.25, warmdown=0.72, matrix_lr=0.022, muon_wd=0.095, ema=0.9965` -#### Requests for PRs +| # | Experiment | Layers | val_bpb | Pre-quant BPB | Size (bytes) | Under 16MB? | +|---|---|---|---|---|---|---| +| 7 | Baseline (new defaults) | 11 | 1.4096 | 1.3998 | 16,077,847 | Borderline | +| 8 | CAT + Sparsity | 12 | 1.4402 | 1.4031 | 15,605,512 | Yes | +| 9 | CAT + Sparsity | 13 | 1.4277 | 1.4006 | 16,687,493 | Over | +| 10 | CAT only (no sparsity) | 11 | 1.4101 | 1.3998 | 16,079,171 | Borderline | +| 11 | CAT + Sparsity + Wide Loop (3-6) | 12 | 1.4290 | 1.4013 | 15,596,401 | Yes | -Breakthrough ideas are rarely immediately state-of-the-art, instead, they're developed slowly, first demonstrating signs-of-life, iterated on, then only ultimately optimized on the systems side. Don't get discouraged if a new algorithm doesn't instantly beat the best leaderboard run or even the naive baseline. If you have an idea you believe in, consider ignoring step times early on: once you prove you can beat the baseline in the same # of steps you can then start focusing on how to also make it fast. +### Round 3: TTT + Hessian-Guided Sparsity (Sliding Window Enabled) -We'd love to see weird & creative ideas in the challenge, since you never know what may work in the end. Most likely, these will be a good fit in our unlimited compute leaderboard as non-record submissions. We have some requests for what we'd love to see people implement: +Same architecture as Round 2, with sliding window eval and score-first TTT active. -- [x] 1-bit quantization - [implementation](records/track_non_record_16mb/2026-03-24_106M_Binary_Asymmetric_UNet_FP8_15L_8192BPE_YaRN_NeoMuon_Smear/README.md) -- [x] Ternary quantization - [implementation](records/track_10min_16mb/2026-03-24_74M_Ternary_UNet_FP8_10L_8192BPE_YaRN_NeoMuon/README.md) -- [ ] JEPA -- [ ] Text diffusion -- [ ] H-net tokenization -- [ ] Universal transformer - [We have lots of depth recurrence submissions, but I'd love to see one 4 hour -- [ ] Megakernels -- [ ] State-space models, E2E TTT, super long context for evaluation or training -- [ ] Learning adapters on random linear maps +| # | Experiment | Layers | val_bpb | Sliding BPB | TTT BPB | Size (bytes) | Under 16MB? | +|---|---|---|---|---|---|---|---| +| **12** | **Baseline + TTT** | **11** | **1.3971** | **1.3817** | **1.3696** | **16,076,488** | **Borderline** | +| 13 | CAT + H-Sparsity + TTT | 11 | 1.4564 | 1.4416 | 1.4045 | 14,700,464 | Yes | +| 14 | CAT + H-Sparsity + TTT | 12 | 1.4633 | 1.4513 | 1.4120 | 15,757,920 | Yes | -## Getting Started +--- -### Training Your First Model (Mac with Apple Silicon) +## Novel Techniques Explored -If you have an Apple laptop or desktop with Apple Silicon, we've set up a simple MLX training script to help you start iterating locally. +### 1. Compressor-Aware Training (CAT) -- idea from PR #1385 -If you don't have a Mac with Apple Silicon, you can run an adapted version of this script without MLX support. Just ask [Codex](https://openai.com/codex/) to refactor it; the change is straightforward. It may still be fairly slow, so we recommend jumping straight to cloud GPUs with Runpod. +**Motivation:** In the parameter golf pipeline, model weights are quantized (GPTQ, int6) and then entropy-coded (brotli). These compression steps are applied post-training, so the model has no incentive during training to produce weights that are easy to quantize or compress. CAT introduces a differentiable proxy for quantization loss directly into the training objective, encouraging the model to learn weight distributions that are "compression-friendly" — weights that naturally cluster near quantization grid points, resulting in lower entropy and better brotli compression ratios. -First, clone the repository, create a fresh Python environment, and install the packages needed for the MLX path plus dataset download: +**How it works:** Every `CAT_EVERY` training steps, a soft-rounding regularization loss is computed over all large weight matrices. For each weight, the distance to the nearest quantization grid point is measured, and a sigmoid function creates a differentiable penalty that is higher when weights fall between grid points. This loss is scaled by `CAT_WEIGHT=0.001` and added to the language modeling loss. The intuition is that weights near grid boundaries contribute the most quantization error; nudging them during training should reduce post-training quantization degradation. -```bash -git clone https://github.com/openai/parameter-golf.git -cd parameter-golf -python3 -m venv .venv -source .venv/bin/activate -python -m pip install --upgrade pip -pip install mlx numpy sentencepiece huggingface-hub datasets tqdm -``` +**Result:** CAT reduced compressed model size by ~70KB (~0.4% savings) but consistently degraded BPB by 0.01-0.06. The regularization disrupts the optimization landscape enough to hurt training quality, and the compression savings are too small to compensate — even when the saved bytes were used for additional model layers. -Download our cached version of FineWeb with the 1024-token vocabulary: +**Verdict: Negative.** GPTQ with SDClip already handles quantization effectively. Adding a training-time compression proxy introduces a conflicting objective that hurts final model quality more than it helps compression. -```bash -python3 data/cached_challenge_fineweb.py --variant sp1024 --train-shards 10 -``` +### 2. 2:4 Structured Sparsity (Magnitude-Based) -This populates `./data/datasets/fineweb10B_sp1024/` and `./data/tokenizers/`. -By default this downloads the full validation split plus 80 training shards (8B tokens). For a smaller local smoke subset, pass `--train-shards 1`, for example `python3 data/cached_challenge_fineweb.py --variant sp1024 --train-shards 1`. +**Motivation:** The 16MB artifact constraint is the binding bottleneck. If 50% of weights can be zeroed out in a structured pattern (keeping the 2 largest magnitudes per group of 4), the resulting sparse matrices should compress dramatically under brotli, freeing space for more model capacity (additional layers). The 2:4 pattern is also hardware-friendly on NVIDIA Ampere/Hopper GPUs, which have native 2:4 sparse tensor cores for inference acceleration. -Then run a small MLX training job: +**How it works:** After training completes, all MLP weight matrices are reshaped into groups of 4 columns. Within each group, the 2 weights with smallest absolute magnitude are zeroed. The sparsified state dict is then passed to GPTQ and brotli compression. With ~50% of MLP weights zeroed, the entropy of the weight distribution drops and brotli achieves much better compression ratios, saving ~1.5MB. -```bash -RUN_ID=mlx_smoke \ -ITERATIONS=200 \ -TRAIN_BATCH_TOKENS=8192 \ -VAL_LOSS_EVERY=0 \ -VAL_BATCH_SIZE=8192 \ -python3 train_gpt_mlx.py -``` +**Result:** Sparsity saved ~1.5MB in artifact size, allowing 12-13 layer models to fit under 16MB. However, the BPB degradation from zeroing weights (0.03-0.06 worse) consistently exceeded the improvement from additional layers at 2000 training steps. The pre-quantization BPB was comparable, but post-GPTQ the sparse models suffered more. -Validation always runs on the full `fineweb_val_*` split, which is the fixed first-50k-document set. The smoke command above skips periodic validation and just prints the final `val_loss` and `val_bpb` once at the end. +**Verdict: Negative.** 50% sparsity is could be aggressive at this model scale. The information destroyed by pruning half the MLP weights outweighs the capacity gained from 1-2 extra layers. -### Scaling Up to a Remote Machine +### 3. Hessian-Guided 2:4 Sparsity -Once you're happy with your local tests, or you want more compute, switch to a remote CUDA machine. +**Motivation:** Naive magnitude-based pruning assumes that the smallest weights are the least important. This might not always be true — a small weight connected to a high-curvature input dimension may contribute disproportionately to the loss function. GPTQ already collects full Hessian matrices (H = X^T X) for quantization. These same Hessians encode which input dimensions are most important. By combining weight magnitude with Hessian diagonal importance, we can make better pruning decisions. -You can rent GPUs from anywhere, but OpenAI is partnering with Runpod to make setup as easy as possible. +**How it works:** The importance score for each weight is computed as `|w_ij| * sqrt(H_jj)`, where `H_jj` is the diagonal of the Hessian matrix for that layer's input. This replaces the standard `|w_ij|` magnitude criterion. The Hessians are collected as part of the existing GPTQ pipeline, so this adds zero computational overhead. Within each group of 4, the 2 weights with the highest combined importance are kept. -#### Launching a 1xH100 Pod +**Result:** Hessian-guided sparsity produced similar results than naive magnitude pruning, but both were substantially worse than no sparsity (1.3696 baseline TTT BPB). The fundamental constraint is that zeroing 50% of weights, regardless of how intelligently they are selected, could be removing too much model capacity at this scale. -1. First, [create a Runpod account](https://console.runpod.io/deploy). You should also set up an SSH key in the Settings tab on the left so you can connect to your remote machine. If you're new to this, ask Codex to help you set it up. +**Verdict: Negative.** While the Hessian-guided importance criterion is theoretically sound and adds zero overhead, the 2:4 structured constraint forces exactly 50% sparsity, which might be too aggressive. Unstructured or lower-ratio pruning (e.g., 20-30%) might help, but would yield smaller compression savings. -2. Once you've set up your account, create a new GPU Cloud Pod. You can choose whichever GPU SKU you'd like. Final leaderboard submissions must run in under 10 minutes on 8xH100s (specifically the SXM variant), but we strongly recommend testing and running experiments on cheaper SKUs first, since an 8xH100 box can cost around $20/hour. +### 4. Mixture of Experts (MoE) -3. Let's start with a 1xH100 pod. Deploy using the official Parameter Golf template: [Launch Template](https://console.runpod.io/deploy?template=y5cejece4j&ref=nl2r56th). Enable SSH terminal access, leaving the other settings at their defaults. Deploy your pod and SSH into it once it's up. You should land in `/workspace/`. +**Motivation:** MoE architectures increase model capacity without proportionally increasing computation per token. By replacing each MLP with N independent expert MLPs and a learned router that activates only the top-K experts per token, the model can maintain a much larger total parameter count while keeping training and inference costs manageable. In the parameter golf context, MoE could theoretically achieve better BPB by having more specialized experts, even if the total model is larger — provided the weights compress well enough to fit in 16MB. -On your remote machine, clone the repo onto local disk. All Python dependencies are already pre-installed in the image. +**How it works:** Each transformer block's MLP is replaced with 4 independent expert MLPs and a learned router. For each token, the router selects the top-2 experts by softmax probability. Only those 2 experts process the token, and their outputs are combined weighted by the router probabilities. A load-balancing auxiliary loss (`alpha=0.01`) encourages even expert utilization. Each expert uses the same LeakyReLU(0.5)² activation as the standard MLP. -```bash -cd /workspace -git clone https://github.com/openai/parameter-golf.git -cd parameter-golf -``` +**Result:** MoE achieved the best pre-quantization BPB of all experiments (1.4291), demonstrating that increased capacity does help language modeling quality. However, the total artifact was 45.4MB — nearly 3x over the 16MB budget. The 4 expert MLPs each have independent weight matrices that learn different specializations, making them highly incompressible — brotli cannot exploit cross-expert redundancy. -Download our cached version of FineWeb. We'll use the 1024-token vocabulary for now. +**Verdict: Interesting but impractical.** MoE improves BPB meaningfully but is fundamentally incompatible with the 16MB constraint. Would require sub-2-bit quantization or expert weight sharing to fit, both of which would likely negate the quality gains. (Not included in final code) -```bash -python3 data/cached_challenge_fineweb.py --variant sp1024 -``` +### 5. KAN (Kolmogorov-Arnold Networks) -This defaults to the full validation split plus 80 training shards (8B tokens). If you only want a smaller subset while iterating, pass `--train-shards N`, for example `--train-shards 1`. +**Motivation:** KAN replaces traditional linear layers + fixed activations with learned activation functions parameterized by B-spline basis functions. Based on the Kolmogorov-Arnold representation theorem, KAN layers can theoretically approximate any continuous function with fewer parameters than standard MLPs for certain function classes. The hypothesis was that KAN's superior function approximation capability might achieve better BPB per parameter, offsetting the overhead of spline coefficients. -Launch your first training run. Note that we're passing `nproc_per_node=1` because we're running on a single H100 GPU in this case. +**How it works:** Each KAN layer parameterizes its activation as a B-spline with `grid_size=5` control points and `spline_order=3`. The spline weights are 3D tensors of shape `(out_features, in_features, num_coefficients)`. Since GPTQ's Hessian collection hooks only attach to standard `nn.Linear` modules, a fallback simple quantization (round-to-nearest with row-wise scaling) was implemented for KAN's spline weight parameters. -```bash -RUN_ID=baseline_sp1024 \ -DATA_PATH=./data/datasets/fineweb10B_sp1024/ \ -TOKENIZER_PATH=./data/tokenizers/fineweb_1024_bpe.model \ -VOCAB_SIZE=1024 \ -torchrun --standalone --nproc_per_node=1 train_gpt.py -``` +**Result:** KAN produced the largest artifact (55MB, 3.4x over budget) despite achieving worse BPB. The spline weights are inherently difficult to compress as they represent smooth continuous functions where every coefficient matters, so quantization introduces visible artifacts and brotli cannot find redundancy in the learned spline shapes. KAN also trained significantly slower than standard MLPs due to the B-spline evaluation overhead. -By default, `train_gpt.py` keeps its ~10 minute wallclock cap. If you want a longer run, override it explicitly, for example `MAX_WALLCLOCK_SECONDS=0`. +**Verdict: Negative.** KAN is fundamentally mismatched with the parameter golf constraint. The spline parameters are expensive in both raw size and compression ratio, and the function approximation benefits do not materialize at this model scale and training budget. (Not included in final code) -By default, this command prints `train_loss` step logs during training and prints `val_loss`, `val_bpb`, and compressed model size in the final `final_int8_zlib_roundtrip` lines at the end. If you want periodic validation logs during the run, set `VAL_LOSS_EVERY`, for example `VAL_LOSS_EVERY=200`. For the baseline config, the final `val_bpb` should land around ~1.2 with a compressed model size under 16MB. +--- -For dataset export, tokenizer export, and docs-cache rebuild instructions, see [data/README.md](data/README.md). +## Established Techniques Adopted (Not Novel) -- Credits -Evaluation will be in the RunPod environment with all packages installed. `requirements.txt` is provided as a reference if you want to self-setup. +These techniques were adopted from the top leaderboard submissions or other pull requests and are not novel contributions of this work. -## FAQ +| Technique | Source | Effect | +|---|---|---| +| Parallel Residuals (GPT-J style, layer 7+) | PR #1412 | Attention and MLP read from same pre-attention input | +| Test-Time Training (Score-First SGD) | PR #1413 | ~0.03 BPB improvement via eval-time adaptation | +| QK Gain = 5.25 | PR #1493 | Per-head learnable query scaling | +| Recurrence Loop 3-5, enabled at 35% | PR #1437 | Wider and earlier depth recurrence | +| Warmdown = 0.72, matrix_lr = 0.022 | PR #1445 | Hyperparameter tuning | +| muon_wd = 0.095, ema_decay = 0.9965 | PR #1445 | Optimizer and EMA tuning | -**What exactly counts toward the 16MB artifact size?** +--- -The submission artifact is computed as code bytes plus compressed model bytes. All counted code should live in the `train_gpt.py` script. -The cap is decimal 16MB, i.e. 16,000,000 total bytes, not 16 MiB / 16,777,216 bytes. -No external downloads, training dataset access, or network calls are allowed during evaluation. The artifact must be fully self-contained and reproducible. +## Artifact Size Note -**Are scores independently verified by OpenAI?** +Some configurations are marginally over the 16MB limit. The submitted `train_gpt.py` is ~68KB because it includes all experimental toggles (CAT, Sparsity, MoE configuration constants, multiple test modes, detailed logging). For a competition submission, several approaches would bring this under budget: -We're not automatically verifying every submission, but we will verify the top leaderboard entries over time. Any non-reproducible results can be disqualified, and issues reproducing submissions should be raised on the PR. If you find an issue with a record on the leaderboard or find a record isn't reproducible, please let us know and add an Github Issue describing your findings. +- **LZMA-compressing the training script**, as the top submissions do, which reduces code size. +- **Stripping unused code paths** (removing CAT, Sparsity) to reduce raw code size before compression. +- **Slightly reducing GPTQ calibration batches** (from 64 to 48) to shave a few KB from the quantized model. -**What counts as 'external compute'? For example, is it fair to tune my hyperparameters offline?** +Since this is a non-record negative results submission focused on documenting technique exploration rather than competing for SOTA, we include the full uncompressed script for readability. -There's no perfectly clear answer here and it's hard to draw a clean line around what does or does not count as external compute. For now, we're reserving the right to disqualify runs that are not in the spirit of the challenge. Tuning your Adam hyperparameters across a bunch of runs is fine, but if there's evidence that you're sneaking in additional compute unfairly, such as brute-forcing ridiculous seeds, we won't allow it. Use your best judgment and there's no penalty for asking questions. +--- -**What are the restrictions on evaluation?** +## Lessons Learned -We won't accept submissions that take more than 10 minutes on 8xH100 to evaluate (Note: This limit is in addition to the 10 minutes of training time allowed!), but otherwise you're free to evaluate however. As with modded-nanogpt, we allow evaluation at any sequence length. And, obviously, you aren't allowed to access any training data during evaluation, unless you pay for those bits in the <16MB limit. We encourage competitors to push the bounds of evaluation methods as aggressively as with training methods. You CANNOT access validation data during training, e.g. by compressing it into your 16mb with "paid prefix". +1. **Post-training compression is already near-optimal.** GPTQ + byte-shuffle + brotli-11 is extremely effective. Training-time techniques like CAT provide marginal compression gains (~0.4%) that do not justify the BPB cost. -If it isn't abundantly obvious: You can't cheat on your test loss. You can't cheat by training on the validation set before you evaluate on the validation set. The validation language around test-time training has been confusing people: you are only allowed to test-time train on validation set tokens _you've already evaluated your model on_, since those tokens have already been graded! +2. **Sparsity at 50% might be too aggressive at this scale.** Whether magnitude-based or Hessian-guided, zeroing half the MLP weights at ~36M parameters destroys more information than extra layers can recover. It would be worth looking into further optimizations to take advantage of the additional space provided from incorporating sparsity. -**What is the process for accepting new submissions?** +3. **MoE and KAN explode model size.** Expert weights and spline parameters are inherently difficult to compress. Neither architecture is compatible with extreme compression constraints without fundamentally different quantization approaches. -Since all submissions are public, we're accepting record submissions chronologically depending on their PR creation time. The leaderboard may take time to update due to verification and review of submissions, so pay consideration to what the current SOTA PR is when submitting. As explained below, submissions should exceed the SOTA record with sufficient statistical significance in order to be accepted for the leaderboard. Otherwise, submissions may be accepted as 'non-record submissions' given they are sufficiently unique or interesting. +--- -**Can I import XYZ package or library?** +## Log Files -Yes, you're free to import any package or library you want, so long as it does not unjustly violate the rules on evaluation, compute, training time, code size or otherwise. Just include a requirements.txt in your records folder and mention setup instructions in your README.md. Since you don't pay for bits imported in Python libraries, limitations clearly apply: You can't sneak in extra compute, capabilities, or massively increase effective code size with custom libraries, but importing FlashAttention, etc. is completely fine. +All runs were executed on 1xH100 SXM 80GB via RunPod. +### Round 1 — Novel techniques on PR #1394 baseline +- `train_round1_baseline.log` — 11L baseline +- `train_round1_sparsity_13L.log` — 13L + 2:4 sparsity +- `train_round1_moe_4e.log` — 11L + MoE (4 experts, top-2) +- `train_round1_kan.log` — 11L + KAN (grid=5, order=3) +- `train_round1_cat.log` — 11L + CAT (every 50 steps) +- `train_round1_cat_sparsity_13L.log` — 13L + CAT + sparsity -## Submission Process +### Round 2 — Top submission defaults + novel combos +- `train_round2_baseline_11L.log` — 11L baseline (parallel residuals, QK 5.25, loop 3-5) +- `train_round2_12L_cat_sparse.log` — 12L + CAT + sparsity +- `train_round2_13L_cat_sparse.log` — 13L + CAT + sparsity +- `train_round2_11L_cat.log` — 11L + CAT (no sparsity) +- `train_round2_12L_wideloop.log` — 12L + CAT + sparsity + wide loop (3-6) -New SOTA records must fulfill the following criteria: - -1. They must beat the existing SOTA by at least 0.005 nats. As in modded-nanogpt, because of inter-run variance all submissions must provide enough run logs to show at `p < 0.01` that they achieved the required 0.005-nat improvement. For submissions that improve speed through systems optimization without changing the ML, this requirement is waived. - -2. If changes are made to the tokenizer or dataset, prove with certainty that the val_bpb is correctly calculated. Submissions that edit the tokenizer will be examined much more carefully, since bugs may unjustly improve your score. - -3. Reproducibly run in under 10 minutes on 8xH100s. - -All submissions should be made as a pull request that only adds a new folder to the appropriate `/records` subfolder and includes the following files. Submissions without the full set of requirements will not be accepted. - -1. A README.md file that explains the submission in reasonable detail. - -2. A `submission.json` file (see the example runs) that includes your name, GitHub ID, `val_bpb`, and related metadata. - -3. A train log, automatically produced by your script. Please demonstrate a statistically significant win. Most often, submitting an average over 3 training runs is sufficient. - -4. A `train_gpt.py` script and any other dependencies. Note: this must successfully compile and run within the records folder. Broken scripts will not be accepted. - -### Non-record Submissions - -Submissions are also open to unique and interesting approaches that might not beat the existing SOTA, but still satisfy the 16MB artifact limit. We strongly encourage participants to submit implementations for weird or out-of-the-box ideas, in-progress or unoptimized solutions, so long as they run successfully, or even interesting negative results. We're excited to see what you come up with. We'll still maintain a high bar for non-record submissions, so be sure to justify your ideas and results in detail when submitting. - -We also accept non-record submissions to an unlimited compute track for runs that are not intended to meet the 10-minute cutoff. Just note as such in your README file. - -Non-record submissions should be made in the same fashion as SOTA records, as described above. - -#### PRs on Core Code - -The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but the best models should stay in the `/records` folder. - -## Support - - -Join the [OpenAI Discord server](https://discord.com/invite/openai) and visit the Parameter Golf channels (#parameter-golf-discussions, #parameter-golf-announcements) and ask questions. - -This repository adapts code from `modded-nanogpt`, see [THIRD_PARTY_NOTICES.md](THIRD_PARTY_NOTICES.md) for attribution. +### Round 3 — TTT + Hessian-guided sparsity +- `train_round3_baseline_ttt.log` — 11L baseline + TTT (best result) +- `train_round3_11L_cat_hsparse_ttt.log` — 11L + CAT + Hessian sparsity + TTT +- `train_round3_12L_cat_hsparse_ttt.log` — 12L + CAT + Hessian sparsity + TTT diff --git a/THIRD_PARTY_NOTICES.md b/THIRD_PARTY_NOTICES.md deleted file mode 100644 index ae833d3725..0000000000 --- a/THIRD_PARTY_NOTICES.md +++ /dev/null @@ -1,29 +0,0 @@ -# Third-Party Notices - -## modded-nanogpt - -Parts of this repository were adapted from [KellerJordan/modded-nanogpt](https://github.com/KellerJordan/modded-nanogpt), which is licensed under the MIT License. - -Original source: https://github.com/KellerJordan/modded-nanogpt - -MIT License - -Copyright (c) 2024 Keller Jordan - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. diff --git a/data/README.md b/data/README.md deleted file mode 100644 index e1920ad9d4..0000000000 --- a/data/README.md +++ /dev/null @@ -1,66 +0,0 @@ -# Data Workflows - -This directory contains the dataset download helpers and export scripts used for the challenge. - -Canonical local layout: -- `data/datasets//` -- `data/tokenizers/` -- `data/manifest.json` -- `data/docs_selected.jsonl` -- `data/docs_selected.source_manifest.json` - -## Downloading Published Data - -Download the cached FineWeb export for a tokenizer variant with: - -```bash -python3 data/cached_challenge_fineweb.py --variant sp1024 -``` - -This populates `./data/datasets/fineweb10B_sp1024/` and `./data/tokenizers/`. -By default it downloads the full validation split and 8B training tokens (80 train shards). - -To fetch more training shards, pass `--train-shards`: - -```bash -python3 data/cached_challenge_fineweb.py --variant sp1024 --train-shards 180 -``` - -The downloader is manifest-driven and can fetch only a prefix of train shards from a larger published export. With the current shard size of `100_000_000` tokens, `10B` retokenized training tokens is `100` train shards: - -```bash -MATCHED_FINEWEB_REPO_ID=your-hf-username/your-dataset-repo \ -MATCHED_FINEWEB_REMOTE_ROOT_PREFIX=your_50B_export_root \ -python3 data/cached_challenge_fineweb.py --variant sp1024 --train-shards 100 -``` - -Validation is always downloaded in full from the fixed `fineweb_val_*` split. Training on the first `N` train shards means training on the prefix of the same frozen shuffled export, so the data order stays aligned with the baseline for that tokenizer family. - -The default published repo is `willdepueoai/parameter-golf`, with the export rooted under the repo subdirectory `datasets/`. - -## Rebuilding Tokenizers From Published Docs - -To retrain a tokenizer or re-export shards from exactly the same selected documents, run the standalone retokenizer against the published docs cache: - -```bash -python3 data/download_hf_docs_and_tokenize.py \ - --repo-id your-hf-username/your-dataset-repo \ - --remote-root your_50B_export_root \ - --output-root /tmp/my_custom_tokenizer_export \ - --tokenizer-config ./data/tokenizer_specs.json -``` - -The sidecar `docs_selected.source_manifest.json` includes `docs_sha256`, so users can verify they are rebuilding from the exact same document list and order as the baseline export. - -## Useful Knobs - -For CPU-heavy exports, useful knobs are: - -```bash -MATCHED_FINEWEB_SP_BATCH_SIZE=2048 -MATCHED_FINEWEB_TOKENIZER_THREADS=16 -MATCHED_FINEWEB_TIKTOKEN_THREADS=16 -MATCHED_FINEWEB_GPT2_DECODE_BATCH_SIZE=512 -``` - -These control batched tokenizer encoding during shard export, tokenizer thread count, tiktoken thread count, and batched GPT-2 decode for the blobstore docs-cache path. diff --git a/data/cached_challenge_fineweb.py b/data/cached_challenge_fineweb.py deleted file mode 100644 index fa8029be42..0000000000 --- a/data/cached_challenge_fineweb.py +++ /dev/null @@ -1,157 +0,0 @@ -import argparse -import json -import os -import shutil -from pathlib import Path - -from huggingface_hub import hf_hub_download - - -REPO_ID = os.environ.get("MATCHED_FINEWEB_REPO_ID", "willdepueoai/parameter-golf") -REMOTE_ROOT_PREFIX = os.environ.get("MATCHED_FINEWEB_REMOTE_ROOT_PREFIX", "datasets") -ROOT = Path(__file__).resolve().parent -DATASETS_DIR = ROOT / "datasets" -TOKENIZERS_DIR = ROOT / "tokenizers" - -def dataset_dir_for_variant(name: str) -> str: - if name == "byte260": - return "fineweb10B_byte260" - if name.startswith("sp") and name[2:].isdigit(): - return f"fineweb10B_{name}" - raise ValueError(f"unsupported variant {name!r}; expected byte260 or sp") - - -def local_path_for_remote(relative_path: str) -> Path: - remote_path = Path(relative_path) - if REMOTE_ROOT_PREFIX and remote_path.parts[:1] == (REMOTE_ROOT_PREFIX,): - remote_path = remote_path.relative_to(REMOTE_ROOT_PREFIX) - if remote_path.parts[:1] == ("datasets",): - return DATASETS_DIR.joinpath(*remote_path.parts[1:]) - if remote_path.parts[:1] == ("tokenizers",): - return TOKENIZERS_DIR.joinpath(*remote_path.parts[1:]) - return ROOT / remote_path - - -def get(relative_path: str) -> None: - destination = local_path_for_remote(relative_path) - if destination.exists(): - return - if destination.is_symlink(): - destination.unlink() - - remote_path = Path(relative_path) - cached_path = Path( - hf_hub_download( - repo_id=REPO_ID, - filename=remote_path.name, - subfolder=remote_path.parent.as_posix() if remote_path.parent != Path(".") else None, - repo_type="dataset", - ) - ) - # HF cache entries may be snapshot symlinks. Resolve to the underlying blob so we - # always materialize a real file in data/, not a broken relative symlink. - cached_source = cached_path.resolve(strict=True) - destination.parent.mkdir(parents=True, exist_ok=True) - try: - os.link(cached_source, destination) - except OSError: - shutil.copy2(cached_source, destination) - - -def manifest_path() -> Path: - return local_path_for_remote(f"{REMOTE_ROOT_PREFIX}/manifest.json") - - -def load_manifest(*, skip_manifest_download: bool) -> dict: - path = manifest_path() - if not path.is_file(): - if skip_manifest_download: - raise FileNotFoundError( - f"manifest.json is required for manifest-driven shard counts but is not present locally at {path}" - ) - get(f"{REMOTE_ROOT_PREFIX}/manifest.json") - return json.loads(path.read_text(encoding="utf-8")) - - -def artifact_paths_for_tokenizer(tokenizer_entry: dict) -> list[str]: - artifacts = [] - for key in ("model_path", "vocab_path", "path"): - value = tokenizer_entry.get(key) - if value: - artifacts.append(str(value)) - if not artifacts: - raise ValueError(f"tokenizer entry is missing downloadable artifacts: {tokenizer_entry}") - return artifacts - - -def build_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser(description="Download challenge FineWeb shards from Hugging Face") - parser.add_argument( - "train_shards_positional", - nargs="?", - type=int, - default=None, - help=argparse.SUPPRESS, - ) - parser.add_argument( - "--train-shards", - type=int, - default=80, - help="Number of training shards to download for the selected variant. Defaults to 80.", - ) - parser.add_argument( - "--variant", - default="sp1024", - help="Tokenizer family to download, for example sp1024, sp4096, or byte260.", - ) - parser.add_argument( - "--skip-manifest", - action="store_true", - help="Skip downloading manifest.json.", - ) - parser.add_argument( - "--with-docs", - action="store_true", - help="Also download docs_selected.jsonl and its sidecar for tokenizer retraining or dataset re-export.", - ) - return parser - - -def main() -> None: - args = build_parser().parse_args() - dataset_dir = dataset_dir_for_variant(args.variant) - train_shards = args.train_shards_positional if args.train_shards_positional is not None else args.train_shards - if train_shards < 0: - raise ValueError("train_shards must be non-negative") - - manifest = load_manifest(skip_manifest_download=args.skip_manifest) - dataset_entry = next((x for x in manifest.get("datasets", []) if x.get("name") == dataset_dir), None) - if dataset_entry is None: - raise ValueError(f"dataset {dataset_dir} not found in {REMOTE_ROOT_PREFIX}/manifest.json") - max_train_shards = int((dataset_entry.get("stats") or {}).get("files_train")) - val_shards = int((dataset_entry.get("stats") or {}).get("files_val")) - if train_shards > max_train_shards: - raise ValueError( - f"{args.variant} only has {max_train_shards} training shards on {REPO_ID}, requested {train_shards}" - ) - tokenizer_name = dataset_entry.get("tokenizer_name") - tokenizer_entry = next((x for x in manifest.get("tokenizers", []) if x.get("name") == tokenizer_name), None) - if tokenizer_entry is None: - raise ValueError(f"tokenizer {tokenizer_name} not found in {REMOTE_ROOT_PREFIX}/manifest.json") - - if args.with_docs: - get(f"{REMOTE_ROOT_PREFIX}/docs_selected.jsonl") - get(f"{REMOTE_ROOT_PREFIX}/docs_selected.source_manifest.json") - - dataset_prefix = f"{REMOTE_ROOT_PREFIX}/datasets/{dataset_dir}" - for i in range(val_shards): - get(f"{dataset_prefix}/fineweb_val_{i:06d}.bin") - for i in range(train_shards): - get(f"{dataset_prefix}/fineweb_train_{i:06d}.bin") - - for artifact_path in artifact_paths_for_tokenizer(tokenizer_entry): - get(f"{REMOTE_ROOT_PREFIX}/{artifact_path}") - - -if __name__ == "__main__": - main() diff --git a/data/download_hf_docs_and_tokenize.py b/data/download_hf_docs_and_tokenize.py deleted file mode 100644 index dcabd40b75..0000000000 --- a/data/download_hf_docs_and_tokenize.py +++ /dev/null @@ -1,627 +0,0 @@ -"""Download docs_selected.jsonl from Hugging Face and tokenize it locally. - -This script is standalone. It does not import any local exporter or tokenizer -helpers. Tokenizer configs are JSON only and currently support the built-in -pure-byte and SentencePiece tokenizer definitions in `data/tokenizer_specs.json`. -""" - -from __future__ import annotations - -import argparse -import json -import os -import shutil -from dataclasses import asdict, dataclass -from pathlib import Path -from typing import Any - -import numpy as np -from huggingface_hub import hf_hub_download -from huggingface_hub.utils import EntryNotFoundError - - -DOCS_FILENAME = "docs_selected.jsonl" -SIDECAR_FILENAME = "docs_selected.source_manifest.json" -VERSION = "10B" -NUM_VAL_DOCS = 50_000 -SHARD_SIZE = 10**8 -APPEND_EOS = False -DATAFILE_MAGIC = 20240520 -DATAFILE_VERSION = 1 -DEFAULT_REPO_ID = os.environ.get("MATCHED_FINEWEB_REPO_ID", "willdepueoai/parameter-golf") -DEFAULT_REMOTE_ROOT = os.environ.get("MATCHED_FINEWEB_REMOTE_ROOT_PREFIX", "datasets") -DEFAULT_CONFIG = Path(__file__).with_name("tokenizer_specs.json") -TOKENIZER_THREADS = max(1, int(os.environ.get("MATCHED_FINEWEB_TOKENIZER_THREADS", str(os.cpu_count() or 8)))) -SP_BATCH_SIZE = max(1, int(os.environ.get("MATCHED_FINEWEB_SP_BATCH_SIZE", "1024"))) - - -@dataclass(frozen=True) -class PureByteTokenizer: - pad_id: int = 0 - bos_id: int = 1 - eos_id: int = 2 - unk_id: int = 3 - byte_offset: int = 4 - byte_count: int = 256 - - @property - def vocab_size(self) -> int: - return self.byte_offset + self.byte_count - - def encode(self, text: str) -> np.ndarray: - data = text.encode("utf-8", errors="replace") - return np.frombuffer(data, dtype=np.uint8).astype(np.uint16, copy=False) + self.byte_offset - - def encode_batch(self, texts: list[str]) -> list[np.ndarray]: - return [self.encode(text) for text in texts] - - def save_json(self, path: str | Path) -> None: - path = Path(path) - path.parent.mkdir(parents=True, exist_ok=True) - payload = { - "tokenizer_type": "pure_byte", - "config": asdict(self), - "vocab_size": self.vocab_size, - } - path.write_text(json.dumps(payload, indent=2, sort_keys=True) + "\n", encoding="utf-8") - - -def default_pure_byte_tokenizer() -> PureByteTokenizer: - return PureByteTokenizer() - - -def docs_sidecar_path(docs_jsonl: Path) -> Path: - return docs_jsonl.with_name(f"{docs_jsonl.stem}.source_manifest.json") - - -def maybe_load_docs_sidecar_meta(docs_jsonl: Path) -> dict[str, Any] | None: - sidecar_path = docs_sidecar_path(docs_jsonl) - if not sidecar_path.is_file(): - return None - payload = json.loads(sidecar_path.read_text(encoding="utf-8")) - if not isinstance(payload, dict): - raise ValueError(f"docs sidecar must be a JSON object: {sidecar_path}") - return payload - - -def copy_from_hf_cache(*, repo_id: str, remote_root: str, filename: str, destination: Path) -> bool: - remote_path = Path(remote_root) / filename if remote_root else Path(filename) - try: - cached_path = Path( - hf_hub_download( - repo_id=repo_id, - filename=remote_path.name, - subfolder=remote_path.parent.as_posix() if remote_path.parent != Path(".") else None, - repo_type="dataset", - ) - ) - except EntryNotFoundError: - return False - - source = cached_path.resolve(strict=True) - destination.parent.mkdir(parents=True, exist_ok=True) - if destination.exists(): - destination.unlink() - try: - os.link(source, destination) - except OSError: - shutil.copy2(source, destination) - return True - - -def iter_docs(path: Path): - with path.open("r", encoding="utf-8") as f: - for line in f: - yield json.loads(line)["text"] - - -def count_docs(path: Path) -> int: - with path.open("r", encoding="utf-8") as f: - return sum(1 for _ in f) - - -def batched_docs_jsonl(path: Path, batch_size: int): - batch: list[str] = [] - for text in iter_docs(path): - batch.append(text) - if len(batch) == batch_size: - yield batch - batch = [] - if batch: - yield batch - - -def write_datafile(path: Path, toks: Any) -> None: - if len(toks) >= 2**31: - raise ValueError("token count too large") - header = np.zeros(256, dtype=" Any: - if isinstance(value, dict): - return {k: relativize_manifest_paths(v, root) for k, v in value.items()} - if isinstance(value, list): - return [relativize_manifest_paths(v, root) for v in value] - if isinstance(value, str): - path = Path(value) - if path.is_absolute(): - try: - return path.relative_to(root).as_posix() - except ValueError: - return value - return value - - -def parse_reuse_sp_models(values: list[str]) -> dict[int, Path]: - reuse_models: dict[int, Path] = {} - for value in values: - vocab_size_str, model_path = value.split("=", 1) - vocab_size = int(vocab_size_str) - if vocab_size in reuse_models: - raise ValueError(f"duplicate --reuse_sp_model for vocab_size={vocab_size}") - reuse_models[vocab_size] = Path(model_path).expanduser().resolve() - return reuse_models - - -def load_specs(config_path: Path) -> list[dict[str, Any]]: - payload = json.loads(config_path.read_text(encoding="utf-8")) - if isinstance(payload, dict): - specs = payload.get("tokenizer_specs", payload.get("tokenizers")) - else: - specs = payload - if not isinstance(specs, list) or not specs: - raise ValueError("tokenizer_config must define a non-empty list") - if not all(isinstance(spec, dict) for spec in specs): - raise ValueError("each tokenizer spec must be a JSON object") - return [dict(spec) for spec in specs] - - -def tokenizer_kind(spec: dict[str, Any]) -> str: - kind = spec.get("kind") - if kind in {"byte", "pure_byte"}: - return "byte" - if kind in {"sentencepiece_bpe", "sentencepiece"}: - return "sentencepiece_bpe" - builder = str(spec.get("builder", "")) - builder_name = builder.rsplit(":", 1)[-1] - if builder_name == "build_pure_byte_tokenizer": - return "byte" - if builder_name == "build_sentencepiece_tokenizer": - return "sentencepiece_bpe" - if spec.get("dataset_suffix") == "byte260": - return "byte" - if "vocab_size" in spec: - return "sentencepiece_bpe" - raise ValueError( - f"unsupported tokenizer spec {spec.get('name', '')!r}: " - "expected a built-in pure-byte or sentencepiece builder" - ) - - -def write_tokenizer_config_export(output_root: Path, selected_specs: list[dict[str, Any]]) -> Path: - path = output_root / "tokenizer_config.export.json" - path.write_text(json.dumps({"tokenizers": selected_specs}, indent=2) + "\n", encoding="utf-8") - return path - - -def _iter_sentencepiece_text(docs_jsonl: Path, *, max_docs: int | None = None): - with docs_jsonl.open("r", encoding="utf-8") as f: - for i, line in enumerate(f): - if max_docs is not None and i >= max_docs: - break - text = json.loads(line)["text"].replace("\x00", " ").strip() - if text: - yield text - - -def build_pure_byte_tokenizer(*, spec: dict[str, Any], docs_jsonl: Path, tokenizers_dir: Path) -> dict[str, Any]: - del docs_jsonl - tok = default_pure_byte_tokenizer() - path = tokenizers_dir / spec.get("filename", "fineweb_pure_byte_260.json") - tok.save_json(path) - return { - "name": spec.get("name", "pure_byte_260"), - "kind": "byte", - "dataset_suffix": spec.get("dataset_suffix", "byte260"), - "vocab_size": tok.vocab_size, - "bos_id": tok.bos_id, - "eos_id": tok.eos_id, - "encode": tok.encode, - "encode_batch": tok.encode_batch, - "manifest": {"path": str(path), "pad_id": tok.pad_id, "unk_id": tok.unk_id}, - } - - -def build_sentencepiece_tokenizer(*, spec: dict[str, Any], docs_jsonl: Path, tokenizers_dir: Path) -> dict[str, Any]: - try: - import sentencepiece as spm - except ImportError as exc: - raise RuntimeError("sentencepiece is required for SentencePiece tokenizer exports") from exc - - vocab_size = int(spec["vocab_size"]) - prefix = tokenizers_dir / spec.get("model_prefix", f"fineweb_{vocab_size}_bpe") - model_path = prefix.with_suffix(".model") - vocab_path = prefix.with_suffix(".vocab") - prefix.parent.mkdir(parents=True, exist_ok=True) - for artifact in (model_path, vocab_path): - if artifact.exists(): - artifact.unlink() - - reuse_model_path = spec.get("reuse_model_path") - if reuse_model_path is not None: - reuse_model_path = Path(reuse_model_path).expanduser().resolve() - if not reuse_model_path.is_file(): - raise FileNotFoundError(reuse_model_path) - shutil.copy2(reuse_model_path, model_path) - reuse_vocab_path = reuse_model_path.with_suffix(".vocab") - if reuse_vocab_path.is_file(): - shutil.copy2(reuse_vocab_path, vocab_path) - else: - kwargs = { - "sentence_iterator": _iter_sentencepiece_text( - docs_jsonl, - max_docs=None if spec.get("tokenizer_train_docs") is None else int(spec["tokenizer_train_docs"]), - ), - "model_prefix": str(prefix), - "model_type": "bpe", - "vocab_size": vocab_size, - "character_coverage": 0.999, - "byte_fallback": True, - "split_digits": True, - "normalization_rule_name": "nmt_nfkc", - "add_dummy_prefix": False, - "pad_id": 0, - "bos_id": 1, - "eos_id": 2, - "unk_id": 3, - "hard_vocab_limit": False, - } - kwargs.update(spec.get("trainer_overrides") or {}) - spm.SentencePieceTrainer.train(**kwargs) - - tok = spm.SentencePieceProcessor(model_file=str(model_path)) - return { - "name": spec.get("name", f"sp_bpe_{vocab_size}"), - "kind": "sentencepiece_bpe", - "dataset_suffix": spec.get("dataset_suffix", f"sp{vocab_size}"), - "vocab_size": int(tok.vocab_size()), - "bos_id": int(tok.bos_id()), - "eos_id": int(tok.eos_id()), - "encode": lambda text, tok=tok: tok.encode(text, out_type=int), - "encode_batch": lambda texts, tok=tok: tok.encode(texts, out_type=int, num_threads=TOKENIZER_THREADS), - "manifest": {"model_path": str(model_path), "vocab_path": str(vocab_path)}, - } - - -def export_shards( - docs_jsonl: Path, - tok: dict[str, Any], - output_dir: Path, - *, - num_val_docs: int, - shard_size: int, - docs_total: int, -) -> dict[str, int]: - output_dir.mkdir(parents=True, exist_ok=True) - for pattern in ("fineweb_train_*.bin", "fineweb_val_*.bin"): - for stale in output_dir.glob(pattern): - stale.unlink() - - stats = { - "docs_total": 0, - "docs_val": 0, - "docs_train": 0, - "files_total": 0, - "files_val": 0, - "files_train": 0, - "tokens_total": 0, - "tokens_val": 0, - "tokens_train": 0, - } - buf = np.empty((shard_size,), dtype=np.uint16) - fill = 0 - split = "val" - shards = {"val": 0, "train": 0} - - def flush() -> None: - nonlocal fill - if fill == 0: - return - write_datafile(output_dir / f"fineweb_{split}_{shards[split]:06d}.bin", buf[:fill]) - stats["files_total"] += 1 - stats[f"files_{split}"] += 1 - shards[split] += 1 - fill = 0 - - vocab_size = int(tok["vocab_size"]) - if vocab_size > 2**16: - raise ValueError(f"vocab_size={vocab_size} is too large for uint16 shard storage") - - batch_encode = tok.get("encode_batch") - batch_size = SP_BATCH_SIZE if callable(batch_encode) else 1 - for texts in batched_docs_jsonl(docs_jsonl, batch_size): - encoded_docs = batch_encode(texts) if callable(batch_encode) else [tok["encode"](text) for text in texts] - for text, encoded in zip(texts, encoded_docs, strict=True): - del text - split_for_doc = "val" if stats["docs_total"] < num_val_docs else "train" - if split_for_doc != split: - flush() - split = split_for_doc - - encoded_arr = np.asarray(encoded, dtype=np.int32) - toks = np.empty((encoded_arr.size + 1 + int(APPEND_EOS),), dtype=np.int32) - toks[0] = tok["bos_id"] - toks[1 : 1 + encoded_arr.size] = encoded_arr - if APPEND_EOS: - toks[-1] = tok["eos_id"] - if not ((0 <= toks).all() and (toks < vocab_size).all()): - bad = int(toks[(toks < 0) | (toks >= vocab_size)][0]) - raise ValueError(f"token id {bad} outside declared vocab_size={vocab_size}") - toks = toks.astype(" tuple[list[dict[str, Any]], list[dict[str, Any]]]: - tokenizers: list[dict[str, Any]] = [] - selected_specs: list[dict[str, Any]] = [] - seen_names: set[str] = set() - seen_datasets: set[str] = set() - - for raw_spec in specs: - spec = dict(raw_spec) - kind = tokenizer_kind(spec) - if skip_byte and kind == "byte": - continue - if kind == "sentencepiece_bpe": - if tokenizer_train_docs is not None: - spec["tokenizer_train_docs"] = int(tokenizer_train_docs) - vocab_size = int(spec["vocab_size"]) - if vocab_size in reuse_sp_models: - spec["reuse_model_path"] = str(reuse_sp_models[vocab_size]) - - selected_specs.append(spec) - built = ( - build_pure_byte_tokenizer(spec=spec, docs_jsonl=docs_jsonl, tokenizers_dir=tokenizers_dir) - if kind == "byte" - else build_sentencepiece_tokenizer(spec=spec, docs_jsonl=docs_jsonl, tokenizers_dir=tokenizers_dir) - ) - name = str(built["name"]) - dataset_suffix = built.get("dataset_suffix") - dataset_name = str(built.get("dataset_name", f"fineweb{VERSION}_{dataset_suffix}")) - if name in seen_names: - raise ValueError(f"duplicate tokenizer name: {name}") - if dataset_name in seen_datasets: - raise ValueError(f"duplicate dataset name: {dataset_name}") - seen_names.add(name) - seen_datasets.add(dataset_name) - vocab_size = int(built["vocab_size"]) - recommended_bigram_vocab_size = int( - built.get("recommended_bigram_vocab_size", ((vocab_size + 127) // 128) * 128 * 5) - ) - tokenizers.append( - { - "name": name, - "kind": str(built["kind"]), - "dataset_name": dataset_name, - "vocab_size": vocab_size, - "bos_id": int(built["bos_id"]), - "eos_id": int(built["eos_id"]), - "encode": built["encode"], - "encode_batch": built.get("encode_batch"), - "recommended_bigram_vocab_size": recommended_bigram_vocab_size, - "manifest": { - "name": name, - "kind": str(built["kind"]), - "vocab_size": vocab_size, - "bos_id": int(built["bos_id"]), - "eos_id": int(built["eos_id"]), - "recommended_bigram_vocab_size": recommended_bigram_vocab_size, - "source_spec": spec, - **(built.get("manifest") or {}), - }, - } - ) - if not tokenizers: - raise ValueError("tokenizer_config produced no tokenizers after filtering") - return tokenizers, selected_specs - - -def build_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser( - description="Download docs_selected.jsonl from a Hugging Face dataset repo and tokenize it locally" - ) - parser.add_argument( - "--repo-id", - default=DEFAULT_REPO_ID, - help="Hugging Face dataset repo id, for example user/dataset", - ) - parser.add_argument( - "--remote-root", - default=DEFAULT_REMOTE_ROOT, - help="Optional subdirectory inside the dataset repo that contains docs_selected.jsonl", - ) - parser.add_argument("--output-root", required=True, help="Directory where docs, tokenizers, shards, and manifest are written") - parser.add_argument( - "--tokenizer-config", - default=str(DEFAULT_CONFIG), - help="Local tokenizer config JSON. Defaults to data/tokenizer_specs.json.", - ) - parser.add_argument( - "--num-val-docs", - type=int, - default=None, - help="Validation document count. Defaults to the downloaded sidecar when present, otherwise 50000.", - ) - parser.add_argument("--chunk-tokens", type=int, default=SHARD_SIZE, help="Shard size in tokens.") - parser.add_argument( - "--tokenizer-train-docs", - type=int, - default=None, - help="Limit the number of docs used for tokenizer training.", - ) - parser.add_argument("--skip-byte", action="store_true", help="Skip byte-tokenizer export.") - parser.add_argument( - "--reuse-sp-model", - action="append", - default=[], - metavar="VOCAB=MODEL", - help="Reuse an existing SentencePiece model for the given vocab size instead of retraining it.", - ) - return parser - - -def main() -> None: - args = build_parser().parse_args() - if args.chunk_tokens <= 0: - raise ValueError(f"--chunk_tokens must be positive, got {args.chunk_tokens}") - - output_root = Path(args.output_root).expanduser().resolve() - output_root.mkdir(parents=True, exist_ok=True) - tokenizers_dir = output_root / "tokenizers" - datasets_dir = output_root / "datasets" - tokenizers_dir.mkdir(parents=True, exist_ok=True) - datasets_dir.mkdir(parents=True, exist_ok=True) - - docs_jsonl = output_root / DOCS_FILENAME - sidecar = output_root / SIDECAR_FILENAME - if not copy_from_hf_cache( - repo_id=args.repo_id, - remote_root=args.remote_root, - filename=DOCS_FILENAME, - destination=docs_jsonl, - ): - remote = f"{args.remote_root}/{DOCS_FILENAME}" if args.remote_root else DOCS_FILENAME - raise FileNotFoundError(f"{remote} not found in Hugging Face dataset repo {args.repo_id}") - if not copy_from_hf_cache( - repo_id=args.repo_id, - remote_root=args.remote_root, - filename=SIDECAR_FILENAME, - destination=sidecar, - ): - sidecar.unlink(missing_ok=True) - - docs_sidecar = maybe_load_docs_sidecar_meta(docs_jsonl) - docs_total = int(docs_sidecar["num_docs"]) if docs_sidecar is not None and docs_sidecar.get("num_docs") is not None else count_docs(docs_jsonl) - if args.num_val_docs is not None: - num_val_docs = int(args.num_val_docs) - elif docs_sidecar is not None and docs_sidecar.get("docs_val") is not None: - num_val_docs = int(docs_sidecar["docs_val"]) - else: - num_val_docs = NUM_VAL_DOCS - if not (0 <= num_val_docs <= docs_total): - raise ValueError(f"num_val_docs must be in [0, {docs_total}], got {num_val_docs}") - - specs = load_specs(Path(args.tokenizer_config).expanduser().resolve()) - reuse_sp_models = parse_reuse_sp_models(args.reuse_sp_model) - tokenizers, selected_specs = build_tokenizers( - specs=specs, - docs_jsonl=docs_jsonl, - tokenizers_dir=tokenizers_dir, - tokenizer_train_docs=args.tokenizer_train_docs, - skip_byte=args.skip_byte, - reuse_sp_models=reuse_sp_models, - ) - write_tokenizer_config_export(output_root, selected_specs) - - docs_meta = { - "remote_repo_id": args.repo_id, - "remote_root": args.remote_root, - "num_docs": docs_total, - "docs_sha256": None if docs_sidecar is None else docs_sidecar.get("docs_sha256"), - "source_manifest": str(docs_sidecar_path(docs_jsonl)) if docs_sidecar is not None else None, - } - if docs_sidecar is not None: - docs_meta["source_sidecar"] = docs_sidecar - - manifest = { - "version": VERSION, - "num_docs": docs_total, - "num_val_docs": num_val_docs, - "shuffle_seed": None if docs_sidecar is None else docs_sidecar.get("shuffle_seed"), - "shard_size": int(args.chunk_tokens), - "append_eos": APPEND_EOS, - "docs_jsonl": str(docs_jsonl), - "docs_meta": docs_meta, - "tokenizer_specs": selected_specs, - "tokenizers": [], - "datasets": [], - } - - for tok in tokenizers: - output_dir = datasets_dir / tok["dataset_name"] - print(f"Exporting dataset: {tok['dataset_name']}", flush=True) - stats = export_shards( - docs_jsonl, - tok, - output_dir, - num_val_docs=num_val_docs, - shard_size=int(args.chunk_tokens), - docs_total=docs_total, - ) - manifest["tokenizers"].append(tok["manifest"]) - manifest["datasets"].append( - { - "name": tok["dataset_name"], - "tokenizer_name": tok["name"], - "tokenizer_kind": tok["kind"], - "path": str(output_dir), - "train_glob": str(output_dir / "fineweb_train_*.bin"), - "val_glob": str(output_dir / "fineweb_val_*.bin"), - "vocab_size": tok["vocab_size"], - "bos_id": tok["bos_id"], - "eos_id": tok["eos_id"], - "recommended_bigram_vocab_size": tok["recommended_bigram_vocab_size"], - "stats": stats, - } - ) - - manifest = relativize_manifest_paths(manifest, output_root) - manifest_path = output_root / "manifest.json" - manifest_path.write_text(json.dumps(manifest, indent=2) + "\n", encoding="utf-8") - print(f"Done. Manifest: {manifest_path}", flush=True) - - -if __name__ == "__main__": - main() diff --git a/data/tokenizer_specs.json b/data/tokenizer_specs.json deleted file mode 100644 index d7ad1ca057..0000000000 --- a/data/tokenizer_specs.json +++ /dev/null @@ -1,9 +0,0 @@ -{ - "tokenizers": [ - { - "name": "sp_bpe_1024", - "dataset_suffix": "sp1024", - "vocab_size": 1024 - } - ] -} diff --git a/records/track_10min_16mb/2026-03-17_LoRA_TTT/README.md b/records/track_10min_16mb/2026-03-17_LoRA_TTT/README.md deleted file mode 100644 index ff7447679b..0000000000 --- a/records/track_10min_16mb/2026-03-17_LoRA_TTT/README.md +++ /dev/null @@ -1,59 +0,0 @@ -This record captures `LoRA TTT`: the naive baseline model with document-aware LoRA test-time training at evaluation. - -## Method - -**Training** is identical to the naive baseline. - -**Evaluation** adds per-document LoRA test-time training (TTT). For each document in the validation set: -1. Find document boundaries using BOS tokens -2. Split the document into overlapping chunks (chunk_size=256 within eval_seq_len=1024 context windows) -3. For each chunk, score it (accumulate loss/bytes for BPB), *then* train rank-8 LoRA adapters on that chunk's loss (so you only train on the context -- no leakage) -4. Reset LoRA parameters between documents (no leakake across documents) - -Documents are batched (batch_size=64) and sorted by length for efficiency. The LoRA adapters target `lm_head`, `c_q`, and `c_v` projections in all transformer blocks. A single Adam optimizer with `lr=0.01, betas=(0.9, 0.95)` trains all LoRA parameters with one gradient step per chunk. - -## Notes - -This is very similar to [a record I submmited to the modded nano-gpt speedrun repo](https://samacquaviva.com/projects/nanogpt/). -The major addition is to make the test-time training ~5x faster by using LoRAs: this let's you have per-sequence adaptation (no leaking between validation sequences) while still batching. - -This is not a heavily optimized run: I just wanted to plant the TTT seed. -It uses ~1/10th of the evaluation budget. - -## Ablations - -The majority of this improvement doesn't come from the TTT itself, but from -1). Only conditioning on the current document -2). Doing strided evaluations - -| Condition | val_loss | val_bpb | Delta bpb | -| --------- | -------- | ------- | --------- | -| Baseline (cross-doc, flat stream) | 2.0731 | 1.2278 | — | -| + Doc-isolated | 2.0561 | 1.2168 | -0.0110 | -| + Stride (chunk=256) | 2.0177 | 1.1941 | -0.0337 | -| + LoRA TTT | 2.0126 | 1.1910 | -0.0368 | - -![ablations](ablations.png) - -## Results - -Validated on the full 50k-document fineweb_val split. Submitting at `bpb=1.195`. - -```bash -bpb: [1.1927, 1.1935, 1.1921, 1.1929] -mean: 1.1928 -std: 0.0005 -p-value < 1.195: 0.00234486 -``` - -## Command - -```bash -torchrun --standalone --nproc_per_node=8 train_gpt.py -``` - -## Included files - -- `train_gpt.py` -- `train_v*.txt` (note that `train_v0.txt` is on 2xH100) -- `submission.json` diff --git a/records/track_10min_16mb/2026-03-17_LoRA_TTT/ablations.png b/records/track_10min_16mb/2026-03-17_LoRA_TTT/ablations.png deleted file mode 100644 index 70454238cd..0000000000 Binary files a/records/track_10min_16mb/2026-03-17_LoRA_TTT/ablations.png and /dev/null differ diff --git a/records/track_10min_16mb/2026-03-17_LoRA_TTT/submission.json b/records/track_10min_16mb/2026-03-17_LoRA_TTT/submission.json deleted file mode 100644 index eccdf0c620..0000000000 --- a/records/track_10min_16mb/2026-03-17_LoRA_TTT/submission.json +++ /dev/null @@ -1,11 +0,0 @@ -{ - "author": "sam", - "github_id": "samacqua", - "name": "LoRA TTT", - "blurb": "Naive baseline + per-document LoRA test-time training at eval. Rank-8 LoRA on lm_head/Q/V with Adam lr=0.01, overlapping 256-token chunks in 1024-token context windows. Same training, smarter eval.", - "date": "2026-03-19T10:00:00Z", - "val_loss": 2.0142, - "val_bpb": 1.1929, - "bytes_total": 15882446, - "bytes_code": 58509 -} diff --git a/records/track_10min_16mb/2026-03-17_LoRA_TTT/train_gpt.py b/records/track_10min_16mb/2026-03-17_LoRA_TTT/train_gpt.py deleted file mode 100644 index 85e2cc463a..0000000000 --- a/records/track_10min_16mb/2026-03-17_LoRA_TTT/train_gpt.py +++ /dev/null @@ -1,1372 +0,0 @@ -""" -The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. - -Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. -""" - -from __future__ import annotations - -import copy -import glob -import io -import math -import os -import random -import subprocess -import sys -import time -import uuid -import zlib -from pathlib import Path - -import numpy as np -import sentencepiece as spm -import torch -import torch.distributed as dist -import torch.nn.functional as F -from torch import Tensor, nn -from torch.nn.parallel import DistributedDataParallel as DDP - -# ----------------------------- -# HYPERPARAMETERS -# ----------------------------- -# Default Simple Baseline run: -# - 9 transformer blocks at width 512 -# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion -# - vocab size 1024, sequence length 1024, tied embeddings -# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap - -class Hyperparameters: - # Data paths are shard globs produced by the existing preprocessing pipeline. - data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") - train_files = os.path.join(data_path, "fineweb_train_*.bin") - val_files = os.path.join(data_path, "fineweb_val_*.bin") - tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") - run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) - seed = int(os.environ.get("SEED", 1337)) - - # Validation cadence and batch size. Validation always uses the full fineweb_val split. - val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) - val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) - train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) - - # Training length. - iterations = int(os.environ.get("ITERATIONS", 20000)) - warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) - warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) - train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) - train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) - max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) - qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) - - # Model shape. - vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) - num_layers = int(os.environ.get("NUM_LAYERS", 9)) - num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) - model_dim = int(os.environ.get("MODEL_DIM", 512)) - num_heads = int(os.environ.get("NUM_HEADS", 8)) - mlp_mult = int(os.environ.get("MLP_MULT", 2)) - tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) - rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) - logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) - - # Optimizer hyperparameters. - embed_lr = float(os.environ.get("EMBED_LR", 0.6)) - head_lr = float(os.environ.get("HEAD_LR", 0.008)) - tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) - tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) - matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) - scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) - muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) - muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) - muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) - muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) - beta1 = float(os.environ.get("BETA1", 0.9)) - beta2 = float(os.environ.get("BETA2", 0.95)) - adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) - grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) - - # Test-time training (LoRA) hyperparameters. - ttt_lora_rank = int(os.environ.get("TTT_LORA_RANK", 8)) - ttt_lora_lr = float(os.environ.get("TTT_LORA_LR", 0.01)) - ttt_chunk_size = int(os.environ.get("TTT_CHUNK_SIZE", 256)) - ttt_eval_seq_len = int(os.environ.get("TTT_EVAL_SEQ_LEN", 1024)) - ttt_batch_size = int(os.environ.get("TTT_BATCH_SIZE", 64)) - -# ----------------------------- -# MUON OPTIMIZER -# ----------------------------- -# -# As borrowed from modded-nanogpt -# Background on Muon: https://kellerjordan.github.io/posts/muon/ - -def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: - # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. - # Muon uses this to normalize matrix-shaped gradients before applying them. - a, b, c = (3.4445, -4.7750, 2.0315) - X = G.bfloat16() - X /= X.norm() + eps - transposed = G.size(0) > G.size(1) - if transposed: - X = X.T - for _ in range(steps): - A = X @ X.T - B = b * A + c * A @ A - X = a * X + B @ X - return X.T if transposed else X - - -class Muon(torch.optim.Optimizer): - def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): - super().__init__( - params, - dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), - ) - - @torch.no_grad() - def step(self, closure=None): - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - distributed = dist.is_available() and dist.is_initialized() - world_size = dist.get_world_size() if distributed else 1 - rank = dist.get_rank() if distributed else 0 - - for group in self.param_groups: - params = group["params"] - if not params: - continue - lr = group["lr"] - momentum = group["momentum"] - backend_steps = group["backend_steps"] - nesterov = group["nesterov"] - - total_params = sum(int(p.numel()) for p in params) - updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) - - curr = 0 - for i, p in enumerate(params): - if i % world_size == rank and p.grad is not None: - g = p.grad - state = self.state[p] - if "momentum_buffer" not in state: - state["momentum_buffer"] = torch.zeros_like(g) - buf = state["momentum_buffer"] - buf.mul_(momentum).add_(g) - if nesterov: - g = g.add(buf, alpha=momentum) - g = zeropower_via_newtonschulz5(g, steps=backend_steps) - # Scale correction from Muon reference implementations. - g *= max(1, g.size(0) / g.size(1)) ** 0.5 - updates_flat[curr : curr + p.numel()] = g.reshape(-1) - curr += p.numel() - - if distributed: - dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) - - curr = 0 - for p in params: - g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) - p.add_(g, alpha=-lr) - curr += p.numel() - - return loss - - -# ----------------------------- -# TOKENIZER-AGNOSTIC EVALUATION SETUP -# ----------------------------- -# -# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. -# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. -# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. -# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. - -def build_sentencepiece_luts( - sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device -) -> tuple[Tensor, Tensor, Tensor]: - sp_vocab_size = int(sp.vocab_size()) - table_size = max(sp_vocab_size, vocab_size) - base_bytes_np = np.zeros((table_size,), dtype=np.int16) - has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) - is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) - for token_id in range(sp_vocab_size): - if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): - continue - is_boundary_token_np[token_id] = False - if sp.is_byte(token_id): - base_bytes_np[token_id] = 1 - continue - piece = sp.id_to_piece(token_id) - if piece.startswith("▁"): - has_leading_space_np[token_id] = True - piece = piece[1:] - base_bytes_np[token_id] = len(piece.encode("utf-8")) - return ( - torch.tensor(base_bytes_np, dtype=torch.int16, device=device), - torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), - torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), - ) - - -def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: - files = [Path(p) for p in sorted(glob.glob(pattern))] - if not files: - raise FileNotFoundError(f"No files found for pattern: {pattern}") - # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. - tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() - usable = ((tokens.numel() - 1) // seq_len) * seq_len - if usable <= 0: - raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") - return tokens[: usable + 1] - - -def eval_val( - args: Hyperparameters, - model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - grad_accum_steps: int, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, -) -> tuple[float, float]: - # Validation computes two metrics: - # - val_loss: token cross-entropy (natural log) - # - val_bpb: tokenizer-agnostic compression metric used by the challenge - local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) - if local_batch_tokens < args.train_seq_len: - raise ValueError( - "VAL_BATCH_SIZE must provide at least one sequence per rank; " - f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " - f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" - ) - local_batch_seqs = local_batch_tokens // args.train_seq_len - total_seqs = (val_tokens.numel() - 1) // args.train_seq_len - seq_start = (total_seqs * rank) // world_size - seq_end = (total_seqs * (rank + 1)) // world_size - val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) - val_token_count = torch.zeros((), device=device, dtype=torch.float64) - val_byte_count = torch.zeros((), device=device, dtype=torch.float64) - - model.eval() - with torch.inference_mode(): - for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): - batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) - raw_start = batch_seq_start * args.train_seq_len - raw_end = batch_seq_end * args.train_seq_len + 1 - local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) - x = local[:-1].reshape(-1, args.train_seq_len) - y = local[1:].reshape(-1, args.train_seq_len) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - batch_loss = model(x, y).detach() - batch_token_count = float(y.numel()) - val_loss_sum += batch_loss.to(torch.float64) * batch_token_count - val_token_count += batch_token_count - prev_ids = x.reshape(-1) - tgt_ids = y.reshape(-1) - token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) - token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) - val_byte_count += token_bytes.to(torch.float64).sum() - - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) - - val_loss = val_loss_sum / val_token_count - bits_per_token = val_loss.item() / math.log(2.0) - tokens_per_byte = val_token_count.item() / val_byte_count.item() - model.train() - return float(val_loss.item()), float(bits_per_token * tokens_per_byte) - -# ----------------------------- -# POST-TRAINING QUANTIZATION -# ----------------------------- -# -# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. -# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. -# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. - -CONTROL_TENSOR_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "CONTROL_TENSOR_NAME_PATTERNS", - "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", - ",".join(CONTROL_TENSOR_NAME_PATTERNS), - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 -INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 -INT8_PER_ROW_SCALE_DTYPE = torch.float16 -INT8_CLIP_PERCENTILE = 99.99984 -INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 - -def tensor_nbytes(t: Tensor) -> int: - return int(t.numel()) * int(t.element_size()) - -def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: - if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): - return t.float().contiguous() - if t.dtype in {torch.float32, torch.bfloat16}: - passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") - return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() - return t - -def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: - t32 = t.float() - if t32.ndim == 2: - # Matrices get one scale per row, which usually tracks output-channel - # ranges much better than a single tensor-wide scale. - clip_abs = ( - torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) - if t32.numel() - else torch.empty((t32.shape[0],), dtype=torch.float32) - ) - clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) - scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) - q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() - return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() - - # Vectors / scalars use a simpler per-tensor scale. - clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 - scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) - q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() - return q, scale - -def quantize_state_dict_int8(state_dict: dict[str, Tensor]): - # Single supported clean-script export format: - # - per-row int8 for 2D float tensors - # - per-tensor int8 for other float tensors - # - exact passthrough for non-floats - # - passthrough for small float tensors, stored as fp16 to save bytes - quantized: dict[str, Tensor] = {} - scales: dict[str, Tensor] = {} - dtypes: dict[str, str] = {} - passthrough: dict[str, Tensor] = {} - passthrough_orig_dtypes: dict[str, str] = {} - qmeta: dict[str, dict[str, object]] = {} - stats = dict.fromkeys( - ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), - 0, - ) - - for name, tensor in state_dict.items(): - t = tensor.detach().to("cpu").contiguous() - stats["param_count"] += int(t.numel()) - stats["num_tensors"] += 1 - stats["baseline_tensor_bytes"] += tensor_nbytes(t) - - if not t.is_floating_point(): - stats["num_nonfloat_tensors"] += 1 - passthrough[name] = t - stats["int8_payload_bytes"] += tensor_nbytes(t) - continue - - # Small float tensors are cheap enough to keep directly. We still downcast - # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. - if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: - kept = keep_float_tensor(name, t, passthrough_orig_dtypes) - passthrough[name] = kept - stats["int8_payload_bytes"] += tensor_nbytes(kept) - continue - - stats["num_float_tensors"] += 1 - q, s = quantize_float_tensor(t) - if s.ndim > 0: - qmeta[name] = {"scheme": "per_row", "axis": 0} - quantized[name] = q - scales[name] = s - dtypes[name] = str(t.dtype).removeprefix("torch.") - stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) - - obj: dict[str, object] = { - "__quant_format__": "int8_clean_per_row_v1", - "quantized": quantized, - "scales": scales, - "dtypes": dtypes, - "passthrough": passthrough, - } - if qmeta: - obj["qmeta"] = qmeta - if passthrough_orig_dtypes: - obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes - return obj, stats - -def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: - out: dict[str, Tensor] = {} - qmeta = obj.get("qmeta", {}) - passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) - for name, q in obj["quantized"].items(): - dtype = getattr(torch, obj["dtypes"][name]) - s = obj["scales"][name] - if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: - s = s.to(dtype=torch.float32) - # Broadcast the saved row scale back across trailing dimensions. - out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() - else: - scale = float(s.item()) - out[name] = (q.float() * scale).to(dtype=dtype).contiguous() - for name, t in obj["passthrough"].items(): - # Restore small tensors, undoing the temporary fp16 storage cast if needed. - out_t = t.detach().to("cpu").contiguous() - orig_dtype = passthrough_orig_dtypes.get(name) - if isinstance(orig_dtype, str): - out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() - out[name] = out_t - return out - - -# ----------------------------- -# DATA LOADING -# ----------------------------- - -def load_data_shard(file: Path) -> Tensor: - header_bytes = 256 * np.dtype(" None: - self.file_idx = (self.file_idx + 1) % len(self.files) - self.tokens = load_data_shard(self.files[self.file_idx]) - self.pos = 0 - - def take(self, n: int) -> Tensor: - chunks: list[Tensor] = [] - remaining = n - while remaining > 0: - avail = self.tokens.numel() - self.pos - if avail <= 0: - self._advance_file() - continue - k = min(remaining, avail) - chunks.append(self.tokens[self.pos : self.pos + k]) - self.pos += k - remaining -= k - return chunks[0] if len(chunks) == 1 else torch.cat(chunks) - - -class DistributedTokenLoader: - # Each call consumes a contiguous chunk from the shared token stream, then slices out - # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. - def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): - self.rank = rank - self.world_size = world_size - self.device = device - self.stream = TokenStream(pattern) - - def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: - local_tokens = global_tokens // (self.world_size * grad_accum_steps) - per_rank_span = local_tokens + 1 - chunk = self.stream.take(per_rank_span * self.world_size) - start = self.rank * per_rank_span - local = chunk[start : start + per_rank_span].to(dtype=torch.int64) - x = local[:-1].reshape(-1, seq_len) - y = local[1:].reshape(-1, seq_len) - return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) - -# ----------------------------- -# TRANSFORMER MODULES -# ----------------------------- - -class RMSNorm(nn.Module): - def __init__(self, eps: float | None = None): - super().__init__() - self.eps = eps - - def forward(self, x: Tensor) -> Tensor: - return F.rms_norm(x, (x.size(-1),), eps=self.eps) - - -class CastedLinear(nn.Linear): - # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. - def forward(self, x: Tensor) -> Tensor: - bias = self.bias.to(x.dtype) if self.bias is not None else None - return F.linear(x, self.weight.to(x.dtype), bias) - - -def restore_low_dim_params_to_fp32(module: nn.Module) -> None: - # Keep small/control parameters in fp32 even when the model body runs in bf16. - with torch.no_grad(): - for name, param in module.named_parameters(): - if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: - param.data = param.data.float() - - -class Rotary(nn.Module): - # Caches cos/sin tables per sequence length on the current device. - def __init__(self, dim: int, base: float = 10000.0): - super().__init__() - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self._seq_len_cached = 0 - self._cos_cached: Tensor | None = None - self._sin_cached: Tensor | None = None - - def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: - if ( - self._cos_cached is None - or self._sin_cached is None - or self._seq_len_cached != seq_len - or self._cos_cached.device != device - ): - t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) - freqs = torch.outer(t, self.inv_freq.to(device)) - self._cos_cached = freqs.cos()[None, None, :, :] - self._sin_cached = freqs.sin()[None, None, :, :] - self._seq_len_cached = seq_len - return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) - - -def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: - half = x.size(-1) // 2 - x1, x2 = x[..., :half], x[..., half:] - return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) - - -class CausalSelfAttention(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - rope_base: float, - qk_gain_init: float, - ): - super().__init__() - if dim % num_heads != 0: - raise ValueError("model_dim must be divisible by num_heads") - if num_heads % num_kv_heads != 0: - raise ValueError("num_heads must be divisible by num_kv_heads") - self.num_heads = num_heads - self.num_kv_heads = num_kv_heads - self.head_dim = dim // num_heads - if self.head_dim % 2 != 0: - raise ValueError("head_dim must be even for RoPE") - kv_dim = self.num_kv_heads * self.head_dim - self.c_q = CastedLinear(dim, dim, bias=False) - self.c_k = CastedLinear(dim, kv_dim, bias=False) - self.c_v = CastedLinear(dim, kv_dim, bias=False) - self.proj = CastedLinear(dim, dim, bias=False) - self.proj._zero_init = True - self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) - self.rotary = Rotary(self.head_dim, base=rope_base) - - def forward(self, x: Tensor, q_delta=None, v_delta=None) -> Tensor: - bsz, seqlen, dim = x.shape - q = self.c_q(x) + (q_delta if q_delta is not None else 0) - k = self.c_k(x) - v = self.c_v(x) + (v_delta if v_delta is not None else 0) - q = q.reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) - k = k.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) - v = v.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) - q = F.rms_norm(q, (q.size(-1),)) - k = F.rms_norm(k, (k.size(-1),)) - cos, sin = self.rotary(seqlen, x.device, q.dtype) - q = apply_rotary_emb(q, cos, sin) - k = apply_rotary_emb(k, cos, sin) - q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] - y = F.scaled_dot_product_attention( - q, - k, - v, - attn_mask=None, - is_causal=True, - enable_gqa=(self.num_kv_heads != self.num_heads), - ) - y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) - return self.proj(y) - - -class MLP(nn.Module): - # relu^2 MLP from the original modded-nanogpt setup - def __init__(self, dim: int, mlp_mult: int): - super().__init__() - hidden = mlp_mult * dim - self.fc = CastedLinear(dim, hidden, bias=False) - self.proj = CastedLinear(hidden, dim, bias=False) - self.proj._zero_init = True - - def forward(self, x: Tensor) -> Tensor: - x = torch.relu(self.fc(x)) - return self.proj(x.square()) - - -class Block(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: int, - rope_base: float, - qk_gain_init: float, - ): - super().__init__() - self.attn_norm = RMSNorm() - self.mlp_norm = RMSNorm() - self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) - self.mlp = MLP(dim, mlp_mult) - self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) - - def forward(self, x: Tensor, x0: Tensor, q_delta_fn=None, v_delta_fn=None) -> Tensor: - mix = self.resid_mix.to(dtype=x.dtype) - x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 - n = self.attn_norm(x) - qd = q_delta_fn(n) if q_delta_fn is not None else None - vd = v_delta_fn(n) if v_delta_fn is not None else None - attn_out = self.attn(n, qd, vd) - x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out - x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) - return x - - -class GPT(nn.Module): - def __init__( - self, - vocab_size: int, - num_layers: int, - model_dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: int, - tie_embeddings: bool, - tied_embed_init_std: float, - logit_softcap: float, - rope_base: float, - qk_gain_init: float, - ): - super().__init__() - if logit_softcap <= 0.0: - raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") - self.tie_embeddings = tie_embeddings - self.tied_embed_init_std = tied_embed_init_std - self.logit_softcap = logit_softcap - self.tok_emb = nn.Embedding(vocab_size, model_dim) - self.num_encoder_layers = num_layers // 2 - self.num_decoder_layers = num_layers - self.num_encoder_layers - self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) - self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) - self.blocks = nn.ModuleList( - [ - Block( - model_dim, - num_heads, - num_kv_heads, - mlp_mult, - rope_base, - qk_gain_init, - ) - for i in range(num_layers) - ] - ) - self.final_norm = RMSNorm() - self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) - if self.lm_head is not None: - self.lm_head._zero_init = True - self._init_weights() - - def _init_weights(self) -> None: - if self.tie_embeddings: - nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) - for module in self.modules(): - if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): - nn.init.zeros_(module.weight) - - def forward(self, input_ids: Tensor, target_ids: Tensor, lora=None) -> Tensor: - x = self.tok_emb(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - x0 = x - skips: list[Tensor] = [] - - # First half stores skips; second half reuses them in reverse order. - for i in range(self.num_encoder_layers): - qd = lora.q_loras[i] if lora else None - vd = lora.v_loras[i] if lora else None - x = self.blocks[i](x, x0, qd, vd) - skips.append(x) - for i in range(self.num_decoder_layers): - bi = self.num_encoder_layers + i - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - qd = lora.q_loras[bi] if lora else None - vd = lora.v_loras[bi] if lora else None - x = self.blocks[bi](x, x0, qd, vd) - x = self.final_norm(x) - if self.tie_embeddings: - logits = F.linear(x, self.tok_emb.weight) - else: - logits = self.lm_head(x) - logits = logits + (lora.lm_head_lora(x) if lora else 0) - logits = self.logit_softcap * torch.tanh(logits / self.logit_softcap) - if lora: - bsz, sl, V = logits.shape - return F.cross_entropy( - logits.float().reshape(-1, V), target_ids.reshape(-1), reduction="none").reshape(bsz, sl) - return F.cross_entropy(logits.float().reshape(-1, logits.size(-1)), target_ids.reshape(-1), reduction="mean") - - -# ----------------------------- -# TEST-TIME TRAINING (LoRA) -# ----------------------------- -# -# At evaluation time, we adapt per-document low-rank adapters on the validation data. -# Each document gets its own adapter, so there is no inter-document dependency. - -BOS_ID = 1 - -class BatchedLinearLoRA(nn.Module): - """LoRA for a linear layer, with independent weights per batch element. - Computes x @ Aᵀ @ Bᵀ = x @ (BA)ᵀ, i.e. the LoRA delta is ΔW = BA.""" - def __init__(self, bsz: int, in_features: int, out_features: int, rank: int): - super().__init__() - self.in_features = in_features - self.A = nn.Parameter(torch.empty(bsz, rank, in_features)) # down-projection - self.B = nn.Parameter(torch.zeros(bsz, out_features, rank)) # up-projection - self.reset() - - def forward(self, x: Tensor) -> Tensor: - return (x @ self.A.transpose(1, 2)) @ self.B.transpose(1, 2) # (bsz, T, out) - - def reset(self) -> None: - bound = 1.0 / math.sqrt(self.in_features) - with torch.no_grad(): - self.A.uniform_(-bound, bound) # kaiming-uniform - self.B.zero_() - -class BatchedTTTLoRA(nn.Module): - """All LoRA adapters for one batch: LM head and Q/V per block.""" - def __init__(self, bsz: int, model: GPT, rank: int): - super().__init__() - dim = model.tok_emb.embedding_dim - vocab = model.tok_emb.num_embeddings - self.lm_head_lora = BatchedLinearLoRA(bsz, dim, vocab, rank) - self.q_loras = nn.ModuleList() - self.v_loras = nn.ModuleList() - for block in model.blocks: - self.q_loras.append(BatchedLinearLoRA(bsz, dim, block.attn.c_q.weight.shape[0], rank)) - self.v_loras.append(BatchedLinearLoRA(bsz, dim, block.attn.c_v.weight.shape[0], rank)) - - def reset(self) -> None: - for m in self.modules(): - if isinstance(m, BatchedLinearLoRA): - m.reset() - -def _reset_ttt_optimizer(opt): - for group in opt.param_groups: - for p in group['params']: - s = opt.state.get(p) - if not s: # Fresh state. - continue - s['exp_avg'].zero_() - s['exp_avg_sq'].zero_() - s['step'].fill_(0) - -def _build_ttt_optimizer(lora, args: Hyperparameters): - return torch.optim.Adam(lora.parameters(), lr=args.ttt_lora_lr, betas=(args.beta1, args.beta2), eps=1e-10) - -def _find_docs(all_tokens: Tensor, include_next_bos: bool = True) -> list[tuple[int, int]]: - """Return (start_offset, length) for each document, identified by BOS boundaries. - - If include_next_bos is True, include next document's BOS (to match continuous-stream - eval token count exactly). - """ - bos_positions = (all_tokens == BOS_ID).nonzero(as_tuple=True)[0].numpy() - docs = [] - for i in range(len(bos_positions)): - start = int(bos_positions[i]) - end = int(bos_positions[i + 1]) if i + 1 < len(bos_positions) else all_tokens.numel() - if include_next_bos and i + 1 < len(bos_positions): - end += 1 - assert end - start >= 2 - docs.append((start, end - start)) - return docs - -def _compute_chunk_window(ci: int, pred_len: int, num_chunks: int, chunk_size: int, eval_seq_len: int): - """Return (win_start, win_len, chunk_offset, chunk_len) for chunk `ci` of a doc.""" - chunk_start = ci * chunk_size - chunk_end = pred_len if ci == num_chunks - 1 else (ci + 1) * chunk_size - win_start = max(0, chunk_end - eval_seq_len) - win_len = chunk_end - win_start - chunk_offset = chunk_start - win_start - chunk_len = chunk_end - chunk_start - return win_start, win_len, chunk_offset, chunk_len - -def _accumulate_bpb( - ptl: Tensor, x: Tensor, y: Tensor, - batch_i: int, chunk_offset: int, chunk_len: int, - base_bytes_lut: Tensor, has_leading_space_lut: Tensor, is_boundary_token_lut: Tensor, - loss_sum: Tensor, byte_sum: Tensor, token_count: Tensor, -): - """Add one doc-chunk's contribution to the running BPB accumulators.""" - lbl = ptl[batch_i, chunk_offset:chunk_offset + chunk_len].to(torch.float64) - prev = x[batch_i, chunk_offset:chunk_offset + chunk_len] - tgt = y[batch_i, chunk_offset:chunk_offset + chunk_len] - tok_bytes = base_bytes_lut[tgt].to(torch.float64) - tok_bytes += has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev] - loss_sum += lbl.sum() - byte_sum += tok_bytes.sum() - token_count += chunk_len - -def eval_val_ttt_lora( - args: Hyperparameters, - base_model: GPT, - rank: int, - world_size: int, - device: torch.device, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, -) -> tuple[float, float]: - """Evaluate with batched LoRA test-time training. Returns (val_loss, val_bpb).""" - # Load validation tokens and find document boundaries - files = sorted(glob.glob(args.val_files)) - all_tokens = torch.cat([load_data_shard(Path(f)) for f in files]) - docs = _find_docs(all_tokens) - - # Each rank takes a contiguous slice of documents - rank_docs = docs[(len(docs) * rank) // world_size : (len(docs) * (rank + 1)) // world_size] - chunk_size = args.ttt_chunk_size - eval_seq_len = args.ttt_eval_seq_len - batch_size = args.ttt_batch_size - lora_rank = args.ttt_lora_rank - - rank_docs.sort(key=lambda d: (d[1] - 2) // chunk_size) - - base_model.eval() - for p in base_model.parameters(): - p.requires_grad_(False) - - lora = BatchedTTTLoRA(batch_size, base_model, lora_rank).to(device) - opt = _build_ttt_optimizer(lora, args) - - loss_sum = torch.zeros((), device=device, dtype=torch.float64) - byte_sum = torch.zeros((), device=device, dtype=torch.float64) - token_count = torch.zeros((), device=device, dtype=torch.float64) - - for bi in range(0, len(rank_docs), batch_size): - batch = rank_docs[bi:bi + batch_size] - bsz = len(batch) - - if bsz == batch_size: - cur_lora, cur_opt = lora, opt - cur_lora.reset() - _reset_ttt_optimizer(cur_opt) - else: - cur_lora = BatchedTTTLoRA(bsz, base_model, lora_rank).to(device) - cur_opt = _build_ttt_optimizer(cur_lora, args) - - pred_lens = [doc_len - 1 for _, doc_len in batch] - num_chunks = [(pl + chunk_size - 1) // chunk_size for pl in pred_lens] - max_nc = max(num_chunks) - - for ci in range(max_nc): - chunk_stats = _compute_chunk_window(ci, (ci + 1) * chunk_size, ci + 1, chunk_size, eval_seq_len) - context_size, chunk_offset = chunk_stats[1], chunk_stats[2] - - active = [ci < nc for nc in num_chunks] - needs_train = any(ci < nc - 1 for nc in num_chunks) - - x = torch.zeros(bsz, context_size, dtype=torch.int64, device=device) - y = torch.zeros(bsz, context_size, dtype=torch.int64, device=device) - doc_info = [] # (chunk_offset, chunk_len) per doc - for b in range(bsz): - if not active[b]: - doc_info.append((0, 0)) - continue - ds, dl = batch[b] - ws, wl, co, cl = _compute_chunk_window(ci, pred_lens[b], num_chunks[b], chunk_size, eval_seq_len) - chunk = all_tokens[ds + ws: ds + ws + wl + 1] - toks = chunk.to(dtype=torch.int64, device=device) - x[b, :wl] = toks[:-1] - y[b, :wl] = toks[1:] - doc_info.append((co, cl)) - - # Forward pass (keep grad graph alive only when we need to train) - if needs_train: - with torch.autocast(device_type="cuda", dtype=torch.bfloat16): - ptl = base_model(x, y, lora=cur_lora) - else: - with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16): - ptl = base_model(x, y, lora=cur_lora) - - # Score: accumulate loss and byte counts for BPB (before training on chunk) - with torch.no_grad(): - for b in range(bsz): - if not active[b]: - continue - co, cl = doc_info[b] - _accumulate_bpb( - ptl, x, y, b, co, cl, base_bytes_lut, has_leading_space_lut, - is_boundary_token_lut, loss_sum, byte_sum, token_count) - - # Train: one Adam step on the LoRA params using this chunk's loss - if needs_train: - mask = torch.tensor([float(ci < num_chunks[b] - 1) for b in range(bsz)], device=device) - per_doc = ptl[:, chunk_offset:chunk_offset + chunk_size].mean(dim=-1) - cur_opt.zero_grad() - (per_doc * mask).sum().backward() - cur_opt.step() - - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(byte_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(token_count, op=dist.ReduceOp.SUM) - - val_loss = float(loss_sum.item() / token_count.item()) - val_bpb = float((loss_sum.item() / math.log(2.0)) / byte_sum.item()) - return val_loss, val_bpb - -# ----------------------------- -# TRAINING -# ----------------------------- - -def main() -> None: - global zeropower_via_newtonschulz5 - - code = Path(__file__).read_text(encoding="utf-8") - args = Hyperparameters() - zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) - - # ----------------------------- - # DISTRIBUTED + CUDA SETUP - # ----------------------------- - - distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ - rank = int(os.environ.get("RANK", "0")) - world_size = int(os.environ.get("WORLD_SIZE", "1")) - local_rank = int(os.environ.get("LOCAL_RANK", "0")) - if world_size <= 0: - raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") - if 8 % world_size != 0: - raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") - grad_accum_steps = 8 // world_size - grad_scale = 1.0 / grad_accum_steps - if not torch.cuda.is_available(): - raise RuntimeError("CUDA is required") - device = torch.device("cuda", local_rank) - torch.cuda.set_device(device) - if distributed: - dist.init_process_group(backend="nccl", device_id=device) - dist.barrier() - master_process = rank == 0 - - # Fast math knobs - torch.backends.cuda.matmul.allow_tf32 = True - torch.backends.cudnn.allow_tf32 = True - from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp - - enable_cudnn_sdp(False) - enable_flash_sdp(True) - enable_mem_efficient_sdp(False) - enable_math_sdp(False) - - logfile = None - if master_process: - os.makedirs("logs", exist_ok=True) - logfile = f"logs/{args.run_id}.txt" - print(logfile) - - def log0(msg: str, console: bool = True) -> None: - if not master_process: - return - if console: - print(msg) - if logfile is not None: - with open(logfile, "a", encoding="utf-8") as f: - print(msg, file=f) - - log0(code, console=False) - log0("=" * 100, console=False) - log0(f"Running Python {sys.version}", console=False) - log0(f"Running PyTorch {torch.__version__}", console=False) - log0( - subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, - console=False, - ) - log0("=" * 100, console=False) - - # ----------------------------- - # TOKENIZER + VALIDATION METRIC SETUP - # ----------------------------- - - random.seed(args.seed) - np.random.seed(args.seed) - torch.manual_seed(args.seed) - torch.cuda.manual_seed_all(args.seed) - - if not args.tokenizer_path.endswith(".model"): - raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") - sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) - if int(sp.vocab_size()) != args.vocab_size: - raise ValueError( - f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" - ) - dataset_dir = Path(args.data_path).resolve() - actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) - val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) - base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( - sp, args.vocab_size, device - ) - log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") - log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") - log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") - - # ----------------------------- - # MODEL + OPTIMIZER SETUP - # ----------------------------- - - base_model = GPT( - vocab_size=args.vocab_size, - num_layers=args.num_layers, - model_dim=args.model_dim, - num_heads=args.num_heads, - num_kv_heads=args.num_kv_heads, - mlp_mult=args.mlp_mult, - tie_embeddings=args.tie_embeddings, - tied_embed_init_std=args.tied_embed_init_std, - logit_softcap=args.logit_softcap, - rope_base=args.rope_base, - qk_gain_init=args.qk_gain_init, - ).to(device).bfloat16() - for module in base_model.modules(): - if isinstance(module, CastedLinear): - module.float() - if isinstance(module, Rotary): - module.inv_freq.data = module.inv_freq.data.float() - restore_low_dim_params_to_fp32(base_model) - compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) - model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model - - # Optimizer split: - # - token embedding (Adam) uses EMBED_LR - # - untied lm_head (Adam) uses HEAD_LR - # - matrix params in transformer blocks use MATRIX_LR via Muon - # - vectors/scalars use SCALAR_LR via Adam - block_named_params = list(base_model.blocks.named_parameters()) - matrix_params = [ - p - for name, p in block_named_params - if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - scalar_params = [ - p - for name, p in block_named_params - if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - if base_model.skip_weights.numel() > 0: - scalar_params.append(base_model.skip_weights) - token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr - optimizer_tok = torch.optim.Adam( - [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizer_muon = Muon( - matrix_params, - lr=args.matrix_lr, - momentum=args.muon_momentum, - backend_steps=args.muon_backend_steps, - ) - for group in optimizer_muon.param_groups: - group["base_lr"] = args.matrix_lr - optimizer_scalar = torch.optim.Adam( - [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] - if base_model.lm_head is not None: - optimizer_head = torch.optim.Adam( - [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizers.insert(1, optimizer_head) - - n_params = sum(p.numel() for p in base_model.parameters()) - log0(f"model_params:{n_params}") - log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") - log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") - log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") - log0( - f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " - f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " - f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" - ) - log0( - f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " - f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " - f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" - ) - log0(f"seed:{args.seed}") - - # ----------------------------- - # DATA LOADER & MODEL WARMUP - # ----------------------------- - - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - def zero_grad_all() -> None: - for opt in optimizers: - opt.zero_grad(set_to_none=True) - - max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None - - def lr_mul(step: int, elapsed_ms: float) -> float: - if args.warmdown_iters <= 0: - return 1.0 - if max_wallclock_ms is None: - warmdown_start = max(args.iterations - args.warmdown_iters, 0) - return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 - step_ms = elapsed_ms / max(step, 1) - warmdown_ms = args.warmdown_iters * step_ms - remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) - return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 - - # Warmup primes the compiled forward/backward/optimizer paths, then we restore the - # initial weights/optimizer state so measured training starts from the true init. - if args.warmup_steps > 0: - initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} - initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] - model.train() - for warmup_step in range(args.warmup_steps): - zero_grad_all() - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - warmup_loss = model(x, y) - (warmup_loss * grad_scale).backward() - for opt in optimizers: - opt.step() - zero_grad_all() - if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: - log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") - base_model.load_state_dict(initial_model_state, strict=True) - for opt, state in zip(optimizers, initial_optimizer_states, strict=True): - opt.load_state_dict(state) - zero_grad_all() - if distributed: - model.require_backward_grad_sync = True - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - # ----------------------------- - # MAIN TRAINING LOOP - # ----------------------------- - - training_time_ms = 0.0 - stop_after_step: int | None = None - torch.cuda.synchronize() - t0 = time.perf_counter() - - step = 0 - while True: - last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) - - should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) - if should_validate: - torch.cuda.synchronize() - training_time_ms += 1000.0 * (time.perf_counter() - t0) - val_loss, val_bpb = eval_val( - args, - model, - rank, - world_size, - device, - grad_accum_steps, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, - ) - log0( - f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " - f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" - ) - torch.cuda.synchronize() - t0 = time.perf_counter() - - if last_step: - if stop_after_step is not None and step < args.iterations: - log0( - f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " - f"step:{step}/{args.iterations}" - ) - break - - elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - scale = lr_mul(step, elapsed_ms) - zero_grad_all() - train_loss = torch.zeros((), device=device) - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - loss = model(x, y) - train_loss += loss.detach() - (loss * grad_scale).backward() - train_loss /= grad_accum_steps - - frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 - muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum - for group in optimizer_muon.param_groups: - group["momentum"] = muon_momentum - - for opt in optimizers: - for group in opt.param_groups: - group["lr"] = group["base_lr"] * scale - - if args.grad_clip_norm > 0: - torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) - for opt in optimizers: - opt.step() - zero_grad_all() - - step += 1 - approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - should_log_train = ( - args.train_log_every > 0 - and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) - ) - if should_log_train: - log0( - f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " - f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" - ) - - # Needed to sync whether we've reached the wallclock cap. - reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms - if distributed and max_wallclock_ms is not None: - reached_cap_tensor = torch.tensor(int(reached_cap), device=device) - dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) - reached_cap = bool(reached_cap_tensor.item()) - if stop_after_step is None and reached_cap: - stop_after_step = step - - log0( - f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " - f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" - ) - - # ----------------------------- - # SERIALIZATION + ROUNDTRIP VALIDATION - # ----------------------------- - # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce - # the compressed int8+zlib artifact and validate the round-tripped weights. - - if master_process: - torch.save(base_model.state_dict(), "final_model.pt") - model_bytes = os.path.getsize("final_model.pt") - code_bytes = len(code.encode("utf-8")) - log0(f"Serialized model: {model_bytes} bytes") - log0(f"Code size: {code_bytes} bytes") - log0(f"Total submission size: {model_bytes + code_bytes} bytes") - - quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) - quant_buf = io.BytesIO() - torch.save(quant_obj, quant_buf) - quant_raw = quant_buf.getvalue() - quant_blob = zlib.compress(quant_raw, level=9) - quant_raw_bytes = len(quant_raw) - if master_process: - with open("final_model.int8.ptz", "wb") as f: - f.write(quant_blob) - quant_file_bytes = os.path.getsize("final_model.int8.ptz") - code_bytes = len(code.encode("utf-8")) - ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) - log0( - f"Serialized model int8+zlib: {quant_file_bytes} bytes " - f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" - ) - log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") - - if distributed: - dist.barrier() - with open("final_model.int8.ptz", "rb") as f: - quant_blob_disk = f.read() - quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") - base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) - torch.cuda.synchronize() - t_qeval = time.perf_counter() - q_val_loss, q_val_bpb = eval_val( - args, - model, - rank, - world_size, - device, - grad_accum_steps, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, - ) - torch.cuda.synchronize() - log0( - f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " - f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" - ) - log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") - - # LoRA test-time training evaluation (the competition score) - torch._dynamo.reset() - torch.cuda.synchronize() - t_ttt = time.perf_counter() - ttt_val_loss, ttt_val_bpb = eval_val_ttt_lora( - args, base_model, rank, world_size, device, - base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - ) - torch.cuda.synchronize() - log0( - f"final_int8_ttt_lora val_loss:{ttt_val_loss:.4f} val_bpb:{ttt_val_bpb:.4f} " - f"eval_time:{1000.0 * (time.perf_counter() - t_ttt):.0f}ms" - ) - - if distributed: - dist.destroy_process_group() - - -if __name__ == "__main__": - main() diff --git a/records/track_10min_16mb/2026-03-17_LoRA_TTT/train_v0.txt b/records/track_10min_16mb/2026-03-17_LoRA_TTT/train_v0.txt deleted file mode 100644 index 5445d48403..0000000000 --- a/records/track_10min_16mb/2026-03-17_LoRA_TTT/train_v0.txt +++ /dev/null @@ -1,1829 +0,0 @@ -""" -The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. - -Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. -""" - -from __future__ import annotations - -import copy -import glob -import io -import math -import os -import random -import subprocess -import sys -import time -import uuid -import zlib -from pathlib import Path - -import numpy as np -import sentencepiece as spm -import torch -import torch.distributed as dist -import torch.nn.functional as F -from torch import Tensor, nn -from torch.nn.parallel import DistributedDataParallel as DDP - -# ----------------------------- -# HYPERPARAMETERS -# ----------------------------- -# Default Simple Baseline run: -# - 9 transformer blocks at width 512 -# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion -# - vocab size 1024, sequence length 1024, tied embeddings -# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap - -class Hyperparameters: - # Data paths are shard globs produced by the existing preprocessing pipeline. - data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") - train_files = os.path.join(data_path, "fineweb_train_*.bin") - val_files = os.path.join(data_path, "fineweb_val_*.bin") - tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") - run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) - seed = int(os.environ.get("SEED", 1337)) - resume_from = os.environ.get("RESUME_FROM", "") - - # Validation cadence and batch size. Validation always uses the full fineweb_val split. - val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) - val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) - train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) - - # Training length. - iterations = int(os.environ.get("ITERATIONS", 20000)) - warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) - warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) - train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) - train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) - max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) - qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) - - # Model shape. - vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) - num_layers = int(os.environ.get("NUM_LAYERS", 9)) - num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) - model_dim = int(os.environ.get("MODEL_DIM", 512)) - num_heads = int(os.environ.get("NUM_HEADS", 8)) - mlp_mult = int(os.environ.get("MLP_MULT", 2)) - tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) - rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) - logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) - - # Optimizer hyperparameters. - embed_lr = float(os.environ.get("EMBED_LR", 0.6)) - head_lr = float(os.environ.get("HEAD_LR", 0.008)) - tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) - tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) - matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) - scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) - muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) - muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) - muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) - muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) - beta1 = float(os.environ.get("BETA1", 0.9)) - beta2 = float(os.environ.get("BETA2", 0.95)) - adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) - grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) - - # Test-time training (LoRA) hyperparameters. - ttt_lora_rank = int(os.environ.get("TTT_LORA_RANK", 8)) - ttt_lora_lr = float(os.environ.get("TTT_LORA_LR", 0.01)) - ttt_chunk_size = int(os.environ.get("TTT_CHUNK_SIZE", 256)) - ttt_eval_seq_len = int(os.environ.get("TTT_EVAL_SEQ_LEN", 1024)) - ttt_batch_size = int(os.environ.get("TTT_BATCH_SIZE", 64)) - -# ----------------------------- -# MUON OPTIMIZER -# ----------------------------- -# -# As borrowed from modded-nanogpt -# Background on Muon: https://kellerjordan.github.io/posts/muon/ - -def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: - # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. - # Muon uses this to normalize matrix-shaped gradients before applying them. - a, b, c = (3.4445, -4.7750, 2.0315) - X = G.bfloat16() - X /= X.norm() + eps - transposed = G.size(0) > G.size(1) - if transposed: - X = X.T - for _ in range(steps): - A = X @ X.T - B = b * A + c * A @ A - X = a * X + B @ X - return X.T if transposed else X - - -class Muon(torch.optim.Optimizer): - def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): - super().__init__( - params, - dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), - ) - - @torch.no_grad() - def step(self, closure=None): - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - distributed = dist.is_available() and dist.is_initialized() - world_size = dist.get_world_size() if distributed else 1 - rank = dist.get_rank() if distributed else 0 - - for group in self.param_groups: - params = group["params"] - if not params: - continue - lr = group["lr"] - momentum = group["momentum"] - backend_steps = group["backend_steps"] - nesterov = group["nesterov"] - - total_params = sum(int(p.numel()) for p in params) - updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) - - curr = 0 - for i, p in enumerate(params): - if i % world_size == rank and p.grad is not None: - g = p.grad - state = self.state[p] - if "momentum_buffer" not in state: - state["momentum_buffer"] = torch.zeros_like(g) - buf = state["momentum_buffer"] - buf.mul_(momentum).add_(g) - if nesterov: - g = g.add(buf, alpha=momentum) - g = zeropower_via_newtonschulz5(g, steps=backend_steps) - # Scale correction from Muon reference implementations. - g *= max(1, g.size(0) / g.size(1)) ** 0.5 - updates_flat[curr : curr + p.numel()] = g.reshape(-1) - curr += p.numel() - - if distributed: - dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) - - curr = 0 - for p in params: - g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) - p.add_(g, alpha=-lr) - curr += p.numel() - - return loss - - -# ----------------------------- -# TOKENIZER-AGNOSTIC EVALUATION SETUP -# ----------------------------- -# -# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. -# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. -# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. -# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. - -def build_sentencepiece_luts( - sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device -) -> tuple[Tensor, Tensor, Tensor]: - sp_vocab_size = int(sp.vocab_size()) - table_size = max(sp_vocab_size, vocab_size) - base_bytes_np = np.zeros((table_size,), dtype=np.int16) - has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) - is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) - for token_id in range(sp_vocab_size): - if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): - continue - is_boundary_token_np[token_id] = False - if sp.is_byte(token_id): - base_bytes_np[token_id] = 1 - continue - piece = sp.id_to_piece(token_id) - if piece.startswith("▁"): - has_leading_space_np[token_id] = True - piece = piece[1:] - base_bytes_np[token_id] = len(piece.encode("utf-8")) - return ( - torch.tensor(base_bytes_np, dtype=torch.int16, device=device), - torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), - torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), - ) - - -def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: - files = [Path(p) for p in sorted(glob.glob(pattern))] - if not files: - raise FileNotFoundError(f"No files found for pattern: {pattern}") - # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. - tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() - usable = ((tokens.numel() - 1) // seq_len) * seq_len - if usable <= 0: - raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") - return tokens[: usable + 1] - - -def eval_val( - args: Hyperparameters, - model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - grad_accum_steps: int, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, -) -> tuple[float, float]: - # Validation computes two metrics: - # - val_loss: token cross-entropy (natural log) - # - val_bpb: tokenizer-agnostic compression metric used by the challenge - local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) - if local_batch_tokens < args.train_seq_len: - raise ValueError( - "VAL_BATCH_SIZE must provide at least one sequence per rank; " - f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " - f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" - ) - local_batch_seqs = local_batch_tokens // args.train_seq_len - total_seqs = (val_tokens.numel() - 1) // args.train_seq_len - seq_start = (total_seqs * rank) // world_size - seq_end = (total_seqs * (rank + 1)) // world_size - val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) - val_token_count = torch.zeros((), device=device, dtype=torch.float64) - val_byte_count = torch.zeros((), device=device, dtype=torch.float64) - - model.eval() - with torch.inference_mode(): - for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): - batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) - raw_start = batch_seq_start * args.train_seq_len - raw_end = batch_seq_end * args.train_seq_len + 1 - local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) - x = local[:-1].reshape(-1, args.train_seq_len) - y = local[1:].reshape(-1, args.train_seq_len) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - batch_loss = model(x, y).detach() - batch_token_count = float(y.numel()) - val_loss_sum += batch_loss.to(torch.float64) * batch_token_count - val_token_count += batch_token_count - prev_ids = x.reshape(-1) - tgt_ids = y.reshape(-1) - token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) - token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) - val_byte_count += token_bytes.to(torch.float64).sum() - - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) - - val_loss = val_loss_sum / val_token_count - bits_per_token = val_loss.item() / math.log(2.0) - tokens_per_byte = val_token_count.item() / val_byte_count.item() - model.train() - return float(val_loss.item()), float(bits_per_token * tokens_per_byte) - -# ----------------------------- -# POST-TRAINING QUANTIZATION -# ----------------------------- -# -# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. -# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. -# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. - -CONTROL_TENSOR_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "CONTROL_TENSOR_NAME_PATTERNS", - "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", - ",".join(CONTROL_TENSOR_NAME_PATTERNS), - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 -INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 -INT8_PER_ROW_SCALE_DTYPE = torch.float16 -INT8_CLIP_PERCENTILE = 99.99984 -INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 - -def tensor_nbytes(t: Tensor) -> int: - return int(t.numel()) * int(t.element_size()) - -def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: - if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): - return t.float().contiguous() - if t.dtype in {torch.float32, torch.bfloat16}: - passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") - return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() - return t - -def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: - t32 = t.float() - if t32.ndim == 2: - # Matrices get one scale per row, which usually tracks output-channel - # ranges much better than a single tensor-wide scale. - clip_abs = ( - torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) - if t32.numel() - else torch.empty((t32.shape[0],), dtype=torch.float32) - ) - clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) - scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) - q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() - return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() - - # Vectors / scalars use a simpler per-tensor scale. - clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 - scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) - q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() - return q, scale - -def quantize_state_dict_int8(state_dict: dict[str, Tensor]): - # Single supported clean-script export format: - # - per-row int8 for 2D float tensors - # - per-tensor int8 for other float tensors - # - exact passthrough for non-floats - # - passthrough for small float tensors, stored as fp16 to save bytes - quantized: dict[str, Tensor] = {} - scales: dict[str, Tensor] = {} - dtypes: dict[str, str] = {} - passthrough: dict[str, Tensor] = {} - passthrough_orig_dtypes: dict[str, str] = {} - qmeta: dict[str, dict[str, object]] = {} - stats = dict.fromkeys( - ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), - 0, - ) - - for name, tensor in state_dict.items(): - t = tensor.detach().to("cpu").contiguous() - stats["param_count"] += int(t.numel()) - stats["num_tensors"] += 1 - stats["baseline_tensor_bytes"] += tensor_nbytes(t) - - if not t.is_floating_point(): - stats["num_nonfloat_tensors"] += 1 - passthrough[name] = t - stats["int8_payload_bytes"] += tensor_nbytes(t) - continue - - # Small float tensors are cheap enough to keep directly. We still downcast - # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. - if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: - kept = keep_float_tensor(name, t, passthrough_orig_dtypes) - passthrough[name] = kept - stats["int8_payload_bytes"] += tensor_nbytes(kept) - continue - - stats["num_float_tensors"] += 1 - q, s = quantize_float_tensor(t) - if s.ndim > 0: - qmeta[name] = {"scheme": "per_row", "axis": 0} - quantized[name] = q - scales[name] = s - dtypes[name] = str(t.dtype).removeprefix("torch.") - stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) - - obj: dict[str, object] = { - "__quant_format__": "int8_clean_per_row_v1", - "quantized": quantized, - "scales": scales, - "dtypes": dtypes, - "passthrough": passthrough, - } - if qmeta: - obj["qmeta"] = qmeta - if passthrough_orig_dtypes: - obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes - return obj, stats - -def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: - out: dict[str, Tensor] = {} - qmeta = obj.get("qmeta", {}) - passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) - for name, q in obj["quantized"].items(): - dtype = getattr(torch, obj["dtypes"][name]) - s = obj["scales"][name] - if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: - s = s.to(dtype=torch.float32) - # Broadcast the saved row scale back across trailing dimensions. - out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() - else: - scale = float(s.item()) - out[name] = (q.float() * scale).to(dtype=dtype).contiguous() - for name, t in obj["passthrough"].items(): - # Restore small tensors, undoing the temporary fp16 storage cast if needed. - out_t = t.detach().to("cpu").contiguous() - orig_dtype = passthrough_orig_dtypes.get(name) - if isinstance(orig_dtype, str): - out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() - out[name] = out_t - return out - - -# ----------------------------- -# DATA LOADING -# ----------------------------- - -def load_data_shard(file: Path) -> Tensor: - header_bytes = 256 * np.dtype(" None: - self.file_idx = (self.file_idx + 1) % len(self.files) - self.tokens = load_data_shard(self.files[self.file_idx]) - self.pos = 0 - - def take(self, n: int) -> Tensor: - chunks: list[Tensor] = [] - remaining = n - while remaining > 0: - avail = self.tokens.numel() - self.pos - if avail <= 0: - self._advance_file() - continue - k = min(remaining, avail) - chunks.append(self.tokens[self.pos : self.pos + k]) - self.pos += k - remaining -= k - return chunks[0] if len(chunks) == 1 else torch.cat(chunks) - - -class DistributedTokenLoader: - # Each call consumes a contiguous chunk from the shared token stream, then slices out - # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. - def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): - self.rank = rank - self.world_size = world_size - self.device = device - self.stream = TokenStream(pattern) - - def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: - local_tokens = global_tokens // (self.world_size * grad_accum_steps) - per_rank_span = local_tokens + 1 - chunk = self.stream.take(per_rank_span * self.world_size) - start = self.rank * per_rank_span - local = chunk[start : start + per_rank_span].to(dtype=torch.int64) - x = local[:-1].reshape(-1, seq_len) - y = local[1:].reshape(-1, seq_len) - return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) - -# ----------------------------- -# TRANSFORMER MODULES -# ----------------------------- - -class RMSNorm(nn.Module): - def __init__(self, eps: float | None = None): - super().__init__() - self.eps = eps - - def forward(self, x: Tensor) -> Tensor: - return F.rms_norm(x, (x.size(-1),), eps=self.eps) - - -class CastedLinear(nn.Linear): - # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. - def forward(self, x: Tensor) -> Tensor: - bias = self.bias.to(x.dtype) if self.bias is not None else None - return F.linear(x, self.weight.to(x.dtype), bias) - - -def restore_low_dim_params_to_fp32(module: nn.Module) -> None: - # Keep small/control parameters in fp32 even when the model body runs in bf16. - with torch.no_grad(): - for name, param in module.named_parameters(): - if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: - param.data = param.data.float() - - -class Rotary(nn.Module): - # Caches cos/sin tables per sequence length on the current device. - def __init__(self, dim: int, base: float = 10000.0): - super().__init__() - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self._seq_len_cached = 0 - self._cos_cached: Tensor | None = None - self._sin_cached: Tensor | None = None - - def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: - if ( - self._cos_cached is None - or self._sin_cached is None - or self._seq_len_cached != seq_len - or self._cos_cached.device != device - ): - t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) - freqs = torch.outer(t, self.inv_freq.to(device)) - self._cos_cached = freqs.cos()[None, None, :, :] - self._sin_cached = freqs.sin()[None, None, :, :] - self._seq_len_cached = seq_len - return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) - - -def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: - half = x.size(-1) // 2 - x1, x2 = x[..., :half], x[..., half:] - return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) - - -class CausalSelfAttention(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - rope_base: float, - qk_gain_init: float, - ): - super().__init__() - if dim % num_heads != 0: - raise ValueError("model_dim must be divisible by num_heads") - if num_heads % num_kv_heads != 0: - raise ValueError("num_heads must be divisible by num_kv_heads") - self.num_heads = num_heads - self.num_kv_heads = num_kv_heads - self.head_dim = dim // num_heads - if self.head_dim % 2 != 0: - raise ValueError("head_dim must be even for RoPE") - kv_dim = self.num_kv_heads * self.head_dim - self.c_q = CastedLinear(dim, dim, bias=False) - self.c_k = CastedLinear(dim, kv_dim, bias=False) - self.c_v = CastedLinear(dim, kv_dim, bias=False) - self.proj = CastedLinear(dim, dim, bias=False) - self.proj._zero_init = True - self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) - self.rotary = Rotary(self.head_dim, base=rope_base) - - def forward(self, x: Tensor, q_delta=None, v_delta=None) -> Tensor: - bsz, seqlen, dim = x.shape - q = self.c_q(x) + (q_delta if q_delta is not None else 0) - k = self.c_k(x) - v = self.c_v(x) + (v_delta if v_delta is not None else 0) - q = q.reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) - k = k.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) - v = v.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) - q = F.rms_norm(q, (q.size(-1),)) - k = F.rms_norm(k, (k.size(-1),)) - cos, sin = self.rotary(seqlen, x.device, q.dtype) - q = apply_rotary_emb(q, cos, sin) - k = apply_rotary_emb(k, cos, sin) - q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] - y = F.scaled_dot_product_attention( - q, - k, - v, - attn_mask=None, - is_causal=True, - enable_gqa=(self.num_kv_heads != self.num_heads), - ) - y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) - return self.proj(y) - - -class MLP(nn.Module): - # relu^2 MLP from the original modded-nanogpt setup - def __init__(self, dim: int, mlp_mult: int): - super().__init__() - hidden = mlp_mult * dim - self.fc = CastedLinear(dim, hidden, bias=False) - self.proj = CastedLinear(hidden, dim, bias=False) - self.proj._zero_init = True - - def forward(self, x: Tensor) -> Tensor: - x = torch.relu(self.fc(x)) - return self.proj(x.square()) - - -class Block(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: int, - rope_base: float, - qk_gain_init: float, - ): - super().__init__() - self.attn_norm = RMSNorm() - self.mlp_norm = RMSNorm() - self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) - self.mlp = MLP(dim, mlp_mult) - self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) - - def forward(self, x: Tensor, x0: Tensor, q_delta_fn=None, v_delta_fn=None) -> Tensor: - mix = self.resid_mix.to(dtype=x.dtype) - x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 - n = self.attn_norm(x) - qd = q_delta_fn(n) if q_delta_fn is not None else None - vd = v_delta_fn(n) if v_delta_fn is not None else None - attn_out = self.attn(n, qd, vd) - x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out - x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) - return x - - -class GPT(nn.Module): - def __init__( - self, - vocab_size: int, - num_layers: int, - model_dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: int, - tie_embeddings: bool, - tied_embed_init_std: float, - logit_softcap: float, - rope_base: float, - qk_gain_init: float, - ): - super().__init__() - if logit_softcap <= 0.0: - raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") - self.tie_embeddings = tie_embeddings - self.tied_embed_init_std = tied_embed_init_std - self.logit_softcap = logit_softcap - self.tok_emb = nn.Embedding(vocab_size, model_dim) - self.num_encoder_layers = num_layers // 2 - self.num_decoder_layers = num_layers - self.num_encoder_layers - self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) - self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) - self.blocks = nn.ModuleList( - [ - Block( - model_dim, - num_heads, - num_kv_heads, - mlp_mult, - rope_base, - qk_gain_init, - ) - for i in range(num_layers) - ] - ) - self.final_norm = RMSNorm() - self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) - if self.lm_head is not None: - self.lm_head._zero_init = True - self._init_weights() - - def _init_weights(self) -> None: - if self.tie_embeddings: - nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) - for module in self.modules(): - if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): - nn.init.zeros_(module.weight) - - def forward(self, input_ids: Tensor, target_ids: Tensor, lora=None) -> Tensor: - x = self.tok_emb(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - x0 = x - skips: list[Tensor] = [] - - # First half stores skips; second half reuses them in reverse order. - for i in range(self.num_encoder_layers): - qd = lora.q_loras[i] if lora else None - vd = lora.v_loras[i] if lora else None - x = self.blocks[i](x, x0, qd, vd) - skips.append(x) - for i in range(self.num_decoder_layers): - bi = self.num_encoder_layers + i - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - qd = lora.q_loras[bi] if lora else None - vd = lora.v_loras[bi] if lora else None - x = self.blocks[bi](x, x0, qd, vd) - x = self.final_norm(x) - if self.tie_embeddings: - logits = F.linear(x, self.tok_emb.weight) - else: - logits = self.lm_head(x) - logits = logits + (lora.lm_head_lora(x) if lora else 0) - logits = self.logit_softcap * torch.tanh(logits / self.logit_softcap) - if lora: - bsz, sl, V = logits.shape - return F.cross_entropy( - logits.float().reshape(-1, V), target_ids.reshape(-1), reduction="none").reshape(bsz, sl) - return F.cross_entropy(logits.float().reshape(-1, logits.size(-1)), target_ids.reshape(-1), reduction="mean") - - -# ----------------------------- -# TEST-TIME TRAINING (LoRA) -# ----------------------------- -# -# At evaluation time, we adapt per-document low-rank adapters on the validation data. -# Each document gets its own adapter, so there is no inter-document dependency. - -BOS_ID = 1 - -class BatchedLinearLoRA(nn.Module): - """LoRA for a linear layer, with independent weights per batch element. - Computes x @ Aᵀ @ Bᵀ = x @ (BA)ᵀ, i.e. the LoRA delta is ΔW = BA.""" - def __init__(self, bsz: int, in_features: int, out_features: int, rank: int): - super().__init__() - self.in_features = in_features - self.A = nn.Parameter(torch.empty(bsz, rank, in_features)) # down-projection - self.B = nn.Parameter(torch.zeros(bsz, out_features, rank)) # up-projection - self.reset() - - def forward(self, x: Tensor) -> Tensor: - return (x @ self.A.transpose(1, 2)) @ self.B.transpose(1, 2) # (bsz, T, out) - - def reset(self) -> None: - bound = 1.0 / math.sqrt(self.in_features) - with torch.no_grad(): - self.A.uniform_(-bound, bound) # kaiming-uniform - self.B.zero_() - -class BatchedTTTLoRA(nn.Module): - """All LoRA adapters for one batch: LM head and Q/V per block.""" - def __init__(self, bsz: int, model: GPT, rank: int): - super().__init__() - dim = model.tok_emb.embedding_dim - vocab = model.tok_emb.num_embeddings - self.lm_head_lora = BatchedLinearLoRA(bsz, dim, vocab, rank) - self.q_loras = nn.ModuleList() - self.v_loras = nn.ModuleList() - for block in model.blocks: - self.q_loras.append(BatchedLinearLoRA(bsz, dim, block.attn.c_q.weight.shape[0], rank)) - self.v_loras.append(BatchedLinearLoRA(bsz, dim, block.attn.c_v.weight.shape[0], rank)) - - def reset(self) -> None: - for m in self.modules(): - if isinstance(m, BatchedLinearLoRA): - m.reset() - -def _reset_ttt_optimizer(opt): - for group in opt.param_groups: - for p in group['params']: - s = opt.state.get(p) - if not s: # Fresh state. - continue - s['exp_avg'].zero_() - s['exp_avg_sq'].zero_() - s['step'].fill_(0) - -def _build_ttt_optimizer(lora, args: Hyperparameters): - return torch.optim.Adam(lora.parameters(), lr=args.ttt_lora_lr, betas=(args.beta1, args.beta2), eps=1e-10) - -def _find_docs(all_tokens: Tensor, include_next_bos: bool = True) -> list[tuple[int, int]]: - """Return (start_offset, length) for each document, identified by BOS boundaries. - - If include_next_bos is True, include next document's BOS (to match continuous-stream - eval token count exactly). - """ - bos_positions = (all_tokens == BOS_ID).nonzero(as_tuple=True)[0].numpy() - docs = [] - for i in range(len(bos_positions)): - start = int(bos_positions[i]) - end = int(bos_positions[i + 1]) if i + 1 < len(bos_positions) else all_tokens.numel() - if include_next_bos and i + 1 < len(bos_positions): - end += 1 - assert end - start >= 2 - docs.append((start, end - start)) - return docs - -def _compute_chunk_window(ci: int, pred_len: int, num_chunks: int, chunk_size: int, eval_seq_len: int): - """Return (win_start, win_len, chunk_offset, chunk_len) for chunk `ci` of a doc.""" - chunk_start = ci * chunk_size - chunk_end = pred_len if ci == num_chunks - 1 else (ci + 1) * chunk_size - win_start = max(0, chunk_end - eval_seq_len) - win_len = chunk_end - win_start - chunk_offset = chunk_start - win_start - chunk_len = chunk_end - chunk_start - return win_start, win_len, chunk_offset, chunk_len - -def _accumulate_bpb( - ptl: Tensor, x: Tensor, y: Tensor, - batch_i: int, chunk_offset: int, chunk_len: int, - base_bytes_lut: Tensor, has_leading_space_lut: Tensor, is_boundary_token_lut: Tensor, - loss_sum: Tensor, byte_sum: Tensor, token_count: Tensor, -): - """Add one doc-chunk's contribution to the running BPB accumulators.""" - lbl = ptl[batch_i, chunk_offset:chunk_offset + chunk_len].to(torch.float64) - prev = x[batch_i, chunk_offset:chunk_offset + chunk_len] - tgt = y[batch_i, chunk_offset:chunk_offset + chunk_len] - tok_bytes = base_bytes_lut[tgt].to(torch.float64) - tok_bytes += has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev] - loss_sum += lbl.sum() - byte_sum += tok_bytes.sum() - token_count += chunk_len - -def eval_val_ttt_lora( - args: Hyperparameters, - base_model: GPT, - rank: int, - world_size: int, - device: torch.device, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, -) -> tuple[float, float]: - """Evaluate with batched LoRA test-time training. Returns (val_loss, val_bpb).""" - # Load validation tokens and find document boundaries - files = sorted(glob.glob(args.val_files)) - all_tokens = torch.cat([load_data_shard(Path(f)) for f in files]) - docs = _find_docs(all_tokens) - - # Each rank takes a contiguous slice of documents - rank_docs = docs[(len(docs) * rank) // world_size : (len(docs) * (rank + 1)) // world_size] - chunk_size = args.ttt_chunk_size - eval_seq_len = args.ttt_eval_seq_len - batch_size = args.ttt_batch_size - lora_rank = args.ttt_lora_rank - - rank_docs.sort(key=lambda d: (d[1] - 2) // chunk_size) - - base_model.eval() - for p in base_model.parameters(): - p.requires_grad_(False) - - lora = BatchedTTTLoRA(batch_size, base_model, lora_rank).to(device) - opt = _build_ttt_optimizer(lora, args) - - loss_sum = torch.zeros((), device=device, dtype=torch.float64) - byte_sum = torch.zeros((), device=device, dtype=torch.float64) - token_count = torch.zeros((), device=device, dtype=torch.float64) - - for bi in range(0, len(rank_docs), batch_size): - batch = rank_docs[bi:bi + batch_size] - bsz = len(batch) - - if bsz == batch_size: - cur_lora, cur_opt = lora, opt - cur_lora.reset() - _reset_ttt_optimizer(cur_opt) - else: - cur_lora = BatchedTTTLoRA(bsz, base_model, lora_rank).to(device) - cur_opt = _build_ttt_optimizer(cur_lora, args) - - pred_lens = [doc_len - 1 for _, doc_len in batch] - num_chunks = [(pl + chunk_size - 1) // chunk_size for pl in pred_lens] - max_nc = max(num_chunks) - - for ci in range(max_nc): - chunk_stats = _compute_chunk_window(ci, (ci + 1) * chunk_size, ci + 1, chunk_size, eval_seq_len) - context_size, chunk_offset = chunk_stats[1], chunk_stats[2] - - active = [ci < nc for nc in num_chunks] - needs_train = any(ci < nc - 1 for nc in num_chunks) - - x = torch.zeros(bsz, context_size, dtype=torch.int64, device=device) - y = torch.zeros(bsz, context_size, dtype=torch.int64, device=device) - doc_info = [] # (chunk_offset, chunk_len) per doc - for b in range(bsz): - if not active[b]: - doc_info.append((0, 0)) - continue - ds, dl = batch[b] - ws, wl, co, cl = _compute_chunk_window(ci, pred_lens[b], num_chunks[b], chunk_size, eval_seq_len) - chunk = all_tokens[ds + ws: ds + ws + wl + 1] - toks = chunk.to(dtype=torch.int64, device=device) - x[b, :wl] = toks[:-1] - y[b, :wl] = toks[1:] - doc_info.append((co, cl)) - - # Forward pass (keep grad graph alive only when we need to train) - if needs_train: - with torch.autocast(device_type="cuda", dtype=torch.bfloat16): - ptl = base_model(x, y, lora=cur_lora) - else: - with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16): - ptl = base_model(x, y, lora=cur_lora) - - # Score: accumulate loss and byte counts for BPB (before training on chunk) - with torch.no_grad(): - for b in range(bsz): - if not active[b]: - continue - co, cl = doc_info[b] - _accumulate_bpb( - ptl, x, y, b, co, cl, base_bytes_lut, has_leading_space_lut, - is_boundary_token_lut, loss_sum, byte_sum, token_count) - - # Train: one Adam step on the LoRA params using this chunk's loss - if needs_train: - mask = torch.tensor([float(ci < num_chunks[b] - 1) for b in range(bsz)], device=device) - per_doc = ptl[:, chunk_offset:chunk_offset + chunk_size].mean(dim=-1) - cur_opt.zero_grad() - (per_doc * mask).sum().backward() - cur_opt.step() - - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(byte_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(token_count, op=dist.ReduceOp.SUM) - - val_loss = float(loss_sum.item() / token_count.item()) - val_bpb = float((loss_sum.item() / math.log(2.0)) / byte_sum.item()) - return val_loss, val_bpb - -# ----------------------------- -# TRAINING -# ----------------------------- - -def main() -> None: - global zeropower_via_newtonschulz5 - - code = Path(__file__).read_text(encoding="utf-8") - args = Hyperparameters() - zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) - - # ----------------------------- - # DISTRIBUTED + CUDA SETUP - # ----------------------------- - - distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ - rank = int(os.environ.get("RANK", "0")) - world_size = int(os.environ.get("WORLD_SIZE", "1")) - local_rank = int(os.environ.get("LOCAL_RANK", "0")) - if world_size <= 0: - raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") - if 8 % world_size != 0: - raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") - grad_accum_steps = 8 // world_size - grad_scale = 1.0 / grad_accum_steps - if not torch.cuda.is_available(): - raise RuntimeError("CUDA is required") - device = torch.device("cuda", local_rank) - torch.cuda.set_device(device) - if distributed: - dist.init_process_group(backend="nccl", device_id=device) - dist.barrier() - master_process = rank == 0 - - # Fast math knobs - torch.backends.cuda.matmul.allow_tf32 = True - torch.backends.cudnn.allow_tf32 = True - from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp - - enable_cudnn_sdp(False) - enable_flash_sdp(True) - enable_mem_efficient_sdp(False) - enable_math_sdp(False) - - logfile = None - if master_process: - os.makedirs("logs", exist_ok=True) - logfile = f"logs/{args.run_id}.txt" - print(logfile) - - def log0(msg: str, console: bool = True) -> None: - if not master_process: - return - if console: - print(msg) - if logfile is not None: - with open(logfile, "a", encoding="utf-8") as f: - print(msg, file=f) - - log0(code, console=False) - log0("=" * 100, console=False) - log0(f"Running Python {sys.version}", console=False) - log0(f"Running PyTorch {torch.__version__}", console=False) - log0( - subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, - console=False, - ) - log0("=" * 100, console=False) - - # ----------------------------- - # TOKENIZER + VALIDATION METRIC SETUP - # ----------------------------- - - random.seed(args.seed) - np.random.seed(args.seed) - torch.manual_seed(args.seed) - torch.cuda.manual_seed_all(args.seed) - - if not args.tokenizer_path.endswith(".model"): - raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") - sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) - if int(sp.vocab_size()) != args.vocab_size: - raise ValueError( - f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" - ) - dataset_dir = Path(args.data_path).resolve() - actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) - val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) - base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( - sp, args.vocab_size, device - ) - log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") - log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") - log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") - - # ----------------------------- - # MODEL + OPTIMIZER SETUP - # ----------------------------- - - base_model = GPT( - vocab_size=args.vocab_size, - num_layers=args.num_layers, - model_dim=args.model_dim, - num_heads=args.num_heads, - num_kv_heads=args.num_kv_heads, - mlp_mult=args.mlp_mult, - tie_embeddings=args.tie_embeddings, - tied_embed_init_std=args.tied_embed_init_std, - logit_softcap=args.logit_softcap, - rope_base=args.rope_base, - qk_gain_init=args.qk_gain_init, - ).to(device).bfloat16() - for module in base_model.modules(): - if isinstance(module, CastedLinear): - module.float() - if isinstance(module, Rotary): - module.inv_freq.data = module.inv_freq.data.float() - restore_low_dim_params_to_fp32(base_model) - compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) - model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model - - # Optimizer split: - # - token embedding (Adam) uses EMBED_LR - # - untied lm_head (Adam) uses HEAD_LR - # - matrix params in transformer blocks use MATRIX_LR via Muon - # - vectors/scalars use SCALAR_LR via Adam - block_named_params = list(base_model.blocks.named_parameters()) - matrix_params = [ - p - for name, p in block_named_params - if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - scalar_params = [ - p - for name, p in block_named_params - if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - if base_model.skip_weights.numel() > 0: - scalar_params.append(base_model.skip_weights) - token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr - optimizer_tok = torch.optim.Adam( - [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizer_muon = Muon( - matrix_params, - lr=args.matrix_lr, - momentum=args.muon_momentum, - backend_steps=args.muon_backend_steps, - ) - for group in optimizer_muon.param_groups: - group["base_lr"] = args.matrix_lr - optimizer_scalar = torch.optim.Adam( - [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] - if base_model.lm_head is not None: - optimizer_head = torch.optim.Adam( - [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizers.insert(1, optimizer_head) - - n_params = sum(p.numel() for p in base_model.parameters()) - log0(f"model_params:{n_params}") - log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") - log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") - log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") - log0( - f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " - f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " - f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" - ) - log0( - f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " - f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " - f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" - ) - log0(f"seed:{args.seed}") - - # ----------------------------- - # DATA LOADER & MODEL WARMUP - # ----------------------------- - - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - def zero_grad_all() -> None: - for opt in optimizers: - opt.zero_grad(set_to_none=True) - - max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None - - def lr_mul(step: int, elapsed_ms: float) -> float: - if args.warmdown_iters <= 0: - return 1.0 - if max_wallclock_ms is None: - warmdown_start = max(args.iterations - args.warmdown_iters, 0) - return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 - step_ms = elapsed_ms / max(step, 1) - warmdown_ms = args.warmdown_iters * step_ms - remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) - return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 - - # Warmup primes the compiled forward/backward/optimizer paths, then we restore the - # initial weights/optimizer state so measured training starts from the true init. - if args.warmup_steps > 0: - initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} - initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] - model.train() - for warmup_step in range(args.warmup_steps): - zero_grad_all() - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - warmup_loss = model(x, y) - (warmup_loss * grad_scale).backward() - for opt in optimizers: - opt.step() - zero_grad_all() - if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: - log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") - base_model.load_state_dict(initial_model_state, strict=True) - for opt, state in zip(optimizers, initial_optimizer_states, strict=True): - opt.load_state_dict(state) - zero_grad_all() - if distributed: - model.require_backward_grad_sync = True - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - # ----------------------------- - # MAIN TRAINING LOOP - # ----------------------------- - - training_time_ms = 0.0 - stop_after_step: int | None = None - torch.cuda.synchronize() - t0 = time.perf_counter() - - step = 0 - while True: - last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) - - should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) - if should_validate: - torch.cuda.synchronize() - training_time_ms += 1000.0 * (time.perf_counter() - t0) - val_loss, val_bpb = eval_val( - args, - model, - rank, - world_size, - device, - grad_accum_steps, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, - ) - log0( - f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " - f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" - ) - torch.cuda.synchronize() - t0 = time.perf_counter() - - if last_step: - if stop_after_step is not None and step < args.iterations: - log0( - f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " - f"step:{step}/{args.iterations}" - ) - break - - elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - scale = lr_mul(step, elapsed_ms) - zero_grad_all() - train_loss = torch.zeros((), device=device) - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - loss = model(x, y) - train_loss += loss.detach() - (loss * grad_scale).backward() - train_loss /= grad_accum_steps - - frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 - muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum - for group in optimizer_muon.param_groups: - group["momentum"] = muon_momentum - - for opt in optimizers: - for group in opt.param_groups: - group["lr"] = group["base_lr"] * scale - - if args.grad_clip_norm > 0: - torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) - for opt in optimizers: - opt.step() - zero_grad_all() - - step += 1 - approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - should_log_train = ( - args.train_log_every > 0 - and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) - ) - if should_log_train: - log0( - f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " - f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" - ) - - # Needed to sync whether we've reached the wallclock cap. - reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms - if distributed and max_wallclock_ms is not None: - reached_cap_tensor = torch.tensor(int(reached_cap), device=device) - dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) - reached_cap = bool(reached_cap_tensor.item()) - if stop_after_step is None and reached_cap: - stop_after_step = step - - log0( - f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " - f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" - ) - - # ----------------------------- - # SERIALIZATION + ROUNDTRIP VALIDATION - # ----------------------------- - # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce - # the compressed int8+zlib artifact and validate the round-tripped weights. - - if master_process: - torch.save(base_model.state_dict(), "final_model.pt") - model_bytes = os.path.getsize("final_model.pt") - code_bytes = len(code.encode("utf-8")) - log0(f"Serialized model: {model_bytes} bytes") - log0(f"Code size: {code_bytes} bytes") - log0(f"Total submission size: {model_bytes + code_bytes} bytes") - - quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) - quant_buf = io.BytesIO() - torch.save(quant_obj, quant_buf) - quant_raw = quant_buf.getvalue() - quant_blob = zlib.compress(quant_raw, level=9) - quant_raw_bytes = len(quant_raw) - if master_process: - with open("final_model.int8.ptz", "wb") as f: - f.write(quant_blob) - quant_file_bytes = os.path.getsize("final_model.int8.ptz") - code_bytes = len(code.encode("utf-8")) - ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) - log0( - f"Serialized model int8+zlib: {quant_file_bytes} bytes " - f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" - ) - log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") - - if distributed: - dist.barrier() - with open("final_model.int8.ptz", "rb") as f: - quant_blob_disk = f.read() - quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") - base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) - torch.cuda.synchronize() - t_qeval = time.perf_counter() - q_val_loss, q_val_bpb = eval_val( - args, - model, - rank, - world_size, - device, - grad_accum_steps, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, - ) - torch.cuda.synchronize() - log0( - f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " - f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" - ) - log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") - - # LoRA test-time training evaluation (the competition score) - torch._dynamo.reset() - torch.cuda.synchronize() - t_ttt = time.perf_counter() - ttt_val_loss, ttt_val_bpb = eval_val_ttt_lora( - args, base_model, rank, world_size, device, - base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - ) - torch.cuda.synchronize() - log0( - f"final_int8_ttt_lora val_loss:{ttt_val_loss:.4f} val_bpb:{ttt_val_bpb:.4f} " - f"eval_time:{1000.0 * (time.perf_counter() - t_ttt):.0f}ms" - ) - - if distributed: - dist.destroy_process_group() - - -if __name__ == "__main__": - main() - -==================================================================================================== -Running Python 3.12.13 (main, Mar 10 2026, 18:17:25) [Clang 21.1.4 ] -Running PyTorch 2.10.0+cu128 -Thu Mar 19 10:17:46 2026 -+-----------------------------------------------------------------------------------------+ -| NVIDIA-SMI 570.211.01 Driver Version: 570.211.01 CUDA Version: 12.8 | -|-----------------------------------------+------------------------+----------------------+ -| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | -| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | -| | | MIG M. | -|=========================================+========================+======================| -| 0 NVIDIA H100 80GB HBM3 On | 00000000:04:00.0 Off | 0 | -| N/A 35C P0 119W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 1 NVIDIA H100 80GB HBM3 On | 00000000:05:00.0 Off | 0 | -| N/A 31C P0 115W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 2 NVIDIA H100 80GB HBM3 On | 00000000:0A:00.0 Off | 0 | -| N/A 35C P0 121W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 3 NVIDIA H100 80GB HBM3 On | 00000000:0B:00.0 Off | 0 | -| N/A 33C P0 119W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 4 NVIDIA H100 80GB HBM3 On | 00000000:84:00.0 Off | 0 | -| N/A 36C P0 119W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 5 NVIDIA H100 80GB HBM3 On | 00000000:85:00.0 Off | 0 | -| N/A 32C P0 115W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 6 NVIDIA H100 80GB HBM3 On | 00000000:8A:00.0 Off | 0 | -| N/A 33C P0 115W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 7 NVIDIA H100 80GB HBM3 On | 00000000:8B:00.0 Off | 0 | -| N/A 32C P0 115W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ - -+-----------------------------------------------------------------------------------------+ -| Processes: | -| GPU GI CI PID Type Process name GPU Memory | -| ID ID Usage | -|=========================================================================================| -| 0 N/A N/A 34383 C ...ai-codegolf/.venv/bin/python3 1510MiB | -| 1 N/A N/A 34384 C ...ai-codegolf/.venv/bin/python3 1510MiB | -| 2 N/A N/A 34385 C ...ai-codegolf/.venv/bin/python3 1510MiB | -| 3 N/A N/A 34386 C ...ai-codegolf/.venv/bin/python3 1510MiB | -| 4 N/A N/A 34387 C ...ai-codegolf/.venv/bin/python3 1510MiB | -| 5 N/A N/A 34388 C ...ai-codegolf/.venv/bin/python3 1510MiB | -| 6 N/A N/A 34389 C ...ai-codegolf/.venv/bin/python3 1510MiB | -| 7 N/A N/A 34390 C ...ai-codegolf/.venv/bin/python3 1510MiB | -+-----------------------------------------------------------------------------------------+ - -==================================================================================================== -val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model -train_loader:dataset:fineweb10B_sp1024 train_shards:25 -val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 -model_params:17059912 -world_size:8 grad_accum_steps:1 -sdp_backends:cudnn=False flash=True mem_efficient=False math=False -attention_mode:gqa num_heads:8 num_kv_heads:4 -tie_embeddings:True embed_lr:0.05 head_lr:0.0 matrix_lr:0.04 scalar_lr:0.04 -train_batch_tokens:524288 train_seq_len:1024 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 -seed:1337 -warmup_step:1/20 -warmup_step:2/20 -warmup_step:3/20 -warmup_step:4/20 -warmup_step:5/20 -warmup_step:6/20 -warmup_step:7/20 -warmup_step:8/20 -warmup_step:9/20 -warmup_step:10/20 -warmup_step:11/20 -warmup_step:12/20 -warmup_step:13/20 -warmup_step:14/20 -warmup_step:15/20 -warmup_step:16/20 -warmup_step:17/20 -warmup_step:18/20 -warmup_step:19/20 -warmup_step:20/20 -step:0/20000 val_loss:6.9357 val_bpb:4.1077 train_time:0ms step_avg:0.01ms -step:1/20000 train_loss:6.9370 train_time:24ms step_avg:24.05ms -step:2/20000 train_loss:16.8366 train_time:67ms step_avg:33.26ms -step:3/20000 train_loss:8.7610 train_time:109ms step_avg:36.40ms -step:4/20000 train_loss:6.6384 train_time:152ms step_avg:38.07ms -step:5/20000 train_loss:6.6119 train_time:196ms step_avg:39.12ms -step:6/20000 train_loss:7.4221 train_time:239ms step_avg:39.82ms -step:7/20000 train_loss:6.3502 train_time:284ms step_avg:40.50ms -step:8/20000 train_loss:6.1581 train_time:326ms step_avg:40.77ms -step:9/20000 train_loss:6.0679 train_time:370ms step_avg:41.08ms -step:10/20000 train_loss:5.9747 train_time:413ms step_avg:41.32ms -step:50/20000 train_loss:4.0981 train_time:2146ms step_avg:42.93ms -step:100/20000 train_loss:3.4127 train_time:4311ms step_avg:43.11ms -step:150/20000 train_loss:3.0586 train_time:6474ms step_avg:43.16ms -step:200/20000 train_loss:2.8472 train_time:8711ms step_avg:43.55ms -step:200/20000 val_loss:2.8382 val_bpb:1.6810 train_time:8737ms step_avg:43.69ms -step:250/20000 train_loss:2.7512 train_time:10877ms step_avg:43.51ms -step:300/20000 train_loss:2.4788 train_time:13043ms step_avg:43.48ms -step:350/20000 train_loss:2.6744 train_time:15220ms step_avg:43.48ms -step:400/20000 train_loss:2.3620 train_time:17456ms step_avg:43.64ms -step:400/20000 val_loss:2.5653 val_bpb:1.5193 train_time:17482ms step_avg:43.71ms -step:450/20000 train_loss:2.5101 train_time:19618ms step_avg:43.59ms -step:500/20000 train_loss:2.5009 train_time:21779ms step_avg:43.56ms -step:550/20000 train_loss:2.4061 train_time:23941ms step_avg:43.53ms -step:600/20000 train_loss:2.5474 train_time:26171ms step_avg:43.62ms -step:600/20000 val_loss:2.4532 val_bpb:1.4529 train_time:26197ms step_avg:43.66ms -step:650/20000 train_loss:2.3861 train_time:28335ms step_avg:43.59ms -step:700/20000 train_loss:2.4381 train_time:30496ms step_avg:43.57ms -step:750/20000 train_loss:2.2819 train_time:32658ms step_avg:43.54ms -step:800/20000 train_loss:2.2909 train_time:34902ms step_avg:43.63ms -step:800/20000 val_loss:2.3822 val_bpb:1.4108 train_time:34928ms step_avg:43.66ms -step:850/20000 train_loss:2.7163 train_time:37064ms step_avg:43.60ms -step:900/20000 train_loss:2.3387 train_time:39226ms step_avg:43.58ms -step:950/20000 train_loss:2.4032 train_time:41385ms step_avg:43.56ms -step:1000/20000 train_loss:2.3745 train_time:43619ms step_avg:43.62ms -step:1000/20000 val_loss:2.3353 val_bpb:1.3831 train_time:43645ms step_avg:43.65ms -step:1050/20000 train_loss:2.4839 train_time:45781ms step_avg:43.60ms -step:1100/20000 train_loss:2.2656 train_time:47941ms step_avg:43.58ms -step:1150/20000 train_loss:2.2568 train_time:50179ms step_avg:43.63ms -step:1200/20000 train_loss:2.3842 train_time:52339ms step_avg:43.62ms -step:1200/20000 val_loss:2.3033 val_bpb:1.3641 train_time:52365ms step_avg:43.64ms -step:1250/20000 train_loss:2.2091 train_time:54500ms step_avg:43.60ms -step:1300/20000 train_loss:2.3569 train_time:56660ms step_avg:43.58ms -step:1350/20000 train_loss:2.2731 train_time:58892ms step_avg:43.62ms -step:1400/20000 train_loss:2.4328 train_time:61053ms step_avg:43.61ms -step:1400/20000 val_loss:2.2820 val_bpb:1.3516 train_time:61079ms step_avg:43.63ms -step:1450/20000 train_loss:2.2351 train_time:63214ms step_avg:43.60ms -step:1500/20000 train_loss:2.2253 train_time:65376ms step_avg:43.58ms -step:1550/20000 train_loss:2.1577 train_time:67620ms step_avg:43.63ms -step:1600/20000 train_loss:2.1028 train_time:69781ms step_avg:43.61ms -step:1600/20000 val_loss:2.2677 val_bpb:1.3431 train_time:69807ms step_avg:43.63ms -step:1650/20000 train_loss:2.2339 train_time:71944ms step_avg:43.60ms -step:1700/20000 train_loss:2.1725 train_time:74104ms step_avg:43.59ms -step:1750/20000 train_loss:2.2485 train_time:76333ms step_avg:43.62ms -step:1800/20000 train_loss:2.1978 train_time:78493ms step_avg:43.61ms -step:1800/20000 val_loss:2.2520 val_bpb:1.3337 train_time:78519ms step_avg:43.62ms -step:1850/20000 train_loss:2.3073 train_time:80656ms step_avg:43.60ms -step:1900/20000 train_loss:2.1927 train_time:82817ms step_avg:43.59ms -step:1950/20000 train_loss:2.2128 train_time:85048ms step_avg:43.61ms -step:2000/20000 train_loss:2.2480 train_time:87209ms step_avg:43.60ms -step:2000/20000 val_loss:2.2362 val_bpb:1.3244 train_time:87235ms step_avg:43.62ms -step:2050/20000 train_loss:2.2551 train_time:89371ms step_avg:43.60ms -step:2100/20000 train_loss:2.2680 train_time:91606ms step_avg:43.62ms -step:2150/20000 train_loss:2.1907 train_time:93765ms step_avg:43.61ms -step:2200/20000 train_loss:2.0748 train_time:95924ms step_avg:43.60ms -step:2200/20000 val_loss:2.2283 val_bpb:1.3197 train_time:95951ms step_avg:43.61ms -step:2250/20000 train_loss:2.1622 train_time:98088ms step_avg:43.59ms -step:2300/20000 train_loss:2.3789 train_time:100334ms step_avg:43.62ms -step:2350/20000 train_loss:2.1982 train_time:102494ms step_avg:43.61ms -step:2400/20000 train_loss:2.2039 train_time:104656ms step_avg:43.61ms -step:2400/20000 val_loss:2.2179 val_bpb:1.3136 train_time:104682ms step_avg:43.62ms -step:2450/20000 train_loss:2.2033 train_time:106818ms step_avg:43.60ms -step:2500/20000 train_loss:2.1229 train_time:109053ms step_avg:43.62ms -step:2550/20000 train_loss:2.1382 train_time:111215ms step_avg:43.61ms -step:2600/20000 train_loss:2.4142 train_time:113375ms step_avg:43.61ms -step:2600/20000 val_loss:2.2189 val_bpb:1.3141 train_time:113401ms step_avg:43.62ms -step:2650/20000 train_loss:2.2441 train_time:115539ms step_avg:43.60ms -step:2700/20000 train_loss:2.1538 train_time:117791ms step_avg:43.63ms -step:2750/20000 train_loss:2.3638 train_time:119952ms step_avg:43.62ms -step:2800/20000 train_loss:2.2326 train_time:122113ms step_avg:43.61ms -step:2800/20000 val_loss:2.2024 val_bpb:1.3044 train_time:122139ms step_avg:43.62ms -step:2850/20000 train_loss:2.1876 train_time:124275ms step_avg:43.61ms -step:2900/20000 train_loss:2.1816 train_time:126510ms step_avg:43.62ms -step:2950/20000 train_loss:2.2381 train_time:128671ms step_avg:43.62ms -step:3000/20000 train_loss:2.2290 train_time:130832ms step_avg:43.61ms -step:3000/20000 val_loss:2.1954 val_bpb:1.3002 train_time:130858ms step_avg:43.62ms -step:3050/20000 train_loss:2.1750 train_time:132995ms step_avg:43.61ms -step:3100/20000 train_loss:2.2081 train_time:135227ms step_avg:43.62ms -step:3150/20000 train_loss:2.1601 train_time:137390ms step_avg:43.62ms -step:3200/20000 train_loss:2.1897 train_time:139551ms step_avg:43.61ms -step:3200/20000 val_loss:2.1901 val_bpb:1.2971 train_time:139577ms step_avg:43.62ms -step:3250/20000 train_loss:2.0892 train_time:141783ms step_avg:43.63ms -step:3300/20000 train_loss:2.2393 train_time:143945ms step_avg:43.62ms -step:3350/20000 train_loss:2.0947 train_time:146104ms step_avg:43.61ms -step:3400/20000 train_loss:2.1618 train_time:148263ms step_avg:43.61ms -step:3400/20000 val_loss:2.1873 val_bpb:1.2955 train_time:148289ms step_avg:43.61ms -step:3450/20000 train_loss:2.1101 train_time:150509ms step_avg:43.63ms -step:3500/20000 train_loss:2.2549 train_time:152670ms step_avg:43.62ms -step:3550/20000 train_loss:2.3879 train_time:154829ms step_avg:43.61ms -step:3600/20000 train_loss:2.1151 train_time:156990ms step_avg:43.61ms -step:3600/20000 val_loss:2.1790 val_bpb:1.2905 train_time:157016ms step_avg:43.62ms -step:3650/20000 train_loss:2.2190 train_time:159223ms step_avg:43.62ms -step:3700/20000 train_loss:2.1509 train_time:161384ms step_avg:43.62ms -step:3750/20000 train_loss:2.1447 train_time:163543ms step_avg:43.61ms -step:3800/20000 train_loss:2.2250 train_time:165702ms step_avg:43.61ms -step:3800/20000 val_loss:2.1761 val_bpb:1.2888 train_time:165728ms step_avg:43.61ms -step:3850/20000 train_loss:2.1758 train_time:167942ms step_avg:43.62ms -step:3900/20000 train_loss:1.9878 train_time:170102ms step_avg:43.62ms -step:3950/20000 train_loss:2.1273 train_time:172262ms step_avg:43.61ms -step:4000/20000 train_loss:2.1646 train_time:174422ms step_avg:43.61ms -step:4000/20000 val_loss:2.1713 val_bpb:1.2860 train_time:174448ms step_avg:43.61ms -step:4050/20000 train_loss:2.1020 train_time:176658ms step_avg:43.62ms -step:4100/20000 train_loss:2.1886 train_time:178818ms step_avg:43.61ms -step:4150/20000 train_loss:2.3235 train_time:180979ms step_avg:43.61ms -step:4200/20000 train_loss:2.1733 train_time:183219ms step_avg:43.62ms -step:4200/20000 val_loss:2.1669 val_bpb:1.2834 train_time:183245ms step_avg:43.63ms -step:4250/20000 train_loss:2.1272 train_time:185379ms step_avg:43.62ms -step:4300/20000 train_loss:2.0240 train_time:187539ms step_avg:43.61ms -step:4350/20000 train_loss:2.2121 train_time:189700ms step_avg:43.61ms -step:4400/20000 train_loss:2.1162 train_time:191931ms step_avg:43.62ms -step:4400/20000 val_loss:2.1671 val_bpb:1.2835 train_time:191957ms step_avg:43.63ms -step:4450/20000 train_loss:2.0625 train_time:194091ms step_avg:43.62ms -step:4500/20000 train_loss:2.2588 train_time:196251ms step_avg:43.61ms -step:4550/20000 train_loss:2.0573 train_time:198409ms step_avg:43.61ms -step:4600/20000 train_loss:1.9725 train_time:200640ms step_avg:43.62ms -step:4600/20000 val_loss:2.1626 val_bpb:1.2808 train_time:200665ms step_avg:43.62ms -step:4650/20000 train_loss:2.0772 train_time:202798ms step_avg:43.61ms -step:4700/20000 train_loss:2.2693 train_time:204958ms step_avg:43.61ms -step:4750/20000 train_loss:1.9775 train_time:207118ms step_avg:43.60ms -step:4800/20000 train_loss:2.1280 train_time:209346ms step_avg:43.61ms -step:4800/20000 val_loss:2.1577 val_bpb:1.2779 train_time:209372ms step_avg:43.62ms -step:4850/20000 train_loss:2.2137 train_time:211507ms step_avg:43.61ms -step:4900/20000 train_loss:2.4117 train_time:213667ms step_avg:43.61ms -step:4950/20000 train_loss:2.1671 train_time:215827ms step_avg:43.60ms -step:5000/20000 train_loss:2.1401 train_time:218057ms step_avg:43.61ms -step:5000/20000 val_loss:2.1541 val_bpb:1.2758 train_time:218083ms step_avg:43.62ms -step:5050/20000 train_loss:2.0841 train_time:220219ms step_avg:43.61ms -step:5100/20000 train_loss:2.0908 train_time:222380ms step_avg:43.60ms -step:5150/20000 train_loss:2.1483 train_time:224605ms step_avg:43.61ms -step:5200/20000 train_loss:2.2376 train_time:226762ms step_avg:43.61ms -step:5200/20000 val_loss:2.1515 val_bpb:1.2743 train_time:226788ms step_avg:43.61ms -step:5250/20000 train_loss:2.0875 train_time:228922ms step_avg:43.60ms -step:5300/20000 train_loss:2.2171 train_time:231081ms step_avg:43.60ms -step:5350/20000 train_loss:2.5601 train_time:233311ms step_avg:43.61ms -step:5400/20000 train_loss:2.2814 train_time:235470ms step_avg:43.61ms -step:5400/20000 val_loss:2.1498 val_bpb:1.2732 train_time:235496ms step_avg:43.61ms -step:5450/20000 train_loss:2.1625 train_time:237632ms step_avg:43.60ms -step:5500/20000 train_loss:2.1717 train_time:239792ms step_avg:43.60ms -step:5550/20000 train_loss:2.1832 train_time:242022ms step_avg:43.61ms -step:5600/20000 train_loss:2.1615 train_time:244183ms step_avg:43.60ms -step:5600/20000 val_loss:2.1462 val_bpb:1.2711 train_time:244209ms step_avg:43.61ms -step:5650/20000 train_loss:2.1320 train_time:246344ms step_avg:43.60ms -step:5700/20000 train_loss:2.2610 train_time:248505ms step_avg:43.60ms -step:5750/20000 train_loss:2.0951 train_time:250740ms step_avg:43.61ms -step:5800/20000 train_loss:2.2384 train_time:252899ms step_avg:43.60ms -step:5800/20000 val_loss:2.1447 val_bpb:1.2702 train_time:252925ms step_avg:43.61ms -step:5850/20000 train_loss:2.3104 train_time:255061ms step_avg:43.60ms -step:5900/20000 train_loss:2.1449 train_time:257222ms step_avg:43.60ms -step:5950/20000 train_loss:2.0354 train_time:259454ms step_avg:43.61ms -step:6000/20000 train_loss:2.2121 train_time:261615ms step_avg:43.60ms -step:6000/20000 val_loss:2.1412 val_bpb:1.2682 train_time:261641ms step_avg:43.61ms -step:6050/20000 train_loss:2.0146 train_time:263777ms step_avg:43.60ms -step:6100/20000 train_loss:2.2950 train_time:265936ms step_avg:43.60ms -step:6150/20000 train_loss:1.9642 train_time:268174ms step_avg:43.61ms -step:6200/20000 train_loss:2.1078 train_time:270334ms step_avg:43.60ms -step:6200/20000 val_loss:2.1397 val_bpb:1.2673 train_time:270360ms step_avg:43.61ms -step:6250/20000 train_loss:2.1381 train_time:272500ms step_avg:43.60ms -step:6300/20000 train_loss:1.9401 train_time:274730ms step_avg:43.61ms -step:6350/20000 train_loss:2.1690 train_time:276891ms step_avg:43.60ms -step:6400/20000 train_loss:2.1152 train_time:279052ms step_avg:43.60ms -step:6400/20000 val_loss:2.1392 val_bpb:1.2670 train_time:279078ms step_avg:43.61ms -step:6450/20000 train_loss:2.1248 train_time:281214ms step_avg:43.60ms -step:6500/20000 train_loss:2.1134 train_time:283449ms step_avg:43.61ms -step:6550/20000 train_loss:2.0951 train_time:285610ms step_avg:43.60ms -step:6600/20000 train_loss:2.0080 train_time:287770ms step_avg:43.60ms -step:6600/20000 val_loss:2.1374 val_bpb:1.2659 train_time:287796ms step_avg:43.61ms -step:6650/20000 train_loss:2.2209 train_time:289932ms step_avg:43.60ms -step:6700/20000 train_loss:2.1401 train_time:292151ms step_avg:43.60ms -step:6750/20000 train_loss:2.1543 train_time:294311ms step_avg:43.60ms -step:6800/20000 train_loss:1.9578 train_time:296471ms step_avg:43.60ms -step:6800/20000 val_loss:2.1372 val_bpb:1.2658 train_time:296497ms step_avg:43.60ms -step:6850/20000 train_loss:2.0710 train_time:298632ms step_avg:43.60ms -step:6900/20000 train_loss:2.1392 train_time:300868ms step_avg:43.60ms -step:6950/20000 train_loss:2.0298 train_time:303028ms step_avg:43.60ms -step:7000/20000 train_loss:2.1949 train_time:305189ms step_avg:43.60ms -step:7000/20000 val_loss:2.1321 val_bpb:1.2627 train_time:305215ms step_avg:43.60ms -step:7050/20000 train_loss:2.0632 train_time:307351ms step_avg:43.60ms -step:7100/20000 train_loss:2.2315 train_time:309582ms step_avg:43.60ms -step:7150/20000 train_loss:2.1151 train_time:311742ms step_avg:43.60ms -step:7200/20000 train_loss:2.0344 train_time:313902ms step_avg:43.60ms -step:7200/20000 val_loss:2.1321 val_bpb:1.2628 train_time:313928ms step_avg:43.60ms -step:7250/20000 train_loss:2.0835 train_time:316128ms step_avg:43.60ms -step:7300/20000 train_loss:2.1824 train_time:318288ms step_avg:43.60ms -step:7350/20000 train_loss:2.2120 train_time:320449ms step_avg:43.60ms -step:7400/20000 train_loss:2.1412 train_time:322610ms step_avg:43.60ms -step:7400/20000 val_loss:2.1287 val_bpb:1.2607 train_time:322636ms step_avg:43.60ms -step:7450/20000 train_loss:2.1712 train_time:324837ms step_avg:43.60ms -step:7500/20000 train_loss:2.1289 train_time:326997ms step_avg:43.60ms -step:7550/20000 train_loss:2.1397 train_time:329158ms step_avg:43.60ms -step:7600/20000 train_loss:2.1551 train_time:331318ms step_avg:43.59ms -step:7600/20000 val_loss:2.1276 val_bpb:1.2601 train_time:331344ms step_avg:43.60ms -step:7650/20000 train_loss:2.1268 train_time:333556ms step_avg:43.60ms -step:7700/20000 train_loss:2.1523 train_time:335716ms step_avg:43.60ms -step:7750/20000 train_loss:2.2340 train_time:337875ms step_avg:43.60ms -step:7800/20000 train_loss:2.0813 train_time:340036ms step_avg:43.59ms -step:7800/20000 val_loss:2.1279 val_bpb:1.2602 train_time:340062ms step_avg:43.60ms -step:7850/20000 train_loss:2.1322 train_time:342261ms step_avg:43.60ms -step:7900/20000 train_loss:2.0901 train_time:344421ms step_avg:43.60ms -step:7950/20000 train_loss:2.1353 train_time:346581ms step_avg:43.60ms -step:8000/20000 train_loss:2.1558 train_time:348741ms step_avg:43.59ms -step:8000/20000 val_loss:2.1253 val_bpb:1.2587 train_time:348767ms step_avg:43.60ms -step:8050/20000 train_loss:2.1541 train_time:350979ms step_avg:43.60ms -step:8100/20000 train_loss:2.1860 train_time:353139ms step_avg:43.60ms -step:8150/20000 train_loss:2.0543 train_time:355298ms step_avg:43.59ms -step:8200/20000 train_loss:2.0280 train_time:357459ms step_avg:43.59ms -step:8200/20000 val_loss:2.1262 val_bpb:1.2593 train_time:357485ms step_avg:43.60ms -step:8250/20000 train_loss:2.0998 train_time:359690ms step_avg:43.60ms -step:8300/20000 train_loss:2.0534 train_time:361851ms step_avg:43.60ms -step:8350/20000 train_loss:2.1787 train_time:364010ms step_avg:43.59ms -step:8400/20000 train_loss:2.2065 train_time:366260ms step_avg:43.60ms -step:8400/20000 val_loss:2.1219 val_bpb:1.2567 train_time:366285ms step_avg:43.61ms -step:8450/20000 train_loss:2.1879 train_time:368419ms step_avg:43.60ms -step:8500/20000 train_loss:2.1337 train_time:370579ms step_avg:43.60ms -step:8550/20000 train_loss:2.1921 train_time:372739ms step_avg:43.60ms -step:8600/20000 train_loss:2.1191 train_time:374971ms step_avg:43.60ms -step:8600/20000 val_loss:2.1187 val_bpb:1.2548 train_time:374997ms step_avg:43.60ms -step:8650/20000 train_loss:2.0102 train_time:377133ms step_avg:43.60ms -step:8700/20000 train_loss:2.0759 train_time:379294ms step_avg:43.60ms -step:8750/20000 train_loss:2.1173 train_time:381453ms step_avg:43.59ms -step:8800/20000 train_loss:2.0618 train_time:383685ms step_avg:43.60ms -step:8800/20000 val_loss:2.1176 val_bpb:1.2541 train_time:383710ms step_avg:43.60ms -step:8850/20000 train_loss:2.0558 train_time:385844ms step_avg:43.60ms -step:8900/20000 train_loss:2.1128 train_time:388005ms step_avg:43.60ms -step:8950/20000 train_loss:2.1617 train_time:390164ms step_avg:43.59ms -step:9000/20000 train_loss:2.3229 train_time:392396ms step_avg:43.60ms -step:9000/20000 val_loss:2.1165 val_bpb:1.2535 train_time:392422ms step_avg:43.60ms -step:9050/20000 train_loss:2.1856 train_time:394557ms step_avg:43.60ms -step:9100/20000 train_loss:2.0103 train_time:396717ms step_avg:43.60ms -step:9150/20000 train_loss:2.2225 train_time:398876ms step_avg:43.59ms -step:9200/20000 train_loss:2.2887 train_time:401107ms step_avg:43.60ms -step:9200/20000 val_loss:2.1155 val_bpb:1.2529 train_time:401133ms step_avg:43.60ms -step:9250/20000 train_loss:2.1497 train_time:403267ms step_avg:43.60ms -step:9300/20000 train_loss:2.3689 train_time:405428ms step_avg:43.59ms -step:9350/20000 train_loss:2.1751 train_time:407655ms step_avg:43.60ms -step:9400/20000 train_loss:1.9253 train_time:409813ms step_avg:43.60ms -step:9400/20000 val_loss:2.1164 val_bpb:1.2535 train_time:409839ms step_avg:43.60ms -step:9450/20000 train_loss:2.0587 train_time:411973ms step_avg:43.60ms -step:9500/20000 train_loss:2.1604 train_time:414132ms step_avg:43.59ms -step:9550/20000 train_loss:2.2149 train_time:416364ms step_avg:43.60ms -step:9600/20000 train_loss:2.0156 train_time:418523ms step_avg:43.60ms -step:9600/20000 val_loss:2.1152 val_bpb:1.2527 train_time:418549ms step_avg:43.60ms -step:9650/20000 train_loss:2.0649 train_time:420684ms step_avg:43.59ms -step:9700/20000 train_loss:2.1622 train_time:422842ms step_avg:43.59ms -step:9750/20000 train_loss:2.1573 train_time:425066ms step_avg:43.60ms -step:9800/20000 train_loss:2.0715 train_time:427227ms step_avg:43.59ms -step:9800/20000 val_loss:2.1113 val_bpb:1.2504 train_time:427253ms step_avg:43.60ms -step:9850/20000 train_loss:2.1412 train_time:429389ms step_avg:43.59ms -step:9900/20000 train_loss:2.0139 train_time:431550ms step_avg:43.59ms -step:9950/20000 train_loss:2.1478 train_time:433776ms step_avg:43.60ms -step:10000/20000 train_loss:2.0268 train_time:435937ms step_avg:43.59ms -step:10000/20000 val_loss:2.1136 val_bpb:1.2518 train_time:435963ms step_avg:43.60ms -step:10050/20000 train_loss:2.0426 train_time:438097ms step_avg:43.59ms -step:10100/20000 train_loss:2.1084 train_time:440257ms step_avg:43.59ms -step:10150/20000 train_loss:2.1516 train_time:442485ms step_avg:43.59ms -step:10200/20000 train_loss:2.1381 train_time:444644ms step_avg:43.59ms -step:10200/20000 val_loss:2.1113 val_bpb:1.2504 train_time:444670ms step_avg:43.60ms -step:10250/20000 train_loss:2.0997 train_time:446805ms step_avg:43.59ms -step:10300/20000 train_loss:2.0006 train_time:449044ms step_avg:43.60ms -step:10350/20000 train_loss:1.9698 train_time:451202ms step_avg:43.59ms -step:10400/20000 train_loss:2.1025 train_time:453362ms step_avg:43.59ms -step:10400/20000 val_loss:2.1098 val_bpb:1.2496 train_time:453388ms step_avg:43.60ms -step:10450/20000 train_loss:1.9823 train_time:455524ms step_avg:43.59ms -step:10500/20000 train_loss:2.0548 train_time:457751ms step_avg:43.60ms -step:10550/20000 train_loss:2.1259 train_time:459909ms step_avg:43.59ms -step:10600/20000 train_loss:2.0770 train_time:462070ms step_avg:43.59ms -step:10600/20000 val_loss:2.1089 val_bpb:1.2490 train_time:462096ms step_avg:43.59ms -step:10650/20000 train_loss:2.1330 train_time:464231ms step_avg:43.59ms -step:10700/20000 train_loss:2.2574 train_time:466480ms step_avg:43.60ms -step:10750/20000 train_loss:2.0062 train_time:468640ms step_avg:43.59ms -step:10800/20000 train_loss:2.1240 train_time:470799ms step_avg:43.59ms -step:10800/20000 val_loss:2.1082 val_bpb:1.2486 train_time:470825ms step_avg:43.59ms -step:10850/20000 train_loss:2.2353 train_time:472960ms step_avg:43.59ms -step:10900/20000 train_loss:2.1910 train_time:475190ms step_avg:43.60ms -step:10950/20000 train_loss:2.0309 train_time:477349ms step_avg:43.59ms -step:11000/20000 train_loss:2.1029 train_time:479511ms step_avg:43.59ms -step:11000/20000 val_loss:2.1086 val_bpb:1.2488 train_time:479537ms step_avg:43.59ms -step:11050/20000 train_loss:2.1257 train_time:481674ms step_avg:43.59ms -step:11100/20000 train_loss:2.1027 train_time:483913ms step_avg:43.60ms -step:11150/20000 train_loss:2.0705 train_time:486074ms step_avg:43.59ms -step:11200/20000 train_loss:2.1360 train_time:488234ms step_avg:43.59ms -step:11200/20000 val_loss:2.1064 val_bpb:1.2475 train_time:488260ms step_avg:43.59ms -step:11250/20000 train_loss:2.1185 train_time:490395ms step_avg:43.59ms -step:11300/20000 train_loss:2.0835 train_time:492623ms step_avg:43.59ms -step:11350/20000 train_loss:2.0724 train_time:494783ms step_avg:43.59ms -step:11400/20000 train_loss:2.2232 train_time:496942ms step_avg:43.59ms -step:11400/20000 val_loss:2.1052 val_bpb:1.2468 train_time:496968ms step_avg:43.59ms -step:11450/20000 train_loss:2.1210 train_time:499185ms step_avg:43.60ms -step:11500/20000 train_loss:2.1043 train_time:501344ms step_avg:43.60ms -step:11550/20000 train_loss:2.1288 train_time:503506ms step_avg:43.59ms -step:11600/20000 train_loss:2.1453 train_time:505668ms step_avg:43.59ms -step:11600/20000 val_loss:2.1048 val_bpb:1.2466 train_time:505694ms step_avg:43.59ms -step:11650/20000 train_loss:2.1545 train_time:507895ms step_avg:43.60ms -step:11700/20000 train_loss:2.0164 train_time:510057ms step_avg:43.59ms -step:11750/20000 train_loss:2.0665 train_time:512218ms step_avg:43.59ms -step:11800/20000 train_loss:2.0242 train_time:514379ms step_avg:43.59ms -step:11800/20000 val_loss:2.1053 val_bpb:1.2469 train_time:514405ms step_avg:43.59ms -step:11850/20000 train_loss:2.1169 train_time:516605ms step_avg:43.60ms -step:11900/20000 train_loss:2.2988 train_time:518766ms step_avg:43.59ms -step:11950/20000 train_loss:2.1103 train_time:520926ms step_avg:43.59ms -step:12000/20000 train_loss:2.0714 train_time:523087ms step_avg:43.59ms -step:12000/20000 val_loss:2.1032 val_bpb:1.2457 train_time:523113ms step_avg:43.59ms -step:12050/20000 train_loss:2.1956 train_time:525311ms step_avg:43.59ms -step:12100/20000 train_loss:2.0917 train_time:527470ms step_avg:43.59ms -step:12150/20000 train_loss:2.1263 train_time:529631ms step_avg:43.59ms -step:12200/20000 train_loss:2.1623 train_time:531790ms step_avg:43.59ms -step:12200/20000 val_loss:2.1023 val_bpb:1.2451 train_time:531816ms step_avg:43.59ms -step:12250/20000 train_loss:2.0889 train_time:534028ms step_avg:43.59ms -step:12300/20000 train_loss:1.9699 train_time:536188ms step_avg:43.59ms -step:12350/20000 train_loss:2.0895 train_time:538348ms step_avg:43.59ms -step:12400/20000 train_loss:2.1525 train_time:540575ms step_avg:43.59ms -step:12400/20000 val_loss:2.1014 val_bpb:1.2445 train_time:540601ms step_avg:43.60ms -step:12450/20000 train_loss:2.1547 train_time:542736ms step_avg:43.59ms -step:12500/20000 train_loss:2.1525 train_time:544895ms step_avg:43.59ms -step:12550/20000 train_loss:2.1029 train_time:547056ms step_avg:43.59ms -step:12600/20000 train_loss:2.3650 train_time:549281ms step_avg:43.59ms -step:12600/20000 val_loss:2.1001 val_bpb:1.2438 train_time:549307ms step_avg:43.60ms -step:12650/20000 train_loss:1.9888 train_time:551444ms step_avg:43.59ms -step:12700/20000 train_loss:2.0505 train_time:553604ms step_avg:43.59ms -step:12750/20000 train_loss:2.0902 train_time:555765ms step_avg:43.59ms -step:12800/20000 train_loss:2.1704 train_time:557995ms step_avg:43.59ms -step:12800/20000 val_loss:2.0948 val_bpb:1.2407 train_time:558021ms step_avg:43.60ms -step:12850/20000 train_loss:1.9939 train_time:560156ms step_avg:43.59ms -step:12900/20000 train_loss:2.4728 train_time:562317ms step_avg:43.59ms -step:12950/20000 train_loss:2.0365 train_time:564477ms step_avg:43.59ms -step:13000/20000 train_loss:2.1251 train_time:566719ms step_avg:43.59ms -step:13000/20000 val_loss:2.0865 val_bpb:1.2357 train_time:566745ms step_avg:43.60ms -step:13050/20000 train_loss:2.0270 train_time:568879ms step_avg:43.59ms -step:13100/20000 train_loss:1.9699 train_time:571041ms step_avg:43.59ms -step:13150/20000 train_loss:2.1015 train_time:573200ms step_avg:43.59ms -step:13200/20000 train_loss:2.1239 train_time:575429ms step_avg:43.59ms -step:13200/20000 val_loss:2.0785 val_bpb:1.2310 train_time:575455ms step_avg:43.60ms -step:13250/20000 train_loss:2.1207 train_time:577589ms step_avg:43.59ms -step:13300/20000 train_loss:2.1447 train_time:579749ms step_avg:43.59ms -step:13350/20000 train_loss:2.1591 train_time:581910ms step_avg:43.59ms -step:13400/20000 train_loss:2.1473 train_time:584147ms step_avg:43.59ms -step:13400/20000 val_loss:2.0717 val_bpb:1.2270 train_time:584173ms step_avg:43.59ms -step:13450/20000 train_loss:2.1372 train_time:586307ms step_avg:43.59ms -step:13500/20000 train_loss:1.9835 train_time:588468ms step_avg:43.59ms -step:13550/20000 train_loss:2.0803 train_time:590696ms step_avg:43.59ms -step:13600/20000 train_loss:2.0473 train_time:592857ms step_avg:43.59ms -step:13600/20000 val_loss:2.0642 val_bpb:1.2225 train_time:592883ms step_avg:43.59ms -step:13650/20000 train_loss:2.0700 train_time:595018ms step_avg:43.59ms -step:13700/20000 train_loss:2.1010 train_time:597179ms step_avg:43.59ms -step:13750/20000 train_loss:2.0827 train_time:599431ms step_avg:43.59ms -step:13764/20000 val_loss:2.0598 val_bpb:1.2199 train_time:600061ms step_avg:43.60ms -stopping_early: wallclock_cap train_time:600061ms step:13764/20000 -peak memory allocated: 10184 MiB reserved: 10624 MiB -Serialized model: 67224983 bytes -Code size: 58380 bytes -Total submission size: 67283363 bytes -Serialized model int8+zlib: 15814468 bytes (payload:17178912 raw_torch:17224025 payload_ratio:3.91x) -Total submission size int8+zlib: 15872848 bytes -final_int8_zlib_roundtrip val_loss:2.0741 val_bpb:1.2284 eval_time:5599ms -final_int8_zlib_roundtrip_exact val_loss:2.07405347 val_bpb:1.22837129 -final_ttt_lora val_loss:2.0153 val_bpb:1.1935 eval_time:72361ms diff --git a/records/track_10min_16mb/2026-03-17_LoRA_TTT/train_v1.txt b/records/track_10min_16mb/2026-03-17_LoRA_TTT/train_v1.txt deleted file mode 100644 index 523f81ae26..0000000000 --- a/records/track_10min_16mb/2026-03-17_LoRA_TTT/train_v1.txt +++ /dev/null @@ -1,1809 +0,0 @@ -""" -The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. - -Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. -""" - -from __future__ import annotations - -import copy -import glob -import io -import math -import os -import random -import subprocess -import sys -import time -import uuid -import zlib -from pathlib import Path - -import numpy as np -import sentencepiece as spm -import torch -import torch.distributed as dist -import torch.nn.functional as F -from torch import Tensor, nn -from torch.nn.parallel import DistributedDataParallel as DDP - -# ----------------------------- -# HYPERPARAMETERS -# ----------------------------- -# Default Simple Baseline run: -# - 9 transformer blocks at width 512 -# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion -# - vocab size 1024, sequence length 1024, tied embeddings -# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap - -class Hyperparameters: - # Data paths are shard globs produced by the existing preprocessing pipeline. - data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") - train_files = os.path.join(data_path, "fineweb_train_*.bin") - val_files = os.path.join(data_path, "fineweb_val_*.bin") - tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") - run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) - seed = int(os.environ.get("SEED", 1337)) - resume_from = os.environ.get("RESUME_FROM", "") - - # Validation cadence and batch size. Validation always uses the full fineweb_val split. - val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) - val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) - train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) - - # Training length. - iterations = int(os.environ.get("ITERATIONS", 20000)) - warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) - warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) - train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) - train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) - max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) - qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) - - # Model shape. - vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) - num_layers = int(os.environ.get("NUM_LAYERS", 9)) - num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) - model_dim = int(os.environ.get("MODEL_DIM", 512)) - num_heads = int(os.environ.get("NUM_HEADS", 8)) - mlp_mult = int(os.environ.get("MLP_MULT", 2)) - tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) - rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) - logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) - - # Optimizer hyperparameters. - embed_lr = float(os.environ.get("EMBED_LR", 0.6)) - head_lr = float(os.environ.get("HEAD_LR", 0.008)) - tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) - tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) - matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) - scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) - muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) - muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) - muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) - muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) - beta1 = float(os.environ.get("BETA1", 0.9)) - beta2 = float(os.environ.get("BETA2", 0.95)) - adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) - grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) - - # Test-time training (LoRA) hyperparameters. - ttt_lora_rank = int(os.environ.get("TTT_LORA_RANK", 8)) - ttt_lora_lr = float(os.environ.get("TTT_LORA_LR", 0.01)) - ttt_chunk_size = int(os.environ.get("TTT_CHUNK_SIZE", 256)) - ttt_eval_seq_len = int(os.environ.get("TTT_EVAL_SEQ_LEN", 1024)) - ttt_batch_size = int(os.environ.get("TTT_BATCH_SIZE", 64)) - -# ----------------------------- -# MUON OPTIMIZER -# ----------------------------- -# -# As borrowed from modded-nanogpt -# Background on Muon: https://kellerjordan.github.io/posts/muon/ - -def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: - # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. - # Muon uses this to normalize matrix-shaped gradients before applying them. - a, b, c = (3.4445, -4.7750, 2.0315) - X = G.bfloat16() - X /= X.norm() + eps - transposed = G.size(0) > G.size(1) - if transposed: - X = X.T - for _ in range(steps): - A = X @ X.T - B = b * A + c * A @ A - X = a * X + B @ X - return X.T if transposed else X - - -class Muon(torch.optim.Optimizer): - def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): - super().__init__( - params, - dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), - ) - - @torch.no_grad() - def step(self, closure=None): - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - distributed = dist.is_available() and dist.is_initialized() - world_size = dist.get_world_size() if distributed else 1 - rank = dist.get_rank() if distributed else 0 - - for group in self.param_groups: - params = group["params"] - if not params: - continue - lr = group["lr"] - momentum = group["momentum"] - backend_steps = group["backend_steps"] - nesterov = group["nesterov"] - - total_params = sum(int(p.numel()) for p in params) - updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) - - curr = 0 - for i, p in enumerate(params): - if i % world_size == rank and p.grad is not None: - g = p.grad - state = self.state[p] - if "momentum_buffer" not in state: - state["momentum_buffer"] = torch.zeros_like(g) - buf = state["momentum_buffer"] - buf.mul_(momentum).add_(g) - if nesterov: - g = g.add(buf, alpha=momentum) - g = zeropower_via_newtonschulz5(g, steps=backend_steps) - # Scale correction from Muon reference implementations. - g *= max(1, g.size(0) / g.size(1)) ** 0.5 - updates_flat[curr : curr + p.numel()] = g.reshape(-1) - curr += p.numel() - - if distributed: - dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) - - curr = 0 - for p in params: - g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) - p.add_(g, alpha=-lr) - curr += p.numel() - - return loss - - -# ----------------------------- -# TOKENIZER-AGNOSTIC EVALUATION SETUP -# ----------------------------- -# -# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. -# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. -# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. -# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. - -def build_sentencepiece_luts( - sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device -) -> tuple[Tensor, Tensor, Tensor]: - sp_vocab_size = int(sp.vocab_size()) - table_size = max(sp_vocab_size, vocab_size) - base_bytes_np = np.zeros((table_size,), dtype=np.int16) - has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) - is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) - for token_id in range(sp_vocab_size): - if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): - continue - is_boundary_token_np[token_id] = False - if sp.is_byte(token_id): - base_bytes_np[token_id] = 1 - continue - piece = sp.id_to_piece(token_id) - if piece.startswith("▁"): - has_leading_space_np[token_id] = True - piece = piece[1:] - base_bytes_np[token_id] = len(piece.encode("utf-8")) - return ( - torch.tensor(base_bytes_np, dtype=torch.int16, device=device), - torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), - torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), - ) - - -def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: - files = [Path(p) for p in sorted(glob.glob(pattern))] - if not files: - raise FileNotFoundError(f"No files found for pattern: {pattern}") - # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. - tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() - usable = ((tokens.numel() - 1) // seq_len) * seq_len - if usable <= 0: - raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") - return tokens[: usable + 1] - - -def eval_val( - args: Hyperparameters, - model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - grad_accum_steps: int, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, -) -> tuple[float, float]: - # Validation computes two metrics: - # - val_loss: token cross-entropy (natural log) - # - val_bpb: tokenizer-agnostic compression metric used by the challenge - local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) - if local_batch_tokens < args.train_seq_len: - raise ValueError( - "VAL_BATCH_SIZE must provide at least one sequence per rank; " - f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " - f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" - ) - local_batch_seqs = local_batch_tokens // args.train_seq_len - total_seqs = (val_tokens.numel() - 1) // args.train_seq_len - seq_start = (total_seqs * rank) // world_size - seq_end = (total_seqs * (rank + 1)) // world_size - val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) - val_token_count = torch.zeros((), device=device, dtype=torch.float64) - val_byte_count = torch.zeros((), device=device, dtype=torch.float64) - - model.eval() - with torch.inference_mode(): - for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): - batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) - raw_start = batch_seq_start * args.train_seq_len - raw_end = batch_seq_end * args.train_seq_len + 1 - local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) - x = local[:-1].reshape(-1, args.train_seq_len) - y = local[1:].reshape(-1, args.train_seq_len) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - batch_loss = model(x, y).detach() - batch_token_count = float(y.numel()) - val_loss_sum += batch_loss.to(torch.float64) * batch_token_count - val_token_count += batch_token_count - prev_ids = x.reshape(-1) - tgt_ids = y.reshape(-1) - token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) - token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) - val_byte_count += token_bytes.to(torch.float64).sum() - - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) - - val_loss = val_loss_sum / val_token_count - bits_per_token = val_loss.item() / math.log(2.0) - tokens_per_byte = val_token_count.item() / val_byte_count.item() - model.train() - return float(val_loss.item()), float(bits_per_token * tokens_per_byte) - -# ----------------------------- -# POST-TRAINING QUANTIZATION -# ----------------------------- -# -# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. -# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. -# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. - -CONTROL_TENSOR_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "CONTROL_TENSOR_NAME_PATTERNS", - "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", - ",".join(CONTROL_TENSOR_NAME_PATTERNS), - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 -INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 -INT8_PER_ROW_SCALE_DTYPE = torch.float16 -INT8_CLIP_PERCENTILE = 99.99984 -INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 - -def tensor_nbytes(t: Tensor) -> int: - return int(t.numel()) * int(t.element_size()) - -def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: - if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): - return t.float().contiguous() - if t.dtype in {torch.float32, torch.bfloat16}: - passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") - return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() - return t - -def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: - t32 = t.float() - if t32.ndim == 2: - # Matrices get one scale per row, which usually tracks output-channel - # ranges much better than a single tensor-wide scale. - clip_abs = ( - torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) - if t32.numel() - else torch.empty((t32.shape[0],), dtype=torch.float32) - ) - clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) - scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) - q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() - return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() - - # Vectors / scalars use a simpler per-tensor scale. - clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 - scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) - q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() - return q, scale - -def quantize_state_dict_int8(state_dict: dict[str, Tensor]): - # Single supported clean-script export format: - # - per-row int8 for 2D float tensors - # - per-tensor int8 for other float tensors - # - exact passthrough for non-floats - # - passthrough for small float tensors, stored as fp16 to save bytes - quantized: dict[str, Tensor] = {} - scales: dict[str, Tensor] = {} - dtypes: dict[str, str] = {} - passthrough: dict[str, Tensor] = {} - passthrough_orig_dtypes: dict[str, str] = {} - qmeta: dict[str, dict[str, object]] = {} - stats = dict.fromkeys( - ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), - 0, - ) - - for name, tensor in state_dict.items(): - t = tensor.detach().to("cpu").contiguous() - stats["param_count"] += int(t.numel()) - stats["num_tensors"] += 1 - stats["baseline_tensor_bytes"] += tensor_nbytes(t) - - if not t.is_floating_point(): - stats["num_nonfloat_tensors"] += 1 - passthrough[name] = t - stats["int8_payload_bytes"] += tensor_nbytes(t) - continue - - # Small float tensors are cheap enough to keep directly. We still downcast - # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. - if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: - kept = keep_float_tensor(name, t, passthrough_orig_dtypes) - passthrough[name] = kept - stats["int8_payload_bytes"] += tensor_nbytes(kept) - continue - - stats["num_float_tensors"] += 1 - q, s = quantize_float_tensor(t) - if s.ndim > 0: - qmeta[name] = {"scheme": "per_row", "axis": 0} - quantized[name] = q - scales[name] = s - dtypes[name] = str(t.dtype).removeprefix("torch.") - stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) - - obj: dict[str, object] = { - "__quant_format__": "int8_clean_per_row_v1", - "quantized": quantized, - "scales": scales, - "dtypes": dtypes, - "passthrough": passthrough, - } - if qmeta: - obj["qmeta"] = qmeta - if passthrough_orig_dtypes: - obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes - return obj, stats - -def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: - out: dict[str, Tensor] = {} - qmeta = obj.get("qmeta", {}) - passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) - for name, q in obj["quantized"].items(): - dtype = getattr(torch, obj["dtypes"][name]) - s = obj["scales"][name] - if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: - s = s.to(dtype=torch.float32) - # Broadcast the saved row scale back across trailing dimensions. - out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() - else: - scale = float(s.item()) - out[name] = (q.float() * scale).to(dtype=dtype).contiguous() - for name, t in obj["passthrough"].items(): - # Restore small tensors, undoing the temporary fp16 storage cast if needed. - out_t = t.detach().to("cpu").contiguous() - orig_dtype = passthrough_orig_dtypes.get(name) - if isinstance(orig_dtype, str): - out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() - out[name] = out_t - return out - - -# ----------------------------- -# DATA LOADING -# ----------------------------- - -def load_data_shard(file: Path) -> Tensor: - header_bytes = 256 * np.dtype(" None: - self.file_idx = (self.file_idx + 1) % len(self.files) - self.tokens = load_data_shard(self.files[self.file_idx]) - self.pos = 0 - - def take(self, n: int) -> Tensor: - chunks: list[Tensor] = [] - remaining = n - while remaining > 0: - avail = self.tokens.numel() - self.pos - if avail <= 0: - self._advance_file() - continue - k = min(remaining, avail) - chunks.append(self.tokens[self.pos : self.pos + k]) - self.pos += k - remaining -= k - return chunks[0] if len(chunks) == 1 else torch.cat(chunks) - - -class DistributedTokenLoader: - # Each call consumes a contiguous chunk from the shared token stream, then slices out - # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. - def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): - self.rank = rank - self.world_size = world_size - self.device = device - self.stream = TokenStream(pattern) - - def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: - local_tokens = global_tokens // (self.world_size * grad_accum_steps) - per_rank_span = local_tokens + 1 - chunk = self.stream.take(per_rank_span * self.world_size) - start = self.rank * per_rank_span - local = chunk[start : start + per_rank_span].to(dtype=torch.int64) - x = local[:-1].reshape(-1, seq_len) - y = local[1:].reshape(-1, seq_len) - return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) - -# ----------------------------- -# TRANSFORMER MODULES -# ----------------------------- - -class RMSNorm(nn.Module): - def __init__(self, eps: float | None = None): - super().__init__() - self.eps = eps - - def forward(self, x: Tensor) -> Tensor: - return F.rms_norm(x, (x.size(-1),), eps=self.eps) - - -class CastedLinear(nn.Linear): - # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. - def forward(self, x: Tensor) -> Tensor: - bias = self.bias.to(x.dtype) if self.bias is not None else None - return F.linear(x, self.weight.to(x.dtype), bias) - - -def restore_low_dim_params_to_fp32(module: nn.Module) -> None: - # Keep small/control parameters in fp32 even when the model body runs in bf16. - with torch.no_grad(): - for name, param in module.named_parameters(): - if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: - param.data = param.data.float() - - -class Rotary(nn.Module): - # Caches cos/sin tables per sequence length on the current device. - def __init__(self, dim: int, base: float = 10000.0): - super().__init__() - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self._seq_len_cached = 0 - self._cos_cached: Tensor | None = None - self._sin_cached: Tensor | None = None - - def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: - if ( - self._cos_cached is None - or self._sin_cached is None - or self._seq_len_cached != seq_len - or self._cos_cached.device != device - ): - t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) - freqs = torch.outer(t, self.inv_freq.to(device)) - self._cos_cached = freqs.cos()[None, None, :, :] - self._sin_cached = freqs.sin()[None, None, :, :] - self._seq_len_cached = seq_len - return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) - - -def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: - half = x.size(-1) // 2 - x1, x2 = x[..., :half], x[..., half:] - return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) - - -class CausalSelfAttention(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - rope_base: float, - qk_gain_init: float, - ): - super().__init__() - if dim % num_heads != 0: - raise ValueError("model_dim must be divisible by num_heads") - if num_heads % num_kv_heads != 0: - raise ValueError("num_heads must be divisible by num_kv_heads") - self.num_heads = num_heads - self.num_kv_heads = num_kv_heads - self.head_dim = dim // num_heads - if self.head_dim % 2 != 0: - raise ValueError("head_dim must be even for RoPE") - kv_dim = self.num_kv_heads * self.head_dim - self.c_q = CastedLinear(dim, dim, bias=False) - self.c_k = CastedLinear(dim, kv_dim, bias=False) - self.c_v = CastedLinear(dim, kv_dim, bias=False) - self.proj = CastedLinear(dim, dim, bias=False) - self.proj._zero_init = True - self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) - self.rotary = Rotary(self.head_dim, base=rope_base) - - def forward(self, x: Tensor, q_delta=None, v_delta=None) -> Tensor: - bsz, seqlen, dim = x.shape - q = self.c_q(x) + (q_delta if q_delta is not None else 0) - k = self.c_k(x) - v = self.c_v(x) + (v_delta if v_delta is not None else 0) - q = q.reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) - k = k.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) - v = v.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) - q = F.rms_norm(q, (q.size(-1),)) - k = F.rms_norm(k, (k.size(-1),)) - cos, sin = self.rotary(seqlen, x.device, q.dtype) - q = apply_rotary_emb(q, cos, sin) - k = apply_rotary_emb(k, cos, sin) - q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] - y = F.scaled_dot_product_attention( - q, - k, - v, - attn_mask=None, - is_causal=True, - enable_gqa=(self.num_kv_heads != self.num_heads), - ) - y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) - return self.proj(y) - - -class MLP(nn.Module): - # relu^2 MLP from the original modded-nanogpt setup - def __init__(self, dim: int, mlp_mult: int): - super().__init__() - hidden = mlp_mult * dim - self.fc = CastedLinear(dim, hidden, bias=False) - self.proj = CastedLinear(hidden, dim, bias=False) - self.proj._zero_init = True - - def forward(self, x: Tensor) -> Tensor: - x = torch.relu(self.fc(x)) - return self.proj(x.square()) - - -class Block(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: int, - rope_base: float, - qk_gain_init: float, - ): - super().__init__() - self.attn_norm = RMSNorm() - self.mlp_norm = RMSNorm() - self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) - self.mlp = MLP(dim, mlp_mult) - self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) - - def forward(self, x: Tensor, x0: Tensor, q_delta_fn=None, v_delta_fn=None) -> Tensor: - mix = self.resid_mix.to(dtype=x.dtype) - x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 - n = self.attn_norm(x) - qd = q_delta_fn(n) if q_delta_fn is not None else None - vd = v_delta_fn(n) if v_delta_fn is not None else None - attn_out = self.attn(n, qd, vd) - x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out - x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) - return x - - -class GPT(nn.Module): - def __init__( - self, - vocab_size: int, - num_layers: int, - model_dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: int, - tie_embeddings: bool, - tied_embed_init_std: float, - logit_softcap: float, - rope_base: float, - qk_gain_init: float, - ): - super().__init__() - if logit_softcap <= 0.0: - raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") - self.tie_embeddings = tie_embeddings - self.tied_embed_init_std = tied_embed_init_std - self.logit_softcap = logit_softcap - self.tok_emb = nn.Embedding(vocab_size, model_dim) - self.num_encoder_layers = num_layers // 2 - self.num_decoder_layers = num_layers - self.num_encoder_layers - self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) - self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) - self.blocks = nn.ModuleList( - [ - Block( - model_dim, - num_heads, - num_kv_heads, - mlp_mult, - rope_base, - qk_gain_init, - ) - for i in range(num_layers) - ] - ) - self.final_norm = RMSNorm() - self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) - if self.lm_head is not None: - self.lm_head._zero_init = True - self._init_weights() - - def _init_weights(self) -> None: - if self.tie_embeddings: - nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) - for module in self.modules(): - if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): - nn.init.zeros_(module.weight) - - def forward(self, input_ids: Tensor, target_ids: Tensor, lora=None) -> Tensor: - x = self.tok_emb(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - x0 = x - skips: list[Tensor] = [] - - # First half stores skips; second half reuses them in reverse order. - for i in range(self.num_encoder_layers): - qd = lora.q_loras[i] if lora else None - vd = lora.v_loras[i] if lora else None - x = self.blocks[i](x, x0, qd, vd) - skips.append(x) - for i in range(self.num_decoder_layers): - bi = self.num_encoder_layers + i - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - qd = lora.q_loras[bi] if lora else None - vd = lora.v_loras[bi] if lora else None - x = self.blocks[bi](x, x0, qd, vd) - x = self.final_norm(x) - if self.tie_embeddings: - logits = F.linear(x, self.tok_emb.weight) - else: - logits = self.lm_head(x) - logits = logits + (lora.lm_head_lora(x) if lora else 0) - logits = self.logit_softcap * torch.tanh(logits / self.logit_softcap) - if lora: - bsz, sl, V = logits.shape - return F.cross_entropy( - logits.float().reshape(-1, V), target_ids.reshape(-1), reduction="none").reshape(bsz, sl) - return F.cross_entropy(logits.float().reshape(-1, logits.size(-1)), target_ids.reshape(-1), reduction="mean") - - -# ----------------------------- -# TEST-TIME TRAINING (LoRA) -# ----------------------------- -# -# At evaluation time, we adapt per-document low-rank adapters on the validation data. -# Each document gets its own adapter, so there is no inter-document dependency. - -BOS_ID = 1 - -class BatchedLinearLoRA(nn.Module): - """LoRA for a linear layer, with independent weights per batch element. - Computes x @ Aᵀ @ Bᵀ = x @ (BA)ᵀ, i.e. the LoRA delta is ΔW = BA.""" - def __init__(self, bsz: int, in_features: int, out_features: int, rank: int): - super().__init__() - self.in_features = in_features - self.A = nn.Parameter(torch.empty(bsz, rank, in_features)) # down-projection - self.B = nn.Parameter(torch.zeros(bsz, out_features, rank)) # up-projection - self.reset() - - def forward(self, x: Tensor) -> Tensor: - return (x @ self.A.transpose(1, 2)) @ self.B.transpose(1, 2) # (bsz, T, out) - - def reset(self) -> None: - bound = 1.0 / math.sqrt(self.in_features) - with torch.no_grad(): - self.A.uniform_(-bound, bound) # kaiming-uniform - self.B.zero_() - -class BatchedTTTLoRA(nn.Module): - """All LoRA adapters for one batch: LM head and Q/V per block.""" - def __init__(self, bsz: int, model: GPT, rank: int): - super().__init__() - dim = model.tok_emb.embedding_dim - vocab = model.tok_emb.num_embeddings - self.lm_head_lora = BatchedLinearLoRA(bsz, dim, vocab, rank) - self.q_loras = nn.ModuleList() - self.v_loras = nn.ModuleList() - for block in model.blocks: - self.q_loras.append(BatchedLinearLoRA(bsz, dim, block.attn.c_q.weight.shape[0], rank)) - self.v_loras.append(BatchedLinearLoRA(bsz, dim, block.attn.c_v.weight.shape[0], rank)) - - def reset(self) -> None: - for m in self.modules(): - if isinstance(m, BatchedLinearLoRA): - m.reset() - -def _reset_ttt_optimizer(opt): - for group in opt.param_groups: - for p in group['params']: - s = opt.state.get(p) - if not s: # Fresh state. - continue - s['exp_avg'].zero_() - s['exp_avg_sq'].zero_() - s['step'].fill_(0) - -def _build_ttt_optimizer(lora, args: Hyperparameters): - return torch.optim.Adam(lora.parameters(), lr=args.ttt_lora_lr, betas=(args.beta1, args.beta2), eps=1e-10) - -def _find_docs(all_tokens: Tensor, include_next_bos: bool = True) -> list[tuple[int, int]]: - """Return (start_offset, length) for each document, identified by BOS boundaries. - - If include_next_bos is True, include next document's BOS (to match continuous-stream - eval token count exactly). - """ - bos_positions = (all_tokens == BOS_ID).nonzero(as_tuple=True)[0].numpy() - docs = [] - for i in range(len(bos_positions)): - start = int(bos_positions[i]) - end = int(bos_positions[i + 1]) if i + 1 < len(bos_positions) else all_tokens.numel() - if include_next_bos and i + 1 < len(bos_positions): - end += 1 - assert end - start >= 2 - docs.append((start, end - start)) - return docs - -def _compute_chunk_window(ci: int, pred_len: int, num_chunks: int, chunk_size: int, eval_seq_len: int): - """Return (win_start, win_len, chunk_offset, chunk_len) for chunk `ci` of a doc.""" - chunk_start = ci * chunk_size - chunk_end = pred_len if ci == num_chunks - 1 else (ci + 1) * chunk_size - win_start = max(0, chunk_end - eval_seq_len) - win_len = chunk_end - win_start - chunk_offset = chunk_start - win_start - chunk_len = chunk_end - chunk_start - return win_start, win_len, chunk_offset, chunk_len - -def _accumulate_bpb( - ptl: Tensor, x: Tensor, y: Tensor, - batch_i: int, chunk_offset: int, chunk_len: int, - base_bytes_lut: Tensor, has_leading_space_lut: Tensor, is_boundary_token_lut: Tensor, - loss_sum: Tensor, byte_sum: Tensor, token_count: Tensor, -): - """Add one doc-chunk's contribution to the running BPB accumulators.""" - lbl = ptl[batch_i, chunk_offset:chunk_offset + chunk_len].to(torch.float64) - prev = x[batch_i, chunk_offset:chunk_offset + chunk_len] - tgt = y[batch_i, chunk_offset:chunk_offset + chunk_len] - tok_bytes = base_bytes_lut[tgt].to(torch.float64) - tok_bytes += has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev] - loss_sum += lbl.sum() - byte_sum += tok_bytes.sum() - token_count += chunk_len - -def eval_val_ttt_lora( - args: Hyperparameters, - base_model: GPT, - rank: int, - world_size: int, - device: torch.device, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, -) -> tuple[float, float]: - """Evaluate with batched LoRA test-time training. Returns (val_loss, val_bpb).""" - # Load validation tokens and find document boundaries - files = sorted(glob.glob(args.val_files)) - all_tokens = torch.cat([load_data_shard(Path(f)) for f in files]) - docs = _find_docs(all_tokens) - - # Each rank takes a contiguous slice of documents - rank_docs = docs[(len(docs) * rank) // world_size : (len(docs) * (rank + 1)) // world_size] - chunk_size = args.ttt_chunk_size - eval_seq_len = args.ttt_eval_seq_len - batch_size = args.ttt_batch_size - lora_rank = args.ttt_lora_rank - - rank_docs.sort(key=lambda d: (d[1] - 2) // chunk_size) - - base_model.eval() - for p in base_model.parameters(): - p.requires_grad_(False) - - lora = BatchedTTTLoRA(batch_size, base_model, lora_rank).to(device) - opt = _build_ttt_optimizer(lora, args) - - loss_sum = torch.zeros((), device=device, dtype=torch.float64) - byte_sum = torch.zeros((), device=device, dtype=torch.float64) - token_count = torch.zeros((), device=device, dtype=torch.float64) - - for bi in range(0, len(rank_docs), batch_size): - batch = rank_docs[bi:bi + batch_size] - bsz = len(batch) - - if bsz == batch_size: - cur_lora, cur_opt = lora, opt - cur_lora.reset() - _reset_ttt_optimizer(cur_opt) - else: - cur_lora = BatchedTTTLoRA(bsz, base_model, lora_rank).to(device) - cur_opt = _build_ttt_optimizer(cur_lora, args) - - pred_lens = [doc_len - 1 for _, doc_len in batch] - num_chunks = [(pl + chunk_size - 1) // chunk_size for pl in pred_lens] - max_nc = max(num_chunks) - - for ci in range(max_nc): - chunk_stats = _compute_chunk_window(ci, (ci + 1) * chunk_size, ci + 1, chunk_size, eval_seq_len) - context_size, chunk_offset = chunk_stats[1], chunk_stats[2] - - active = [ci < nc for nc in num_chunks] - needs_train = any(ci < nc - 1 for nc in num_chunks) - - x = torch.zeros(bsz, context_size, dtype=torch.int64, device=device) - y = torch.zeros(bsz, context_size, dtype=torch.int64, device=device) - doc_info = [] # (chunk_offset, chunk_len) per doc - for b in range(bsz): - if not active[b]: - doc_info.append((0, 0)) - continue - ds, dl = batch[b] - ws, wl, co, cl = _compute_chunk_window(ci, pred_lens[b], num_chunks[b], chunk_size, eval_seq_len) - chunk = all_tokens[ds + ws: ds + ws + wl + 1] - toks = chunk.to(dtype=torch.int64, device=device) - x[b, :wl] = toks[:-1] - y[b, :wl] = toks[1:] - doc_info.append((co, cl)) - - # Forward pass (keep grad graph alive only when we need to train) - if needs_train: - with torch.autocast(device_type="cuda", dtype=torch.bfloat16): - ptl = base_model(x, y, lora=cur_lora) - else: - with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16): - ptl = base_model(x, y, lora=cur_lora) - - # Score: accumulate loss and byte counts for BPB (before training on chunk) - with torch.no_grad(): - for b in range(bsz): - if not active[b]: - continue - co, cl = doc_info[b] - _accumulate_bpb( - ptl, x, y, b, co, cl, base_bytes_lut, has_leading_space_lut, - is_boundary_token_lut, loss_sum, byte_sum, token_count) - - # Train: one Adam step on the LoRA params using this chunk's loss - if needs_train: - mask = torch.tensor([float(ci < num_chunks[b] - 1) for b in range(bsz)], device=device) - per_doc = ptl[:, chunk_offset:chunk_offset + chunk_size].mean(dim=-1) - cur_opt.zero_grad() - (per_doc * mask).sum().backward() - cur_opt.step() - - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(byte_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(token_count, op=dist.ReduceOp.SUM) - - val_loss = float(loss_sum.item() / token_count.item()) - val_bpb = float((loss_sum.item() / math.log(2.0)) / byte_sum.item()) - return val_loss, val_bpb - -# ----------------------------- -# TRAINING -# ----------------------------- - -def main() -> None: - global zeropower_via_newtonschulz5 - - code = Path(__file__).read_text(encoding="utf-8") - args = Hyperparameters() - zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) - - # ----------------------------- - # DISTRIBUTED + CUDA SETUP - # ----------------------------- - - distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ - rank = int(os.environ.get("RANK", "0")) - world_size = int(os.environ.get("WORLD_SIZE", "1")) - local_rank = int(os.environ.get("LOCAL_RANK", "0")) - if world_size <= 0: - raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") - if 8 % world_size != 0: - raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") - grad_accum_steps = 8 // world_size - grad_scale = 1.0 / grad_accum_steps - if not torch.cuda.is_available(): - raise RuntimeError("CUDA is required") - device = torch.device("cuda", local_rank) - torch.cuda.set_device(device) - if distributed: - dist.init_process_group(backend="nccl", device_id=device) - dist.barrier() - master_process = rank == 0 - - # Fast math knobs - torch.backends.cuda.matmul.allow_tf32 = True - torch.backends.cudnn.allow_tf32 = True - from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp - - enable_cudnn_sdp(False) - enable_flash_sdp(True) - enable_mem_efficient_sdp(False) - enable_math_sdp(False) - - logfile = None - if master_process: - os.makedirs("logs", exist_ok=True) - logfile = f"logs/{args.run_id}.txt" - print(logfile) - - def log0(msg: str, console: bool = True) -> None: - if not master_process: - return - if console: - print(msg) - if logfile is not None: - with open(logfile, "a", encoding="utf-8") as f: - print(msg, file=f) - - log0(code, console=False) - log0("=" * 100, console=False) - log0(f"Running Python {sys.version}", console=False) - log0(f"Running PyTorch {torch.__version__}", console=False) - log0( - subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, - console=False, - ) - log0("=" * 100, console=False) - - # ----------------------------- - # TOKENIZER + VALIDATION METRIC SETUP - # ----------------------------- - - random.seed(args.seed) - np.random.seed(args.seed) - torch.manual_seed(args.seed) - torch.cuda.manual_seed_all(args.seed) - - if not args.tokenizer_path.endswith(".model"): - raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") - sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) - if int(sp.vocab_size()) != args.vocab_size: - raise ValueError( - f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" - ) - dataset_dir = Path(args.data_path).resolve() - actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) - val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) - base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( - sp, args.vocab_size, device - ) - log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") - log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") - log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") - - # ----------------------------- - # MODEL + OPTIMIZER SETUP - # ----------------------------- - - base_model = GPT( - vocab_size=args.vocab_size, - num_layers=args.num_layers, - model_dim=args.model_dim, - num_heads=args.num_heads, - num_kv_heads=args.num_kv_heads, - mlp_mult=args.mlp_mult, - tie_embeddings=args.tie_embeddings, - tied_embed_init_std=args.tied_embed_init_std, - logit_softcap=args.logit_softcap, - rope_base=args.rope_base, - qk_gain_init=args.qk_gain_init, - ).to(device).bfloat16() - for module in base_model.modules(): - if isinstance(module, CastedLinear): - module.float() - if isinstance(module, Rotary): - module.inv_freq.data = module.inv_freq.data.float() - restore_low_dim_params_to_fp32(base_model) - compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) - model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model - - # Optimizer split: - # - token embedding (Adam) uses EMBED_LR - # - untied lm_head (Adam) uses HEAD_LR - # - matrix params in transformer blocks use MATRIX_LR via Muon - # - vectors/scalars use SCALAR_LR via Adam - block_named_params = list(base_model.blocks.named_parameters()) - matrix_params = [ - p - for name, p in block_named_params - if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - scalar_params = [ - p - for name, p in block_named_params - if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - if base_model.skip_weights.numel() > 0: - scalar_params.append(base_model.skip_weights) - token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr - optimizer_tok = torch.optim.Adam( - [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizer_muon = Muon( - matrix_params, - lr=args.matrix_lr, - momentum=args.muon_momentum, - backend_steps=args.muon_backend_steps, - ) - for group in optimizer_muon.param_groups: - group["base_lr"] = args.matrix_lr - optimizer_scalar = torch.optim.Adam( - [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] - if base_model.lm_head is not None: - optimizer_head = torch.optim.Adam( - [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizers.insert(1, optimizer_head) - - n_params = sum(p.numel() for p in base_model.parameters()) - log0(f"model_params:{n_params}") - log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") - log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") - log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") - log0( - f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " - f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " - f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" - ) - log0( - f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " - f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " - f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" - ) - log0(f"seed:{args.seed}") - - # ----------------------------- - # DATA LOADER & MODEL WARMUP - # ----------------------------- - - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - def zero_grad_all() -> None: - for opt in optimizers: - opt.zero_grad(set_to_none=True) - - max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None - - def lr_mul(step: int, elapsed_ms: float) -> float: - if args.warmdown_iters <= 0: - return 1.0 - if max_wallclock_ms is None: - warmdown_start = max(args.iterations - args.warmdown_iters, 0) - return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 - step_ms = elapsed_ms / max(step, 1) - warmdown_ms = args.warmdown_iters * step_ms - remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) - return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 - - # Warmup primes the compiled forward/backward/optimizer paths, then we restore the - # initial weights/optimizer state so measured training starts from the true init. - if args.warmup_steps > 0: - initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} - initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] - model.train() - for warmup_step in range(args.warmup_steps): - zero_grad_all() - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - warmup_loss = model(x, y) - (warmup_loss * grad_scale).backward() - for opt in optimizers: - opt.step() - zero_grad_all() - if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: - log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") - base_model.load_state_dict(initial_model_state, strict=True) - for opt, state in zip(optimizers, initial_optimizer_states, strict=True): - opt.load_state_dict(state) - zero_grad_all() - if distributed: - model.require_backward_grad_sync = True - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - # ----------------------------- - # MAIN TRAINING LOOP - # ----------------------------- - - training_time_ms = 0.0 - stop_after_step: int | None = None - torch.cuda.synchronize() - t0 = time.perf_counter() - - step = 0 - while True: - last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) - - should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) - if should_validate: - torch.cuda.synchronize() - training_time_ms += 1000.0 * (time.perf_counter() - t0) - val_loss, val_bpb = eval_val( - args, - model, - rank, - world_size, - device, - grad_accum_steps, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, - ) - log0( - f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " - f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" - ) - torch.cuda.synchronize() - t0 = time.perf_counter() - - if last_step: - if stop_after_step is not None and step < args.iterations: - log0( - f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " - f"step:{step}/{args.iterations}" - ) - break - - elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - scale = lr_mul(step, elapsed_ms) - zero_grad_all() - train_loss = torch.zeros((), device=device) - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - loss = model(x, y) - train_loss += loss.detach() - (loss * grad_scale).backward() - train_loss /= grad_accum_steps - - frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 - muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum - for group in optimizer_muon.param_groups: - group["momentum"] = muon_momentum - - for opt in optimizers: - for group in opt.param_groups: - group["lr"] = group["base_lr"] * scale - - if args.grad_clip_norm > 0: - torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) - for opt in optimizers: - opt.step() - zero_grad_all() - - step += 1 - approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - should_log_train = ( - args.train_log_every > 0 - and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) - ) - if should_log_train: - log0( - f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " - f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" - ) - - # Needed to sync whether we've reached the wallclock cap. - reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms - if distributed and max_wallclock_ms is not None: - reached_cap_tensor = torch.tensor(int(reached_cap), device=device) - dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) - reached_cap = bool(reached_cap_tensor.item()) - if stop_after_step is None and reached_cap: - stop_after_step = step - - log0( - f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " - f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" - ) - - # ----------------------------- - # SERIALIZATION + ROUNDTRIP VALIDATION - # ----------------------------- - # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce - # the compressed int8+zlib artifact and validate the round-tripped weights. - - if master_process: - torch.save(base_model.state_dict(), "final_model.pt") - model_bytes = os.path.getsize("final_model.pt") - code_bytes = len(code.encode("utf-8")) - log0(f"Serialized model: {model_bytes} bytes") - log0(f"Code size: {code_bytes} bytes") - log0(f"Total submission size: {model_bytes + code_bytes} bytes") - - quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) - quant_buf = io.BytesIO() - torch.save(quant_obj, quant_buf) - quant_raw = quant_buf.getvalue() - quant_blob = zlib.compress(quant_raw, level=9) - quant_raw_bytes = len(quant_raw) - if master_process: - with open("final_model.int8.ptz", "wb") as f: - f.write(quant_blob) - quant_file_bytes = os.path.getsize("final_model.int8.ptz") - code_bytes = len(code.encode("utf-8")) - ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) - log0( - f"Serialized model int8+zlib: {quant_file_bytes} bytes " - f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" - ) - log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") - - if distributed: - dist.barrier() - with open("final_model.int8.ptz", "rb") as f: - quant_blob_disk = f.read() - quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") - base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) - torch.cuda.synchronize() - t_qeval = time.perf_counter() - q_val_loss, q_val_bpb = eval_val( - args, - model, - rank, - world_size, - device, - grad_accum_steps, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, - ) - torch.cuda.synchronize() - log0( - f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " - f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" - ) - log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") - - # LoRA test-time training evaluation (the competition score) - torch._dynamo.reset() - torch.cuda.synchronize() - t_ttt = time.perf_counter() - ttt_val_loss, ttt_val_bpb = eval_val_ttt_lora( - args, base_model, rank, world_size, device, - base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - ) - torch.cuda.synchronize() - log0( - f"final_int8_ttt_lora val_loss:{ttt_val_loss:.4f} val_bpb:{ttt_val_bpb:.4f} " - f"eval_time:{1000.0 * (time.perf_counter() - t_ttt):.0f}ms" - ) - - if distributed: - dist.destroy_process_group() - - -if __name__ == "__main__": - main() - -==================================================================================================== -Running Python 3.10.12 (main, Jan 8 2026, 06:52:19) [GCC 11.4.0] -Running PyTorch 2.7.1+cu128 -Thu Mar 19 00:44:40 2026 -+-----------------------------------------------------------------------------------------+ -| NVIDIA-SMI 570.211.01 Driver Version: 570.211.01 CUDA Version: 12.8 | -|-----------------------------------------+------------------------+----------------------+ -| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | -| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | -| | | MIG M. | -|=========================================+========================+======================| -| 0 NVIDIA H100 80GB HBM3 On | 00000000:04:00.0 Off | 0 | -| N/A 39C P0 109W / 700W | 1449MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 1 NVIDIA H100 80GB HBM3 On | 00000000:05:00.0 Off | 0 | -| N/A 36C P0 112W / 700W | 1449MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ - -+-----------------------------------------------------------------------------------------+ -| Processes: | -| GPU GI CI PID Type Process name GPU Memory | -| ID ID Usage | -|=========================================================================================| -| 0 N/A N/A 31626 C /usr/bin/python3 1440MiB | -| 1 N/A N/A 31627 C /usr/bin/python3 1440MiB | -+-----------------------------------------------------------------------------------------+ - -==================================================================================================== -val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model -train_loader:dataset:fineweb10B_sp1024 train_shards:25 -val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 -model_params:17059912 -world_size:2 grad_accum_steps:4 -sdp_backends:cudnn=False flash=True mem_efficient=False math=False -attention_mode:gqa num_heads:8 num_kv_heads:4 -tie_embeddings:True embed_lr:0.05 head_lr:0.0 matrix_lr:0.04 scalar_lr:0.04 -train_batch_tokens:524288 train_seq_len:1024 iterations:20000 warmup_steps:20 max_wallclock_seconds:2400.000 -seed:1337 -warmup_step:1/20 -warmup_step:2/20 -warmup_step:3/20 -warmup_step:4/20 -warmup_step:5/20 -warmup_step:6/20 -warmup_step:7/20 -warmup_step:8/20 -warmup_step:9/20 -warmup_step:10/20 -warmup_step:11/20 -warmup_step:12/20 -warmup_step:13/20 -warmup_step:14/20 -warmup_step:15/20 -warmup_step:16/20 -warmup_step:17/20 -warmup_step:18/20 -warmup_step:19/20 -warmup_step:20/20 -step:0/20000 val_loss:6.9357 val_bpb:4.1077 train_time:0ms step_avg:0.02ms -step:1/20000 train_loss:6.9351 train_time:174ms step_avg:173.60ms -step:2/20000 train_loss:16.7655 train_time:342ms step_avg:171.16ms -step:3/20000 train_loss:8.7595 train_time:512ms step_avg:170.53ms -step:4/20000 train_loss:6.6029 train_time:682ms step_avg:170.42ms -step:5/20000 train_loss:6.6597 train_time:851ms step_avg:170.20ms -step:6/20000 train_loss:6.8160 train_time:1020ms step_avg:169.99ms -step:7/20000 train_loss:6.3288 train_time:1189ms step_avg:169.84ms -step:8/20000 train_loss:6.1695 train_time:1358ms step_avg:169.74ms -step:9/20000 train_loss:6.0765 train_time:1527ms step_avg:169.64ms -step:10/20000 train_loss:5.9763 train_time:1696ms step_avg:169.59ms -step:50/20000 train_loss:4.0527 train_time:8482ms step_avg:169.63ms -step:100/20000 train_loss:3.3540 train_time:16978ms step_avg:169.78ms -step:150/20000 train_loss:3.0452 train_time:25470ms step_avg:169.80ms -step:200/20000 train_loss:2.7877 train_time:33973ms step_avg:169.87ms -step:200/20000 val_loss:2.8495 val_bpb:1.6877 train_time:33975ms step_avg:169.87ms -step:250/20000 train_loss:2.7531 train_time:42463ms step_avg:169.85ms -step:300/20000 train_loss:2.5894 train_time:50946ms step_avg:169.82ms -step:350/20000 train_loss:2.6244 train_time:59430ms step_avg:169.80ms -step:400/20000 train_loss:2.3761 train_time:67946ms step_avg:169.87ms -step:400/20000 val_loss:2.5684 val_bpb:1.5211 train_time:67947ms step_avg:169.87ms -step:450/20000 train_loss:2.4938 train_time:76426ms step_avg:169.84ms -step:500/20000 train_loss:2.4739 train_time:84902ms step_avg:169.80ms -step:550/20000 train_loss:2.3493 train_time:93378ms step_avg:169.78ms -step:600/20000 train_loss:2.5328 train_time:101915ms step_avg:169.86ms -step:600/20000 val_loss:2.4503 val_bpb:1.4512 train_time:101916ms step_avg:169.86ms -step:650/20000 train_loss:2.3565 train_time:110390ms step_avg:169.83ms -step:700/20000 train_loss:2.4337 train_time:118865ms step_avg:169.81ms -step:750/20000 train_loss:2.3672 train_time:127340ms step_avg:169.79ms -step:800/20000 train_loss:2.3169 train_time:135831ms step_avg:169.79ms -step:800/20000 val_loss:2.3805 val_bpb:1.4098 train_time:135832ms step_avg:169.79ms -step:850/20000 train_loss:2.4973 train_time:144303ms step_avg:169.77ms -step:900/20000 train_loss:2.3849 train_time:152780ms step_avg:169.76ms -step:950/20000 train_loss:2.4032 train_time:161247ms step_avg:169.73ms -step:1000/20000 train_loss:2.3434 train_time:169728ms step_avg:169.73ms -step:1000/20000 val_loss:2.3368 val_bpb:1.3840 train_time:169729ms step_avg:169.73ms -step:1050/20000 train_loss:2.4253 train_time:178198ms step_avg:169.71ms -step:1100/20000 train_loss:2.3005 train_time:186666ms step_avg:169.70ms -step:1150/20000 train_loss:2.3095 train_time:195168ms step_avg:169.71ms -step:1200/20000 train_loss:2.3013 train_time:203632ms step_avg:169.69ms -step:1200/20000 val_loss:2.3046 val_bpb:1.3649 train_time:203633ms step_avg:169.69ms -step:1250/20000 train_loss:2.2475 train_time:212099ms step_avg:169.68ms -step:1300/20000 train_loss:2.3358 train_time:220560ms step_avg:169.66ms -step:1350/20000 train_loss:2.2053 train_time:229093ms step_avg:169.70ms -step:1400/20000 train_loss:2.3307 train_time:237552ms step_avg:169.68ms -step:1400/20000 val_loss:2.2807 val_bpb:1.3508 train_time:237553ms step_avg:169.68ms -step:1450/20000 train_loss:2.2429 train_time:246012ms step_avg:169.66ms -step:1500/20000 train_loss:2.3022 train_time:254473ms step_avg:169.65ms -step:1550/20000 train_loss:2.4251 train_time:262952ms step_avg:169.65ms -step:1600/20000 train_loss:2.1973 train_time:271422ms step_avg:169.64ms -step:1600/20000 val_loss:2.2679 val_bpb:1.3432 train_time:271423ms step_avg:169.64ms -step:1650/20000 train_loss:2.2489 train_time:279884ms step_avg:169.63ms -step:1700/20000 train_loss:2.2760 train_time:288346ms step_avg:169.62ms -step:1750/20000 train_loss:2.2502 train_time:296823ms step_avg:169.61ms -step:1800/20000 train_loss:2.2752 train_time:305279ms step_avg:169.60ms -step:1800/20000 val_loss:2.2507 val_bpb:1.3330 train_time:305280ms step_avg:169.60ms -step:1850/20000 train_loss:2.2722 train_time:313737ms step_avg:169.59ms -step:1900/20000 train_loss:2.2218 train_time:322193ms step_avg:169.58ms -step:1950/20000 train_loss:2.2899 train_time:330692ms step_avg:169.59ms -step:2000/20000 train_loss:2.2005 train_time:339154ms step_avg:169.58ms -step:2000/20000 val_loss:2.2358 val_bpb:1.3241 train_time:339155ms step_avg:169.58ms -step:2050/20000 train_loss:2.2737 train_time:347614ms step_avg:169.57ms -step:2100/20000 train_loss:2.3770 train_time:356141ms step_avg:169.59ms -step:2150/20000 train_loss:2.1612 train_time:364596ms step_avg:169.58ms -step:2200/20000 train_loss:2.1348 train_time:373058ms step_avg:169.57ms -step:2200/20000 val_loss:2.2278 val_bpb:1.3194 train_time:373059ms step_avg:169.57ms -step:2250/20000 train_loss:2.1900 train_time:381521ms step_avg:169.56ms -step:2300/20000 train_loss:2.2810 train_time:389996ms step_avg:169.56ms -step:2350/20000 train_loss:2.1192 train_time:398457ms step_avg:169.56ms -step:2400/20000 train_loss:2.2346 train_time:406915ms step_avg:169.55ms -step:2400/20000 val_loss:2.2170 val_bpb:1.3130 train_time:406916ms step_avg:169.55ms -step:2450/20000 train_loss:2.2517 train_time:415371ms step_avg:169.54ms -step:2500/20000 train_loss:2.1884 train_time:423843ms step_avg:169.54ms -step:2550/20000 train_loss:2.1759 train_time:432305ms step_avg:169.53ms -step:2600/20000 train_loss:2.2675 train_time:440770ms step_avg:169.53ms -step:2600/20000 val_loss:2.2190 val_bpb:1.3142 train_time:440772ms step_avg:169.53ms -step:2650/20000 train_loss:2.2572 train_time:449226ms step_avg:169.52ms -step:2700/20000 train_loss:2.1910 train_time:457747ms step_avg:169.54ms -step:2750/20000 train_loss:2.2221 train_time:466205ms step_avg:169.53ms -step:2800/20000 train_loss:2.2557 train_time:474665ms step_avg:169.52ms -step:2800/20000 val_loss:2.2019 val_bpb:1.3041 train_time:474666ms step_avg:169.52ms -step:2850/20000 train_loss:2.1905 train_time:483125ms step_avg:169.52ms -step:2900/20000 train_loss:2.1932 train_time:491602ms step_avg:169.52ms -step:2950/20000 train_loss:2.2255 train_time:500060ms step_avg:169.51ms -step:3000/20000 train_loss:2.1296 train_time:508513ms step_avg:169.50ms -step:3000/20000 val_loss:2.1951 val_bpb:1.3000 train_time:508514ms step_avg:169.50ms -step:3050/20000 train_loss:2.1954 train_time:516970ms step_avg:169.50ms -step:3100/20000 train_loss:2.2106 train_time:525447ms step_avg:169.50ms -step:3150/20000 train_loss:2.0825 train_time:533905ms step_avg:169.49ms -step:3200/20000 train_loss:2.2203 train_time:542364ms step_avg:169.49ms -step:3200/20000 val_loss:2.1896 val_bpb:1.2968 train_time:542364ms step_avg:169.49ms -step:3250/20000 train_loss:2.1742 train_time:550863ms step_avg:169.50ms -step:3300/20000 train_loss:2.2282 train_time:559322ms step_avg:169.49ms -step:3350/20000 train_loss:2.1648 train_time:567780ms step_avg:169.49ms -step:3400/20000 train_loss:2.2120 train_time:576238ms step_avg:169.48ms -step:3400/20000 val_loss:2.1866 val_bpb:1.2950 train_time:576239ms step_avg:169.48ms -step:3450/20000 train_loss:2.0896 train_time:584768ms step_avg:169.50ms -step:3500/20000 train_loss:2.2498 train_time:593228ms step_avg:169.49ms -step:3550/20000 train_loss:2.2177 train_time:601685ms step_avg:169.49ms -step:3600/20000 train_loss:2.1529 train_time:610142ms step_avg:169.48ms -step:3600/20000 val_loss:2.1792 val_bpb:1.2907 train_time:610143ms step_avg:169.48ms -step:3650/20000 train_loss:2.1788 train_time:618624ms step_avg:169.49ms -step:3700/20000 train_loss:2.1938 train_time:627082ms step_avg:169.48ms -step:3750/20000 train_loss:2.1617 train_time:635540ms step_avg:169.48ms -step:3800/20000 train_loss:2.2400 train_time:644003ms step_avg:169.47ms -step:3800/20000 val_loss:2.1749 val_bpb:1.2881 train_time:644004ms step_avg:169.47ms -step:3850/20000 train_loss:2.1662 train_time:652483ms step_avg:169.48ms -step:3900/20000 train_loss:2.1166 train_time:660938ms step_avg:169.47ms -step:3950/20000 train_loss:2.2001 train_time:669395ms step_avg:169.47ms -step:4000/20000 train_loss:2.1281 train_time:677853ms step_avg:169.46ms -step:4000/20000 val_loss:2.1701 val_bpb:1.2853 train_time:677854ms step_avg:169.46ms -step:4050/20000 train_loss:2.1296 train_time:686353ms step_avg:169.47ms -step:4100/20000 train_loss:2.1709 train_time:694811ms step_avg:169.47ms -step:4150/20000 train_loss:2.2218 train_time:703269ms step_avg:169.46ms -step:4200/20000 train_loss:2.2025 train_time:711795ms step_avg:169.48ms -step:4200/20000 val_loss:2.1667 val_bpb:1.2832 train_time:711796ms step_avg:169.48ms -step:4250/20000 train_loss:2.2002 train_time:720251ms step_avg:169.47ms -step:4300/20000 train_loss:2.0832 train_time:728706ms step_avg:169.47ms -step:4350/20000 train_loss:2.2102 train_time:737162ms step_avg:169.46ms -step:4400/20000 train_loss:2.1602 train_time:745650ms step_avg:169.47ms -step:4400/20000 val_loss:2.1669 val_bpb:1.2834 train_time:745651ms step_avg:169.47ms -step:4450/20000 train_loss:2.0757 train_time:754108ms step_avg:169.46ms -step:4500/20000 train_loss:2.1909 train_time:762569ms step_avg:169.46ms -step:4550/20000 train_loss:2.1664 train_time:771028ms step_avg:169.46ms -step:4600/20000 train_loss:2.0006 train_time:779507ms step_avg:169.46ms -step:4600/20000 val_loss:2.1626 val_bpb:1.2808 train_time:779507ms step_avg:169.46ms -step:4650/20000 train_loss:2.1337 train_time:787966ms step_avg:169.46ms -step:4700/20000 train_loss:2.2058 train_time:796426ms step_avg:169.45ms -step:4750/20000 train_loss:2.0283 train_time:804882ms step_avg:169.45ms -step:4800/20000 train_loss:2.1505 train_time:813374ms step_avg:169.45ms -step:4800/20000 val_loss:2.1570 val_bpb:1.2775 train_time:813375ms step_avg:169.45ms -step:4850/20000 train_loss:2.1672 train_time:821834ms step_avg:169.45ms -step:4900/20000 train_loss:2.2597 train_time:830288ms step_avg:169.45ms -step:4950/20000 train_loss:2.1454 train_time:838740ms step_avg:169.44ms -step:5000/20000 train_loss:2.1790 train_time:847257ms step_avg:169.45ms -step:5000/20000 val_loss:2.1546 val_bpb:1.2761 train_time:847258ms step_avg:169.45ms -step:5050/20000 train_loss:2.0755 train_time:855713ms step_avg:169.45ms -step:5100/20000 train_loss:2.1092 train_time:864167ms step_avg:169.44ms -step:5150/20000 train_loss:2.1724 train_time:872650ms step_avg:169.45ms -step:5200/20000 train_loss:2.2597 train_time:881106ms step_avg:169.44ms -step:5200/20000 val_loss:2.1513 val_bpb:1.2741 train_time:881107ms step_avg:169.44ms -step:5250/20000 train_loss:2.1363 train_time:889565ms step_avg:169.44ms -step:5300/20000 train_loss:2.1027 train_time:898022ms step_avg:169.44ms -step:5350/20000 train_loss:2.2834 train_time:906518ms step_avg:169.44ms -step:5400/20000 train_loss:2.2248 train_time:914977ms step_avg:169.44ms -step:5400/20000 val_loss:2.1494 val_bpb:1.2730 train_time:914978ms step_avg:169.44ms -step:5450/20000 train_loss:2.1689 train_time:923434ms step_avg:169.44ms -step:5500/20000 train_loss:2.2277 train_time:931892ms step_avg:169.43ms -step:5550/20000 train_loss:2.1456 train_time:940413ms step_avg:169.44ms -step:5600/20000 train_loss:2.2142 train_time:948874ms step_avg:169.44ms -step:5600/20000 val_loss:2.1469 val_bpb:1.2715 train_time:948874ms step_avg:169.44ms -step:5650/20000 train_loss:2.0974 train_time:957327ms step_avg:169.44ms -step:5700/20000 train_loss:2.1700 train_time:965778ms step_avg:169.43ms -step:5750/20000 train_loss:2.1025 train_time:974252ms step_avg:169.44ms -step:5800/20000 train_loss:2.1325 train_time:982704ms step_avg:169.43ms -step:5800/20000 val_loss:2.1447 val_bpb:1.2702 train_time:982706ms step_avg:169.43ms -step:5850/20000 train_loss:2.0898 train_time:991160ms step_avg:169.43ms -step:5900/20000 train_loss:2.1624 train_time:999615ms step_avg:169.43ms -step:5950/20000 train_loss:2.1254 train_time:1008086ms step_avg:169.43ms -step:6000/20000 train_loss:2.1318 train_time:1016539ms step_avg:169.42ms -step:6000/20000 val_loss:2.1411 val_bpb:1.2681 train_time:1016539ms step_avg:169.42ms -step:6050/20000 train_loss:2.0401 train_time:1024998ms step_avg:169.42ms -step:6100/20000 train_loss:2.2274 train_time:1033454ms step_avg:169.42ms -step:6150/20000 train_loss:2.0974 train_time:1041946ms step_avg:169.42ms -step:6200/20000 train_loss:2.0789 train_time:1050405ms step_avg:169.42ms -step:6200/20000 val_loss:2.1392 val_bpb:1.2669 train_time:1050407ms step_avg:169.42ms -step:6250/20000 train_loss:2.1952 train_time:1058860ms step_avg:169.42ms -step:6300/20000 train_loss:2.0949 train_time:1067377ms step_avg:169.42ms -step:6350/20000 train_loss:2.1056 train_time:1075831ms step_avg:169.42ms -step:6400/20000 train_loss:2.1028 train_time:1084288ms step_avg:169.42ms -step:6400/20000 val_loss:2.1391 val_bpb:1.2669 train_time:1084289ms step_avg:169.42ms -step:6450/20000 train_loss:2.2170 train_time:1092746ms step_avg:169.42ms -step:6500/20000 train_loss:2.0875 train_time:1101225ms step_avg:169.42ms -step:6550/20000 train_loss:2.2588 train_time:1109684ms step_avg:169.42ms -step:6600/20000 train_loss:2.0452 train_time:1118136ms step_avg:169.41ms -step:6600/20000 val_loss:2.1361 val_bpb:1.2651 train_time:1118137ms step_avg:169.41ms -step:6650/20000 train_loss:2.1288 train_time:1126589ms step_avg:169.41ms -step:6700/20000 train_loss:2.1602 train_time:1135071ms step_avg:169.41ms -step:6750/20000 train_loss:2.1690 train_time:1143528ms step_avg:169.41ms -step:6800/20000 train_loss:2.0779 train_time:1151984ms step_avg:169.41ms -step:6800/20000 val_loss:2.1372 val_bpb:1.2657 train_time:1151986ms step_avg:169.41ms -step:6850/20000 train_loss:2.1198 train_time:1160444ms step_avg:169.41ms -step:6900/20000 train_loss:2.1455 train_time:1168938ms step_avg:169.41ms -step:6950/20000 train_loss:2.1165 train_time:1177392ms step_avg:169.41ms -step:7000/20000 train_loss:2.1155 train_time:1185846ms step_avg:169.41ms -step:7000/20000 val_loss:2.1331 val_bpb:1.2633 train_time:1185847ms step_avg:169.41ms -step:7050/20000 train_loss:2.0908 train_time:1194300ms step_avg:169.40ms -step:7100/20000 train_loss:2.2298 train_time:1202820ms step_avg:169.41ms -step:7150/20000 train_loss:2.1186 train_time:1211276ms step_avg:169.41ms -step:7200/20000 train_loss:2.0549 train_time:1219731ms step_avg:169.41ms -step:7200/20000 val_loss:2.1317 val_bpb:1.2625 train_time:1219732ms step_avg:169.41ms -step:7250/20000 train_loss:2.0966 train_time:1228219ms step_avg:169.41ms -step:7300/20000 train_loss:2.1691 train_time:1236675ms step_avg:169.41ms -step:7350/20000 train_loss:2.2159 train_time:1245134ms step_avg:169.41ms -step:7400/20000 train_loss:2.1267 train_time:1253592ms step_avg:169.40ms -step:7400/20000 val_loss:2.1293 val_bpb:1.2611 train_time:1253593ms step_avg:169.40ms -step:7450/20000 train_loss:2.1447 train_time:1262074ms step_avg:169.41ms -step:7500/20000 train_loss:2.1555 train_time:1270528ms step_avg:169.40ms -step:7550/20000 train_loss:2.1362 train_time:1278984ms step_avg:169.40ms -step:7600/20000 train_loss:2.1267 train_time:1287440ms step_avg:169.40ms -step:7600/20000 val_loss:2.1281 val_bpb:1.2604 train_time:1287441ms step_avg:169.40ms -step:7650/20000 train_loss:2.0994 train_time:1295939ms step_avg:169.40ms -step:7700/20000 train_loss:2.1604 train_time:1304395ms step_avg:169.40ms -step:7750/20000 train_loss:2.1470 train_time:1312849ms step_avg:169.40ms -step:7800/20000 train_loss:2.1313 train_time:1321307ms step_avg:169.40ms -step:7800/20000 val_loss:2.1280 val_bpb:1.2603 train_time:1321308ms step_avg:169.40ms -step:7850/20000 train_loss:2.1290 train_time:1329830ms step_avg:169.41ms -step:7900/20000 train_loss:2.2608 train_time:1338285ms step_avg:169.40ms -step:7950/20000 train_loss:2.0942 train_time:1346741ms step_avg:169.40ms -step:8000/20000 train_loss:2.0943 train_time:1355193ms step_avg:169.40ms -step:8000/20000 val_loss:2.1257 val_bpb:1.2589 train_time:1355194ms step_avg:169.40ms -step:8050/20000 train_loss:2.1504 train_time:1363675ms step_avg:169.40ms -step:8100/20000 train_loss:2.1743 train_time:1372130ms step_avg:169.40ms -step:8150/20000 train_loss:2.0741 train_time:1380584ms step_avg:169.40ms -step:8200/20000 train_loss:2.1584 train_time:1389042ms step_avg:169.40ms -step:8200/20000 val_loss:2.1262 val_bpb:1.2593 train_time:1389044ms step_avg:169.40ms -step:8250/20000 train_loss:2.0769 train_time:1397544ms step_avg:169.40ms -step:8300/20000 train_loss:2.1049 train_time:1405999ms step_avg:169.40ms -step:8350/20000 train_loss:2.1274 train_time:1414451ms step_avg:169.40ms -step:8400/20000 train_loss:2.3278 train_time:1422982ms step_avg:169.40ms -step:8400/20000 val_loss:2.1218 val_bpb:1.2567 train_time:1422983ms step_avg:169.40ms -step:8450/20000 train_loss:2.2107 train_time:1431439ms step_avg:169.40ms -step:8500/20000 train_loss:2.1143 train_time:1439894ms step_avg:169.40ms -step:8550/20000 train_loss:2.1998 train_time:1448350ms step_avg:169.40ms -step:8600/20000 train_loss:2.1278 train_time:1456834ms step_avg:169.40ms -step:8600/20000 val_loss:2.1196 val_bpb:1.2553 train_time:1456835ms step_avg:169.40ms -step:8650/20000 train_loss:2.0951 train_time:1465298ms step_avg:169.40ms -step:8700/20000 train_loss:2.0944 train_time:1473751ms step_avg:169.40ms -step:8750/20000 train_loss:2.1285 train_time:1482206ms step_avg:169.39ms -step:8800/20000 train_loss:2.0919 train_time:1490680ms step_avg:169.40ms -step:8800/20000 val_loss:2.1174 val_bpb:1.2540 train_time:1490681ms step_avg:169.40ms -step:8850/20000 train_loss:2.1088 train_time:1499139ms step_avg:169.39ms -step:8900/20000 train_loss:2.1639 train_time:1507596ms step_avg:169.39ms -step:8950/20000 train_loss:2.1578 train_time:1516053ms step_avg:169.39ms -step:9000/20000 train_loss:2.1760 train_time:1524549ms step_avg:169.39ms -step:9000/20000 val_loss:2.1168 val_bpb:1.2537 train_time:1524550ms step_avg:169.39ms -step:9050/20000 train_loss:2.3030 train_time:1533016ms step_avg:169.39ms -step:9100/20000 train_loss:2.1155 train_time:1541470ms step_avg:169.39ms -step:9150/20000 train_loss:2.1035 train_time:1549926ms step_avg:169.39ms -step:9200/20000 train_loss:2.1729 train_time:1558444ms step_avg:169.40ms -step:9200/20000 val_loss:2.1156 val_bpb:1.2530 train_time:1558445ms step_avg:169.40ms -step:9250/20000 train_loss:2.2299 train_time:1566902ms step_avg:169.39ms -step:9300/20000 train_loss:2.1980 train_time:1575358ms step_avg:169.39ms -step:9350/20000 train_loss:2.1314 train_time:1583840ms step_avg:169.39ms -step:9400/20000 train_loss:1.9838 train_time:1592299ms step_avg:169.39ms -step:9400/20000 val_loss:2.1162 val_bpb:1.2533 train_time:1592300ms step_avg:169.39ms -step:9450/20000 train_loss:2.0636 train_time:1600750ms step_avg:169.39ms -step:9500/20000 train_loss:2.1437 train_time:1609210ms step_avg:169.39ms -step:9550/20000 train_loss:2.1144 train_time:1617688ms step_avg:169.39ms -step:9600/20000 train_loss:2.1365 train_time:1626144ms step_avg:169.39ms -step:9600/20000 val_loss:2.1153 val_bpb:1.2528 train_time:1626145ms step_avg:169.39ms -step:9650/20000 train_loss:2.0492 train_time:1634603ms step_avg:169.39ms -step:9700/20000 train_loss:2.0978 train_time:1643056ms step_avg:169.39ms -step:9750/20000 train_loss:2.1239 train_time:1651550ms step_avg:169.39ms -step:9800/20000 train_loss:2.1238 train_time:1660003ms step_avg:169.39ms -step:9800/20000 val_loss:2.1118 val_bpb:1.2507 train_time:1660004ms step_avg:169.39ms -step:9850/20000 train_loss:2.1384 train_time:1668461ms step_avg:169.39ms -step:9900/20000 train_loss:2.0547 train_time:1676913ms step_avg:169.39ms -step:9950/20000 train_loss:2.1484 train_time:1685437ms step_avg:169.39ms -step:10000/20000 train_loss:2.1270 train_time:1693893ms step_avg:169.39ms -step:10000/20000 val_loss:2.1137 val_bpb:1.2518 train_time:1693895ms step_avg:169.39ms -step:10050/20000 train_loss:2.1228 train_time:1702350ms step_avg:169.39ms -step:10100/20000 train_loss:2.1150 train_time:1710804ms step_avg:169.39ms -step:10150/20000 train_loss:2.1118 train_time:1719285ms step_avg:169.39ms -step:10200/20000 train_loss:2.1696 train_time:1727746ms step_avg:169.39ms -step:10200/20000 val_loss:2.1115 val_bpb:1.2506 train_time:1727747ms step_avg:169.39ms -step:10250/20000 train_loss:2.1311 train_time:1736204ms step_avg:169.39ms -step:10300/20000 train_loss:2.0295 train_time:1744681ms step_avg:169.39ms -step:10350/20000 train_loss:2.0417 train_time:1753133ms step_avg:169.38ms -step:10400/20000 train_loss:2.1018 train_time:1761588ms step_avg:169.38ms -step:10400/20000 val_loss:2.1105 val_bpb:1.2500 train_time:1761590ms step_avg:169.38ms -step:10450/20000 train_loss:2.0881 train_time:1770042ms step_avg:169.38ms -step:10500/20000 train_loss:2.0357 train_time:1778533ms step_avg:169.38ms -step:10550/20000 train_loss:2.1089 train_time:1786991ms step_avg:169.38ms -step:10600/20000 train_loss:2.0764 train_time:1795446ms step_avg:169.38ms -step:10600/20000 val_loss:2.1091 val_bpb:1.2492 train_time:1795447ms step_avg:169.38ms -step:10650/20000 train_loss:2.0870 train_time:1803907ms step_avg:169.38ms -step:10700/20000 train_loss:2.1576 train_time:1812386ms step_avg:169.38ms -step:10750/20000 train_loss:2.0815 train_time:1820845ms step_avg:169.38ms -step:10800/20000 train_loss:2.1847 train_time:1829301ms step_avg:169.38ms -step:10800/20000 val_loss:2.1083 val_bpb:1.2487 train_time:1829302ms step_avg:169.38ms -step:10850/20000 train_loss:2.0934 train_time:1837757ms step_avg:169.38ms -step:10900/20000 train_loss:2.1508 train_time:1846229ms step_avg:169.38ms -step:10950/20000 train_loss:2.0658 train_time:1854686ms step_avg:169.38ms -step:11000/20000 train_loss:2.0707 train_time:1863145ms step_avg:169.38ms -step:11000/20000 val_loss:2.1088 val_bpb:1.2489 train_time:1863146ms step_avg:169.38ms -step:11050/20000 train_loss:2.0967 train_time:1871601ms step_avg:169.38ms -step:11100/20000 train_loss:2.1133 train_time:1880099ms step_avg:169.38ms -step:11150/20000 train_loss:2.1477 train_time:1888555ms step_avg:169.38ms -step:11200/20000 train_loss:2.0874 train_time:1897016ms step_avg:169.38ms -step:11200/20000 val_loss:2.1066 val_bpb:1.2476 train_time:1897018ms step_avg:169.38ms -step:11250/20000 train_loss:2.1173 train_time:1905471ms step_avg:169.38ms -step:11300/20000 train_loss:2.1713 train_time:1913994ms step_avg:169.38ms -step:11350/20000 train_loss:2.0768 train_time:1922449ms step_avg:169.38ms -step:11400/20000 train_loss:2.1331 train_time:1930902ms step_avg:169.38ms -step:11400/20000 val_loss:2.1046 val_bpb:1.2465 train_time:1930903ms step_avg:169.38ms -step:11450/20000 train_loss:2.1208 train_time:1939381ms step_avg:169.38ms -step:11500/20000 train_loss:2.0917 train_time:1947838ms step_avg:169.38ms -step:11550/20000 train_loss:2.1384 train_time:1956292ms step_avg:169.38ms -step:11600/20000 train_loss:2.0931 train_time:1964750ms step_avg:169.37ms -step:11600/20000 val_loss:2.1046 val_bpb:1.2465 train_time:1964751ms step_avg:169.38ms -step:11650/20000 train_loss:2.1396 train_time:1973232ms step_avg:169.38ms -step:11700/20000 train_loss:2.1252 train_time:1981690ms step_avg:169.38ms -step:11750/20000 train_loss:2.1119 train_time:1990145ms step_avg:169.37ms -step:11800/20000 train_loss:1.9639 train_time:1998603ms step_avg:169.37ms -step:11800/20000 val_loss:2.1059 val_bpb:1.2472 train_time:1998604ms step_avg:169.37ms -step:11850/20000 train_loss:2.0749 train_time:2007105ms step_avg:169.38ms -step:11900/20000 train_loss:2.1753 train_time:2015561ms step_avg:169.37ms -step:11950/20000 train_loss:2.0763 train_time:2024018ms step_avg:169.37ms -step:12000/20000 train_loss:2.1209 train_time:2032474ms step_avg:169.37ms -step:12000/20000 val_loss:2.1037 val_bpb:1.2460 train_time:2032476ms step_avg:169.37ms -step:12050/20000 train_loss:2.1288 train_time:2040997ms step_avg:169.38ms -step:12100/20000 train_loss:2.0873 train_time:2049454ms step_avg:169.38ms -step:12150/20000 train_loss:2.0710 train_time:2057909ms step_avg:169.38ms -step:12200/20000 train_loss:2.1239 train_time:2066368ms step_avg:169.37ms -step:12200/20000 val_loss:2.1023 val_bpb:1.2451 train_time:2066369ms step_avg:169.37ms -step:12250/20000 train_loss:2.1199 train_time:2074844ms step_avg:169.38ms -step:12300/20000 train_loss:2.2217 train_time:2083300ms step_avg:169.37ms -step:12350/20000 train_loss:2.1110 train_time:2091752ms step_avg:169.37ms -step:12400/20000 train_loss:2.0584 train_time:2100227ms step_avg:169.37ms -step:12400/20000 val_loss:2.1013 val_bpb:1.2445 train_time:2100228ms step_avg:169.37ms -step:12450/20000 train_loss:2.1339 train_time:2108687ms step_avg:169.37ms -step:12500/20000 train_loss:2.1787 train_time:2117144ms step_avg:169.37ms -step:12550/20000 train_loss:2.1007 train_time:2125599ms step_avg:169.37ms -step:12600/20000 train_loss:2.1660 train_time:2134096ms step_avg:169.37ms -step:12600/20000 val_loss:2.1004 val_bpb:1.2440 train_time:2134097ms step_avg:169.37ms -step:12650/20000 train_loss:2.1783 train_time:2142555ms step_avg:169.37ms -step:12700/20000 train_loss:2.1038 train_time:2151013ms step_avg:169.37ms -step:12750/20000 train_loss:2.0914 train_time:2159472ms step_avg:169.37ms -step:12800/20000 train_loss:2.0869 train_time:2167995ms step_avg:169.37ms -step:12800/20000 val_loss:2.1015 val_bpb:1.2446 train_time:2167996ms step_avg:169.37ms -step:12850/20000 train_loss:2.0884 train_time:2176451ms step_avg:169.37ms -step:12900/20000 train_loss:2.2230 train_time:2184910ms step_avg:169.37ms -step:12950/20000 train_loss:2.0790 train_time:2193362ms step_avg:169.37ms -step:13000/20000 train_loss:2.1156 train_time:2201839ms step_avg:169.37ms -step:13000/20000 val_loss:2.0991 val_bpb:1.2432 train_time:2201840ms step_avg:169.37ms -step:13050/20000 train_loss:2.0904 train_time:2210302ms step_avg:169.37ms -step:13100/20000 train_loss:2.0774 train_time:2218756ms step_avg:169.37ms -step:13150/20000 train_loss:2.1084 train_time:2227216ms step_avg:169.37ms -step:13200/20000 train_loss:2.1958 train_time:2235693ms step_avg:169.37ms -step:13200/20000 val_loss:2.0925 val_bpb:1.2393 train_time:2235694ms step_avg:169.37ms -step:13250/20000 train_loss:2.0877 train_time:2244150ms step_avg:169.37ms -step:13300/20000 train_loss:2.1376 train_time:2252603ms step_avg:169.37ms -step:13350/20000 train_loss:2.0734 train_time:2261062ms step_avg:169.37ms -step:13400/20000 train_loss:2.1114 train_time:2269588ms step_avg:169.37ms -step:13400/20000 val_loss:2.0854 val_bpb:1.2351 train_time:2269590ms step_avg:169.37ms -step:13450/20000 train_loss:2.1410 train_time:2278042ms step_avg:169.37ms -step:13500/20000 train_loss:2.0551 train_time:2286500ms step_avg:169.37ms -step:13550/20000 train_loss:2.0534 train_time:2294976ms step_avg:169.37ms -step:13600/20000 train_loss:2.0784 train_time:2303429ms step_avg:169.37ms -step:13600/20000 val_loss:2.0778 val_bpb:1.2306 train_time:2303431ms step_avg:169.37ms -step:13650/20000 train_loss:2.0752 train_time:2311884ms step_avg:169.37ms -step:13700/20000 train_loss:2.1112 train_time:2320335ms step_avg:169.37ms -step:13750/20000 train_loss:2.0833 train_time:2328810ms step_avg:169.37ms -step:13800/20000 train_loss:2.0605 train_time:2337262ms step_avg:169.37ms -step:13800/20000 val_loss:2.0705 val_bpb:1.2263 train_time:2337263ms step_avg:169.37ms -step:13850/20000 train_loss:2.0858 train_time:2345717ms step_avg:169.37ms -step:13900/20000 train_loss:2.0457 train_time:2354171ms step_avg:169.36ms -step:13950/20000 train_loss:2.1323 train_time:2362671ms step_avg:169.37ms -step:14000/20000 train_loss:2.0190 train_time:2371126ms step_avg:169.37ms -step:14000/20000 val_loss:2.0635 val_bpb:1.2221 train_time:2371127ms step_avg:169.37ms -step:14050/20000 train_loss:2.0494 train_time:2379579ms step_avg:169.37ms -step:14100/20000 train_loss:2.0507 train_time:2388034ms step_avg:169.36ms -step:14150/20000 train_loss:2.0924 train_time:2396555ms step_avg:169.37ms -step:14171/20000 val_loss:2.0593 val_bpb:1.2197 train_time:2400107ms step_avg:169.37ms -stopping_early: wallclock_cap train_time:2400107ms step:14171/20000 -peak memory allocated: 10334 MiB reserved: 10348 MiB -Serialized model: 67224983 bytes -Code size: 47686 bytes -Total submission size: 67272669 bytes -Serialized model int8+zlib: 15815345 bytes (payload:17178912 raw_torch:17224025 payload_ratio:3.91x) -Total submission size int8+zlib: 15863031 bytes -final_int8_zlib_roundtrip val_loss:2.0723 val_bpb:1.2274 eval_time:1375ms -final_int8_zlib_roundtrip_exact val_loss:2.07234577 val_bpb:1.22735990 -final_ttt_lora val_loss:2.0139 val_bpb:1.1927 eval_time:60031ms diff --git a/records/track_10min_16mb/2026-03-17_LoRA_TTT/train_v2.txt b/records/track_10min_16mb/2026-03-17_LoRA_TTT/train_v2.txt deleted file mode 100644 index 56ede1655e..0000000000 --- a/records/track_10min_16mb/2026-03-17_LoRA_TTT/train_v2.txt +++ /dev/null @@ -1,1829 +0,0 @@ -""" -The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. - -Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. -""" - -from __future__ import annotations - -import copy -import glob -import io -import math -import os -import random -import subprocess -import sys -import time -import uuid -import zlib -from pathlib import Path - -import numpy as np -import sentencepiece as spm -import torch -import torch.distributed as dist -import torch.nn.functional as F -from torch import Tensor, nn -from torch.nn.parallel import DistributedDataParallel as DDP - -# ----------------------------- -# HYPERPARAMETERS -# ----------------------------- -# Default Simple Baseline run: -# - 9 transformer blocks at width 512 -# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion -# - vocab size 1024, sequence length 1024, tied embeddings -# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap - -class Hyperparameters: - # Data paths are shard globs produced by the existing preprocessing pipeline. - data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") - train_files = os.path.join(data_path, "fineweb_train_*.bin") - val_files = os.path.join(data_path, "fineweb_val_*.bin") - tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") - run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) - seed = int(os.environ.get("SEED", 1337)) - resume_from = os.environ.get("RESUME_FROM", "") - - # Validation cadence and batch size. Validation always uses the full fineweb_val split. - val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) - val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) - train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) - - # Training length. - iterations = int(os.environ.get("ITERATIONS", 20000)) - warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) - warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) - train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) - train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) - max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) - qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) - - # Model shape. - vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) - num_layers = int(os.environ.get("NUM_LAYERS", 9)) - num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) - model_dim = int(os.environ.get("MODEL_DIM", 512)) - num_heads = int(os.environ.get("NUM_HEADS", 8)) - mlp_mult = int(os.environ.get("MLP_MULT", 2)) - tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) - rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) - logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) - - # Optimizer hyperparameters. - embed_lr = float(os.environ.get("EMBED_LR", 0.6)) - head_lr = float(os.environ.get("HEAD_LR", 0.008)) - tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) - tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) - matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) - scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) - muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) - muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) - muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) - muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) - beta1 = float(os.environ.get("BETA1", 0.9)) - beta2 = float(os.environ.get("BETA2", 0.95)) - adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) - grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) - - # Test-time training (LoRA) hyperparameters. - ttt_lora_rank = int(os.environ.get("TTT_LORA_RANK", 8)) - ttt_lora_lr = float(os.environ.get("TTT_LORA_LR", 0.01)) - ttt_chunk_size = int(os.environ.get("TTT_CHUNK_SIZE", 256)) - ttt_eval_seq_len = int(os.environ.get("TTT_EVAL_SEQ_LEN", 1024)) - ttt_batch_size = int(os.environ.get("TTT_BATCH_SIZE", 64)) - -# ----------------------------- -# MUON OPTIMIZER -# ----------------------------- -# -# As borrowed from modded-nanogpt -# Background on Muon: https://kellerjordan.github.io/posts/muon/ - -def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: - # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. - # Muon uses this to normalize matrix-shaped gradients before applying them. - a, b, c = (3.4445, -4.7750, 2.0315) - X = G.bfloat16() - X /= X.norm() + eps - transposed = G.size(0) > G.size(1) - if transposed: - X = X.T - for _ in range(steps): - A = X @ X.T - B = b * A + c * A @ A - X = a * X + B @ X - return X.T if transposed else X - - -class Muon(torch.optim.Optimizer): - def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): - super().__init__( - params, - dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), - ) - - @torch.no_grad() - def step(self, closure=None): - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - distributed = dist.is_available() and dist.is_initialized() - world_size = dist.get_world_size() if distributed else 1 - rank = dist.get_rank() if distributed else 0 - - for group in self.param_groups: - params = group["params"] - if not params: - continue - lr = group["lr"] - momentum = group["momentum"] - backend_steps = group["backend_steps"] - nesterov = group["nesterov"] - - total_params = sum(int(p.numel()) for p in params) - updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) - - curr = 0 - for i, p in enumerate(params): - if i % world_size == rank and p.grad is not None: - g = p.grad - state = self.state[p] - if "momentum_buffer" not in state: - state["momentum_buffer"] = torch.zeros_like(g) - buf = state["momentum_buffer"] - buf.mul_(momentum).add_(g) - if nesterov: - g = g.add(buf, alpha=momentum) - g = zeropower_via_newtonschulz5(g, steps=backend_steps) - # Scale correction from Muon reference implementations. - g *= max(1, g.size(0) / g.size(1)) ** 0.5 - updates_flat[curr : curr + p.numel()] = g.reshape(-1) - curr += p.numel() - - if distributed: - dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) - - curr = 0 - for p in params: - g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) - p.add_(g, alpha=-lr) - curr += p.numel() - - return loss - - -# ----------------------------- -# TOKENIZER-AGNOSTIC EVALUATION SETUP -# ----------------------------- -# -# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. -# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. -# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. -# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. - -def build_sentencepiece_luts( - sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device -) -> tuple[Tensor, Tensor, Tensor]: - sp_vocab_size = int(sp.vocab_size()) - table_size = max(sp_vocab_size, vocab_size) - base_bytes_np = np.zeros((table_size,), dtype=np.int16) - has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) - is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) - for token_id in range(sp_vocab_size): - if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): - continue - is_boundary_token_np[token_id] = False - if sp.is_byte(token_id): - base_bytes_np[token_id] = 1 - continue - piece = sp.id_to_piece(token_id) - if piece.startswith("▁"): - has_leading_space_np[token_id] = True - piece = piece[1:] - base_bytes_np[token_id] = len(piece.encode("utf-8")) - return ( - torch.tensor(base_bytes_np, dtype=torch.int16, device=device), - torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), - torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), - ) - - -def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: - files = [Path(p) for p in sorted(glob.glob(pattern))] - if not files: - raise FileNotFoundError(f"No files found for pattern: {pattern}") - # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. - tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() - usable = ((tokens.numel() - 1) // seq_len) * seq_len - if usable <= 0: - raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") - return tokens[: usable + 1] - - -def eval_val( - args: Hyperparameters, - model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - grad_accum_steps: int, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, -) -> tuple[float, float]: - # Validation computes two metrics: - # - val_loss: token cross-entropy (natural log) - # - val_bpb: tokenizer-agnostic compression metric used by the challenge - local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) - if local_batch_tokens < args.train_seq_len: - raise ValueError( - "VAL_BATCH_SIZE must provide at least one sequence per rank; " - f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " - f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" - ) - local_batch_seqs = local_batch_tokens // args.train_seq_len - total_seqs = (val_tokens.numel() - 1) // args.train_seq_len - seq_start = (total_seqs * rank) // world_size - seq_end = (total_seqs * (rank + 1)) // world_size - val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) - val_token_count = torch.zeros((), device=device, dtype=torch.float64) - val_byte_count = torch.zeros((), device=device, dtype=torch.float64) - - model.eval() - with torch.inference_mode(): - for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): - batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) - raw_start = batch_seq_start * args.train_seq_len - raw_end = batch_seq_end * args.train_seq_len + 1 - local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) - x = local[:-1].reshape(-1, args.train_seq_len) - y = local[1:].reshape(-1, args.train_seq_len) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - batch_loss = model(x, y).detach() - batch_token_count = float(y.numel()) - val_loss_sum += batch_loss.to(torch.float64) * batch_token_count - val_token_count += batch_token_count - prev_ids = x.reshape(-1) - tgt_ids = y.reshape(-1) - token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) - token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) - val_byte_count += token_bytes.to(torch.float64).sum() - - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) - - val_loss = val_loss_sum / val_token_count - bits_per_token = val_loss.item() / math.log(2.0) - tokens_per_byte = val_token_count.item() / val_byte_count.item() - model.train() - return float(val_loss.item()), float(bits_per_token * tokens_per_byte) - -# ----------------------------- -# POST-TRAINING QUANTIZATION -# ----------------------------- -# -# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. -# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. -# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. - -CONTROL_TENSOR_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "CONTROL_TENSOR_NAME_PATTERNS", - "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", - ",".join(CONTROL_TENSOR_NAME_PATTERNS), - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 -INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 -INT8_PER_ROW_SCALE_DTYPE = torch.float16 -INT8_CLIP_PERCENTILE = 99.99984 -INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 - -def tensor_nbytes(t: Tensor) -> int: - return int(t.numel()) * int(t.element_size()) - -def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: - if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): - return t.float().contiguous() - if t.dtype in {torch.float32, torch.bfloat16}: - passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") - return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() - return t - -def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: - t32 = t.float() - if t32.ndim == 2: - # Matrices get one scale per row, which usually tracks output-channel - # ranges much better than a single tensor-wide scale. - clip_abs = ( - torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) - if t32.numel() - else torch.empty((t32.shape[0],), dtype=torch.float32) - ) - clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) - scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) - q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() - return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() - - # Vectors / scalars use a simpler per-tensor scale. - clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 - scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) - q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() - return q, scale - -def quantize_state_dict_int8(state_dict: dict[str, Tensor]): - # Single supported clean-script export format: - # - per-row int8 for 2D float tensors - # - per-tensor int8 for other float tensors - # - exact passthrough for non-floats - # - passthrough for small float tensors, stored as fp16 to save bytes - quantized: dict[str, Tensor] = {} - scales: dict[str, Tensor] = {} - dtypes: dict[str, str] = {} - passthrough: dict[str, Tensor] = {} - passthrough_orig_dtypes: dict[str, str] = {} - qmeta: dict[str, dict[str, object]] = {} - stats = dict.fromkeys( - ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), - 0, - ) - - for name, tensor in state_dict.items(): - t = tensor.detach().to("cpu").contiguous() - stats["param_count"] += int(t.numel()) - stats["num_tensors"] += 1 - stats["baseline_tensor_bytes"] += tensor_nbytes(t) - - if not t.is_floating_point(): - stats["num_nonfloat_tensors"] += 1 - passthrough[name] = t - stats["int8_payload_bytes"] += tensor_nbytes(t) - continue - - # Small float tensors are cheap enough to keep directly. We still downcast - # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. - if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: - kept = keep_float_tensor(name, t, passthrough_orig_dtypes) - passthrough[name] = kept - stats["int8_payload_bytes"] += tensor_nbytes(kept) - continue - - stats["num_float_tensors"] += 1 - q, s = quantize_float_tensor(t) - if s.ndim > 0: - qmeta[name] = {"scheme": "per_row", "axis": 0} - quantized[name] = q - scales[name] = s - dtypes[name] = str(t.dtype).removeprefix("torch.") - stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) - - obj: dict[str, object] = { - "__quant_format__": "int8_clean_per_row_v1", - "quantized": quantized, - "scales": scales, - "dtypes": dtypes, - "passthrough": passthrough, - } - if qmeta: - obj["qmeta"] = qmeta - if passthrough_orig_dtypes: - obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes - return obj, stats - -def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: - out: dict[str, Tensor] = {} - qmeta = obj.get("qmeta", {}) - passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) - for name, q in obj["quantized"].items(): - dtype = getattr(torch, obj["dtypes"][name]) - s = obj["scales"][name] - if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: - s = s.to(dtype=torch.float32) - # Broadcast the saved row scale back across trailing dimensions. - out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() - else: - scale = float(s.item()) - out[name] = (q.float() * scale).to(dtype=dtype).contiguous() - for name, t in obj["passthrough"].items(): - # Restore small tensors, undoing the temporary fp16 storage cast if needed. - out_t = t.detach().to("cpu").contiguous() - orig_dtype = passthrough_orig_dtypes.get(name) - if isinstance(orig_dtype, str): - out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() - out[name] = out_t - return out - - -# ----------------------------- -# DATA LOADING -# ----------------------------- - -def load_data_shard(file: Path) -> Tensor: - header_bytes = 256 * np.dtype(" None: - self.file_idx = (self.file_idx + 1) % len(self.files) - self.tokens = load_data_shard(self.files[self.file_idx]) - self.pos = 0 - - def take(self, n: int) -> Tensor: - chunks: list[Tensor] = [] - remaining = n - while remaining > 0: - avail = self.tokens.numel() - self.pos - if avail <= 0: - self._advance_file() - continue - k = min(remaining, avail) - chunks.append(self.tokens[self.pos : self.pos + k]) - self.pos += k - remaining -= k - return chunks[0] if len(chunks) == 1 else torch.cat(chunks) - - -class DistributedTokenLoader: - # Each call consumes a contiguous chunk from the shared token stream, then slices out - # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. - def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): - self.rank = rank - self.world_size = world_size - self.device = device - self.stream = TokenStream(pattern) - - def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: - local_tokens = global_tokens // (self.world_size * grad_accum_steps) - per_rank_span = local_tokens + 1 - chunk = self.stream.take(per_rank_span * self.world_size) - start = self.rank * per_rank_span - local = chunk[start : start + per_rank_span].to(dtype=torch.int64) - x = local[:-1].reshape(-1, seq_len) - y = local[1:].reshape(-1, seq_len) - return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) - -# ----------------------------- -# TRANSFORMER MODULES -# ----------------------------- - -class RMSNorm(nn.Module): - def __init__(self, eps: float | None = None): - super().__init__() - self.eps = eps - - def forward(self, x: Tensor) -> Tensor: - return F.rms_norm(x, (x.size(-1),), eps=self.eps) - - -class CastedLinear(nn.Linear): - # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. - def forward(self, x: Tensor) -> Tensor: - bias = self.bias.to(x.dtype) if self.bias is not None else None - return F.linear(x, self.weight.to(x.dtype), bias) - - -def restore_low_dim_params_to_fp32(module: nn.Module) -> None: - # Keep small/control parameters in fp32 even when the model body runs in bf16. - with torch.no_grad(): - for name, param in module.named_parameters(): - if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: - param.data = param.data.float() - - -class Rotary(nn.Module): - # Caches cos/sin tables per sequence length on the current device. - def __init__(self, dim: int, base: float = 10000.0): - super().__init__() - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self._seq_len_cached = 0 - self._cos_cached: Tensor | None = None - self._sin_cached: Tensor | None = None - - def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: - if ( - self._cos_cached is None - or self._sin_cached is None - or self._seq_len_cached != seq_len - or self._cos_cached.device != device - ): - t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) - freqs = torch.outer(t, self.inv_freq.to(device)) - self._cos_cached = freqs.cos()[None, None, :, :] - self._sin_cached = freqs.sin()[None, None, :, :] - self._seq_len_cached = seq_len - return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) - - -def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: - half = x.size(-1) // 2 - x1, x2 = x[..., :half], x[..., half:] - return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) - - -class CausalSelfAttention(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - rope_base: float, - qk_gain_init: float, - ): - super().__init__() - if dim % num_heads != 0: - raise ValueError("model_dim must be divisible by num_heads") - if num_heads % num_kv_heads != 0: - raise ValueError("num_heads must be divisible by num_kv_heads") - self.num_heads = num_heads - self.num_kv_heads = num_kv_heads - self.head_dim = dim // num_heads - if self.head_dim % 2 != 0: - raise ValueError("head_dim must be even for RoPE") - kv_dim = self.num_kv_heads * self.head_dim - self.c_q = CastedLinear(dim, dim, bias=False) - self.c_k = CastedLinear(dim, kv_dim, bias=False) - self.c_v = CastedLinear(dim, kv_dim, bias=False) - self.proj = CastedLinear(dim, dim, bias=False) - self.proj._zero_init = True - self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) - self.rotary = Rotary(self.head_dim, base=rope_base) - - def forward(self, x: Tensor, q_delta=None, v_delta=None) -> Tensor: - bsz, seqlen, dim = x.shape - q = self.c_q(x) + (q_delta if q_delta is not None else 0) - k = self.c_k(x) - v = self.c_v(x) + (v_delta if v_delta is not None else 0) - q = q.reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) - k = k.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) - v = v.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) - q = F.rms_norm(q, (q.size(-1),)) - k = F.rms_norm(k, (k.size(-1),)) - cos, sin = self.rotary(seqlen, x.device, q.dtype) - q = apply_rotary_emb(q, cos, sin) - k = apply_rotary_emb(k, cos, sin) - q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] - y = F.scaled_dot_product_attention( - q, - k, - v, - attn_mask=None, - is_causal=True, - enable_gqa=(self.num_kv_heads != self.num_heads), - ) - y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) - return self.proj(y) - - -class MLP(nn.Module): - # relu^2 MLP from the original modded-nanogpt setup - def __init__(self, dim: int, mlp_mult: int): - super().__init__() - hidden = mlp_mult * dim - self.fc = CastedLinear(dim, hidden, bias=False) - self.proj = CastedLinear(hidden, dim, bias=False) - self.proj._zero_init = True - - def forward(self, x: Tensor) -> Tensor: - x = torch.relu(self.fc(x)) - return self.proj(x.square()) - - -class Block(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: int, - rope_base: float, - qk_gain_init: float, - ): - super().__init__() - self.attn_norm = RMSNorm() - self.mlp_norm = RMSNorm() - self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) - self.mlp = MLP(dim, mlp_mult) - self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) - - def forward(self, x: Tensor, x0: Tensor, q_delta_fn=None, v_delta_fn=None) -> Tensor: - mix = self.resid_mix.to(dtype=x.dtype) - x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 - n = self.attn_norm(x) - qd = q_delta_fn(n) if q_delta_fn is not None else None - vd = v_delta_fn(n) if v_delta_fn is not None else None - attn_out = self.attn(n, qd, vd) - x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out - x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) - return x - - -class GPT(nn.Module): - def __init__( - self, - vocab_size: int, - num_layers: int, - model_dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: int, - tie_embeddings: bool, - tied_embed_init_std: float, - logit_softcap: float, - rope_base: float, - qk_gain_init: float, - ): - super().__init__() - if logit_softcap <= 0.0: - raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") - self.tie_embeddings = tie_embeddings - self.tied_embed_init_std = tied_embed_init_std - self.logit_softcap = logit_softcap - self.tok_emb = nn.Embedding(vocab_size, model_dim) - self.num_encoder_layers = num_layers // 2 - self.num_decoder_layers = num_layers - self.num_encoder_layers - self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) - self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) - self.blocks = nn.ModuleList( - [ - Block( - model_dim, - num_heads, - num_kv_heads, - mlp_mult, - rope_base, - qk_gain_init, - ) - for i in range(num_layers) - ] - ) - self.final_norm = RMSNorm() - self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) - if self.lm_head is not None: - self.lm_head._zero_init = True - self._init_weights() - - def _init_weights(self) -> None: - if self.tie_embeddings: - nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) - for module in self.modules(): - if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): - nn.init.zeros_(module.weight) - - def forward(self, input_ids: Tensor, target_ids: Tensor, lora=None) -> Tensor: - x = self.tok_emb(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - x0 = x - skips: list[Tensor] = [] - - # First half stores skips; second half reuses them in reverse order. - for i in range(self.num_encoder_layers): - qd = lora.q_loras[i] if lora else None - vd = lora.v_loras[i] if lora else None - x = self.blocks[i](x, x0, qd, vd) - skips.append(x) - for i in range(self.num_decoder_layers): - bi = self.num_encoder_layers + i - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - qd = lora.q_loras[bi] if lora else None - vd = lora.v_loras[bi] if lora else None - x = self.blocks[bi](x, x0, qd, vd) - x = self.final_norm(x) - if self.tie_embeddings: - logits = F.linear(x, self.tok_emb.weight) - else: - logits = self.lm_head(x) - logits = logits + (lora.lm_head_lora(x) if lora else 0) - logits = self.logit_softcap * torch.tanh(logits / self.logit_softcap) - if lora: - bsz, sl, V = logits.shape - return F.cross_entropy( - logits.float().reshape(-1, V), target_ids.reshape(-1), reduction="none").reshape(bsz, sl) - return F.cross_entropy(logits.float().reshape(-1, logits.size(-1)), target_ids.reshape(-1), reduction="mean") - - -# ----------------------------- -# TEST-TIME TRAINING (LoRA) -# ----------------------------- -# -# At evaluation time, we adapt per-document low-rank adapters on the validation data. -# Each document gets its own adapter, so there is no inter-document dependency. - -BOS_ID = 1 - -class BatchedLinearLoRA(nn.Module): - """LoRA for a linear layer, with independent weights per batch element. - Computes x @ Aᵀ @ Bᵀ = x @ (BA)ᵀ, i.e. the LoRA delta is ΔW = BA.""" - def __init__(self, bsz: int, in_features: int, out_features: int, rank: int): - super().__init__() - self.in_features = in_features - self.A = nn.Parameter(torch.empty(bsz, rank, in_features)) # down-projection - self.B = nn.Parameter(torch.zeros(bsz, out_features, rank)) # up-projection - self.reset() - - def forward(self, x: Tensor) -> Tensor: - return (x @ self.A.transpose(1, 2)) @ self.B.transpose(1, 2) # (bsz, T, out) - - def reset(self) -> None: - bound = 1.0 / math.sqrt(self.in_features) - with torch.no_grad(): - self.A.uniform_(-bound, bound) # kaiming-uniform - self.B.zero_() - -class BatchedTTTLoRA(nn.Module): - """All LoRA adapters for one batch: LM head and Q/V per block.""" - def __init__(self, bsz: int, model: GPT, rank: int): - super().__init__() - dim = model.tok_emb.embedding_dim - vocab = model.tok_emb.num_embeddings - self.lm_head_lora = BatchedLinearLoRA(bsz, dim, vocab, rank) - self.q_loras = nn.ModuleList() - self.v_loras = nn.ModuleList() - for block in model.blocks: - self.q_loras.append(BatchedLinearLoRA(bsz, dim, block.attn.c_q.weight.shape[0], rank)) - self.v_loras.append(BatchedLinearLoRA(bsz, dim, block.attn.c_v.weight.shape[0], rank)) - - def reset(self) -> None: - for m in self.modules(): - if isinstance(m, BatchedLinearLoRA): - m.reset() - -def _reset_ttt_optimizer(opt): - for group in opt.param_groups: - for p in group['params']: - s = opt.state.get(p) - if not s: # Fresh state. - continue - s['exp_avg'].zero_() - s['exp_avg_sq'].zero_() - s['step'].fill_(0) - -def _build_ttt_optimizer(lora, args: Hyperparameters): - return torch.optim.Adam(lora.parameters(), lr=args.ttt_lora_lr, betas=(args.beta1, args.beta2), eps=1e-10) - -def _find_docs(all_tokens: Tensor, include_next_bos: bool = True) -> list[tuple[int, int]]: - """Return (start_offset, length) for each document, identified by BOS boundaries. - - If include_next_bos is True, include next document's BOS (to match continuous-stream - eval token count exactly). - """ - bos_positions = (all_tokens == BOS_ID).nonzero(as_tuple=True)[0].numpy() - docs = [] - for i in range(len(bos_positions)): - start = int(bos_positions[i]) - end = int(bos_positions[i + 1]) if i + 1 < len(bos_positions) else all_tokens.numel() - if include_next_bos and i + 1 < len(bos_positions): - end += 1 - assert end - start >= 2 - docs.append((start, end - start)) - return docs - -def _compute_chunk_window(ci: int, pred_len: int, num_chunks: int, chunk_size: int, eval_seq_len: int): - """Return (win_start, win_len, chunk_offset, chunk_len) for chunk `ci` of a doc.""" - chunk_start = ci * chunk_size - chunk_end = pred_len if ci == num_chunks - 1 else (ci + 1) * chunk_size - win_start = max(0, chunk_end - eval_seq_len) - win_len = chunk_end - win_start - chunk_offset = chunk_start - win_start - chunk_len = chunk_end - chunk_start - return win_start, win_len, chunk_offset, chunk_len - -def _accumulate_bpb( - ptl: Tensor, x: Tensor, y: Tensor, - batch_i: int, chunk_offset: int, chunk_len: int, - base_bytes_lut: Tensor, has_leading_space_lut: Tensor, is_boundary_token_lut: Tensor, - loss_sum: Tensor, byte_sum: Tensor, token_count: Tensor, -): - """Add one doc-chunk's contribution to the running BPB accumulators.""" - lbl = ptl[batch_i, chunk_offset:chunk_offset + chunk_len].to(torch.float64) - prev = x[batch_i, chunk_offset:chunk_offset + chunk_len] - tgt = y[batch_i, chunk_offset:chunk_offset + chunk_len] - tok_bytes = base_bytes_lut[tgt].to(torch.float64) - tok_bytes += has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev] - loss_sum += lbl.sum() - byte_sum += tok_bytes.sum() - token_count += chunk_len - -def eval_val_ttt_lora( - args: Hyperparameters, - base_model: GPT, - rank: int, - world_size: int, - device: torch.device, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, -) -> tuple[float, float]: - """Evaluate with batched LoRA test-time training. Returns (val_loss, val_bpb).""" - # Load validation tokens and find document boundaries - files = sorted(glob.glob(args.val_files)) - all_tokens = torch.cat([load_data_shard(Path(f)) for f in files]) - docs = _find_docs(all_tokens) - - # Each rank takes a contiguous slice of documents - rank_docs = docs[(len(docs) * rank) // world_size : (len(docs) * (rank + 1)) // world_size] - chunk_size = args.ttt_chunk_size - eval_seq_len = args.ttt_eval_seq_len - batch_size = args.ttt_batch_size - lora_rank = args.ttt_lora_rank - - rank_docs.sort(key=lambda d: (d[1] - 2) // chunk_size) - - base_model.eval() - for p in base_model.parameters(): - p.requires_grad_(False) - - lora = BatchedTTTLoRA(batch_size, base_model, lora_rank).to(device) - opt = _build_ttt_optimizer(lora, args) - - loss_sum = torch.zeros((), device=device, dtype=torch.float64) - byte_sum = torch.zeros((), device=device, dtype=torch.float64) - token_count = torch.zeros((), device=device, dtype=torch.float64) - - for bi in range(0, len(rank_docs), batch_size): - batch = rank_docs[bi:bi + batch_size] - bsz = len(batch) - - if bsz == batch_size: - cur_lora, cur_opt = lora, opt - cur_lora.reset() - _reset_ttt_optimizer(cur_opt) - else: - cur_lora = BatchedTTTLoRA(bsz, base_model, lora_rank).to(device) - cur_opt = _build_ttt_optimizer(cur_lora, args) - - pred_lens = [doc_len - 1 for _, doc_len in batch] - num_chunks = [(pl + chunk_size - 1) // chunk_size for pl in pred_lens] - max_nc = max(num_chunks) - - for ci in range(max_nc): - chunk_stats = _compute_chunk_window(ci, (ci + 1) * chunk_size, ci + 1, chunk_size, eval_seq_len) - context_size, chunk_offset = chunk_stats[1], chunk_stats[2] - - active = [ci < nc for nc in num_chunks] - needs_train = any(ci < nc - 1 for nc in num_chunks) - - x = torch.zeros(bsz, context_size, dtype=torch.int64, device=device) - y = torch.zeros(bsz, context_size, dtype=torch.int64, device=device) - doc_info = [] # (chunk_offset, chunk_len) per doc - for b in range(bsz): - if not active[b]: - doc_info.append((0, 0)) - continue - ds, dl = batch[b] - ws, wl, co, cl = _compute_chunk_window(ci, pred_lens[b], num_chunks[b], chunk_size, eval_seq_len) - chunk = all_tokens[ds + ws: ds + ws + wl + 1] - toks = chunk.to(dtype=torch.int64, device=device) - x[b, :wl] = toks[:-1] - y[b, :wl] = toks[1:] - doc_info.append((co, cl)) - - # Forward pass (keep grad graph alive only when we need to train) - if needs_train: - with torch.autocast(device_type="cuda", dtype=torch.bfloat16): - ptl = base_model(x, y, lora=cur_lora) - else: - with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16): - ptl = base_model(x, y, lora=cur_lora) - - # Score: accumulate loss and byte counts for BPB (before training on chunk) - with torch.no_grad(): - for b in range(bsz): - if not active[b]: - continue - co, cl = doc_info[b] - _accumulate_bpb( - ptl, x, y, b, co, cl, base_bytes_lut, has_leading_space_lut, - is_boundary_token_lut, loss_sum, byte_sum, token_count) - - # Train: one Adam step on the LoRA params using this chunk's loss - if needs_train: - mask = torch.tensor([float(ci < num_chunks[b] - 1) for b in range(bsz)], device=device) - per_doc = ptl[:, chunk_offset:chunk_offset + chunk_size].mean(dim=-1) - cur_opt.zero_grad() - (per_doc * mask).sum().backward() - cur_opt.step() - - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(byte_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(token_count, op=dist.ReduceOp.SUM) - - val_loss = float(loss_sum.item() / token_count.item()) - val_bpb = float((loss_sum.item() / math.log(2.0)) / byte_sum.item()) - return val_loss, val_bpb - -# ----------------------------- -# TRAINING -# ----------------------------- - -def main() -> None: - global zeropower_via_newtonschulz5 - - code = Path(__file__).read_text(encoding="utf-8") - args = Hyperparameters() - zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) - - # ----------------------------- - # DISTRIBUTED + CUDA SETUP - # ----------------------------- - - distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ - rank = int(os.environ.get("RANK", "0")) - world_size = int(os.environ.get("WORLD_SIZE", "1")) - local_rank = int(os.environ.get("LOCAL_RANK", "0")) - if world_size <= 0: - raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") - if 8 % world_size != 0: - raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") - grad_accum_steps = 8 // world_size - grad_scale = 1.0 / grad_accum_steps - if not torch.cuda.is_available(): - raise RuntimeError("CUDA is required") - device = torch.device("cuda", local_rank) - torch.cuda.set_device(device) - if distributed: - dist.init_process_group(backend="nccl", device_id=device) - dist.barrier() - master_process = rank == 0 - - # Fast math knobs - torch.backends.cuda.matmul.allow_tf32 = True - torch.backends.cudnn.allow_tf32 = True - from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp - - enable_cudnn_sdp(False) - enable_flash_sdp(True) - enable_mem_efficient_sdp(False) - enable_math_sdp(False) - - logfile = None - if master_process: - os.makedirs("logs", exist_ok=True) - logfile = f"logs/{args.run_id}.txt" - print(logfile) - - def log0(msg: str, console: bool = True) -> None: - if not master_process: - return - if console: - print(msg) - if logfile is not None: - with open(logfile, "a", encoding="utf-8") as f: - print(msg, file=f) - - log0(code, console=False) - log0("=" * 100, console=False) - log0(f"Running Python {sys.version}", console=False) - log0(f"Running PyTorch {torch.__version__}", console=False) - log0( - subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, - console=False, - ) - log0("=" * 100, console=False) - - # ----------------------------- - # TOKENIZER + VALIDATION METRIC SETUP - # ----------------------------- - - random.seed(args.seed) - np.random.seed(args.seed) - torch.manual_seed(args.seed) - torch.cuda.manual_seed_all(args.seed) - - if not args.tokenizer_path.endswith(".model"): - raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") - sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) - if int(sp.vocab_size()) != args.vocab_size: - raise ValueError( - f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" - ) - dataset_dir = Path(args.data_path).resolve() - actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) - val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) - base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( - sp, args.vocab_size, device - ) - log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") - log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") - log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") - - # ----------------------------- - # MODEL + OPTIMIZER SETUP - # ----------------------------- - - base_model = GPT( - vocab_size=args.vocab_size, - num_layers=args.num_layers, - model_dim=args.model_dim, - num_heads=args.num_heads, - num_kv_heads=args.num_kv_heads, - mlp_mult=args.mlp_mult, - tie_embeddings=args.tie_embeddings, - tied_embed_init_std=args.tied_embed_init_std, - logit_softcap=args.logit_softcap, - rope_base=args.rope_base, - qk_gain_init=args.qk_gain_init, - ).to(device).bfloat16() - for module in base_model.modules(): - if isinstance(module, CastedLinear): - module.float() - if isinstance(module, Rotary): - module.inv_freq.data = module.inv_freq.data.float() - restore_low_dim_params_to_fp32(base_model) - compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) - model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model - - # Optimizer split: - # - token embedding (Adam) uses EMBED_LR - # - untied lm_head (Adam) uses HEAD_LR - # - matrix params in transformer blocks use MATRIX_LR via Muon - # - vectors/scalars use SCALAR_LR via Adam - block_named_params = list(base_model.blocks.named_parameters()) - matrix_params = [ - p - for name, p in block_named_params - if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - scalar_params = [ - p - for name, p in block_named_params - if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - if base_model.skip_weights.numel() > 0: - scalar_params.append(base_model.skip_weights) - token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr - optimizer_tok = torch.optim.Adam( - [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizer_muon = Muon( - matrix_params, - lr=args.matrix_lr, - momentum=args.muon_momentum, - backend_steps=args.muon_backend_steps, - ) - for group in optimizer_muon.param_groups: - group["base_lr"] = args.matrix_lr - optimizer_scalar = torch.optim.Adam( - [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] - if base_model.lm_head is not None: - optimizer_head = torch.optim.Adam( - [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizers.insert(1, optimizer_head) - - n_params = sum(p.numel() for p in base_model.parameters()) - log0(f"model_params:{n_params}") - log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") - log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") - log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") - log0( - f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " - f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " - f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" - ) - log0( - f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " - f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " - f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" - ) - log0(f"seed:{args.seed}") - - # ----------------------------- - # DATA LOADER & MODEL WARMUP - # ----------------------------- - - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - def zero_grad_all() -> None: - for opt in optimizers: - opt.zero_grad(set_to_none=True) - - max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None - - def lr_mul(step: int, elapsed_ms: float) -> float: - if args.warmdown_iters <= 0: - return 1.0 - if max_wallclock_ms is None: - warmdown_start = max(args.iterations - args.warmdown_iters, 0) - return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 - step_ms = elapsed_ms / max(step, 1) - warmdown_ms = args.warmdown_iters * step_ms - remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) - return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 - - # Warmup primes the compiled forward/backward/optimizer paths, then we restore the - # initial weights/optimizer state so measured training starts from the true init. - if args.warmup_steps > 0: - initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} - initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] - model.train() - for warmup_step in range(args.warmup_steps): - zero_grad_all() - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - warmup_loss = model(x, y) - (warmup_loss * grad_scale).backward() - for opt in optimizers: - opt.step() - zero_grad_all() - if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: - log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") - base_model.load_state_dict(initial_model_state, strict=True) - for opt, state in zip(optimizers, initial_optimizer_states, strict=True): - opt.load_state_dict(state) - zero_grad_all() - if distributed: - model.require_backward_grad_sync = True - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - # ----------------------------- - # MAIN TRAINING LOOP - # ----------------------------- - - training_time_ms = 0.0 - stop_after_step: int | None = None - torch.cuda.synchronize() - t0 = time.perf_counter() - - step = 0 - while True: - last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) - - should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) - if should_validate: - torch.cuda.synchronize() - training_time_ms += 1000.0 * (time.perf_counter() - t0) - val_loss, val_bpb = eval_val( - args, - model, - rank, - world_size, - device, - grad_accum_steps, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, - ) - log0( - f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " - f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" - ) - torch.cuda.synchronize() - t0 = time.perf_counter() - - if last_step: - if stop_after_step is not None and step < args.iterations: - log0( - f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " - f"step:{step}/{args.iterations}" - ) - break - - elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - scale = lr_mul(step, elapsed_ms) - zero_grad_all() - train_loss = torch.zeros((), device=device) - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - loss = model(x, y) - train_loss += loss.detach() - (loss * grad_scale).backward() - train_loss /= grad_accum_steps - - frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 - muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum - for group in optimizer_muon.param_groups: - group["momentum"] = muon_momentum - - for opt in optimizers: - for group in opt.param_groups: - group["lr"] = group["base_lr"] * scale - - if args.grad_clip_norm > 0: - torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) - for opt in optimizers: - opt.step() - zero_grad_all() - - step += 1 - approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - should_log_train = ( - args.train_log_every > 0 - and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) - ) - if should_log_train: - log0( - f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " - f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" - ) - - # Needed to sync whether we've reached the wallclock cap. - reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms - if distributed and max_wallclock_ms is not None: - reached_cap_tensor = torch.tensor(int(reached_cap), device=device) - dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) - reached_cap = bool(reached_cap_tensor.item()) - if stop_after_step is None and reached_cap: - stop_after_step = step - - log0( - f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " - f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" - ) - - # ----------------------------- - # SERIALIZATION + ROUNDTRIP VALIDATION - # ----------------------------- - # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce - # the compressed int8+zlib artifact and validate the round-tripped weights. - - if master_process: - torch.save(base_model.state_dict(), "final_model.pt") - model_bytes = os.path.getsize("final_model.pt") - code_bytes = len(code.encode("utf-8")) - log0(f"Serialized model: {model_bytes} bytes") - log0(f"Code size: {code_bytes} bytes") - log0(f"Total submission size: {model_bytes + code_bytes} bytes") - - quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) - quant_buf = io.BytesIO() - torch.save(quant_obj, quant_buf) - quant_raw = quant_buf.getvalue() - quant_blob = zlib.compress(quant_raw, level=9) - quant_raw_bytes = len(quant_raw) - if master_process: - with open("final_model.int8.ptz", "wb") as f: - f.write(quant_blob) - quant_file_bytes = os.path.getsize("final_model.int8.ptz") - code_bytes = len(code.encode("utf-8")) - ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) - log0( - f"Serialized model int8+zlib: {quant_file_bytes} bytes " - f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" - ) - log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") - - if distributed: - dist.barrier() - with open("final_model.int8.ptz", "rb") as f: - quant_blob_disk = f.read() - quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") - base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) - torch.cuda.synchronize() - t_qeval = time.perf_counter() - q_val_loss, q_val_bpb = eval_val( - args, - model, - rank, - world_size, - device, - grad_accum_steps, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, - ) - torch.cuda.synchronize() - log0( - f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " - f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" - ) - log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") - - # LoRA test-time training evaluation (the competition score) - torch._dynamo.reset() - torch.cuda.synchronize() - t_ttt = time.perf_counter() - ttt_val_loss, ttt_val_bpb = eval_val_ttt_lora( - args, base_model, rank, world_size, device, - base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - ) - torch.cuda.synchronize() - log0( - f"final_int8_ttt_lora val_loss:{ttt_val_loss:.4f} val_bpb:{ttt_val_bpb:.4f} " - f"eval_time:{1000.0 * (time.perf_counter() - t_ttt):.0f}ms" - ) - - if distributed: - dist.destroy_process_group() - - -if __name__ == "__main__": - main() - -==================================================================================================== -Running Python 3.12.13 (main, Mar 10 2026, 18:17:25) [Clang 21.1.4 ] -Running PyTorch 2.10.0+cu128 -Thu Mar 19 10:58:09 2026 -+-----------------------------------------------------------------------------------------+ -| NVIDIA-SMI 570.211.01 Driver Version: 570.211.01 CUDA Version: 12.8 | -|-----------------------------------------+------------------------+----------------------+ -| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | -| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | -| | | MIG M. | -|=========================================+========================+======================| -| 0 NVIDIA H100 80GB HBM3 On | 00000000:04:00.0 Off | 0 | -| N/A 41C P0 122W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 1 NVIDIA H100 80GB HBM3 On | 00000000:05:00.0 Off | 0 | -| N/A 35C P0 117W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 2 NVIDIA H100 80GB HBM3 On | 00000000:0A:00.0 Off | 0 | -| N/A 40C P0 125W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 3 NVIDIA H100 80GB HBM3 On | 00000000:0B:00.0 Off | 0 | -| N/A 35C P0 120W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 4 NVIDIA H100 80GB HBM3 On | 00000000:84:00.0 Off | 0 | -| N/A 42C P0 122W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 5 NVIDIA H100 80GB HBM3 On | 00000000:85:00.0 Off | 0 | -| N/A 36C P0 118W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 6 NVIDIA H100 80GB HBM3 On | 00000000:8A:00.0 Off | 0 | -| N/A 38C P0 117W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 7 NVIDIA H100 80GB HBM3 On | 00000000:8B:00.0 Off | 0 | -| N/A 34C P0 117W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ - -+-----------------------------------------------------------------------------------------+ -| Processes: | -| GPU GI CI PID Type Process name GPU Memory | -| ID ID Usage | -|=========================================================================================| -| 0 N/A N/A 50661 C ...ai-codegolf/.venv/bin/python3 1510MiB | -| 1 N/A N/A 50662 C ...ai-codegolf/.venv/bin/python3 1510MiB | -| 2 N/A N/A 50663 C ...ai-codegolf/.venv/bin/python3 1510MiB | -| 3 N/A N/A 50664 C ...ai-codegolf/.venv/bin/python3 1510MiB | -| 4 N/A N/A 50665 C ...ai-codegolf/.venv/bin/python3 1510MiB | -| 5 N/A N/A 50666 C ...ai-codegolf/.venv/bin/python3 1510MiB | -| 6 N/A N/A 50667 C ...ai-codegolf/.venv/bin/python3 1510MiB | -| 7 N/A N/A 50668 C ...ai-codegolf/.venv/bin/python3 1510MiB | -+-----------------------------------------------------------------------------------------+ - -==================================================================================================== -val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model -train_loader:dataset:fineweb10B_sp1024 train_shards:25 -val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 -model_params:17059912 -world_size:8 grad_accum_steps:1 -sdp_backends:cudnn=False flash=True mem_efficient=False math=False -attention_mode:gqa num_heads:8 num_kv_heads:4 -tie_embeddings:True embed_lr:0.05 head_lr:0.0 matrix_lr:0.04 scalar_lr:0.04 -train_batch_tokens:524288 train_seq_len:1024 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 -seed:1337 -warmup_step:1/20 -warmup_step:2/20 -warmup_step:3/20 -warmup_step:4/20 -warmup_step:5/20 -warmup_step:6/20 -warmup_step:7/20 -warmup_step:8/20 -warmup_step:9/20 -warmup_step:10/20 -warmup_step:11/20 -warmup_step:12/20 -warmup_step:13/20 -warmup_step:14/20 -warmup_step:15/20 -warmup_step:16/20 -warmup_step:17/20 -warmup_step:18/20 -warmup_step:19/20 -warmup_step:20/20 -step:0/20000 val_loss:6.9357 val_bpb:4.1077 train_time:0ms step_avg:0.01ms -step:1/20000 train_loss:6.9370 train_time:24ms step_avg:23.83ms -step:2/20000 train_loss:16.8366 train_time:65ms step_avg:32.57ms -step:3/20000 train_loss:8.7610 train_time:108ms step_avg:36.16ms -step:4/20000 train_loss:6.6384 train_time:152ms step_avg:37.95ms -step:5/20000 train_loss:6.6118 train_time:195ms step_avg:39.03ms -step:6/20000 train_loss:7.4221 train_time:239ms step_avg:39.77ms -step:7/20000 train_loss:6.3501 train_time:282ms step_avg:40.26ms -step:8/20000 train_loss:6.1579 train_time:325ms step_avg:40.64ms -step:9/20000 train_loss:6.0679 train_time:368ms step_avg:40.94ms -step:10/20000 train_loss:5.9746 train_time:412ms step_avg:41.18ms -step:50/20000 train_loss:4.1007 train_time:2143ms step_avg:42.85ms -step:100/20000 train_loss:3.4045 train_time:4307ms step_avg:43.07ms -step:150/20000 train_loss:3.0582 train_time:6470ms step_avg:43.13ms -step:200/20000 train_loss:2.8571 train_time:8707ms step_avg:43.53ms -step:200/20000 val_loss:2.8349 val_bpb:1.6790 train_time:8733ms step_avg:43.66ms -step:250/20000 train_loss:2.7538 train_time:10872ms step_avg:43.49ms -step:300/20000 train_loss:2.4974 train_time:13033ms step_avg:43.44ms -step:350/20000 train_loss:2.6676 train_time:15192ms step_avg:43.41ms -step:400/20000 train_loss:2.3566 train_time:17416ms step_avg:43.54ms -step:400/20000 val_loss:2.5720 val_bpb:1.5233 train_time:17442ms step_avg:43.61ms -step:450/20000 train_loss:2.5119 train_time:19576ms step_avg:43.50ms -step:500/20000 train_loss:2.5032 train_time:21738ms step_avg:43.48ms -step:550/20000 train_loss:2.3963 train_time:23899ms step_avg:43.45ms -step:600/20000 train_loss:2.5467 train_time:26142ms step_avg:43.57ms -step:600/20000 val_loss:2.4485 val_bpb:1.4501 train_time:26167ms step_avg:43.61ms -step:650/20000 train_loss:2.3826 train_time:28301ms step_avg:43.54ms -step:700/20000 train_loss:2.4381 train_time:30460ms step_avg:43.51ms -step:750/20000 train_loss:2.2753 train_time:32621ms step_avg:43.49ms -step:800/20000 train_loss:2.2975 train_time:34849ms step_avg:43.56ms -step:800/20000 val_loss:2.3804 val_bpb:1.4098 train_time:34875ms step_avg:43.59ms -step:850/20000 train_loss:2.7153 train_time:37009ms step_avg:43.54ms -step:900/20000 train_loss:2.3410 train_time:39167ms step_avg:43.52ms -step:950/20000 train_loss:2.4045 train_time:41325ms step_avg:43.50ms -step:1000/20000 train_loss:2.3758 train_time:43560ms step_avg:43.56ms -step:1000/20000 val_loss:2.3351 val_bpb:1.3830 train_time:43587ms step_avg:43.59ms -step:1050/20000 train_loss:2.4854 train_time:45720ms step_avg:43.54ms -step:1100/20000 train_loss:2.2606 train_time:47875ms step_avg:43.52ms -step:1150/20000 train_loss:2.2543 train_time:50101ms step_avg:43.57ms -step:1200/20000 train_loss:2.3901 train_time:52259ms step_avg:43.55ms -step:1200/20000 val_loss:2.3027 val_bpb:1.3638 train_time:52285ms step_avg:43.57ms -step:1250/20000 train_loss:2.2083 train_time:54418ms step_avg:43.53ms -step:1300/20000 train_loss:2.3630 train_time:56577ms step_avg:43.52ms -step:1350/20000 train_loss:2.2739 train_time:58812ms step_avg:43.56ms -step:1400/20000 train_loss:2.4312 train_time:60972ms step_avg:43.55ms -step:1400/20000 val_loss:2.2816 val_bpb:1.3513 train_time:60998ms step_avg:43.57ms -step:1450/20000 train_loss:2.2370 train_time:63131ms step_avg:43.54ms -step:1500/20000 train_loss:2.2243 train_time:65290ms step_avg:43.53ms -step:1550/20000 train_loss:2.1575 train_time:67521ms step_avg:43.56ms -step:1600/20000 train_loss:2.0949 train_time:69682ms step_avg:43.55ms -step:1600/20000 val_loss:2.2658 val_bpb:1.3419 train_time:69708ms step_avg:43.57ms -step:1650/20000 train_loss:2.2284 train_time:71842ms step_avg:43.54ms -step:1700/20000 train_loss:2.1699 train_time:74001ms step_avg:43.53ms -step:1750/20000 train_loss:2.2480 train_time:76231ms step_avg:43.56ms -step:1800/20000 train_loss:2.1953 train_time:78389ms step_avg:43.55ms -step:1800/20000 val_loss:2.2498 val_bpb:1.3325 train_time:78415ms step_avg:43.56ms -step:1850/20000 train_loss:2.3052 train_time:80550ms step_avg:43.54ms -step:1900/20000 train_loss:2.1887 train_time:82708ms step_avg:43.53ms -step:1950/20000 train_loss:2.2155 train_time:84939ms step_avg:43.56ms -step:2000/20000 train_loss:2.2511 train_time:87096ms step_avg:43.55ms -step:2000/20000 val_loss:2.2344 val_bpb:1.3233 train_time:87122ms step_avg:43.56ms -step:2050/20000 train_loss:2.2504 train_time:89257ms step_avg:43.54ms -step:2100/20000 train_loss:2.2657 train_time:91488ms step_avg:43.57ms -step:2150/20000 train_loss:2.1876 train_time:93646ms step_avg:43.56ms -step:2200/20000 train_loss:2.0737 train_time:95804ms step_avg:43.55ms -step:2200/20000 val_loss:2.2258 val_bpb:1.3183 train_time:95830ms step_avg:43.56ms -step:2250/20000 train_loss:2.1606 train_time:97964ms step_avg:43.54ms -step:2300/20000 train_loss:2.3783 train_time:100183ms step_avg:43.56ms -step:2350/20000 train_loss:2.2001 train_time:102342ms step_avg:43.55ms -step:2400/20000 train_loss:2.1972 train_time:104501ms step_avg:43.54ms -step:2400/20000 val_loss:2.2161 val_bpb:1.3125 train_time:104527ms step_avg:43.55ms -step:2450/20000 train_loss:2.2019 train_time:106660ms step_avg:43.53ms -step:2500/20000 train_loss:2.1207 train_time:108883ms step_avg:43.55ms -step:2550/20000 train_loss:2.1356 train_time:111041ms step_avg:43.55ms -step:2600/20000 train_loss:2.4109 train_time:113201ms step_avg:43.54ms -step:2600/20000 val_loss:2.2156 val_bpb:1.3122 train_time:113227ms step_avg:43.55ms -step:2650/20000 train_loss:2.2421 train_time:115360ms step_avg:43.53ms -step:2700/20000 train_loss:2.1534 train_time:117590ms step_avg:43.55ms -step:2750/20000 train_loss:2.3597 train_time:119746ms step_avg:43.54ms -step:2800/20000 train_loss:2.2355 train_time:121903ms step_avg:43.54ms -step:2800/20000 val_loss:2.2006 val_bpb:1.3033 train_time:121929ms step_avg:43.55ms -step:2850/20000 train_loss:2.1814 train_time:124063ms step_avg:43.53ms -step:2900/20000 train_loss:2.1753 train_time:126294ms step_avg:43.55ms -step:2950/20000 train_loss:2.2370 train_time:128451ms step_avg:43.54ms -step:3000/20000 train_loss:2.2270 train_time:130608ms step_avg:43.54ms -step:3000/20000 val_loss:2.1934 val_bpb:1.2991 train_time:130634ms step_avg:43.54ms -step:3050/20000 train_loss:2.1673 train_time:132766ms step_avg:43.53ms -step:3100/20000 train_loss:2.2087 train_time:134991ms step_avg:43.55ms -step:3150/20000 train_loss:2.1582 train_time:137147ms step_avg:43.54ms -step:3200/20000 train_loss:2.1888 train_time:139306ms step_avg:43.53ms -step:3200/20000 val_loss:2.1883 val_bpb:1.2960 train_time:139332ms step_avg:43.54ms -step:3250/20000 train_loss:2.0875 train_time:141544ms step_avg:43.55ms -step:3300/20000 train_loss:2.2393 train_time:143699ms step_avg:43.55ms -step:3350/20000 train_loss:2.0944 train_time:145856ms step_avg:43.54ms -step:3400/20000 train_loss:2.1566 train_time:148015ms step_avg:43.53ms -step:3400/20000 val_loss:2.1855 val_bpb:1.2944 train_time:148041ms step_avg:43.54ms -step:3450/20000 train_loss:2.1014 train_time:150256ms step_avg:43.55ms -step:3500/20000 train_loss:2.2469 train_time:152413ms step_avg:43.55ms -step:3550/20000 train_loss:2.3871 train_time:154570ms step_avg:43.54ms -step:3600/20000 train_loss:2.1130 train_time:156727ms step_avg:43.54ms -step:3600/20000 val_loss:2.1775 val_bpb:1.2896 train_time:156753ms step_avg:43.54ms -step:3650/20000 train_loss:2.2147 train_time:158959ms step_avg:43.55ms -step:3700/20000 train_loss:2.1527 train_time:161116ms step_avg:43.54ms -step:3750/20000 train_loss:2.1433 train_time:163273ms step_avg:43.54ms -step:3800/20000 train_loss:2.2194 train_time:165431ms step_avg:43.53ms -step:3800/20000 val_loss:2.1739 val_bpb:1.2875 train_time:165457ms step_avg:43.54ms -step:3850/20000 train_loss:2.1765 train_time:167667ms step_avg:43.55ms -step:3900/20000 train_loss:1.9904 train_time:169824ms step_avg:43.54ms -step:3950/20000 train_loss:2.1268 train_time:171979ms step_avg:43.54ms -step:4000/20000 train_loss:2.1575 train_time:174136ms step_avg:43.53ms -step:4000/20000 val_loss:2.1687 val_bpb:1.2844 train_time:174162ms step_avg:43.54ms -step:4050/20000 train_loss:2.0994 train_time:176357ms step_avg:43.54ms -step:4100/20000 train_loss:2.1890 train_time:178515ms step_avg:43.54ms -step:4150/20000 train_loss:2.3220 train_time:180673ms step_avg:43.54ms -step:4200/20000 train_loss:2.1723 train_time:182905ms step_avg:43.55ms -step:4200/20000 val_loss:2.1653 val_bpb:1.2824 train_time:182931ms step_avg:43.56ms -step:4250/20000 train_loss:2.1258 train_time:185064ms step_avg:43.54ms -step:4300/20000 train_loss:2.0270 train_time:187221ms step_avg:43.54ms -step:4350/20000 train_loss:2.2115 train_time:189378ms step_avg:43.54ms -step:4400/20000 train_loss:2.1117 train_time:191602ms step_avg:43.55ms -step:4400/20000 val_loss:2.1650 val_bpb:1.2822 train_time:191628ms step_avg:43.55ms -step:4450/20000 train_loss:2.0639 train_time:193761ms step_avg:43.54ms -step:4500/20000 train_loss:2.2567 train_time:195918ms step_avg:43.54ms -step:4550/20000 train_loss:2.0546 train_time:198075ms step_avg:43.53ms -step:4600/20000 train_loss:1.9738 train_time:200304ms step_avg:43.54ms -step:4600/20000 val_loss:2.1614 val_bpb:1.2801 train_time:200330ms step_avg:43.55ms -step:4650/20000 train_loss:2.0739 train_time:202462ms step_avg:43.54ms -step:4700/20000 train_loss:2.2671 train_time:204617ms step_avg:43.54ms -step:4750/20000 train_loss:1.9781 train_time:206775ms step_avg:43.53ms -step:4800/20000 train_loss:2.1288 train_time:208998ms step_avg:43.54ms -step:4800/20000 val_loss:2.1560 val_bpb:1.2769 train_time:209024ms step_avg:43.55ms -step:4850/20000 train_loss:2.2142 train_time:211156ms step_avg:43.54ms -step:4900/20000 train_loss:2.4058 train_time:213313ms step_avg:43.53ms -step:4950/20000 train_loss:2.1662 train_time:215470ms step_avg:43.53ms -step:5000/20000 train_loss:2.1369 train_time:217702ms step_avg:43.54ms -step:5000/20000 val_loss:2.1531 val_bpb:1.2752 train_time:217728ms step_avg:43.55ms -step:5050/20000 train_loss:2.0835 train_time:219861ms step_avg:43.54ms -step:5100/20000 train_loss:2.0890 train_time:222016ms step_avg:43.53ms -step:5150/20000 train_loss:2.1470 train_time:224248ms step_avg:43.54ms -step:5200/20000 train_loss:2.2346 train_time:226403ms step_avg:43.54ms -step:5200/20000 val_loss:2.1499 val_bpb:1.2733 train_time:226428ms step_avg:43.54ms -step:5250/20000 train_loss:2.0871 train_time:228560ms step_avg:43.54ms -step:5300/20000 train_loss:2.2136 train_time:230718ms step_avg:43.53ms -step:5350/20000 train_loss:2.5593 train_time:232961ms step_avg:43.54ms -step:5400/20000 train_loss:2.2792 train_time:235118ms step_avg:43.54ms -step:5400/20000 val_loss:2.1485 val_bpb:1.2725 train_time:235144ms step_avg:43.55ms -step:5450/20000 train_loss:2.1606 train_time:237275ms step_avg:43.54ms -step:5500/20000 train_loss:2.1717 train_time:239431ms step_avg:43.53ms -step:5550/20000 train_loss:2.1798 train_time:241666ms step_avg:43.54ms -step:5600/20000 train_loss:2.1629 train_time:243821ms step_avg:43.54ms -step:5600/20000 val_loss:2.1447 val_bpb:1.2702 train_time:243847ms step_avg:43.54ms -step:5650/20000 train_loss:2.1309 train_time:245978ms step_avg:43.54ms -step:5700/20000 train_loss:2.2591 train_time:248134ms step_avg:43.53ms -step:5750/20000 train_loss:2.0920 train_time:250365ms step_avg:43.54ms -step:5800/20000 train_loss:2.2400 train_time:252523ms step_avg:43.54ms -step:5800/20000 val_loss:2.1435 val_bpb:1.2695 train_time:252549ms step_avg:43.54ms -step:5850/20000 train_loss:2.3056 train_time:254680ms step_avg:43.54ms -step:5900/20000 train_loss:2.1449 train_time:256836ms step_avg:43.53ms -step:5950/20000 train_loss:2.0361 train_time:259060ms step_avg:43.54ms -step:6000/20000 train_loss:2.2117 train_time:261217ms step_avg:43.54ms -step:6000/20000 val_loss:2.1396 val_bpb:1.2672 train_time:261243ms step_avg:43.54ms -step:6050/20000 train_loss:2.0154 train_time:263375ms step_avg:43.53ms -step:6100/20000 train_loss:2.2909 train_time:265531ms step_avg:43.53ms -step:6150/20000 train_loss:1.9638 train_time:267768ms step_avg:43.54ms -step:6200/20000 train_loss:2.1064 train_time:269926ms step_avg:43.54ms -step:6200/20000 val_loss:2.1382 val_bpb:1.2663 train_time:269952ms step_avg:43.54ms -step:6250/20000 train_loss:2.1333 train_time:272084ms step_avg:43.53ms -step:6300/20000 train_loss:1.9430 train_time:274323ms step_avg:43.54ms -step:6350/20000 train_loss:2.1705 train_time:276480ms step_avg:43.54ms -step:6400/20000 train_loss:2.1106 train_time:278638ms step_avg:43.54ms -step:6400/20000 val_loss:2.1387 val_bpb:1.2666 train_time:278664ms step_avg:43.54ms -step:6450/20000 train_loss:2.1195 train_time:280798ms step_avg:43.53ms -step:6500/20000 train_loss:2.1114 train_time:283031ms step_avg:43.54ms -step:6550/20000 train_loss:2.0946 train_time:285189ms step_avg:43.54ms -step:6600/20000 train_loss:2.0078 train_time:287345ms step_avg:43.54ms -step:6600/20000 val_loss:2.1353 val_bpb:1.2646 train_time:287371ms step_avg:43.54ms -step:6650/20000 train_loss:2.2139 train_time:289502ms step_avg:43.53ms -step:6700/20000 train_loss:2.1424 train_time:291725ms step_avg:43.54ms -step:6750/20000 train_loss:2.1548 train_time:293881ms step_avg:43.54ms -step:6800/20000 train_loss:1.9517 train_time:296037ms step_avg:43.53ms -step:6800/20000 val_loss:2.1357 val_bpb:1.2649 train_time:296063ms step_avg:43.54ms -step:6850/20000 train_loss:2.0691 train_time:298196ms step_avg:43.53ms -step:6900/20000 train_loss:2.1401 train_time:300417ms step_avg:43.54ms -step:6950/20000 train_loss:2.0302 train_time:302574ms step_avg:43.54ms -step:7000/20000 train_loss:2.1942 train_time:304732ms step_avg:43.53ms -step:7000/20000 val_loss:2.1313 val_bpb:1.2623 train_time:304757ms step_avg:43.54ms -step:7050/20000 train_loss:2.0640 train_time:306888ms step_avg:43.53ms -step:7100/20000 train_loss:2.2295 train_time:309119ms step_avg:43.54ms -step:7150/20000 train_loss:2.1169 train_time:311275ms step_avg:43.53ms -step:7200/20000 train_loss:2.0339 train_time:313431ms step_avg:43.53ms -step:7200/20000 val_loss:2.1306 val_bpb:1.2618 train_time:313457ms step_avg:43.54ms -step:7250/20000 train_loss:2.0819 train_time:315666ms step_avg:43.54ms -step:7300/20000 train_loss:2.1837 train_time:317825ms step_avg:43.54ms -step:7350/20000 train_loss:2.2092 train_time:319983ms step_avg:43.54ms -step:7400/20000 train_loss:2.1393 train_time:322139ms step_avg:43.53ms -step:7400/20000 val_loss:2.1278 val_bpb:1.2602 train_time:322165ms step_avg:43.54ms -step:7450/20000 train_loss:2.1666 train_time:324375ms step_avg:43.54ms -step:7500/20000 train_loss:2.1289 train_time:326531ms step_avg:43.54ms -step:7550/20000 train_loss:2.1410 train_time:328688ms step_avg:43.53ms -step:7600/20000 train_loss:2.1547 train_time:330844ms step_avg:43.53ms -step:7600/20000 val_loss:2.1269 val_bpb:1.2597 train_time:330870ms step_avg:43.54ms -step:7650/20000 train_loss:2.1230 train_time:333073ms step_avg:43.54ms -step:7700/20000 train_loss:2.1483 train_time:335230ms step_avg:43.54ms -step:7750/20000 train_loss:2.2352 train_time:337389ms step_avg:43.53ms -step:7800/20000 train_loss:2.0825 train_time:339547ms step_avg:43.53ms -step:7800/20000 val_loss:2.1266 val_bpb:1.2595 train_time:339573ms step_avg:43.53ms -step:7850/20000 train_loss:2.1335 train_time:341782ms step_avg:43.54ms -step:7900/20000 train_loss:2.0890 train_time:343938ms step_avg:43.54ms -step:7950/20000 train_loss:2.1352 train_time:346095ms step_avg:43.53ms -step:8000/20000 train_loss:2.1520 train_time:348252ms step_avg:43.53ms -step:8000/20000 val_loss:2.1241 val_bpb:1.2580 train_time:348278ms step_avg:43.53ms -step:8050/20000 train_loss:2.1567 train_time:350488ms step_avg:43.54ms -step:8100/20000 train_loss:2.1795 train_time:352645ms step_avg:43.54ms -step:8150/20000 train_loss:2.0547 train_time:354801ms step_avg:43.53ms -step:8200/20000 train_loss:2.0225 train_time:356958ms step_avg:43.53ms -step:8200/20000 val_loss:2.1258 val_bpb:1.2590 train_time:356984ms step_avg:43.53ms -step:8250/20000 train_loss:2.0966 train_time:359184ms step_avg:43.54ms -step:8300/20000 train_loss:2.0529 train_time:361342ms step_avg:43.54ms -step:8350/20000 train_loss:2.1786 train_time:363500ms step_avg:43.53ms -step:8400/20000 train_loss:2.2070 train_time:365727ms step_avg:43.54ms -step:8400/20000 val_loss:2.1209 val_bpb:1.2561 train_time:365753ms step_avg:43.54ms -step:8450/20000 train_loss:2.1895 train_time:367888ms step_avg:43.54ms -step:8500/20000 train_loss:2.1274 train_time:370046ms step_avg:43.53ms -step:8550/20000 train_loss:2.1892 train_time:372204ms step_avg:43.53ms -step:8600/20000 train_loss:2.1199 train_time:374439ms step_avg:43.54ms -step:8600/20000 val_loss:2.1178 val_bpb:1.2543 train_time:374465ms step_avg:43.54ms -step:8650/20000 train_loss:2.0115 train_time:376598ms step_avg:43.54ms -step:8700/20000 train_loss:2.0750 train_time:378755ms step_avg:43.54ms -step:8750/20000 train_loss:2.1178 train_time:380911ms step_avg:43.53ms -step:8800/20000 train_loss:2.0584 train_time:383155ms step_avg:43.54ms -step:8800/20000 val_loss:2.1169 val_bpb:1.2537 train_time:383181ms step_avg:43.54ms -step:8850/20000 train_loss:2.0574 train_time:385315ms step_avg:43.54ms -step:8900/20000 train_loss:2.1138 train_time:387474ms step_avg:43.54ms -step:8950/20000 train_loss:2.1643 train_time:389633ms step_avg:43.53ms -step:9000/20000 train_loss:2.3215 train_time:391862ms step_avg:43.54ms -step:9000/20000 val_loss:2.1158 val_bpb:1.2531 train_time:391888ms step_avg:43.54ms -step:9050/20000 train_loss:2.1835 train_time:394019ms step_avg:43.54ms -step:9100/20000 train_loss:2.0095 train_time:396176ms step_avg:43.54ms -step:9150/20000 train_loss:2.2263 train_time:398334ms step_avg:43.53ms -step:9200/20000 train_loss:2.2858 train_time:400574ms step_avg:43.54ms -step:9200/20000 val_loss:2.1145 val_bpb:1.2523 train_time:400600ms step_avg:43.54ms -step:9250/20000 train_loss:2.1471 train_time:402731ms step_avg:43.54ms -step:9300/20000 train_loss:2.3688 train_time:404889ms step_avg:43.54ms -step:9350/20000 train_loss:2.1754 train_time:407124ms step_avg:43.54ms -step:9400/20000 train_loss:1.9229 train_time:409282ms step_avg:43.54ms -step:9400/20000 val_loss:2.1153 val_bpb:1.2528 train_time:409308ms step_avg:43.54ms -step:9450/20000 train_loss:2.0598 train_time:411439ms step_avg:43.54ms -step:9500/20000 train_loss:2.1587 train_time:413598ms step_avg:43.54ms -step:9550/20000 train_loss:2.2159 train_time:415826ms step_avg:43.54ms -step:9600/20000 train_loss:2.0161 train_time:417982ms step_avg:43.54ms -step:9600/20000 val_loss:2.1140 val_bpb:1.2521 train_time:418008ms step_avg:43.54ms -step:9650/20000 train_loss:2.0622 train_time:420141ms step_avg:43.54ms -step:9700/20000 train_loss:2.1559 train_time:422299ms step_avg:43.54ms -step:9750/20000 train_loss:2.1587 train_time:424521ms step_avg:43.54ms -step:9800/20000 train_loss:2.0718 train_time:426677ms step_avg:43.54ms -step:9800/20000 val_loss:2.1101 val_bpb:1.2497 train_time:426703ms step_avg:43.54ms -step:9850/20000 train_loss:2.1370 train_time:428834ms step_avg:43.54ms -step:9900/20000 train_loss:2.0179 train_time:430991ms step_avg:43.53ms -step:9950/20000 train_loss:2.1451 train_time:433216ms step_avg:43.54ms -step:10000/20000 train_loss:2.0310 train_time:435372ms step_avg:43.54ms -step:10000/20000 val_loss:2.1125 val_bpb:1.2511 train_time:435398ms step_avg:43.54ms -step:10050/20000 train_loss:2.0433 train_time:437530ms step_avg:43.54ms -step:10100/20000 train_loss:2.1036 train_time:439688ms step_avg:43.53ms -step:10150/20000 train_loss:2.1501 train_time:441912ms step_avg:43.54ms -step:10200/20000 train_loss:2.1359 train_time:444070ms step_avg:43.54ms -step:10200/20000 val_loss:2.1103 val_bpb:1.2498 train_time:444096ms step_avg:43.54ms -step:10250/20000 train_loss:2.0957 train_time:446230ms step_avg:43.53ms -step:10300/20000 train_loss:1.9998 train_time:448466ms step_avg:43.54ms -step:10350/20000 train_loss:1.9660 train_time:450623ms step_avg:43.54ms -step:10400/20000 train_loss:2.0992 train_time:452783ms step_avg:43.54ms -step:10400/20000 val_loss:2.1087 val_bpb:1.2489 train_time:452809ms step_avg:43.54ms -step:10450/20000 train_loss:1.9782 train_time:454944ms step_avg:43.54ms -step:10500/20000 train_loss:2.0581 train_time:457170ms step_avg:43.54ms -step:10550/20000 train_loss:2.1275 train_time:459329ms step_avg:43.54ms -step:10600/20000 train_loss:2.0735 train_time:461486ms step_avg:43.54ms -step:10600/20000 val_loss:2.1080 val_bpb:1.2485 train_time:461512ms step_avg:43.54ms -step:10650/20000 train_loss:2.1363 train_time:463644ms step_avg:43.53ms -step:10700/20000 train_loss:2.2608 train_time:465878ms step_avg:43.54ms -step:10750/20000 train_loss:1.9986 train_time:468035ms step_avg:43.54ms -step:10800/20000 train_loss:2.1241 train_time:470193ms step_avg:43.54ms -step:10800/20000 val_loss:2.1073 val_bpb:1.2481 train_time:470219ms step_avg:43.54ms -step:10850/20000 train_loss:2.2360 train_time:472353ms step_avg:43.53ms -step:10900/20000 train_loss:2.1901 train_time:474587ms step_avg:43.54ms -step:10950/20000 train_loss:2.0288 train_time:476744ms step_avg:43.54ms -step:11000/20000 train_loss:2.0966 train_time:478902ms step_avg:43.54ms -step:11000/20000 val_loss:2.1073 val_bpb:1.2481 train_time:478928ms step_avg:43.54ms -step:11050/20000 train_loss:2.1250 train_time:481060ms step_avg:43.53ms -step:11100/20000 train_loss:2.1007 train_time:483294ms step_avg:43.54ms -step:11150/20000 train_loss:2.0697 train_time:485453ms step_avg:43.54ms -step:11200/20000 train_loss:2.1355 train_time:487609ms step_avg:43.54ms -step:11200/20000 val_loss:2.1058 val_bpb:1.2472 train_time:487635ms step_avg:43.54ms -step:11250/20000 train_loss:2.1169 train_time:489769ms step_avg:43.54ms -step:11300/20000 train_loss:2.0824 train_time:491992ms step_avg:43.54ms -step:11350/20000 train_loss:2.0735 train_time:494151ms step_avg:43.54ms -step:11400/20000 train_loss:2.2207 train_time:496307ms step_avg:43.54ms -step:11400/20000 val_loss:2.1039 val_bpb:1.2461 train_time:496333ms step_avg:43.54ms -step:11450/20000 train_loss:2.1193 train_time:498546ms step_avg:43.54ms -step:11500/20000 train_loss:2.1019 train_time:500701ms step_avg:43.54ms -step:11550/20000 train_loss:2.1322 train_time:502859ms step_avg:43.54ms -step:11600/20000 train_loss:2.1440 train_time:505016ms step_avg:43.54ms -step:11600/20000 val_loss:2.1038 val_bpb:1.2460 train_time:505042ms step_avg:43.54ms -step:11650/20000 train_loss:2.1517 train_time:507249ms step_avg:43.54ms -step:11700/20000 train_loss:2.0176 train_time:509406ms step_avg:43.54ms -step:11750/20000 train_loss:2.0675 train_time:511563ms step_avg:43.54ms -step:11800/20000 train_loss:2.0198 train_time:513722ms step_avg:43.54ms -step:11800/20000 val_loss:2.1048 val_bpb:1.2466 train_time:513748ms step_avg:43.54ms -step:11850/20000 train_loss:2.1150 train_time:515951ms step_avg:43.54ms -step:11900/20000 train_loss:2.2966 train_time:518107ms step_avg:43.54ms -step:11950/20000 train_loss:2.1163 train_time:520262ms step_avg:43.54ms -step:12000/20000 train_loss:2.0682 train_time:522418ms step_avg:43.53ms -step:12000/20000 val_loss:2.1019 val_bpb:1.2449 train_time:522444ms step_avg:43.54ms -step:12050/20000 train_loss:2.1960 train_time:524644ms step_avg:43.54ms -step:12100/20000 train_loss:2.0906 train_time:526802ms step_avg:43.54ms -step:12150/20000 train_loss:2.1258 train_time:528958ms step_avg:43.54ms -step:12200/20000 train_loss:2.1618 train_time:531116ms step_avg:43.53ms -step:12200/20000 val_loss:2.1011 val_bpb:1.2444 train_time:531142ms step_avg:43.54ms -step:12250/20000 train_loss:2.0895 train_time:533353ms step_avg:43.54ms -step:12300/20000 train_loss:1.9704 train_time:535510ms step_avg:43.54ms -step:12350/20000 train_loss:2.0886 train_time:537667ms step_avg:43.54ms -step:12400/20000 train_loss:2.1513 train_time:539892ms step_avg:43.54ms -step:12400/20000 val_loss:2.1007 val_bpb:1.2441 train_time:539918ms step_avg:43.54ms -step:12450/20000 train_loss:2.1455 train_time:542050ms step_avg:43.54ms -step:12500/20000 train_loss:2.1533 train_time:544208ms step_avg:43.54ms -step:12550/20000 train_loss:2.1017 train_time:546366ms step_avg:43.54ms -step:12600/20000 train_loss:2.3639 train_time:548587ms step_avg:43.54ms -step:12600/20000 val_loss:2.0991 val_bpb:1.2432 train_time:548613ms step_avg:43.54ms -step:12650/20000 train_loss:1.9838 train_time:550746ms step_avg:43.54ms -step:12700/20000 train_loss:2.0458 train_time:552901ms step_avg:43.54ms -step:12750/20000 train_loss:2.0925 train_time:555058ms step_avg:43.53ms -step:12800/20000 train_loss:2.1697 train_time:557290ms step_avg:43.54ms -step:12800/20000 val_loss:2.0941 val_bpb:1.2402 train_time:557316ms step_avg:43.54ms -step:12850/20000 train_loss:1.9963 train_time:559448ms step_avg:43.54ms -step:12900/20000 train_loss:2.4743 train_time:561604ms step_avg:43.54ms -step:12950/20000 train_loss:2.0349 train_time:563761ms step_avg:43.53ms -step:13000/20000 train_loss:2.1242 train_time:565989ms step_avg:43.54ms -step:13000/20000 val_loss:2.0860 val_bpb:1.2354 train_time:566015ms step_avg:43.54ms -step:13050/20000 train_loss:2.0320 train_time:568146ms step_avg:43.54ms -step:13100/20000 train_loss:1.9702 train_time:570303ms step_avg:43.53ms -step:13150/20000 train_loss:2.1049 train_time:572460ms step_avg:43.53ms -step:13200/20000 train_loss:2.1302 train_time:574687ms step_avg:43.54ms -step:13200/20000 val_loss:2.0785 val_bpb:1.2310 train_time:574713ms step_avg:43.54ms -step:13250/20000 train_loss:2.1164 train_time:576845ms step_avg:43.54ms -step:13300/20000 train_loss:2.1376 train_time:579003ms step_avg:43.53ms -step:13350/20000 train_loss:2.1590 train_time:581160ms step_avg:43.53ms -step:13400/20000 train_loss:2.1482 train_time:583404ms step_avg:43.54ms -step:13400/20000 val_loss:2.0713 val_bpb:1.2268 train_time:583430ms step_avg:43.54ms -step:13450/20000 train_loss:2.1370 train_time:585562ms step_avg:43.54ms -step:13500/20000 train_loss:1.9860 train_time:587719ms step_avg:43.53ms -step:13550/20000 train_loss:2.0820 train_time:589942ms step_avg:43.54ms -step:13600/20000 train_loss:2.0471 train_time:592097ms step_avg:43.54ms -step:13600/20000 val_loss:2.0638 val_bpb:1.2223 train_time:592123ms step_avg:43.54ms -step:13650/20000 train_loss:2.0668 train_time:594256ms step_avg:43.54ms -step:13700/20000 train_loss:2.0994 train_time:596413ms step_avg:43.53ms -step:13750/20000 train_loss:2.0810 train_time:598642ms step_avg:43.54ms -step:13782/20000 val_loss:2.0590 val_bpb:1.2194 train_time:600047ms step_avg:43.54ms -stopping_early: wallclock_cap train_time:600047ms step:13782/20000 -peak memory allocated: 10184 MiB reserved: 10246 MiB -Serialized model: 67224983 bytes -Code size: 58465 bytes -Total submission size: 67283448 bytes -Serialized model int8+zlib: 15807986 bytes (payload:17178912 raw_torch:17224025 payload_ratio:3.91x) -Total submission size int8+zlib: 15866451 bytes -final_int8_zlib_roundtrip val_loss:2.0716 val_bpb:1.2269 eval_time:1381ms -final_int8_zlib_roundtrip_exact val_loss:2.07160602 val_bpb:1.22692177 -final_int8_ttt_lora val_loss:2.0128 val_bpb:1.1921 eval_time:59874ms diff --git a/records/track_10min_16mb/2026-03-17_LoRA_TTT/train_v3.txt b/records/track_10min_16mb/2026-03-17_LoRA_TTT/train_v3.txt deleted file mode 100644 index 7289e3106c..0000000000 --- a/records/track_10min_16mb/2026-03-17_LoRA_TTT/train_v3.txt +++ /dev/null @@ -1,1828 +0,0 @@ -""" -The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. - -Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. -""" - -from __future__ import annotations - -import copy -import glob -import io -import math -import os -import random -import subprocess -import sys -import time -import uuid -import zlib -from pathlib import Path - -import numpy as np -import sentencepiece as spm -import torch -import torch.distributed as dist -import torch.nn.functional as F -from torch import Tensor, nn -from torch.nn.parallel import DistributedDataParallel as DDP - -# ----------------------------- -# HYPERPARAMETERS -# ----------------------------- -# Default Simple Baseline run: -# - 9 transformer blocks at width 512 -# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion -# - vocab size 1024, sequence length 1024, tied embeddings -# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap - -class Hyperparameters: - # Data paths are shard globs produced by the existing preprocessing pipeline. - data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") - train_files = os.path.join(data_path, "fineweb_train_*.bin") - val_files = os.path.join(data_path, "fineweb_val_*.bin") - tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") - run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) - seed = int(os.environ.get("SEED", 1337)) - resume_from = os.environ.get("RESUME_FROM", "") - - # Validation cadence and batch size. Validation always uses the full fineweb_val split. - val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) - val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) - train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) - - # Training length. - iterations = int(os.environ.get("ITERATIONS", 20000)) - warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) - warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) - train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) - train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) - max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) - qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) - - # Model shape. - vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) - num_layers = int(os.environ.get("NUM_LAYERS", 9)) - num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) - model_dim = int(os.environ.get("MODEL_DIM", 512)) - num_heads = int(os.environ.get("NUM_HEADS", 8)) - mlp_mult = int(os.environ.get("MLP_MULT", 2)) - tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) - rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) - logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) - - # Optimizer hyperparameters. - embed_lr = float(os.environ.get("EMBED_LR", 0.6)) - head_lr = float(os.environ.get("HEAD_LR", 0.008)) - tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) - tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) - matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) - scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) - muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) - muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) - muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) - muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) - beta1 = float(os.environ.get("BETA1", 0.9)) - beta2 = float(os.environ.get("BETA2", 0.95)) - adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) - grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) - - # Test-time training (LoRA) hyperparameters. - ttt_lora_rank = int(os.environ.get("TTT_LORA_RANK", 8)) - ttt_lora_lr = float(os.environ.get("TTT_LORA_LR", 0.01)) - ttt_chunk_size = int(os.environ.get("TTT_CHUNK_SIZE", 256)) - ttt_eval_seq_len = int(os.environ.get("TTT_EVAL_SEQ_LEN", 1024)) - ttt_batch_size = int(os.environ.get("TTT_BATCH_SIZE", 64)) - -# ----------------------------- -# MUON OPTIMIZER -# ----------------------------- -# -# As borrowed from modded-nanogpt -# Background on Muon: https://kellerjordan.github.io/posts/muon/ - -def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: - # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. - # Muon uses this to normalize matrix-shaped gradients before applying them. - a, b, c = (3.4445, -4.7750, 2.0315) - X = G.bfloat16() - X /= X.norm() + eps - transposed = G.size(0) > G.size(1) - if transposed: - X = X.T - for _ in range(steps): - A = X @ X.T - B = b * A + c * A @ A - X = a * X + B @ X - return X.T if transposed else X - - -class Muon(torch.optim.Optimizer): - def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): - super().__init__( - params, - dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), - ) - - @torch.no_grad() - def step(self, closure=None): - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - distributed = dist.is_available() and dist.is_initialized() - world_size = dist.get_world_size() if distributed else 1 - rank = dist.get_rank() if distributed else 0 - - for group in self.param_groups: - params = group["params"] - if not params: - continue - lr = group["lr"] - momentum = group["momentum"] - backend_steps = group["backend_steps"] - nesterov = group["nesterov"] - - total_params = sum(int(p.numel()) for p in params) - updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) - - curr = 0 - for i, p in enumerate(params): - if i % world_size == rank and p.grad is not None: - g = p.grad - state = self.state[p] - if "momentum_buffer" not in state: - state["momentum_buffer"] = torch.zeros_like(g) - buf = state["momentum_buffer"] - buf.mul_(momentum).add_(g) - if nesterov: - g = g.add(buf, alpha=momentum) - g = zeropower_via_newtonschulz5(g, steps=backend_steps) - # Scale correction from Muon reference implementations. - g *= max(1, g.size(0) / g.size(1)) ** 0.5 - updates_flat[curr : curr + p.numel()] = g.reshape(-1) - curr += p.numel() - - if distributed: - dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) - - curr = 0 - for p in params: - g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) - p.add_(g, alpha=-lr) - curr += p.numel() - - return loss - - -# ----------------------------- -# TOKENIZER-AGNOSTIC EVALUATION SETUP -# ----------------------------- -# -# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. -# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. -# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. -# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. - -def build_sentencepiece_luts( - sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device -) -> tuple[Tensor, Tensor, Tensor]: - sp_vocab_size = int(sp.vocab_size()) - table_size = max(sp_vocab_size, vocab_size) - base_bytes_np = np.zeros((table_size,), dtype=np.int16) - has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) - is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) - for token_id in range(sp_vocab_size): - if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): - continue - is_boundary_token_np[token_id] = False - if sp.is_byte(token_id): - base_bytes_np[token_id] = 1 - continue - piece = sp.id_to_piece(token_id) - if piece.startswith("▁"): - has_leading_space_np[token_id] = True - piece = piece[1:] - base_bytes_np[token_id] = len(piece.encode("utf-8")) - return ( - torch.tensor(base_bytes_np, dtype=torch.int16, device=device), - torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), - torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), - ) - - -def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: - files = [Path(p) for p in sorted(glob.glob(pattern))] - if not files: - raise FileNotFoundError(f"No files found for pattern: {pattern}") - # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. - tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() - usable = ((tokens.numel() - 1) // seq_len) * seq_len - if usable <= 0: - raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") - return tokens[: usable + 1] - - -def eval_val( - args: Hyperparameters, - model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - grad_accum_steps: int, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, -) -> tuple[float, float]: - # Validation computes two metrics: - # - val_loss: token cross-entropy (natural log) - # - val_bpb: tokenizer-agnostic compression metric used by the challenge - local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) - if local_batch_tokens < args.train_seq_len: - raise ValueError( - "VAL_BATCH_SIZE must provide at least one sequence per rank; " - f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " - f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" - ) - local_batch_seqs = local_batch_tokens // args.train_seq_len - total_seqs = (val_tokens.numel() - 1) // args.train_seq_len - seq_start = (total_seqs * rank) // world_size - seq_end = (total_seqs * (rank + 1)) // world_size - val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) - val_token_count = torch.zeros((), device=device, dtype=torch.float64) - val_byte_count = torch.zeros((), device=device, dtype=torch.float64) - - model.eval() - with torch.inference_mode(): - for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): - batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) - raw_start = batch_seq_start * args.train_seq_len - raw_end = batch_seq_end * args.train_seq_len + 1 - local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) - x = local[:-1].reshape(-1, args.train_seq_len) - y = local[1:].reshape(-1, args.train_seq_len) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - batch_loss = model(x, y).detach() - batch_token_count = float(y.numel()) - val_loss_sum += batch_loss.to(torch.float64) * batch_token_count - val_token_count += batch_token_count - prev_ids = x.reshape(-1) - tgt_ids = y.reshape(-1) - token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) - token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) - val_byte_count += token_bytes.to(torch.float64).sum() - - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) - - val_loss = val_loss_sum / val_token_count - bits_per_token = val_loss.item() / math.log(2.0) - tokens_per_byte = val_token_count.item() / val_byte_count.item() - model.train() - return float(val_loss.item()), float(bits_per_token * tokens_per_byte) - -# ----------------------------- -# POST-TRAINING QUANTIZATION -# ----------------------------- -# -# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. -# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. -# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. - -CONTROL_TENSOR_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "CONTROL_TENSOR_NAME_PATTERNS", - "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", - ",".join(CONTROL_TENSOR_NAME_PATTERNS), - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 -INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 -INT8_PER_ROW_SCALE_DTYPE = torch.float16 -INT8_CLIP_PERCENTILE = 99.99984 -INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 - -def tensor_nbytes(t: Tensor) -> int: - return int(t.numel()) * int(t.element_size()) - -def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: - if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): - return t.float().contiguous() - if t.dtype in {torch.float32, torch.bfloat16}: - passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") - return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() - return t - -def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: - t32 = t.float() - if t32.ndim == 2: - # Matrices get one scale per row, which usually tracks output-channel - # ranges much better than a single tensor-wide scale. - clip_abs = ( - torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) - if t32.numel() - else torch.empty((t32.shape[0],), dtype=torch.float32) - ) - clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) - scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) - q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() - return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() - - # Vectors / scalars use a simpler per-tensor scale. - clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 - scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) - q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() - return q, scale - -def quantize_state_dict_int8(state_dict: dict[str, Tensor]): - # Single supported clean-script export format: - # - per-row int8 for 2D float tensors - # - per-tensor int8 for other float tensors - # - exact passthrough for non-floats - # - passthrough for small float tensors, stored as fp16 to save bytes - quantized: dict[str, Tensor] = {} - scales: dict[str, Tensor] = {} - dtypes: dict[str, str] = {} - passthrough: dict[str, Tensor] = {} - passthrough_orig_dtypes: dict[str, str] = {} - qmeta: dict[str, dict[str, object]] = {} - stats = dict.fromkeys( - ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), - 0, - ) - - for name, tensor in state_dict.items(): - t = tensor.detach().to("cpu").contiguous() - stats["param_count"] += int(t.numel()) - stats["num_tensors"] += 1 - stats["baseline_tensor_bytes"] += tensor_nbytes(t) - - if not t.is_floating_point(): - stats["num_nonfloat_tensors"] += 1 - passthrough[name] = t - stats["int8_payload_bytes"] += tensor_nbytes(t) - continue - - # Small float tensors are cheap enough to keep directly. We still downcast - # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. - if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: - kept = keep_float_tensor(name, t, passthrough_orig_dtypes) - passthrough[name] = kept - stats["int8_payload_bytes"] += tensor_nbytes(kept) - continue - - stats["num_float_tensors"] += 1 - q, s = quantize_float_tensor(t) - if s.ndim > 0: - qmeta[name] = {"scheme": "per_row", "axis": 0} - quantized[name] = q - scales[name] = s - dtypes[name] = str(t.dtype).removeprefix("torch.") - stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) - - obj: dict[str, object] = { - "__quant_format__": "int8_clean_per_row_v1", - "quantized": quantized, - "scales": scales, - "dtypes": dtypes, - "passthrough": passthrough, - } - if qmeta: - obj["qmeta"] = qmeta - if passthrough_orig_dtypes: - obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes - return obj, stats - -def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: - out: dict[str, Tensor] = {} - qmeta = obj.get("qmeta", {}) - passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) - for name, q in obj["quantized"].items(): - dtype = getattr(torch, obj["dtypes"][name]) - s = obj["scales"][name] - if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: - s = s.to(dtype=torch.float32) - # Broadcast the saved row scale back across trailing dimensions. - out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() - else: - scale = float(s.item()) - out[name] = (q.float() * scale).to(dtype=dtype).contiguous() - for name, t in obj["passthrough"].items(): - # Restore small tensors, undoing the temporary fp16 storage cast if needed. - out_t = t.detach().to("cpu").contiguous() - orig_dtype = passthrough_orig_dtypes.get(name) - if isinstance(orig_dtype, str): - out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() - out[name] = out_t - return out - - -# ----------------------------- -# DATA LOADING -# ----------------------------- - -def load_data_shard(file: Path) -> Tensor: - header_bytes = 256 * np.dtype(" None: - self.file_idx = (self.file_idx + 1) % len(self.files) - self.tokens = load_data_shard(self.files[self.file_idx]) - self.pos = 0 - - def take(self, n: int) -> Tensor: - chunks: list[Tensor] = [] - remaining = n - while remaining > 0: - avail = self.tokens.numel() - self.pos - if avail <= 0: - self._advance_file() - continue - k = min(remaining, avail) - chunks.append(self.tokens[self.pos : self.pos + k]) - self.pos += k - remaining -= k - return chunks[0] if len(chunks) == 1 else torch.cat(chunks) - - -class DistributedTokenLoader: - # Each call consumes a contiguous chunk from the shared token stream, then slices out - # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. - def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): - self.rank = rank - self.world_size = world_size - self.device = device - self.stream = TokenStream(pattern) - - def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: - local_tokens = global_tokens // (self.world_size * grad_accum_steps) - per_rank_span = local_tokens + 1 - chunk = self.stream.take(per_rank_span * self.world_size) - start = self.rank * per_rank_span - local = chunk[start : start + per_rank_span].to(dtype=torch.int64) - x = local[:-1].reshape(-1, seq_len) - y = local[1:].reshape(-1, seq_len) - return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) - -# ----------------------------- -# TRANSFORMER MODULES -# ----------------------------- - -class RMSNorm(nn.Module): - def __init__(self, eps: float | None = None): - super().__init__() - self.eps = eps - - def forward(self, x: Tensor) -> Tensor: - return F.rms_norm(x, (x.size(-1),), eps=self.eps) - - -class CastedLinear(nn.Linear): - # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. - def forward(self, x: Tensor) -> Tensor: - bias = self.bias.to(x.dtype) if self.bias is not None else None - return F.linear(x, self.weight.to(x.dtype), bias) - - -def restore_low_dim_params_to_fp32(module: nn.Module) -> None: - # Keep small/control parameters in fp32 even when the model body runs in bf16. - with torch.no_grad(): - for name, param in module.named_parameters(): - if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: - param.data = param.data.float() - - -class Rotary(nn.Module): - # Caches cos/sin tables per sequence length on the current device. - def __init__(self, dim: int, base: float = 10000.0): - super().__init__() - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self._seq_len_cached = 0 - self._cos_cached: Tensor | None = None - self._sin_cached: Tensor | None = None - - def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: - if ( - self._cos_cached is None - or self._sin_cached is None - or self._seq_len_cached != seq_len - or self._cos_cached.device != device - ): - t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) - freqs = torch.outer(t, self.inv_freq.to(device)) - self._cos_cached = freqs.cos()[None, None, :, :] - self._sin_cached = freqs.sin()[None, None, :, :] - self._seq_len_cached = seq_len - return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) - - -def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: - half = x.size(-1) // 2 - x1, x2 = x[..., :half], x[..., half:] - return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) - - -class CausalSelfAttention(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - rope_base: float, - qk_gain_init: float, - ): - super().__init__() - if dim % num_heads != 0: - raise ValueError("model_dim must be divisible by num_heads") - if num_heads % num_kv_heads != 0: - raise ValueError("num_heads must be divisible by num_kv_heads") - self.num_heads = num_heads - self.num_kv_heads = num_kv_heads - self.head_dim = dim // num_heads - if self.head_dim % 2 != 0: - raise ValueError("head_dim must be even for RoPE") - kv_dim = self.num_kv_heads * self.head_dim - self.c_q = CastedLinear(dim, dim, bias=False) - self.c_k = CastedLinear(dim, kv_dim, bias=False) - self.c_v = CastedLinear(dim, kv_dim, bias=False) - self.proj = CastedLinear(dim, dim, bias=False) - self.proj._zero_init = True - self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) - self.rotary = Rotary(self.head_dim, base=rope_base) - - def forward(self, x: Tensor, q_delta=None, v_delta=None) -> Tensor: - bsz, seqlen, dim = x.shape - q = self.c_q(x) + (q_delta if q_delta is not None else 0) - k = self.c_k(x) - v = self.c_v(x) + (v_delta if v_delta is not None else 0) - q = q.reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) - k = k.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) - v = v.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) - q = F.rms_norm(q, (q.size(-1),)) - k = F.rms_norm(k, (k.size(-1),)) - cos, sin = self.rotary(seqlen, x.device, q.dtype) - q = apply_rotary_emb(q, cos, sin) - k = apply_rotary_emb(k, cos, sin) - q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] - y = F.scaled_dot_product_attention( - q, - k, - v, - attn_mask=None, - is_causal=True, - enable_gqa=(self.num_kv_heads != self.num_heads), - ) - y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) - return self.proj(y) - - -class MLP(nn.Module): - # relu^2 MLP from the original modded-nanogpt setup - def __init__(self, dim: int, mlp_mult: int): - super().__init__() - hidden = mlp_mult * dim - self.fc = CastedLinear(dim, hidden, bias=False) - self.proj = CastedLinear(hidden, dim, bias=False) - self.proj._zero_init = True - - def forward(self, x: Tensor) -> Tensor: - x = torch.relu(self.fc(x)) - return self.proj(x.square()) - - -class Block(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: int, - rope_base: float, - qk_gain_init: float, - ): - super().__init__() - self.attn_norm = RMSNorm() - self.mlp_norm = RMSNorm() - self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) - self.mlp = MLP(dim, mlp_mult) - self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) - - def forward(self, x: Tensor, x0: Tensor, q_delta_fn=None, v_delta_fn=None) -> Tensor: - mix = self.resid_mix.to(dtype=x.dtype) - x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 - n = self.attn_norm(x) - qd = q_delta_fn(n) if q_delta_fn is not None else None - vd = v_delta_fn(n) if v_delta_fn is not None else None - attn_out = self.attn(n, qd, vd) - x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out - x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) - return x - - -class GPT(nn.Module): - def __init__( - self, - vocab_size: int, - num_layers: int, - model_dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: int, - tie_embeddings: bool, - tied_embed_init_std: float, - logit_softcap: float, - rope_base: float, - qk_gain_init: float, - ): - super().__init__() - if logit_softcap <= 0.0: - raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") - self.tie_embeddings = tie_embeddings - self.tied_embed_init_std = tied_embed_init_std - self.logit_softcap = logit_softcap - self.tok_emb = nn.Embedding(vocab_size, model_dim) - self.num_encoder_layers = num_layers // 2 - self.num_decoder_layers = num_layers - self.num_encoder_layers - self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) - self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) - self.blocks = nn.ModuleList( - [ - Block( - model_dim, - num_heads, - num_kv_heads, - mlp_mult, - rope_base, - qk_gain_init, - ) - for i in range(num_layers) - ] - ) - self.final_norm = RMSNorm() - self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) - if self.lm_head is not None: - self.lm_head._zero_init = True - self._init_weights() - - def _init_weights(self) -> None: - if self.tie_embeddings: - nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) - for module in self.modules(): - if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): - nn.init.zeros_(module.weight) - - def forward(self, input_ids: Tensor, target_ids: Tensor, lora=None) -> Tensor: - x = self.tok_emb(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - x0 = x - skips: list[Tensor] = [] - - # First half stores skips; second half reuses them in reverse order. - for i in range(self.num_encoder_layers): - qd = lora.q_loras[i] if lora else None - vd = lora.v_loras[i] if lora else None - x = self.blocks[i](x, x0, qd, vd) - skips.append(x) - for i in range(self.num_decoder_layers): - bi = self.num_encoder_layers + i - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - qd = lora.q_loras[bi] if lora else None - vd = lora.v_loras[bi] if lora else None - x = self.blocks[bi](x, x0, qd, vd) - x = self.final_norm(x) - if self.tie_embeddings: - logits = F.linear(x, self.tok_emb.weight) - else: - logits = self.lm_head(x) - logits = logits + (lora.lm_head_lora(x) if lora else 0) - logits = self.logit_softcap * torch.tanh(logits / self.logit_softcap) - if lora: - bsz, sl, V = logits.shape - return F.cross_entropy( - logits.float().reshape(-1, V), target_ids.reshape(-1), reduction="none").reshape(bsz, sl) - return F.cross_entropy(logits.float().reshape(-1, logits.size(-1)), target_ids.reshape(-1), reduction="mean") - - -# ----------------------------- -# TEST-TIME TRAINING (LoRA) -# ----------------------------- -# -# At evaluation time, we adapt per-document low-rank adapters on the validation data. -# Each document gets its own adapter, so there is no inter-document dependency. - -BOS_ID = 1 - -class BatchedLinearLoRA(nn.Module): - """LoRA for a linear layer, with independent weights per batch element. - Computes x @ Aᵀ @ Bᵀ = x @ (BA)ᵀ, i.e. the LoRA delta is ΔW = BA.""" - def __init__(self, bsz: int, in_features: int, out_features: int, rank: int): - super().__init__() - self.in_features = in_features - self.A = nn.Parameter(torch.empty(bsz, rank, in_features)) # down-projection - self.B = nn.Parameter(torch.zeros(bsz, out_features, rank)) # up-projection - self.reset() - - def forward(self, x: Tensor) -> Tensor: - return (x @ self.A.transpose(1, 2)) @ self.B.transpose(1, 2) # (bsz, T, out) - - def reset(self) -> None: - bound = 1.0 / math.sqrt(self.in_features) - with torch.no_grad(): - self.A.uniform_(-bound, bound) # kaiming-uniform - self.B.zero_() - -class BatchedTTTLoRA(nn.Module): - """All LoRA adapters for one batch: LM head and Q/V per block.""" - def __init__(self, bsz: int, model: GPT, rank: int): - super().__init__() - dim = model.tok_emb.embedding_dim - vocab = model.tok_emb.num_embeddings - self.lm_head_lora = BatchedLinearLoRA(bsz, dim, vocab, rank) - self.q_loras = nn.ModuleList() - self.v_loras = nn.ModuleList() - for block in model.blocks: - self.q_loras.append(BatchedLinearLoRA(bsz, dim, block.attn.c_q.weight.shape[0], rank)) - self.v_loras.append(BatchedLinearLoRA(bsz, dim, block.attn.c_v.weight.shape[0], rank)) - - def reset(self) -> None: - for m in self.modules(): - if isinstance(m, BatchedLinearLoRA): - m.reset() - -def _reset_ttt_optimizer(opt): - for group in opt.param_groups: - for p in group['params']: - s = opt.state.get(p) - if not s: # Fresh state. - continue - s['exp_avg'].zero_() - s['exp_avg_sq'].zero_() - s['step'].fill_(0) - -def _build_ttt_optimizer(lora, args: Hyperparameters): - return torch.optim.Adam(lora.parameters(), lr=args.ttt_lora_lr, betas=(args.beta1, args.beta2), eps=1e-10) - -def _find_docs(all_tokens: Tensor, include_next_bos: bool = True) -> list[tuple[int, int]]: - """Return (start_offset, length) for each document, identified by BOS boundaries. - - If include_next_bos is True, include next document's BOS (to match continuous-stream - eval token count exactly). - """ - bos_positions = (all_tokens == BOS_ID).nonzero(as_tuple=True)[0].numpy() - docs = [] - for i in range(len(bos_positions)): - start = int(bos_positions[i]) - end = int(bos_positions[i + 1]) if i + 1 < len(bos_positions) else all_tokens.numel() - if include_next_bos and i + 1 < len(bos_positions): - end += 1 - assert end - start >= 2 - docs.append((start, end - start)) - return docs - -def _compute_chunk_window(ci: int, pred_len: int, num_chunks: int, chunk_size: int, eval_seq_len: int): - """Return (win_start, win_len, chunk_offset, chunk_len) for chunk `ci` of a doc.""" - chunk_start = ci * chunk_size - chunk_end = pred_len if ci == num_chunks - 1 else (ci + 1) * chunk_size - win_start = max(0, chunk_end - eval_seq_len) - win_len = chunk_end - win_start - chunk_offset = chunk_start - win_start - chunk_len = chunk_end - chunk_start - return win_start, win_len, chunk_offset, chunk_len - -def _accumulate_bpb( - ptl: Tensor, x: Tensor, y: Tensor, - batch_i: int, chunk_offset: int, chunk_len: int, - base_bytes_lut: Tensor, has_leading_space_lut: Tensor, is_boundary_token_lut: Tensor, - loss_sum: Tensor, byte_sum: Tensor, token_count: Tensor, -): - """Add one doc-chunk's contribution to the running BPB accumulators.""" - lbl = ptl[batch_i, chunk_offset:chunk_offset + chunk_len].to(torch.float64) - prev = x[batch_i, chunk_offset:chunk_offset + chunk_len] - tgt = y[batch_i, chunk_offset:chunk_offset + chunk_len] - tok_bytes = base_bytes_lut[tgt].to(torch.float64) - tok_bytes += has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev] - loss_sum += lbl.sum() - byte_sum += tok_bytes.sum() - token_count += chunk_len - -def eval_val_ttt_lora( - args: Hyperparameters, - base_model: GPT, - rank: int, - world_size: int, - device: torch.device, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, -) -> tuple[float, float]: - """Evaluate with batched LoRA test-time training. Returns (val_loss, val_bpb).""" - # Load validation tokens and find document boundaries - files = sorted(glob.glob(args.val_files)) - all_tokens = torch.cat([load_data_shard(Path(f)) for f in files]) - docs = _find_docs(all_tokens) - - # Each rank takes a contiguous slice of documents - rank_docs = docs[(len(docs) * rank) // world_size : (len(docs) * (rank + 1)) // world_size] - chunk_size = args.ttt_chunk_size - eval_seq_len = args.ttt_eval_seq_len - batch_size = args.ttt_batch_size - lora_rank = args.ttt_lora_rank - - rank_docs.sort(key=lambda d: (d[1] - 2) // chunk_size) - - base_model.eval() - for p in base_model.parameters(): - p.requires_grad_(False) - - lora = BatchedTTTLoRA(batch_size, base_model, lora_rank).to(device) - opt = _build_ttt_optimizer(lora, args) - - loss_sum = torch.zeros((), device=device, dtype=torch.float64) - byte_sum = torch.zeros((), device=device, dtype=torch.float64) - token_count = torch.zeros((), device=device, dtype=torch.float64) - - for bi in range(0, len(rank_docs), batch_size): - batch = rank_docs[bi:bi + batch_size] - bsz = len(batch) - - if bsz == batch_size: - cur_lora, cur_opt = lora, opt - cur_lora.reset() - _reset_ttt_optimizer(cur_opt) - else: - cur_lora = BatchedTTTLoRA(bsz, base_model, lora_rank).to(device) - cur_opt = _build_ttt_optimizer(cur_lora, args) - - pred_lens = [doc_len - 1 for _, doc_len in batch] - num_chunks = [(pl + chunk_size - 1) // chunk_size for pl in pred_lens] - max_nc = max(num_chunks) - - for ci in range(max_nc): - chunk_stats = _compute_chunk_window(ci, (ci + 1) * chunk_size, ci + 1, chunk_size, eval_seq_len) - context_size, chunk_offset = chunk_stats[1], chunk_stats[2] - - active = [ci < nc for nc in num_chunks] - needs_train = any(ci < nc - 1 for nc in num_chunks) - - x = torch.zeros(bsz, context_size, dtype=torch.int64, device=device) - y = torch.zeros(bsz, context_size, dtype=torch.int64, device=device) - doc_info = [] # (chunk_offset, chunk_len) per doc - for b in range(bsz): - if not active[b]: - doc_info.append((0, 0)) - continue - ds, dl = batch[b] - ws, wl, co, cl = _compute_chunk_window(ci, pred_lens[b], num_chunks[b], chunk_size, eval_seq_len) - chunk = all_tokens[ds + ws: ds + ws + wl + 1] - toks = chunk.to(dtype=torch.int64, device=device) - x[b, :wl] = toks[:-1] - y[b, :wl] = toks[1:] - doc_info.append((co, cl)) - - # Forward pass (keep grad graph alive only when we need to train) - if needs_train: - with torch.autocast(device_type="cuda", dtype=torch.bfloat16): - ptl = base_model(x, y, lora=cur_lora) - else: - with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16): - ptl = base_model(x, y, lora=cur_lora) - - # Score: accumulate loss and byte counts for BPB (before training on chunk) - with torch.no_grad(): - for b in range(bsz): - if not active[b]: - continue - co, cl = doc_info[b] - _accumulate_bpb( - ptl, x, y, b, co, cl, base_bytes_lut, has_leading_space_lut, - is_boundary_token_lut, loss_sum, byte_sum, token_count) - - # Train: one Adam step on the LoRA params using this chunk's loss - if needs_train: - mask = torch.tensor([float(ci < num_chunks[b] - 1) for b in range(bsz)], device=device) - per_doc = ptl[:, chunk_offset:chunk_offset + chunk_size].mean(dim=-1) - cur_opt.zero_grad() - (per_doc * mask).sum().backward() - cur_opt.step() - - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(byte_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(token_count, op=dist.ReduceOp.SUM) - - val_loss = float(loss_sum.item() / token_count.item()) - val_bpb = float((loss_sum.item() / math.log(2.0)) / byte_sum.item()) - return val_loss, val_bpb - -# ----------------------------- -# TRAINING -# ----------------------------- - -def main() -> None: - global zeropower_via_newtonschulz5 - - code = Path(__file__).read_text(encoding="utf-8") - args = Hyperparameters() - zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) - - # ----------------------------- - # DISTRIBUTED + CUDA SETUP - # ----------------------------- - - distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ - rank = int(os.environ.get("RANK", "0")) - world_size = int(os.environ.get("WORLD_SIZE", "1")) - local_rank = int(os.environ.get("LOCAL_RANK", "0")) - if world_size <= 0: - raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") - if 8 % world_size != 0: - raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") - grad_accum_steps = 8 // world_size - grad_scale = 1.0 / grad_accum_steps - if not torch.cuda.is_available(): - raise RuntimeError("CUDA is required") - device = torch.device("cuda", local_rank) - torch.cuda.set_device(device) - if distributed: - dist.init_process_group(backend="nccl", device_id=device) - dist.barrier() - master_process = rank == 0 - - # Fast math knobs - torch.backends.cuda.matmul.allow_tf32 = True - torch.backends.cudnn.allow_tf32 = True - from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp - - enable_cudnn_sdp(False) - enable_flash_sdp(True) - enable_mem_efficient_sdp(False) - enable_math_sdp(False) - - logfile = None - if master_process: - os.makedirs("logs", exist_ok=True) - logfile = f"logs/{args.run_id}.txt" - print(logfile) - - def log0(msg: str, console: bool = True) -> None: - if not master_process: - return - if console: - print(msg) - if logfile is not None: - with open(logfile, "a", encoding="utf-8") as f: - print(msg, file=f) - - log0(code, console=False) - log0("=" * 100, console=False) - log0(f"Running Python {sys.version}", console=False) - log0(f"Running PyTorch {torch.__version__}", console=False) - log0( - subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, - console=False, - ) - log0("=" * 100, console=False) - - # ----------------------------- - # TOKENIZER + VALIDATION METRIC SETUP - # ----------------------------- - - random.seed(args.seed) - np.random.seed(args.seed) - torch.manual_seed(args.seed) - torch.cuda.manual_seed_all(args.seed) - - if not args.tokenizer_path.endswith(".model"): - raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") - sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) - if int(sp.vocab_size()) != args.vocab_size: - raise ValueError( - f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" - ) - dataset_dir = Path(args.data_path).resolve() - actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) - val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) - base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( - sp, args.vocab_size, device - ) - log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") - log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") - log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") - - # ----------------------------- - # MODEL + OPTIMIZER SETUP - # ----------------------------- - - base_model = GPT( - vocab_size=args.vocab_size, - num_layers=args.num_layers, - model_dim=args.model_dim, - num_heads=args.num_heads, - num_kv_heads=args.num_kv_heads, - mlp_mult=args.mlp_mult, - tie_embeddings=args.tie_embeddings, - tied_embed_init_std=args.tied_embed_init_std, - logit_softcap=args.logit_softcap, - rope_base=args.rope_base, - qk_gain_init=args.qk_gain_init, - ).to(device).bfloat16() - for module in base_model.modules(): - if isinstance(module, CastedLinear): - module.float() - if isinstance(module, Rotary): - module.inv_freq.data = module.inv_freq.data.float() - restore_low_dim_params_to_fp32(base_model) - compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) - model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model - - # Optimizer split: - # - token embedding (Adam) uses EMBED_LR - # - untied lm_head (Adam) uses HEAD_LR - # - matrix params in transformer blocks use MATRIX_LR via Muon - # - vectors/scalars use SCALAR_LR via Adam - block_named_params = list(base_model.blocks.named_parameters()) - matrix_params = [ - p - for name, p in block_named_params - if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - scalar_params = [ - p - for name, p in block_named_params - if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - if base_model.skip_weights.numel() > 0: - scalar_params.append(base_model.skip_weights) - token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr - optimizer_tok = torch.optim.Adam( - [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizer_muon = Muon( - matrix_params, - lr=args.matrix_lr, - momentum=args.muon_momentum, - backend_steps=args.muon_backend_steps, - ) - for group in optimizer_muon.param_groups: - group["base_lr"] = args.matrix_lr - optimizer_scalar = torch.optim.Adam( - [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] - if base_model.lm_head is not None: - optimizer_head = torch.optim.Adam( - [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizers.insert(1, optimizer_head) - - n_params = sum(p.numel() for p in base_model.parameters()) - log0(f"model_params:{n_params}") - log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") - log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") - log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") - log0( - f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " - f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " - f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" - ) - log0( - f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " - f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " - f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" - ) - log0(f"seed:{args.seed}") - - # ----------------------------- - # DATA LOADER & MODEL WARMUP - # ----------------------------- - - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - def zero_grad_all() -> None: - for opt in optimizers: - opt.zero_grad(set_to_none=True) - - max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None - - def lr_mul(step: int, elapsed_ms: float) -> float: - if args.warmdown_iters <= 0: - return 1.0 - if max_wallclock_ms is None: - warmdown_start = max(args.iterations - args.warmdown_iters, 0) - return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 - step_ms = elapsed_ms / max(step, 1) - warmdown_ms = args.warmdown_iters * step_ms - remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) - return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 - - # Warmup primes the compiled forward/backward/optimizer paths, then we restore the - # initial weights/optimizer state so measured training starts from the true init. - if args.warmup_steps > 0: - initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} - initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] - model.train() - for warmup_step in range(args.warmup_steps): - zero_grad_all() - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - warmup_loss = model(x, y) - (warmup_loss * grad_scale).backward() - for opt in optimizers: - opt.step() - zero_grad_all() - if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: - log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") - base_model.load_state_dict(initial_model_state, strict=True) - for opt, state in zip(optimizers, initial_optimizer_states, strict=True): - opt.load_state_dict(state) - zero_grad_all() - if distributed: - model.require_backward_grad_sync = True - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - # ----------------------------- - # MAIN TRAINING LOOP - # ----------------------------- - - training_time_ms = 0.0 - stop_after_step: int | None = None - torch.cuda.synchronize() - t0 = time.perf_counter() - - step = 0 - while True: - last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) - - should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) - if should_validate: - torch.cuda.synchronize() - training_time_ms += 1000.0 * (time.perf_counter() - t0) - val_loss, val_bpb = eval_val( - args, - model, - rank, - world_size, - device, - grad_accum_steps, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, - ) - log0( - f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " - f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" - ) - torch.cuda.synchronize() - t0 = time.perf_counter() - - if last_step: - if stop_after_step is not None and step < args.iterations: - log0( - f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " - f"step:{step}/{args.iterations}" - ) - break - - elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - scale = lr_mul(step, elapsed_ms) - zero_grad_all() - train_loss = torch.zeros((), device=device) - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - loss = model(x, y) - train_loss += loss.detach() - (loss * grad_scale).backward() - train_loss /= grad_accum_steps - - frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 - muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum - for group in optimizer_muon.param_groups: - group["momentum"] = muon_momentum - - for opt in optimizers: - for group in opt.param_groups: - group["lr"] = group["base_lr"] * scale - - if args.grad_clip_norm > 0: - torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) - for opt in optimizers: - opt.step() - zero_grad_all() - - step += 1 - approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - should_log_train = ( - args.train_log_every > 0 - and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) - ) - if should_log_train: - log0( - f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " - f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" - ) - - # Needed to sync whether we've reached the wallclock cap. - reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms - if distributed and max_wallclock_ms is not None: - reached_cap_tensor = torch.tensor(int(reached_cap), device=device) - dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) - reached_cap = bool(reached_cap_tensor.item()) - if stop_after_step is None and reached_cap: - stop_after_step = step - - log0( - f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " - f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" - ) - - # ----------------------------- - # SERIALIZATION + ROUNDTRIP VALIDATION - # ----------------------------- - # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce - # the compressed int8+zlib artifact and validate the round-tripped weights. - - if master_process: - torch.save(base_model.state_dict(), "final_model.pt") - model_bytes = os.path.getsize("final_model.pt") - code_bytes = len(code.encode("utf-8")) - log0(f"Serialized model: {model_bytes} bytes") - log0(f"Code size: {code_bytes} bytes") - log0(f"Total submission size: {model_bytes + code_bytes} bytes") - - quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) - quant_buf = io.BytesIO() - torch.save(quant_obj, quant_buf) - quant_raw = quant_buf.getvalue() - quant_blob = zlib.compress(quant_raw, level=9) - quant_raw_bytes = len(quant_raw) - if master_process: - with open("final_model.int8.ptz", "wb") as f: - f.write(quant_blob) - quant_file_bytes = os.path.getsize("final_model.int8.ptz") - code_bytes = len(code.encode("utf-8")) - ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) - log0( - f"Serialized model int8+zlib: {quant_file_bytes} bytes " - f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" - ) - log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") - - if distributed: - dist.barrier() - with open("final_model.int8.ptz", "rb") as f: - quant_blob_disk = f.read() - quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") - base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) - torch.cuda.synchronize() - t_qeval = time.perf_counter() - q_val_loss, q_val_bpb = eval_val( - args, - model, - rank, - world_size, - device, - grad_accum_steps, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, - ) - torch.cuda.synchronize() - log0( - f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " - f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" - ) - log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") - - # LoRA test-time training evaluation (the competition score) - torch._dynamo.reset() - torch.cuda.synchronize() - t_ttt = time.perf_counter() - ttt_val_loss, ttt_val_bpb = eval_val_ttt_lora( - args, base_model, rank, world_size, device, - base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - ) - torch.cuda.synchronize() - log0( - f"final_int8_ttt_lora val_loss:{ttt_val_loss:.4f} val_bpb:{ttt_val_bpb:.4f} " - f"eval_time:{1000.0 * (time.perf_counter() - t_ttt):.0f}ms" - ) - - if distributed: - dist.destroy_process_group() - - -if __name__ == "__main__": - main() - -==================================================================================================== -Running Python 3.12.13 (main, Mar 10 2026, 18:17:25) [Clang 21.1.4 ] -Running PyTorch 2.10.0+cu128 -Thu Mar 19 11:15:42 2026 -+-----------------------------------------------------------------------------------------+ -| NVIDIA-SMI 570.211.01 Driver Version: 570.211.01 CUDA Version: 12.8 | -|-----------------------------------------+------------------------+----------------------+ -| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | -| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | -| | | MIG M. | -|=========================================+========================+======================| -| 0 NVIDIA H100 80GB HBM3 On | 00000000:04:00.0 Off | 0 | -| N/A 41C P0 122W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 1 NVIDIA H100 80GB HBM3 On | 00000000:05:00.0 Off | 0 | -| N/A 35C P0 117W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 2 NVIDIA H100 80GB HBM3 On | 00000000:0A:00.0 Off | 0 | -| N/A 40C P0 123W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 3 NVIDIA H100 80GB HBM3 On | 00000000:0B:00.0 Off | 0 | -| N/A 35C P0 121W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 4 NVIDIA H100 80GB HBM3 On | 00000000:84:00.0 Off | 0 | -| N/A 41C P0 122W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 5 NVIDIA H100 80GB HBM3 On | 00000000:85:00.0 Off | 0 | -| N/A 36C P0 119W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 6 NVIDIA H100 80GB HBM3 On | 00000000:8A:00.0 Off | 0 | -| N/A 37C P0 120W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 7 NVIDIA H100 80GB HBM3 On | 00000000:8B:00.0 Off | 0 | -| N/A 34C P0 119W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ - -+-----------------------------------------------------------------------------------------+ -| Processes: | -| GPU GI CI PID Type Process name GPU Memory | -| ID ID Usage | -|=========================================================================================| -| 0 N/A N/A 55839 C ...ai-codegolf/.venv/bin/python3 1510MiB | -| 1 N/A N/A 55840 C ...ai-codegolf/.venv/bin/python3 1510MiB | -| 2 N/A N/A 55841 C ...ai-codegolf/.venv/bin/python3 1510MiB | -| 3 N/A N/A 55842 C ...ai-codegolf/.venv/bin/python3 1510MiB | -| 4 N/A N/A 55843 C ...ai-codegolf/.venv/bin/python3 1510MiB | -| 5 N/A N/A 55844 C ...ai-codegolf/.venv/bin/python3 1510MiB | -| 6 N/A N/A 55845 C ...ai-codegolf/.venv/bin/python3 1510MiB | -| 7 N/A N/A 55846 C ...ai-codegolf/.venv/bin/python3 1510MiB | -+-----------------------------------------------------------------------------------------+ - -==================================================================================================== -val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model -train_loader:dataset:fineweb10B_sp1024 train_shards:25 -val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 -model_params:17059912 -world_size:8 grad_accum_steps:1 -sdp_backends:cudnn=False flash=True mem_efficient=False math=False -attention_mode:gqa num_heads:8 num_kv_heads:4 -tie_embeddings:True embed_lr:0.05 head_lr:0.0 matrix_lr:0.04 scalar_lr:0.04 -train_batch_tokens:524288 train_seq_len:1024 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 -seed:1337 -warmup_step:1/20 -warmup_step:2/20 -warmup_step:3/20 -warmup_step:4/20 -warmup_step:5/20 -warmup_step:6/20 -warmup_step:7/20 -warmup_step:8/20 -warmup_step:9/20 -warmup_step:10/20 -warmup_step:11/20 -warmup_step:12/20 -warmup_step:13/20 -warmup_step:14/20 -warmup_step:15/20 -warmup_step:16/20 -warmup_step:17/20 -warmup_step:18/20 -warmup_step:19/20 -warmup_step:20/20 -step:0/20000 val_loss:6.9357 val_bpb:4.1077 train_time:0ms step_avg:0.01ms -step:1/20000 train_loss:6.9370 train_time:24ms step_avg:24.04ms -step:2/20000 train_loss:16.8367 train_time:66ms step_avg:32.99ms -step:3/20000 train_loss:8.7608 train_time:110ms step_avg:36.59ms -step:4/20000 train_loss:6.6385 train_time:153ms step_avg:38.28ms -step:5/20000 train_loss:6.6121 train_time:197ms step_avg:39.34ms -step:6/20000 train_loss:7.4220 train_time:241ms step_avg:40.08ms -step:7/20000 train_loss:6.3502 train_time:284ms step_avg:40.58ms -step:8/20000 train_loss:6.1582 train_time:328ms step_avg:40.98ms -step:9/20000 train_loss:6.0679 train_time:371ms step_avg:41.26ms -step:10/20000 train_loss:5.9745 train_time:415ms step_avg:41.52ms -step:50/20000 train_loss:4.0980 train_time:2156ms step_avg:43.12ms -step:100/20000 train_loss:3.4059 train_time:4333ms step_avg:43.33ms -step:150/20000 train_loss:3.0556 train_time:6509ms step_avg:43.39ms -step:200/20000 train_loss:2.8570 train_time:8746ms step_avg:43.73ms -step:200/20000 val_loss:2.8396 val_bpb:1.6817 train_time:8772ms step_avg:43.86ms -step:250/20000 train_loss:2.7420 train_time:10923ms step_avg:43.69ms -step:300/20000 train_loss:2.4765 train_time:13097ms step_avg:43.66ms -step:350/20000 train_loss:2.6724 train_time:15280ms step_avg:43.66ms -step:400/20000 train_loss:2.3554 train_time:17522ms step_avg:43.80ms -step:400/20000 val_loss:2.5660 val_bpb:1.5197 train_time:17548ms step_avg:43.87ms -step:450/20000 train_loss:2.5112 train_time:19694ms step_avg:43.76ms -step:500/20000 train_loss:2.4970 train_time:21865ms step_avg:43.73ms -step:550/20000 train_loss:2.4029 train_time:24036ms step_avg:43.70ms -step:600/20000 train_loss:2.5472 train_time:26269ms step_avg:43.78ms -step:600/20000 val_loss:2.4526 val_bpb:1.4526 train_time:26295ms step_avg:43.83ms -step:650/20000 train_loss:2.3897 train_time:28441ms step_avg:43.76ms -step:700/20000 train_loss:2.4431 train_time:30612ms step_avg:43.73ms -step:750/20000 train_loss:2.2780 train_time:32782ms step_avg:43.71ms -step:800/20000 train_loss:2.2972 train_time:35024ms step_avg:43.78ms -step:800/20000 val_loss:2.3835 val_bpb:1.4116 train_time:35051ms step_avg:43.81ms -step:850/20000 train_loss:2.7247 train_time:37196ms step_avg:43.76ms -step:900/20000 train_loss:2.3416 train_time:39367ms step_avg:43.74ms -step:950/20000 train_loss:2.4075 train_time:41538ms step_avg:43.72ms -step:1000/20000 train_loss:2.3741 train_time:43774ms step_avg:43.77ms -step:1000/20000 val_loss:2.3353 val_bpb:1.3831 train_time:43800ms step_avg:43.80ms -step:1050/20000 train_loss:2.4838 train_time:45945ms step_avg:43.76ms -step:1100/20000 train_loss:2.2593 train_time:48115ms step_avg:43.74ms -step:1150/20000 train_loss:2.2540 train_time:50353ms step_avg:43.79ms -step:1200/20000 train_loss:2.3866 train_time:52523ms step_avg:43.77ms -step:1200/20000 val_loss:2.3042 val_bpb:1.3647 train_time:52549ms step_avg:43.79ms -step:1250/20000 train_loss:2.2102 train_time:54693ms step_avg:43.75ms -step:1300/20000 train_loss:2.3626 train_time:56863ms step_avg:43.74ms -step:1350/20000 train_loss:2.2732 train_time:59102ms step_avg:43.78ms -step:1400/20000 train_loss:2.4311 train_time:61273ms step_avg:43.77ms -step:1400/20000 val_loss:2.2830 val_bpb:1.3521 train_time:61299ms step_avg:43.79ms -step:1450/20000 train_loss:2.2370 train_time:63445ms step_avg:43.76ms -step:1500/20000 train_loss:2.2247 train_time:65619ms step_avg:43.75ms -step:1550/20000 train_loss:2.1566 train_time:67854ms step_avg:43.78ms -step:1600/20000 train_loss:2.0967 train_time:70025ms step_avg:43.77ms -step:1600/20000 val_loss:2.2676 val_bpb:1.3430 train_time:70051ms step_avg:43.78ms -step:1650/20000 train_loss:2.2340 train_time:72198ms step_avg:43.76ms -step:1700/20000 train_loss:2.1750 train_time:74368ms step_avg:43.75ms -step:1750/20000 train_loss:2.2492 train_time:76604ms step_avg:43.77ms -step:1800/20000 train_loss:2.1997 train_time:78778ms step_avg:43.77ms -step:1800/20000 val_loss:2.2522 val_bpb:1.3339 train_time:78804ms step_avg:43.78ms -step:1850/20000 train_loss:2.3072 train_time:80950ms step_avg:43.76ms -step:1900/20000 train_loss:2.1950 train_time:83120ms step_avg:43.75ms -step:1950/20000 train_loss:2.2103 train_time:85359ms step_avg:43.77ms -step:2000/20000 train_loss:2.2531 train_time:87532ms step_avg:43.77ms -step:2000/20000 val_loss:2.2376 val_bpb:1.3252 train_time:87558ms step_avg:43.78ms -step:2050/20000 train_loss:2.2545 train_time:89706ms step_avg:43.76ms -step:2100/20000 train_loss:2.2682 train_time:91951ms step_avg:43.79ms -step:2150/20000 train_loss:2.1910 train_time:94122ms step_avg:43.78ms -step:2200/20000 train_loss:2.0775 train_time:96292ms step_avg:43.77ms -step:2200/20000 val_loss:2.2287 val_bpb:1.3200 train_time:96318ms step_avg:43.78ms -step:2250/20000 train_loss:2.1628 train_time:98463ms step_avg:43.76ms -step:2300/20000 train_loss:2.3838 train_time:100706ms step_avg:43.79ms -step:2350/20000 train_loss:2.1988 train_time:102878ms step_avg:43.78ms -step:2400/20000 train_loss:2.2034 train_time:105048ms step_avg:43.77ms -step:2400/20000 val_loss:2.2178 val_bpb:1.3135 train_time:105074ms step_avg:43.78ms -step:2450/20000 train_loss:2.2071 train_time:107220ms step_avg:43.76ms -step:2500/20000 train_loss:2.1254 train_time:109456ms step_avg:43.78ms -step:2550/20000 train_loss:2.1383 train_time:111628ms step_avg:43.78ms -step:2600/20000 train_loss:2.4129 train_time:113800ms step_avg:43.77ms -step:2600/20000 val_loss:2.2204 val_bpb:1.3150 train_time:113826ms step_avg:43.78ms -step:2650/20000 train_loss:2.2454 train_time:115974ms step_avg:43.76ms -step:2700/20000 train_loss:2.1558 train_time:118222ms step_avg:43.79ms -step:2750/20000 train_loss:2.3613 train_time:120392ms step_avg:43.78ms -step:2800/20000 train_loss:2.2343 train_time:122566ms step_avg:43.77ms -step:2800/20000 val_loss:2.2031 val_bpb:1.3048 train_time:122592ms step_avg:43.78ms -step:2850/20000 train_loss:2.1913 train_time:124738ms step_avg:43.77ms -step:2900/20000 train_loss:2.1802 train_time:127001ms step_avg:43.79ms -step:2950/20000 train_loss:2.2385 train_time:129172ms step_avg:43.79ms -step:3000/20000 train_loss:2.2329 train_time:131345ms step_avg:43.78ms -step:3000/20000 val_loss:2.1966 val_bpb:1.3009 train_time:131372ms step_avg:43.79ms -step:3050/20000 train_loss:2.1716 train_time:133517ms step_avg:43.78ms -step:3100/20000 train_loss:2.2063 train_time:135763ms step_avg:43.79ms -step:3150/20000 train_loss:2.1626 train_time:137938ms step_avg:43.79ms -step:3200/20000 train_loss:2.1921 train_time:140108ms step_avg:43.78ms -step:3200/20000 val_loss:2.1905 val_bpb:1.2973 train_time:140134ms step_avg:43.79ms -step:3250/20000 train_loss:2.0906 train_time:142349ms step_avg:43.80ms -step:3300/20000 train_loss:2.2413 train_time:144521ms step_avg:43.79ms -step:3350/20000 train_loss:2.0963 train_time:146692ms step_avg:43.79ms -step:3400/20000 train_loss:2.1598 train_time:148864ms step_avg:43.78ms -step:3400/20000 val_loss:2.1874 val_bpb:1.2955 train_time:148890ms step_avg:43.79ms -step:3450/20000 train_loss:2.1116 train_time:151109ms step_avg:43.80ms -step:3500/20000 train_loss:2.2545 train_time:153278ms step_avg:43.79ms -step:3550/20000 train_loss:2.3901 train_time:155449ms step_avg:43.79ms -step:3600/20000 train_loss:2.1151 train_time:157620ms step_avg:43.78ms -step:3600/20000 val_loss:2.1798 val_bpb:1.2910 train_time:157646ms step_avg:43.79ms -step:3650/20000 train_loss:2.2218 train_time:159858ms step_avg:43.80ms -step:3700/20000 train_loss:2.1565 train_time:162029ms step_avg:43.79ms -step:3750/20000 train_loss:2.1475 train_time:164200ms step_avg:43.79ms -step:3800/20000 train_loss:2.2253 train_time:166371ms step_avg:43.78ms -step:3800/20000 val_loss:2.1760 val_bpb:1.2888 train_time:166398ms step_avg:43.79ms -step:3850/20000 train_loss:2.1791 train_time:168620ms step_avg:43.80ms -step:3900/20000 train_loss:1.9959 train_time:170789ms step_avg:43.79ms -step:3950/20000 train_loss:2.1253 train_time:172959ms step_avg:43.79ms -step:4000/20000 train_loss:2.1645 train_time:175133ms step_avg:43.78ms -step:4000/20000 val_loss:2.1710 val_bpb:1.2858 train_time:175159ms step_avg:43.79ms -step:4050/20000 train_loss:2.1017 train_time:177373ms step_avg:43.80ms -step:4100/20000 train_loss:2.1904 train_time:179543ms step_avg:43.79ms -step:4150/20000 train_loss:2.3270 train_time:181715ms step_avg:43.79ms -step:4200/20000 train_loss:2.1751 train_time:183961ms step_avg:43.80ms -step:4200/20000 val_loss:2.1674 val_bpb:1.2836 train_time:183987ms step_avg:43.81ms -step:4250/20000 train_loss:2.1313 train_time:186133ms step_avg:43.80ms -step:4300/20000 train_loss:2.0313 train_time:188304ms step_avg:43.79ms -step:4350/20000 train_loss:2.2108 train_time:190473ms step_avg:43.79ms -step:4400/20000 train_loss:2.1132 train_time:192710ms step_avg:43.80ms -step:4400/20000 val_loss:2.1677 val_bpb:1.2839 train_time:192737ms step_avg:43.80ms -step:4450/20000 train_loss:2.0662 train_time:194882ms step_avg:43.79ms -step:4500/20000 train_loss:2.2548 train_time:197054ms step_avg:43.79ms -step:4550/20000 train_loss:2.0614 train_time:199226ms step_avg:43.79ms -step:4600/20000 train_loss:1.9724 train_time:201465ms step_avg:43.80ms -step:4600/20000 val_loss:2.1629 val_bpb:1.2810 train_time:201492ms step_avg:43.80ms -step:4650/20000 train_loss:2.0723 train_time:203638ms step_avg:43.79ms -step:4700/20000 train_loss:2.2682 train_time:205809ms step_avg:43.79ms -step:4750/20000 train_loss:1.9806 train_time:207980ms step_avg:43.79ms -step:4800/20000 train_loss:2.1260 train_time:210219ms step_avg:43.80ms -step:4800/20000 val_loss:2.1585 val_bpb:1.2784 train_time:210245ms step_avg:43.80ms -step:4850/20000 train_loss:2.2142 train_time:212393ms step_avg:43.79ms -step:4900/20000 train_loss:2.4104 train_time:214564ms step_avg:43.79ms -step:4950/20000 train_loss:2.1696 train_time:216737ms step_avg:43.79ms -step:5000/20000 train_loss:2.1399 train_time:218974ms step_avg:43.79ms -step:5000/20000 val_loss:2.1552 val_bpb:1.2765 train_time:219000ms step_avg:43.80ms -step:5050/20000 train_loss:2.0838 train_time:221146ms step_avg:43.79ms -step:5100/20000 train_loss:2.0876 train_time:223315ms step_avg:43.79ms -step:5150/20000 train_loss:2.1514 train_time:225550ms step_avg:43.80ms -step:5200/20000 train_loss:2.2388 train_time:227721ms step_avg:43.79ms -step:5200/20000 val_loss:2.1522 val_bpb:1.2746 train_time:227747ms step_avg:43.80ms -step:5250/20000 train_loss:2.0890 train_time:229893ms step_avg:43.79ms -step:5300/20000 train_loss:2.2180 train_time:232064ms step_avg:43.79ms -step:5350/20000 train_loss:2.5612 train_time:234307ms step_avg:43.80ms -step:5400/20000 train_loss:2.2834 train_time:236479ms step_avg:43.79ms -step:5400/20000 val_loss:2.1502 val_bpb:1.2735 train_time:236505ms step_avg:43.80ms -step:5450/20000 train_loss:2.1625 train_time:238651ms step_avg:43.79ms -step:5500/20000 train_loss:2.1710 train_time:240821ms step_avg:43.79ms -step:5550/20000 train_loss:2.1861 train_time:243062ms step_avg:43.79ms -step:5600/20000 train_loss:2.1630 train_time:245232ms step_avg:43.79ms -step:5600/20000 val_loss:2.1468 val_bpb:1.2715 train_time:245258ms step_avg:43.80ms -step:5650/20000 train_loss:2.1339 train_time:247403ms step_avg:43.79ms -step:5700/20000 train_loss:2.2647 train_time:249575ms step_avg:43.79ms -step:5750/20000 train_loss:2.0960 train_time:251821ms step_avg:43.80ms -step:5800/20000 train_loss:2.2417 train_time:253991ms step_avg:43.79ms -step:5800/20000 val_loss:2.1449 val_bpb:1.2703 train_time:254018ms step_avg:43.80ms -step:5850/20000 train_loss:2.3108 train_time:256163ms step_avg:43.79ms -step:5900/20000 train_loss:2.1466 train_time:258333ms step_avg:43.79ms -step:5950/20000 train_loss:2.0357 train_time:260573ms step_avg:43.79ms -step:6000/20000 train_loss:2.2133 train_time:262743ms step_avg:43.79ms -step:6000/20000 val_loss:2.1415 val_bpb:1.2683 train_time:262769ms step_avg:43.79ms -step:6050/20000 train_loss:2.0175 train_time:264915ms step_avg:43.79ms -step:6100/20000 train_loss:2.2965 train_time:267085ms step_avg:43.78ms -step:6150/20000 train_loss:1.9643 train_time:269337ms step_avg:43.79ms -step:6200/20000 train_loss:2.1052 train_time:271507ms step_avg:43.79ms -step:6200/20000 val_loss:2.1403 val_bpb:1.2676 train_time:271533ms step_avg:43.80ms -step:6250/20000 train_loss:2.1371 train_time:273679ms step_avg:43.79ms -step:6300/20000 train_loss:1.9363 train_time:275911ms step_avg:43.80ms -step:6350/20000 train_loss:2.1653 train_time:278081ms step_avg:43.79ms -step:6400/20000 train_loss:2.1143 train_time:280252ms step_avg:43.79ms -step:6400/20000 val_loss:2.1402 val_bpb:1.2675 train_time:280279ms step_avg:43.79ms -step:6450/20000 train_loss:2.1194 train_time:282423ms step_avg:43.79ms -step:6500/20000 train_loss:2.1142 train_time:284668ms step_avg:43.80ms -step:6550/20000 train_loss:2.0981 train_time:286840ms step_avg:43.79ms -step:6600/20000 train_loss:2.0068 train_time:289010ms step_avg:43.79ms -step:6600/20000 val_loss:2.1372 val_bpb:1.2657 train_time:289036ms step_avg:43.79ms -step:6650/20000 train_loss:2.2201 train_time:291182ms step_avg:43.79ms -step:6700/20000 train_loss:2.1445 train_time:293420ms step_avg:43.79ms -step:6750/20000 train_loss:2.1571 train_time:295591ms step_avg:43.79ms -step:6800/20000 train_loss:1.9528 train_time:297763ms step_avg:43.79ms -step:6800/20000 val_loss:2.1372 val_bpb:1.2658 train_time:297789ms step_avg:43.79ms -step:6850/20000 train_loss:2.0690 train_time:299936ms step_avg:43.79ms -step:6900/20000 train_loss:2.1382 train_time:302195ms step_avg:43.80ms -step:6950/20000 train_loss:2.0363 train_time:304368ms step_avg:43.79ms -step:7000/20000 train_loss:2.1949 train_time:306538ms step_avg:43.79ms -step:7000/20000 val_loss:2.1330 val_bpb:1.2633 train_time:306564ms step_avg:43.79ms -step:7050/20000 train_loss:2.0663 train_time:308710ms step_avg:43.79ms -step:7100/20000 train_loss:2.2323 train_time:310948ms step_avg:43.80ms -step:7150/20000 train_loss:2.1158 train_time:313119ms step_avg:43.79ms -step:7200/20000 train_loss:2.0351 train_time:315289ms step_avg:43.79ms -step:7200/20000 val_loss:2.1327 val_bpb:1.2631 train_time:315315ms step_avg:43.79ms -step:7250/20000 train_loss:2.0840 train_time:317544ms step_avg:43.80ms -step:7300/20000 train_loss:2.1822 train_time:319715ms step_avg:43.80ms -step:7350/20000 train_loss:2.2138 train_time:321888ms step_avg:43.79ms -step:7400/20000 train_loss:2.1412 train_time:324057ms step_avg:43.79ms -step:7400/20000 val_loss:2.1300 val_bpb:1.2615 train_time:324083ms step_avg:43.80ms -step:7450/20000 train_loss:2.1683 train_time:326297ms step_avg:43.80ms -step:7500/20000 train_loss:2.1322 train_time:328467ms step_avg:43.80ms -step:7550/20000 train_loss:2.1391 train_time:330637ms step_avg:43.79ms -step:7600/20000 train_loss:2.1593 train_time:332809ms step_avg:43.79ms -step:7600/20000 val_loss:2.1284 val_bpb:1.2606 train_time:332835ms step_avg:43.79ms -step:7650/20000 train_loss:2.1257 train_time:335067ms step_avg:43.80ms -step:7700/20000 train_loss:2.1509 train_time:337238ms step_avg:43.80ms -step:7750/20000 train_loss:2.2348 train_time:339409ms step_avg:43.79ms -step:7800/20000 train_loss:2.0846 train_time:341580ms step_avg:43.79ms -step:7800/20000 val_loss:2.1283 val_bpb:1.2605 train_time:341606ms step_avg:43.80ms -step:7850/20000 train_loss:2.1334 train_time:343814ms step_avg:43.80ms -step:7900/20000 train_loss:2.0901 train_time:345984ms step_avg:43.80ms -step:7950/20000 train_loss:2.1335 train_time:348154ms step_avg:43.79ms -step:8000/20000 train_loss:2.1574 train_time:350323ms step_avg:43.79ms -step:8000/20000 val_loss:2.1254 val_bpb:1.2588 train_time:350349ms step_avg:43.79ms -step:8050/20000 train_loss:2.1555 train_time:352566ms step_avg:43.80ms -step:8100/20000 train_loss:2.1871 train_time:354738ms step_avg:43.79ms -step:8150/20000 train_loss:2.0576 train_time:356908ms step_avg:43.79ms -step:8200/20000 train_loss:2.0271 train_time:359078ms step_avg:43.79ms -step:8200/20000 val_loss:2.1271 val_bpb:1.2598 train_time:359105ms step_avg:43.79ms -step:8250/20000 train_loss:2.1031 train_time:361316ms step_avg:43.80ms -step:8300/20000 train_loss:2.0508 train_time:363485ms step_avg:43.79ms -step:8350/20000 train_loss:2.1812 train_time:365655ms step_avg:43.79ms -step:8400/20000 train_loss:2.2092 train_time:367890ms step_avg:43.80ms -step:8400/20000 val_loss:2.1231 val_bpb:1.2574 train_time:367916ms step_avg:43.80ms -step:8450/20000 train_loss:2.1858 train_time:370061ms step_avg:43.79ms -step:8500/20000 train_loss:2.1313 train_time:372231ms step_avg:43.79ms -step:8550/20000 train_loss:2.1930 train_time:374401ms step_avg:43.79ms -step:8600/20000 train_loss:2.1264 train_time:376636ms step_avg:43.79ms -step:8600/20000 val_loss:2.1198 val_bpb:1.2555 train_time:376663ms step_avg:43.80ms -step:8650/20000 train_loss:2.0138 train_time:378807ms step_avg:43.79ms -step:8700/20000 train_loss:2.0725 train_time:380977ms step_avg:43.79ms -step:8750/20000 train_loss:2.1241 train_time:383147ms step_avg:43.79ms -step:8800/20000 train_loss:2.0577 train_time:385386ms step_avg:43.79ms -step:8800/20000 val_loss:2.1179 val_bpb:1.2543 train_time:385413ms step_avg:43.80ms -step:8850/20000 train_loss:2.0572 train_time:387558ms step_avg:43.79ms -step:8900/20000 train_loss:2.1148 train_time:389729ms step_avg:43.79ms -step:8950/20000 train_loss:2.1656 train_time:391900ms step_avg:43.79ms -step:9000/20000 train_loss:2.3235 train_time:394135ms step_avg:43.79ms -step:9000/20000 val_loss:2.1174 val_bpb:1.2541 train_time:394161ms step_avg:43.80ms -step:9050/20000 train_loss:2.1874 train_time:396307ms step_avg:43.79ms -step:9100/20000 train_loss:2.0096 train_time:398478ms step_avg:43.79ms -step:9150/20000 train_loss:2.2272 train_time:400649ms step_avg:43.79ms -step:9200/20000 train_loss:2.2863 train_time:402906ms step_avg:43.79ms -step:9200/20000 val_loss:2.1162 val_bpb:1.2533 train_time:402932ms step_avg:43.80ms -step:9250/20000 train_loss:2.1488 train_time:405078ms step_avg:43.79ms -step:9300/20000 train_loss:2.3716 train_time:407251ms step_avg:43.79ms -step:9350/20000 train_loss:2.1750 train_time:409481ms step_avg:43.79ms -step:9400/20000 train_loss:1.9241 train_time:411651ms step_avg:43.79ms -step:9400/20000 val_loss:2.1168 val_bpb:1.2537 train_time:411677ms step_avg:43.80ms -step:9450/20000 train_loss:2.0555 train_time:413821ms step_avg:43.79ms -step:9500/20000 train_loss:2.1634 train_time:415993ms step_avg:43.79ms -step:9550/20000 train_loss:2.2153 train_time:418244ms step_avg:43.80ms -step:9600/20000 train_loss:2.0163 train_time:420415ms step_avg:43.79ms -step:9600/20000 val_loss:2.1162 val_bpb:1.2533 train_time:420441ms step_avg:43.80ms -step:9650/20000 train_loss:2.0629 train_time:422589ms step_avg:43.79ms -step:9700/20000 train_loss:2.1591 train_time:424759ms step_avg:43.79ms -step:9750/20000 train_loss:2.1551 train_time:426994ms step_avg:43.79ms -step:9800/20000 train_loss:2.0674 train_time:429164ms step_avg:43.79ms -step:9800/20000 val_loss:2.1121 val_bpb:1.2509 train_time:429190ms step_avg:43.79ms -step:9850/20000 train_loss:2.1399 train_time:431336ms step_avg:43.79ms -step:9900/20000 train_loss:2.0134 train_time:433508ms step_avg:43.79ms -step:9950/20000 train_loss:2.1507 train_time:435750ms step_avg:43.79ms -step:10000/20000 train_loss:2.0270 train_time:437919ms step_avg:43.79ms -step:10000/20000 val_loss:2.1148 val_bpb:1.2525 train_time:437945ms step_avg:43.79ms -step:10050/20000 train_loss:2.0438 train_time:440091ms step_avg:43.79ms -step:10100/20000 train_loss:2.1068 train_time:442261ms step_avg:43.79ms -step:10150/20000 train_loss:2.1505 train_time:444498ms step_avg:43.79ms -step:10200/20000 train_loss:2.1344 train_time:446670ms step_avg:43.79ms -step:10200/20000 val_loss:2.1120 val_bpb:1.2509 train_time:446696ms step_avg:43.79ms -step:10250/20000 train_loss:2.0986 train_time:448842ms step_avg:43.79ms -step:10300/20000 train_loss:2.0060 train_time:451093ms step_avg:43.80ms -step:10350/20000 train_loss:1.9689 train_time:453261ms step_avg:43.79ms -step:10400/20000 train_loss:2.1015 train_time:455431ms step_avg:43.79ms -step:10400/20000 val_loss:2.1101 val_bpb:1.2497 train_time:455458ms step_avg:43.79ms -step:10450/20000 train_loss:1.9867 train_time:457604ms step_avg:43.79ms -step:10500/20000 train_loss:2.0564 train_time:459840ms step_avg:43.79ms -step:10550/20000 train_loss:2.1288 train_time:462011ms step_avg:43.79ms -step:10600/20000 train_loss:2.0751 train_time:464184ms step_avg:43.79ms -step:10600/20000 val_loss:2.1095 val_bpb:1.2494 train_time:464210ms step_avg:43.79ms -step:10650/20000 train_loss:2.1335 train_time:466355ms step_avg:43.79ms -step:10700/20000 train_loss:2.2599 train_time:468597ms step_avg:43.79ms -step:10750/20000 train_loss:2.0068 train_time:470768ms step_avg:43.79ms -step:10800/20000 train_loss:2.1260 train_time:472938ms step_avg:43.79ms -step:10800/20000 val_loss:2.1087 val_bpb:1.2489 train_time:472964ms step_avg:43.79ms -step:10850/20000 train_loss:2.2389 train_time:475109ms step_avg:43.79ms -step:10900/20000 train_loss:2.1941 train_time:477345ms step_avg:43.79ms -step:10950/20000 train_loss:2.0311 train_time:479516ms step_avg:43.79ms -step:11000/20000 train_loss:2.1000 train_time:481687ms step_avg:43.79ms -step:11000/20000 val_loss:2.1090 val_bpb:1.2491 train_time:481713ms step_avg:43.79ms -step:11050/20000 train_loss:2.1310 train_time:483857ms step_avg:43.79ms -step:11100/20000 train_loss:2.1062 train_time:486104ms step_avg:43.79ms -step:11150/20000 train_loss:2.0740 train_time:488274ms step_avg:43.79ms -step:11200/20000 train_loss:2.1338 train_time:490445ms step_avg:43.79ms -step:11200/20000 val_loss:2.1066 val_bpb:1.2476 train_time:490471ms step_avg:43.79ms -step:11250/20000 train_loss:2.1177 train_time:492616ms step_avg:43.79ms -step:11300/20000 train_loss:2.0860 train_time:494849ms step_avg:43.79ms -step:11350/20000 train_loss:2.0760 train_time:497021ms step_avg:43.79ms -step:11400/20000 train_loss:2.2242 train_time:499192ms step_avg:43.79ms -step:11400/20000 val_loss:2.1057 val_bpb:1.2471 train_time:499218ms step_avg:43.79ms -step:11450/20000 train_loss:2.1219 train_time:501431ms step_avg:43.79ms -step:11500/20000 train_loss:2.1016 train_time:503600ms step_avg:43.79ms -step:11550/20000 train_loss:2.1222 train_time:505770ms step_avg:43.79ms -step:11600/20000 train_loss:2.1468 train_time:507939ms step_avg:43.79ms -step:11600/20000 val_loss:2.1056 val_bpb:1.2471 train_time:507966ms step_avg:43.79ms -step:11650/20000 train_loss:2.1549 train_time:510177ms step_avg:43.79ms -step:11700/20000 train_loss:2.0195 train_time:512348ms step_avg:43.79ms -step:11750/20000 train_loss:2.0654 train_time:514518ms step_avg:43.79ms -step:11800/20000 train_loss:2.0257 train_time:516690ms step_avg:43.79ms -step:11800/20000 val_loss:2.1059 val_bpb:1.2472 train_time:516716ms step_avg:43.79ms -step:11850/20000 train_loss:2.1168 train_time:518940ms step_avg:43.79ms -step:11900/20000 train_loss:2.2975 train_time:521110ms step_avg:43.79ms -step:11950/20000 train_loss:2.1150 train_time:523282ms step_avg:43.79ms -step:12000/20000 train_loss:2.0739 train_time:525452ms step_avg:43.79ms -step:12000/20000 val_loss:2.1045 val_bpb:1.2464 train_time:525478ms step_avg:43.79ms -step:12050/20000 train_loss:2.1966 train_time:527686ms step_avg:43.79ms -step:12100/20000 train_loss:2.0933 train_time:529857ms step_avg:43.79ms -step:12150/20000 train_loss:2.1247 train_time:532028ms step_avg:43.79ms -step:12200/20000 train_loss:2.1644 train_time:534200ms step_avg:43.79ms -step:12200/20000 val_loss:2.1027 val_bpb:1.2453 train_time:534226ms step_avg:43.79ms -step:12250/20000 train_loss:2.0929 train_time:536449ms step_avg:43.79ms -step:12300/20000 train_loss:1.9745 train_time:538621ms step_avg:43.79ms -step:12350/20000 train_loss:2.0896 train_time:540792ms step_avg:43.79ms -step:12400/20000 train_loss:2.1576 train_time:543026ms step_avg:43.79ms -step:12400/20000 val_loss:2.1025 val_bpb:1.2452 train_time:543052ms step_avg:43.79ms -step:12450/20000 train_loss:2.1487 train_time:545198ms step_avg:43.79ms -step:12500/20000 train_loss:2.1525 train_time:547368ms step_avg:43.79ms -step:12550/20000 train_loss:2.1040 train_time:549538ms step_avg:43.79ms -step:12600/20000 train_loss:2.3639 train_time:551777ms step_avg:43.79ms -step:12600/20000 val_loss:2.0988 val_bpb:1.2430 train_time:551803ms step_avg:43.79ms -step:12650/20000 train_loss:1.9861 train_time:553950ms step_avg:43.79ms -step:12700/20000 train_loss:2.0478 train_time:556120ms step_avg:43.79ms -step:12750/20000 train_loss:2.0901 train_time:558294ms step_avg:43.79ms -step:12800/20000 train_loss:2.1683 train_time:560532ms step_avg:43.79ms -step:12800/20000 val_loss:2.0935 val_bpb:1.2399 train_time:560558ms step_avg:43.79ms -step:12850/20000 train_loss:1.9931 train_time:562704ms step_avg:43.79ms -step:12900/20000 train_loss:2.4707 train_time:564875ms step_avg:43.79ms -step:12950/20000 train_loss:2.0368 train_time:567047ms step_avg:43.79ms -step:13000/20000 train_loss:2.1259 train_time:569296ms step_avg:43.79ms -step:13000/20000 val_loss:2.0847 val_bpb:1.2347 train_time:569323ms step_avg:43.79ms -step:13050/20000 train_loss:2.0280 train_time:571468ms step_avg:43.79ms -step:13100/20000 train_loss:1.9683 train_time:573638ms step_avg:43.79ms -step:13150/20000 train_loss:2.0988 train_time:575809ms step_avg:43.79ms -step:13200/20000 train_loss:2.1256 train_time:578042ms step_avg:43.79ms -step:13200/20000 val_loss:2.0771 val_bpb:1.2302 train_time:578068ms step_avg:43.79ms -step:13250/20000 train_loss:2.1152 train_time:580213ms step_avg:43.79ms -step:13300/20000 train_loss:2.1388 train_time:582383ms step_avg:43.79ms -step:13350/20000 train_loss:2.1556 train_time:584554ms step_avg:43.79ms -step:13400/20000 train_loss:2.1470 train_time:586802ms step_avg:43.79ms -step:13400/20000 val_loss:2.0698 val_bpb:1.2259 train_time:586828ms step_avg:43.79ms -step:13450/20000 train_loss:2.1367 train_time:588974ms step_avg:43.79ms -step:13500/20000 train_loss:1.9808 train_time:591144ms step_avg:43.79ms -step:13550/20000 train_loss:2.0810 train_time:593382ms step_avg:43.79ms -step:13600/20000 train_loss:2.0428 train_time:595553ms step_avg:43.79ms -step:13600/20000 val_loss:2.0624 val_bpb:1.2215 train_time:595579ms step_avg:43.79ms -step:13650/20000 train_loss:2.0705 train_time:597724ms step_avg:43.79ms -step:13700/20000 train_loss:2.0979 train_time:599895ms step_avg:43.79ms -step:13703/20000 val_loss:2.0601 val_bpb:1.2201 train_time:600051ms step_avg:43.79ms -stopping_early: wallclock_cap train_time:600051ms step:13703/20000 -peak memory allocated: 10185 MiB reserved: 10572 MiB -Serialized model: 67224983 bytes -Code size: 60906 bytes -Total submission size: 67285889 bytes -Serialized model int8+zlib: 15823937 bytes (payload:17178912 raw_torch:17224025 payload_ratio:3.91x) -Total submission size int8+zlib: 15884843 bytes -final_int8_zlib_roundtrip val_loss:2.0724 val_bpb:1.2274 eval_time:1374ms -final_int8_zlib_roundtrip_exact val_loss:2.07240908 val_bpb:1.22739739 -final_int8_ttt_lora val_loss:2.0142 val_bpb:1.1929 eval_time:60184ms diff --git a/records/track_10min_16mb/2026-03-17_NaiveBaseline/README.md b/records/track_10min_16mb/2026-03-17_NaiveBaseline/README.md deleted file mode 100644 index 1eb352a8a3..0000000000 --- a/records/track_10min_16mb/2026-03-17_NaiveBaseline/README.md +++ /dev/null @@ -1,46 +0,0 @@ -This record captures the `Simple Baseline`. - -Trainer changes in this snapshot: -- current repository `train_gpt.py` snapshot copied into the record folder -- published `fineweb10B_sp1024` dataset and tokenizer loaded from the new Hugging Face export -- 10-minute wallclock cap on `8xH100` -- periodic validation every `200` steps on the full `fineweb_val_*` split - -Configuration: -- Layout: `VOCAB_SIZE=1024 NUM_LAYERS=9 MODEL_DIM=512 NUM_HEADS=8 NUM_KV_HEADS=4 MLP_MULT=2` -- Tied output/input embeddings: `TIE_EMBEDDINGS=1` -- Tied embedding LR: `TIED_EMBED_LR=0.05` -- Batching: `TRAIN_BATCH_TOKENS=524288 TRAIN_SEQ_LEN=1024` - -Command (track-relevant params): -```bash -NCCL_IB_DISABLE=1 \ -RUN_ID=hf_verify_sp1024_8gpu \ -DATA_PATH=/root/code/parameter-golf/data/datasets/fineweb10B_sp1024 \ -TOKENIZER_PATH=/root/code/parameter-golf/data/tokenizers/fineweb_1024_bpe.model \ -VOCAB_SIZE=1024 \ -MAX_WALLCLOCK_SECONDS=600 \ -TRAIN_LOG_EVERY=50 \ -VAL_LOSS_EVERY=200 \ -torchrun --standalone --nproc_per_node=8 /root/code/parameter-golf/train_gpt.py -``` - -Key metrics (from `train.log`): -- Timed training stopped at `13780/20000` steps due to the wallclock cap. -- Pre-quant eval at stop: `val_loss:2.0606`, `val_bpb:1.2172` -- Post-quant roundtrip eval: `val_loss:2.0727`, `val_bpb:1.2244` -- Exact printed metric: `final_int8_zlib_roundtrip_exact val_bpb:1.22436570` -- Train time: `600038ms` (`step_avg:43.54ms`) -- Peak memory: `10184 MiB allocated`, `10200 MiB reserved` -- Serialized model int8+zlib: `15815847 bytes` -- Code size: `47642 bytes` -- Total submission size int8+zlib: `15863489 bytes` - -Training volume: -- Global batch: `524288` tokens/step -- Total train tokens seen: `7224688640` - -Included files: -- `train_gpt.py` (code snapshot used for the run) -- `train.log` (exact remote training log) -- `submission.json` (leaderboard metadata) diff --git a/records/track_10min_16mb/2026-03-17_NaiveBaseline/submission.json b/records/track_10min_16mb/2026-03-17_NaiveBaseline/submission.json deleted file mode 100644 index faffa83ea8..0000000000 --- a/records/track_10min_16mb/2026-03-17_NaiveBaseline/submission.json +++ /dev/null @@ -1,11 +0,0 @@ -{ - "author": "Baseline", - "github_id": "openai", - "name": "Naive Baseline", - "blurb": "SP-1024 9x512 KV4 run on pgut1 using the published Hugging Face fineweb10B_sp1024 export and the current train_gpt.py; score is the default final int8+zlib roundtrip metric under the 16,000,000-byte cap.", - "date": "2026-03-18T14:56:29Z", - "val_loss": 2.07269931, - "val_bpb": 1.2243657, - "bytes_total": 15863489, - "bytes_code": 47642 -} diff --git a/records/track_10min_16mb/2026-03-17_NaiveBaseline/train.log b/records/track_10min_16mb/2026-03-17_NaiveBaseline/train.log deleted file mode 100644 index 69b17b6c7f..0000000000 --- a/records/track_10min_16mb/2026-03-17_NaiveBaseline/train.log +++ /dev/null @@ -1,448 +0,0 @@ -W0318 14:37:59.159000 871689 site-packages/torch/distributed/run.py:852] -W0318 14:37:59.159000 871689 site-packages/torch/distributed/run.py:852] ***************************************** -W0318 14:37:59.159000 871689 site-packages/torch/distributed/run.py:852] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. -W0318 14:37:59.159000 871689 site-packages/torch/distributed/run.py:852] ***************************************** -[W318 14:38:11.514156940 Utils.hpp:137] Warning: Environment variable NCCL_ASYNC_ERROR_HANDLING is deprecated; use TORCH_NCCL_ASYNC_ERROR_HANDLING instead (function operator()) -[W318 14:38:11.543417305 Utils.hpp:137] Warning: Environment variable NCCL_ASYNC_ERROR_HANDLING is deprecated; use TORCH_NCCL_ASYNC_ERROR_HANDLING instead (function operator()) -[W318 14:38:11.552597211 Utils.hpp:137] Warning: Environment variable NCCL_ASYNC_ERROR_HANDLING is deprecated; use TORCH_NCCL_ASYNC_ERROR_HANDLING instead (function operator()) -NCCL version 2.27.5+cuda12.9 -[W318 14:38:11.832390267 Utils.hpp:137] Warning: Environment variable NCCL_ASYNC_ERROR_HANDLING is deprecated; use TORCH_NCCL_ASYNC_ERROR_HANDLING instead (function operator()) -[W318 14:38:11.842257581 Utils.hpp:137] Warning: Environment variable NCCL_ASYNC_ERROR_HANDLING is deprecated; use TORCH_NCCL_ASYNC_ERROR_HANDLING instead (function operator()) -[W318 14:38:11.842253680 Utils.hpp:137] Warning: Environment variable NCCL_ASYNC_ERROR_HANDLING is deprecated; use TORCH_NCCL_ASYNC_ERROR_HANDLING instead (function operator()) -[W318 14:38:11.899166383 Utils.hpp:137] Warning: Environment variable NCCL_ASYNC_ERROR_HANDLING is deprecated; use TORCH_NCCL_ASYNC_ERROR_HANDLING instead (function operator()) -[W318 14:38:11.901800020 Utils.hpp:137] Warning: Environment variable NCCL_ASYNC_ERROR_HANDLING is deprecated; use TORCH_NCCL_ASYNC_ERROR_HANDLING instead (function operator()) - -[2026-03-18 14:38:12] pgut1-0:871784:871848 [5] ibvwrap.c:94 NCCL WARN Call to ibv_open_device failed - -[2026-03-18 14:38:12] pgut1-0:871784:871848 [5] p2p_plugin.c:565 NCCL WARN NET/IB : Unable to open device mlx5_an0 - -[2026-03-18 14:38:12] pgut1-0:871786:871849 [7] ibvwrap.c:94 NCCL WARN Call to ibv_open_device failed - -[2026-03-18 14:38:12] pgut1-0:871786:871849 [7] p2p_plugin.c:565 NCCL WARN NET/IB : Unable to open device mlx5_an0 - -[2026-03-18 14:38:12] pgut1-0:871779:871850 [0] ibvwrap.c:94 NCCL WARN Call to ibv_open_device failed - -[2026-03-18 14:38:12] pgut1-0:871779:871850 [0] p2p_plugin.c:565 NCCL WARN NET/IB : Unable to open device mlx5_an0 - -[2026-03-18 14:38:12] pgut1-0:871780:871857 [1] ibvwrap.c:94 NCCL WARN Call to ibv_open_device failed - -[2026-03-18 14:38:12] pgut1-0:871780:871857 [1] p2p_plugin.c:565 NCCL WARN NET/IB : Unable to open device mlx5_an0 - -[2026-03-18 14:38:12] pgut1-0:871781:871858 [2] ibvwrap.c:94 NCCL WARN Call to ibv_open_device failed - -[2026-03-18 14:38:12] pgut1-0:871781:871858 [2] p2p_plugin.c:565 NCCL WARN NET/IB : Unable to open device mlx5_an0 - -[2026-03-18 14:38:12] pgut1-0:871783:871859 [4] ibvwrap.c:94 NCCL WARN Call to ibv_open_device failed - -[2026-03-18 14:38:12] pgut1-0:871783:871859 [4] p2p_plugin.c:565 NCCL WARN NET/IB : Unable to open device mlx5_an0 - -[2026-03-18 14:38:12] pgut1-0:871782:871864 [3] ibvwrap.c:94 NCCL WARN Call to ibv_open_device failed - -[2026-03-18 14:38:12] pgut1-0:871782:871864 [3] p2p_plugin.c:565 NCCL WARN NET/IB : Unable to open device mlx5_an0 - -[2026-03-18 14:38:12] pgut1-0:871785:871865 [6] ibvwrap.c:94 NCCL WARN Call to ibv_open_device failed - -[2026-03-18 14:38:12] pgut1-0:871785:871865 [6] p2p_plugin.c:565 NCCL WARN NET/IB : Unable to open device mlx5_an0 -logs/hf_verify_sp1024_8gpu.txt -val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=/root/code/parameter-golf/data/tokenizers/fineweb_1024_bpe.model -train_loader:dataset:fineweb10B_sp1024 train_shards:25 -val_loader:shards pattern=/root/code/parameter-golf/data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:63779840 -[rank0]:[W318 14:38:18.833454927 Utils.hpp:112] Warning: Environment variable NCCL_ASYNC_ERROR_HANDLING is deprecated; use TORCH_NCCL_ASYNC_ERROR_HANDLING instead (function operator()) -model_params:17059912 -world_size:8 grad_accum_steps:1 -sdp_backends:cudnn=False flash=True mem_efficient=False math=False -attention_mode:gqa num_heads:8 num_kv_heads:4 -tie_embeddings:True embed_lr:0.05 head_lr:0.0 matrix_lr:0.04 scalar_lr:0.04 -train_batch_tokens:524288 train_seq_len:1024 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 -seed:1337 -[rank3]:[W318 14:38:18.835915381 Utils.hpp:112] Warning: Environment variable NCCL_ASYNC_ERROR_HANDLING is deprecated; use TORCH_NCCL_ASYNC_ERROR_HANDLING instead (function operator()) -[rank7]:[W318 14:38:18.835951425 Utils.hpp:112] Warning: Environment variable NCCL_ASYNC_ERROR_HANDLING is deprecated; use TORCH_NCCL_ASYNC_ERROR_HANDLING instead (function operator()) -[rank6]:[W318 14:38:18.835967008 Utils.hpp:112] Warning: Environment variable NCCL_ASYNC_ERROR_HANDLING is deprecated; use TORCH_NCCL_ASYNC_ERROR_HANDLING instead (function operator()) -[rank2]:[W318 14:38:18.836023454 Utils.hpp:112] Warning: Environment variable NCCL_ASYNC_ERROR_HANDLING is deprecated; use TORCH_NCCL_ASYNC_ERROR_HANDLING instead (function operator()) -[rank5]:[W318 14:38:18.836119632 Utils.hpp:112] Warning: Environment variable NCCL_ASYNC_ERROR_HANDLING is deprecated; use TORCH_NCCL_ASYNC_ERROR_HANDLING instead (function operator()) -[rank4]:[W318 14:38:18.836127772 Utils.hpp:112] Warning: Environment variable NCCL_ASYNC_ERROR_HANDLING is deprecated; use TORCH_NCCL_ASYNC_ERROR_HANDLING instead (function operator()) -[rank1]:[W318 14:38:18.836354967 Utils.hpp:112] Warning: Environment variable NCCL_ASYNC_ERROR_HANDLING is deprecated; use TORCH_NCCL_ASYNC_ERROR_HANDLING instead (function operator()) -warmup_step:1/20 -warmup_step:2/20 -warmup_step:3/20 -warmup_step:4/20 -warmup_step:5/20 -warmup_step:6/20 -warmup_step:7/20 -warmup_step:8/20 -warmup_step:9/20 -warmup_step:10/20 -warmup_step:11/20 -warmup_step:12/20 -warmup_step:13/20 -warmup_step:14/20 -warmup_step:15/20 -warmup_step:16/20 -warmup_step:17/20 -warmup_step:18/20 -warmup_step:19/20 -warmup_step:20/20 -step:0/20000 val_loss:6.9370 val_bpb:4.0978 train_time:0ms step_avg:0.01ms -step:1/20000 train_loss:6.9408 train_time:24ms step_avg:23.99ms -step:2/20000 train_loss:16.8763 train_time:67ms step_avg:33.39ms -step:3/20000 train_loss:9.0044 train_time:110ms step_avg:36.62ms -step:4/20000 train_loss:6.5686 train_time:152ms step_avg:37.99ms -step:5/20000 train_loss:6.6665 train_time:195ms step_avg:38.97ms -step:6/20000 train_loss:6.5027 train_time:239ms step_avg:39.81ms -step:7/20000 train_loss:6.2808 train_time:280ms step_avg:40.05ms -step:8/20000 train_loss:5.9951 train_time:324ms step_avg:40.52ms -step:9/20000 train_loss:6.0187 train_time:367ms step_avg:40.77ms -step:10/20000 train_loss:5.9718 train_time:409ms step_avg:40.93ms -step:50/20000 train_loss:3.9508 train_time:2126ms step_avg:42.52ms -step:100/20000 train_loss:3.3373 train_time:4267ms step_avg:42.67ms -step:150/20000 train_loss:2.9651 train_time:6414ms step_avg:42.76ms -step:200/20000 train_loss:2.8041 train_time:8677ms step_avg:43.38ms -step:200/20000 val_loss:2.8397 val_bpb:1.6774 train_time:8699ms step_avg:43.49ms -step:250/20000 train_loss:2.7379 train_time:10816ms step_avg:43.27ms -step:300/20000 train_loss:2.6613 train_time:12958ms step_avg:43.19ms -step:350/20000 train_loss:2.6434 train_time:15097ms step_avg:43.13ms -step:400/20000 train_loss:2.7684 train_time:17357ms step_avg:43.39ms -step:400/20000 val_loss:2.5687 val_bpb:1.5174 train_time:17382ms step_avg:43.45ms -step:450/20000 train_loss:2.6035 train_time:19502ms step_avg:43.34ms -step:500/20000 train_loss:2.5265 train_time:21643ms step_avg:43.29ms -step:550/20000 train_loss:2.4803 train_time:23782ms step_avg:43.24ms -step:600/20000 train_loss:2.4731 train_time:26034ms step_avg:43.39ms -step:600/20000 val_loss:2.4456 val_bpb:1.4447 train_time:26059ms step_avg:43.43ms -step:650/20000 train_loss:2.3204 train_time:28175ms step_avg:43.35ms -step:700/20000 train_loss:2.5926 train_time:30315ms step_avg:43.31ms -step:750/20000 train_loss:2.4301 train_time:32457ms step_avg:43.28ms -step:800/20000 train_loss:2.4775 train_time:34707ms step_avg:43.38ms -step:800/20000 val_loss:2.3868 val_bpb:1.4099 train_time:34732ms step_avg:43.42ms -step:850/20000 train_loss:2.3941 train_time:36851ms step_avg:43.35ms -step:900/20000 train_loss:2.3716 train_time:38990ms step_avg:43.32ms -step:950/20000 train_loss:2.3216 train_time:41131ms step_avg:43.30ms -step:1000/20000 train_loss:2.3030 train_time:43390ms step_avg:43.39ms -step:1000/20000 val_loss:2.3370 val_bpb:1.3805 train_time:43415ms step_avg:43.42ms -step:1050/20000 train_loss:2.3893 train_time:45532ms step_avg:43.36ms -step:1100/20000 train_loss:2.4145 train_time:47675ms step_avg:43.34ms -step:1150/20000 train_loss:2.2261 train_time:49933ms step_avg:43.42ms -step:1200/20000 train_loss:2.2607 train_time:52072ms step_avg:43.39ms -step:1200/20000 val_loss:2.3026 val_bpb:1.3602 train_time:52097ms step_avg:43.41ms -step:1250/20000 train_loss:2.3312 train_time:54219ms step_avg:43.38ms -step:1300/20000 train_loss:2.3575 train_time:56363ms step_avg:43.36ms -step:1350/20000 train_loss:2.2774 train_time:58628ms step_avg:43.43ms -step:1400/20000 train_loss:2.2436 train_time:60772ms step_avg:43.41ms -step:1400/20000 val_loss:2.2812 val_bpb:1.3475 train_time:60797ms step_avg:43.43ms -step:1450/20000 train_loss:2.3006 train_time:62917ms step_avg:43.39ms -step:1500/20000 train_loss:2.2831 train_time:65060ms step_avg:43.37ms -step:1550/20000 train_loss:2.2957 train_time:67324ms step_avg:43.43ms -step:1600/20000 train_loss:2.2187 train_time:69467ms step_avg:43.42ms -step:1600/20000 val_loss:2.2631 val_bpb:1.3368 train_time:69491ms step_avg:43.43ms -step:1650/20000 train_loss:2.2629 train_time:71614ms step_avg:43.40ms -step:1700/20000 train_loss:2.2619 train_time:73759ms step_avg:43.39ms -step:1750/20000 train_loss:2.1068 train_time:76028ms step_avg:43.44ms -step:1800/20000 train_loss:2.3312 train_time:78171ms step_avg:43.43ms -step:1800/20000 val_loss:2.2479 val_bpb:1.3279 train_time:78197ms step_avg:43.44ms -step:1850/20000 train_loss:2.2211 train_time:80317ms step_avg:43.41ms -step:1900/20000 train_loss:2.2477 train_time:82462ms step_avg:43.40ms -step:1950/20000 train_loss:2.2707 train_time:84723ms step_avg:43.45ms -step:2000/20000 train_loss:2.2346 train_time:86867ms step_avg:43.43ms -step:2000/20000 val_loss:2.2368 val_bpb:1.3213 train_time:86892ms step_avg:43.45ms -step:2050/20000 train_loss:2.0689 train_time:89013ms step_avg:43.42ms -step:2100/20000 train_loss:2.3382 train_time:91276ms step_avg:43.46ms -step:2150/20000 train_loss:2.1161 train_time:93418ms step_avg:43.45ms -step:2200/20000 train_loss:2.2380 train_time:95565ms step_avg:43.44ms -step:2200/20000 val_loss:2.2251 val_bpb:1.3144 train_time:95590ms step_avg:43.45ms -step:2250/20000 train_loss:2.2362 train_time:97711ms step_avg:43.43ms -step:2300/20000 train_loss:2.2390 train_time:99973ms step_avg:43.47ms -step:2350/20000 train_loss:2.1494 train_time:102118ms step_avg:43.45ms -step:2400/20000 train_loss:2.1004 train_time:104264ms step_avg:43.44ms -step:2400/20000 val_loss:2.2158 val_bpb:1.3089 train_time:104288ms step_avg:43.45ms -step:2450/20000 train_loss:2.2078 train_time:106409ms step_avg:43.43ms -step:2500/20000 train_loss:2.2990 train_time:108679ms step_avg:43.47ms -step:2550/20000 train_loss:2.3510 train_time:110825ms step_avg:43.46ms -step:2600/20000 train_loss:2.1989 train_time:112969ms step_avg:43.45ms -step:2600/20000 val_loss:2.2097 val_bpb:1.3053 train_time:112994ms step_avg:43.46ms -step:2650/20000 train_loss:2.0953 train_time:115115ms step_avg:43.44ms -step:2700/20000 train_loss:2.2119 train_time:117382ms step_avg:43.47ms -step:2750/20000 train_loss:2.2833 train_time:119524ms step_avg:43.46ms -step:2800/20000 train_loss:2.2056 train_time:121673ms step_avg:43.45ms -step:2800/20000 val_loss:2.2011 val_bpb:1.3002 train_time:121697ms step_avg:43.46ms -step:2850/20000 train_loss:2.1613 train_time:123815ms step_avg:43.44ms -step:2900/20000 train_loss:2.2400 train_time:126078ms step_avg:43.48ms -step:2950/20000 train_loss:2.2531 train_time:128222ms step_avg:43.47ms -step:3000/20000 train_loss:2.1098 train_time:130368ms step_avg:43.46ms -step:3000/20000 val_loss:2.1953 val_bpb:1.2968 train_time:130392ms step_avg:43.46ms -step:3050/20000 train_loss:2.4246 train_time:132514ms step_avg:43.45ms -step:3100/20000 train_loss:2.1884 train_time:134780ms step_avg:43.48ms -step:3150/20000 train_loss:2.2749 train_time:136926ms step_avg:43.47ms -step:3200/20000 train_loss:2.1492 train_time:139071ms step_avg:43.46ms -step:3200/20000 val_loss:2.1881 val_bpb:1.2925 train_time:139096ms step_avg:43.47ms -step:3250/20000 train_loss:2.1286 train_time:141341ms step_avg:43.49ms -step:3300/20000 train_loss:2.1058 train_time:143485ms step_avg:43.48ms -step:3350/20000 train_loss:2.2214 train_time:145628ms step_avg:43.47ms -step:3400/20000 train_loss:2.2454 train_time:147773ms step_avg:43.46ms -step:3400/20000 val_loss:2.1854 val_bpb:1.2909 train_time:147798ms step_avg:43.47ms -step:3450/20000 train_loss:2.2601 train_time:150039ms step_avg:43.49ms -step:3500/20000 train_loss:2.1183 train_time:152184ms step_avg:43.48ms -step:3550/20000 train_loss:2.0846 train_time:154329ms step_avg:43.47ms -step:3600/20000 train_loss:2.2507 train_time:156472ms step_avg:43.46ms -step:3600/20000 val_loss:2.1784 val_bpb:1.2868 train_time:156496ms step_avg:43.47ms -step:3650/20000 train_loss:2.1383 train_time:158738ms step_avg:43.49ms -step:3700/20000 train_loss:2.2848 train_time:160882ms step_avg:43.48ms -step:3750/20000 train_loss:2.1982 train_time:163029ms step_avg:43.47ms -step:3800/20000 train_loss:2.1399 train_time:165176ms step_avg:43.47ms -step:3800/20000 val_loss:2.1767 val_bpb:1.2858 train_time:165200ms step_avg:43.47ms -step:3850/20000 train_loss:2.3361 train_time:167438ms step_avg:43.49ms -step:3900/20000 train_loss:2.2756 train_time:169582ms step_avg:43.48ms -step:3950/20000 train_loss:2.1261 train_time:171729ms step_avg:43.48ms -step:4000/20000 train_loss:2.1437 train_time:173878ms step_avg:43.47ms -step:4000/20000 val_loss:2.1718 val_bpb:1.2829 train_time:173903ms step_avg:43.48ms -step:4050/20000 train_loss:2.1718 train_time:176147ms step_avg:43.49ms -step:4100/20000 train_loss:2.1899 train_time:178291ms step_avg:43.49ms -step:4150/20000 train_loss:2.1285 train_time:180438ms step_avg:43.48ms -step:4200/20000 train_loss:2.0498 train_time:182707ms step_avg:43.50ms -step:4200/20000 val_loss:2.1666 val_bpb:1.2798 train_time:182731ms step_avg:43.51ms -step:4250/20000 train_loss:2.2487 train_time:184852ms step_avg:43.49ms -step:4300/20000 train_loss:2.1979 train_time:186996ms step_avg:43.49ms -step:4350/20000 train_loss:2.1314 train_time:189141ms step_avg:43.48ms -step:4400/20000 train_loss:2.1727 train_time:191402ms step_avg:43.50ms -step:4400/20000 val_loss:2.1625 val_bpb:1.2774 train_time:191427ms step_avg:43.51ms -step:4450/20000 train_loss:2.1882 train_time:193549ms step_avg:43.49ms -step:4500/20000 train_loss:2.0735 train_time:195696ms step_avg:43.49ms -step:4550/20000 train_loss:2.1347 train_time:197840ms step_avg:43.48ms -step:4600/20000 train_loss:2.1710 train_time:200091ms step_avg:43.50ms -step:4600/20000 val_loss:2.1597 val_bpb:1.2757 train_time:200114ms step_avg:43.50ms -step:4650/20000 train_loss:2.2563 train_time:202236ms step_avg:43.49ms -step:4700/20000 train_loss:2.2077 train_time:204381ms step_avg:43.49ms -step:4750/20000 train_loss:2.1328 train_time:206643ms step_avg:43.50ms -step:4800/20000 train_loss:2.1473 train_time:208788ms step_avg:43.50ms -step:4800/20000 val_loss:2.1579 val_bpb:1.2747 train_time:208812ms step_avg:43.50ms -step:4850/20000 train_loss:2.2067 train_time:210933ms step_avg:43.49ms -step:4900/20000 train_loss:2.1119 train_time:213078ms step_avg:43.49ms -step:4950/20000 train_loss:2.0031 train_time:215339ms step_avg:43.50ms -step:5000/20000 train_loss:2.1104 train_time:217483ms step_avg:43.50ms -step:5000/20000 val_loss:2.1532 val_bpb:1.2719 train_time:217508ms step_avg:43.50ms -step:5050/20000 train_loss:2.0232 train_time:219627ms step_avg:43.49ms -step:5100/20000 train_loss:2.1995 train_time:221774ms step_avg:43.49ms -step:5150/20000 train_loss:2.0709 train_time:224038ms step_avg:43.50ms -step:5200/20000 train_loss:2.0972 train_time:226182ms step_avg:43.50ms -step:5200/20000 val_loss:2.1501 val_bpb:1.2701 train_time:226207ms step_avg:43.50ms -step:5250/20000 train_loss:2.1395 train_time:228330ms step_avg:43.49ms -step:5300/20000 train_loss:2.0947 train_time:230476ms step_avg:43.49ms -step:5350/20000 train_loss:2.0819 train_time:232740ms step_avg:43.50ms -step:5400/20000 train_loss:2.2099 train_time:234884ms step_avg:43.50ms -step:5400/20000 val_loss:2.1475 val_bpb:1.2685 train_time:234909ms step_avg:43.50ms -step:5450/20000 train_loss:2.1314 train_time:237031ms step_avg:43.49ms -step:5500/20000 train_loss:2.2057 train_time:239295ms step_avg:43.51ms -step:5550/20000 train_loss:2.0856 train_time:241437ms step_avg:43.50ms -step:5600/20000 train_loss:2.1448 train_time:243583ms step_avg:43.50ms -step:5600/20000 val_loss:2.1455 val_bpb:1.2674 train_time:243608ms step_avg:43.50ms -step:5650/20000 train_loss:2.0312 train_time:245730ms step_avg:43.49ms -step:5700/20000 train_loss:2.1392 train_time:247996ms step_avg:43.51ms -step:5750/20000 train_loss:2.0206 train_time:250140ms step_avg:43.50ms -step:5800/20000 train_loss:2.2107 train_time:252283ms step_avg:43.50ms -step:5800/20000 val_loss:2.1439 val_bpb:1.2664 train_time:252308ms step_avg:43.50ms -step:5850/20000 train_loss:2.0973 train_time:254429ms step_avg:43.49ms -step:5900/20000 train_loss:2.1270 train_time:256697ms step_avg:43.51ms -step:5950/20000 train_loss:2.0899 train_time:258840ms step_avg:43.50ms -step:6000/20000 train_loss:2.2182 train_time:260985ms step_avg:43.50ms -step:6000/20000 val_loss:2.1445 val_bpb:1.2668 train_time:261009ms step_avg:43.50ms -step:6050/20000 train_loss:2.1230 train_time:263130ms step_avg:43.49ms -step:6100/20000 train_loss:2.1640 train_time:265401ms step_avg:43.51ms -step:6150/20000 train_loss:2.1960 train_time:267547ms step_avg:43.50ms -step:6200/20000 train_loss:2.1217 train_time:269692ms step_avg:43.50ms -step:6200/20000 val_loss:2.1416 val_bpb:1.2651 train_time:269717ms step_avg:43.50ms -step:6250/20000 train_loss:2.1106 train_time:271837ms step_avg:43.49ms -step:6300/20000 train_loss:2.1989 train_time:274105ms step_avg:43.51ms -step:6350/20000 train_loss:2.1738 train_time:276249ms step_avg:43.50ms -step:6400/20000 train_loss:2.1333 train_time:278396ms step_avg:43.50ms -step:6400/20000 val_loss:2.1377 val_bpb:1.2628 train_time:278421ms step_avg:43.50ms -step:6450/20000 train_loss:1.9696 train_time:280544ms step_avg:43.50ms -step:6500/20000 train_loss:2.1279 train_time:282815ms step_avg:43.51ms -step:6550/20000 train_loss:2.2768 train_time:284958ms step_avg:43.51ms -step:6600/20000 train_loss:2.1060 train_time:287102ms step_avg:43.50ms -step:6600/20000 val_loss:2.1354 val_bpb:1.2614 train_time:287126ms step_avg:43.50ms -step:6650/20000 train_loss:2.1036 train_time:289368ms step_avg:43.51ms -step:6700/20000 train_loss:2.1438 train_time:291511ms step_avg:43.51ms -step:6750/20000 train_loss:1.8938 train_time:293654ms step_avg:43.50ms -step:6800/20000 train_loss:2.1809 train_time:295799ms step_avg:43.50ms -step:6800/20000 val_loss:2.1342 val_bpb:1.2607 train_time:295824ms step_avg:43.50ms -step:6850/20000 train_loss:2.0978 train_time:298068ms step_avg:43.51ms -step:6900/20000 train_loss:2.1146 train_time:300210ms step_avg:43.51ms -step:6950/20000 train_loss:2.1328 train_time:302354ms step_avg:43.50ms -step:7000/20000 train_loss:2.1537 train_time:304499ms step_avg:43.50ms -step:7000/20000 val_loss:2.1326 val_bpb:1.2598 train_time:304523ms step_avg:43.50ms -step:7050/20000 train_loss:2.1382 train_time:306765ms step_avg:43.51ms -step:7100/20000 train_loss:2.1078 train_time:308911ms step_avg:43.51ms -step:7150/20000 train_loss:2.1952 train_time:311056ms step_avg:43.50ms -step:7200/20000 train_loss:2.1143 train_time:313204ms step_avg:43.50ms -step:7200/20000 val_loss:2.1299 val_bpb:1.2582 train_time:313228ms step_avg:43.50ms -step:7250/20000 train_loss:2.1009 train_time:315469ms step_avg:43.51ms -step:7300/20000 train_loss:2.1529 train_time:317612ms step_avg:43.51ms -step:7350/20000 train_loss:2.1532 train_time:319759ms step_avg:43.50ms -step:7400/20000 train_loss:2.1137 train_time:321901ms step_avg:43.50ms -step:7400/20000 val_loss:2.1282 val_bpb:1.2572 train_time:321927ms step_avg:43.50ms -step:7450/20000 train_loss:2.4067 train_time:324167ms step_avg:43.51ms -step:7500/20000 train_loss:2.0751 train_time:326311ms step_avg:43.51ms -step:7550/20000 train_loss:2.1258 train_time:328457ms step_avg:43.50ms -step:7600/20000 train_loss:2.1723 train_time:330730ms step_avg:43.52ms -step:7600/20000 val_loss:2.1289 val_bpb:1.2576 train_time:330754ms step_avg:43.52ms -step:7650/20000 train_loss:2.2193 train_time:332878ms step_avg:43.51ms -step:7700/20000 train_loss:2.1329 train_time:335023ms step_avg:43.51ms -step:7750/20000 train_loss:2.0562 train_time:337169ms step_avg:43.51ms -step:7800/20000 train_loss:2.1669 train_time:339436ms step_avg:43.52ms -step:7800/20000 val_loss:2.1252 val_bpb:1.2554 train_time:339460ms step_avg:43.52ms -step:7850/20000 train_loss:2.0994 train_time:341583ms step_avg:43.51ms -step:7900/20000 train_loss:2.1585 train_time:343729ms step_avg:43.51ms -step:7950/20000 train_loss:2.1319 train_time:345873ms step_avg:43.51ms -step:8000/20000 train_loss:2.2613 train_time:348141ms step_avg:43.52ms -step:8000/20000 val_loss:2.1232 val_bpb:1.2542 train_time:348165ms step_avg:43.52ms -step:8050/20000 train_loss:2.1775 train_time:350287ms step_avg:43.51ms -step:8100/20000 train_loss:1.9587 train_time:352431ms step_avg:43.51ms -step:8150/20000 train_loss:2.0401 train_time:354575ms step_avg:43.51ms -step:8200/20000 train_loss:2.1076 train_time:356845ms step_avg:43.52ms -step:8200/20000 val_loss:2.1228 val_bpb:1.2540 train_time:356869ms step_avg:43.52ms -step:8250/20000 train_loss:2.0951 train_time:358988ms step_avg:43.51ms -step:8300/20000 train_loss:2.2244 train_time:361133ms step_avg:43.51ms -step:8350/20000 train_loss:2.0681 train_time:363279ms step_avg:43.51ms -step:8400/20000 train_loss:2.1494 train_time:365552ms step_avg:43.52ms -step:8400/20000 val_loss:2.1201 val_bpb:1.2524 train_time:365577ms step_avg:43.52ms -step:8450/20000 train_loss:2.1278 train_time:367698ms step_avg:43.51ms -step:8500/20000 train_loss:2.0289 train_time:369845ms step_avg:43.51ms -step:8550/20000 train_loss:2.0465 train_time:372114ms step_avg:43.52ms -step:8600/20000 train_loss:2.0682 train_time:374259ms step_avg:43.52ms -step:8600/20000 val_loss:2.1206 val_bpb:1.2526 train_time:374282ms step_avg:43.52ms -step:8650/20000 train_loss:2.2717 train_time:376403ms step_avg:43.51ms -step:8700/20000 train_loss:2.1795 train_time:378549ms step_avg:43.51ms -step:8750/20000 train_loss:2.0492 train_time:380817ms step_avg:43.52ms -step:8800/20000 train_loss:2.1100 train_time:382964ms step_avg:43.52ms -step:8800/20000 val_loss:2.1192 val_bpb:1.2518 train_time:382989ms step_avg:43.52ms -step:8850/20000 train_loss:2.4323 train_time:385110ms step_avg:43.52ms -step:8900/20000 train_loss:2.1016 train_time:387258ms step_avg:43.51ms -step:8950/20000 train_loss:2.0290 train_time:389530ms step_avg:43.52ms -step:9000/20000 train_loss:2.1119 train_time:391675ms step_avg:43.52ms -step:9000/20000 val_loss:2.1204 val_bpb:1.2525 train_time:391698ms step_avg:43.52ms -step:9050/20000 train_loss:2.0826 train_time:393819ms step_avg:43.52ms -step:9100/20000 train_loss:2.0427 train_time:395963ms step_avg:43.51ms -step:9150/20000 train_loss:2.1201 train_time:398238ms step_avg:43.52ms -step:9200/20000 train_loss:2.1490 train_time:400385ms step_avg:43.52ms -step:9200/20000 val_loss:2.1170 val_bpb:1.2505 train_time:400409ms step_avg:43.52ms -step:9250/20000 train_loss:2.1221 train_time:402534ms step_avg:43.52ms -step:9300/20000 train_loss:2.4550 train_time:404680ms step_avg:43.51ms -step:9350/20000 train_loss:2.0384 train_time:406932ms step_avg:43.52ms -step:9400/20000 train_loss:2.0736 train_time:409077ms step_avg:43.52ms -step:9400/20000 val_loss:2.1139 val_bpb:1.2487 train_time:409102ms step_avg:43.52ms -step:9450/20000 train_loss:2.1096 train_time:411223ms step_avg:43.52ms -step:9500/20000 train_loss:2.1070 train_time:413493ms step_avg:43.53ms -step:9550/20000 train_loss:2.0249 train_time:415641ms step_avg:43.52ms -step:9600/20000 train_loss:2.1141 train_time:417785ms step_avg:43.52ms -step:9600/20000 val_loss:2.1138 val_bpb:1.2486 train_time:417809ms step_avg:43.52ms -step:9650/20000 train_loss:2.0183 train_time:419932ms step_avg:43.52ms -step:9700/20000 train_loss:2.1482 train_time:422212ms step_avg:43.53ms -step:9750/20000 train_loss:2.1811 train_time:424359ms step_avg:43.52ms -step:9800/20000 train_loss:2.1011 train_time:426503ms step_avg:43.52ms -step:9800/20000 val_loss:2.1143 val_bpb:1.2489 train_time:426528ms step_avg:43.52ms -step:9850/20000 train_loss:2.1134 train_time:428771ms step_avg:43.53ms -step:9900/20000 train_loss:2.0497 train_time:430915ms step_avg:43.53ms -step:9950/20000 train_loss:2.1989 train_time:433061ms step_avg:43.52ms -step:10000/20000 train_loss:2.1982 train_time:435207ms step_avg:43.52ms -step:10000/20000 val_loss:2.1122 val_bpb:1.2477 train_time:435232ms step_avg:43.52ms -step:10050/20000 train_loss:2.0940 train_time:437485ms step_avg:43.53ms -step:10100/20000 train_loss:2.1277 train_time:439630ms step_avg:43.53ms -step:10150/20000 train_loss:2.0896 train_time:441773ms step_avg:43.52ms -step:10200/20000 train_loss:2.0642 train_time:443918ms step_avg:43.52ms -step:10200/20000 val_loss:2.1112 val_bpb:1.2471 train_time:443941ms step_avg:43.52ms -step:10250/20000 train_loss:2.0627 train_time:446192ms step_avg:43.53ms -step:10300/20000 train_loss:2.2191 train_time:448339ms step_avg:43.53ms -step:10350/20000 train_loss:2.1354 train_time:450485ms step_avg:43.53ms -step:10400/20000 train_loss:2.0705 train_time:452630ms step_avg:43.52ms -step:10400/20000 val_loss:2.1098 val_bpb:1.2463 train_time:452654ms step_avg:43.52ms -step:10450/20000 train_loss:2.0663 train_time:454900ms step_avg:43.53ms -step:10500/20000 train_loss:2.1334 train_time:457046ms step_avg:43.53ms -step:10550/20000 train_loss:2.1931 train_time:459192ms step_avg:43.53ms -step:10600/20000 train_loss:2.0978 train_time:461337ms step_avg:43.52ms -step:10600/20000 val_loss:2.1081 val_bpb:1.2453 train_time:461361ms step_avg:43.52ms -step:10650/20000 train_loss:2.0676 train_time:463610ms step_avg:43.53ms -step:10700/20000 train_loss:2.2333 train_time:465754ms step_avg:43.53ms -step:10750/20000 train_loss:2.1661 train_time:467899ms step_avg:43.53ms -step:10800/20000 train_loss:2.0966 train_time:470044ms step_avg:43.52ms -step:10800/20000 val_loss:2.1081 val_bpb:1.2453 train_time:470069ms step_avg:43.52ms -step:10850/20000 train_loss:2.0708 train_time:472323ms step_avg:43.53ms -step:10900/20000 train_loss:2.1666 train_time:474468ms step_avg:43.53ms -step:10950/20000 train_loss:2.1079 train_time:476615ms step_avg:43.53ms -step:11000/20000 train_loss:2.0774 train_time:478893ms step_avg:43.54ms -step:11000/20000 val_loss:2.1069 val_bpb:1.2446 train_time:478917ms step_avg:43.54ms -step:11050/20000 train_loss:2.1288 train_time:481038ms step_avg:43.53ms -step:11100/20000 train_loss:2.0801 train_time:483185ms step_avg:43.53ms -step:11150/20000 train_loss:1.8743 train_time:485331ms step_avg:43.53ms -step:11200/20000 train_loss:2.1471 train_time:487603ms step_avg:43.54ms -step:11200/20000 val_loss:2.1080 val_bpb:1.2452 train_time:487627ms step_avg:43.54ms -step:11250/20000 train_loss:2.2046 train_time:489748ms step_avg:43.53ms -step:11300/20000 train_loss:2.0957 train_time:491892ms step_avg:43.53ms -step:11350/20000 train_loss:2.0963 train_time:494038ms step_avg:43.53ms -step:11400/20000 train_loss:2.3223 train_time:496318ms step_avg:43.54ms -step:11400/20000 val_loss:2.1051 val_bpb:1.2435 train_time:496342ms step_avg:43.54ms -step:11450/20000 train_loss:2.0724 train_time:498464ms step_avg:43.53ms -step:11500/20000 train_loss:2.1197 train_time:500609ms step_avg:43.53ms -step:11550/20000 train_loss:2.0975 train_time:502754ms step_avg:43.53ms -step:11600/20000 train_loss:2.1091 train_time:505029ms step_avg:43.54ms -step:11600/20000 val_loss:2.1054 val_bpb:1.2437 train_time:505053ms step_avg:43.54ms -step:11650/20000 train_loss:2.1235 train_time:507175ms step_avg:43.53ms -step:11700/20000 train_loss:2.0795 train_time:509324ms step_avg:43.53ms -step:11750/20000 train_loss:2.0662 train_time:511469ms step_avg:43.53ms -step:11800/20000 train_loss:2.0765 train_time:513742ms step_avg:43.54ms -step:11800/20000 val_loss:2.1048 val_bpb:1.2433 train_time:513766ms step_avg:43.54ms -step:11850/20000 train_loss:2.1202 train_time:515888ms step_avg:43.53ms -step:11900/20000 train_loss:2.1029 train_time:518033ms step_avg:43.53ms -step:11950/20000 train_loss:2.1512 train_time:520308ms step_avg:43.54ms -step:12000/20000 train_loss:2.1814 train_time:522453ms step_avg:43.54ms -step:12000/20000 val_loss:2.1029 val_bpb:1.2422 train_time:522477ms step_avg:43.54ms -step:12050/20000 train_loss:2.1085 train_time:524601ms step_avg:43.54ms -step:12100/20000 train_loss:2.0347 train_time:526747ms step_avg:43.53ms -step:12150/20000 train_loss:2.0601 train_time:529018ms step_avg:43.54ms -step:12200/20000 train_loss:2.0387 train_time:531162ms step_avg:43.54ms -step:12200/20000 val_loss:2.1021 val_bpb:1.2418 train_time:531186ms step_avg:43.54ms -step:12250/20000 train_loss:2.0381 train_time:533312ms step_avg:43.54ms -step:12300/20000 train_loss:2.1302 train_time:535458ms step_avg:43.53ms -step:12350/20000 train_loss:2.1272 train_time:537727ms step_avg:43.54ms -step:12400/20000 train_loss:2.1828 train_time:539873ms step_avg:43.54ms -step:12400/20000 val_loss:2.1001 val_bpb:1.2406 train_time:539897ms step_avg:43.54ms -step:12450/20000 train_loss:2.1003 train_time:542019ms step_avg:43.54ms -step:12500/20000 train_loss:2.0696 train_time:544164ms step_avg:43.53ms -step:12550/20000 train_loss:2.1302 train_time:546436ms step_avg:43.54ms -step:12600/20000 train_loss:2.0527 train_time:548582ms step_avg:43.54ms -step:12600/20000 val_loss:2.0998 val_bpb:1.2404 train_time:548606ms step_avg:43.54ms -step:12650/20000 train_loss:2.1438 train_time:550728ms step_avg:43.54ms -step:12700/20000 train_loss:2.2689 train_time:552877ms step_avg:43.53ms -step:12750/20000 train_loss:2.1438 train_time:555147ms step_avg:43.54ms -step:12800/20000 train_loss:2.0105 train_time:557293ms step_avg:43.54ms -step:12800/20000 val_loss:2.0930 val_bpb:1.2364 train_time:557317ms step_avg:43.54ms -step:12850/20000 train_loss:2.0413 train_time:559440ms step_avg:43.54ms -step:12900/20000 train_loss:2.0630 train_time:561586ms step_avg:43.53ms -step:12950/20000 train_loss:2.1627 train_time:563863ms step_avg:43.54ms -step:13000/20000 train_loss:1.9579 train_time:566009ms step_avg:43.54ms -step:13000/20000 val_loss:2.0859 val_bpb:1.2322 train_time:566032ms step_avg:43.54ms -step:13050/20000 train_loss:2.0206 train_time:568155ms step_avg:43.54ms -step:13100/20000 train_loss:1.9294 train_time:570432ms step_avg:43.54ms -step:13150/20000 train_loss:2.0689 train_time:572576ms step_avg:43.54ms -step:13200/20000 train_loss:2.0074 train_time:574722ms step_avg:43.54ms -step:13200/20000 val_loss:2.0790 val_bpb:1.2281 train_time:574747ms step_avg:43.54ms -step:13250/20000 train_loss:2.0596 train_time:576871ms step_avg:43.54ms -step:13300/20000 train_loss:1.9474 train_time:579143ms step_avg:43.54ms -step:13350/20000 train_loss:2.0459 train_time:581289ms step_avg:43.54ms -step:13400/20000 train_loss:2.0441 train_time:583434ms step_avg:43.54ms -step:13400/20000 val_loss:2.0718 val_bpb:1.2239 train_time:583458ms step_avg:43.54ms -step:13450/20000 train_loss:2.1638 train_time:585582ms step_avg:43.54ms -step:13500/20000 train_loss:2.1216 train_time:587857ms step_avg:43.54ms -step:13550/20000 train_loss:2.1855 train_time:590003ms step_avg:43.54ms -step:13600/20000 train_loss:2.0234 train_time:592147ms step_avg:43.54ms -step:13600/20000 val_loss:2.0649 val_bpb:1.2197 train_time:592172ms step_avg:43.54ms -step:13650/20000 train_loss:2.0316 train_time:594295ms step_avg:43.54ms -step:13700/20000 train_loss:2.0323 train_time:596577ms step_avg:43.55ms -step:13750/20000 train_loss:1.9910 train_time:598726ms step_avg:43.54ms -step:13780/20000 val_loss:2.0606 val_bpb:1.2172 train_time:600038ms step_avg:43.54ms -stopping_early: wallclock_cap train_time:600038ms step:13780/20000 -peak memory allocated: 10184 MiB reserved: 10200 MiB -Serialized model: 67224983 bytes -Code size: 47642 bytes -Total submission size: 67272625 bytes -Serialized model int8+zlib: 15815847 bytes (payload:17178912 raw_torch:17224025 payload_ratio:3.91x) -Total submission size int8+zlib: 15863489 bytes -final_int8_zlib_roundtrip val_loss:2.0727 val_bpb:1.2244 eval_time:1401ms -final_int8_zlib_roundtrip_exact val_loss:2.07269931 val_bpb:1.22436570 diff --git a/records/track_10min_16mb/2026-03-17_NaiveBaseline/train_gpt.py b/records/track_10min_16mb/2026-03-17_NaiveBaseline/train_gpt.py deleted file mode 100644 index 0deb0565f5..0000000000 --- a/records/track_10min_16mb/2026-03-17_NaiveBaseline/train_gpt.py +++ /dev/null @@ -1,1126 +0,0 @@ -""" -The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. - -Hard stop: `train_gpt.py` and `train_gpt_mlx.py` must never be longer than 1500 lines. -""" - -from __future__ import annotations - -import copy -import glob -import io -import math -import os -import random -import subprocess -import sys -import time -import uuid -import zlib -from pathlib import Path - -import numpy as np -import sentencepiece as spm -import torch -import torch.distributed as dist -import torch.nn.functional as F -from torch import Tensor, nn -from torch.nn.parallel import DistributedDataParallel as DDP - -# ----------------------------- -# HYPERPARAMETERS -# ----------------------------- -# Default Simple Baseline run: -# - 9 transformer blocks at width 512 -# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion -# - vocab size 1024, sequence length 1024, tied embeddings -# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap - -class Hyperparameters: - # Data paths are shard globs produced by the existing preprocessing pipeline. - data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") - train_files = os.path.join(data_path, "fineweb_train_*.bin") - val_files = os.path.join(data_path, "fineweb_val_*.bin") - tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") - run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) - seed = int(os.environ.get("SEED", 1337)) - - # Validation cadence and batch size. Validation always uses the full fineweb_val split. - val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) - val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) - train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) - - # Training length. - iterations = int(os.environ.get("ITERATIONS", 20000)) - warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) - warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) - train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) - train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) - max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) - qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) - - # Model shape. - vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) - num_layers = int(os.environ.get("NUM_LAYERS", 9)) - num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) - model_dim = int(os.environ.get("MODEL_DIM", 512)) - num_heads = int(os.environ.get("NUM_HEADS", 8)) - mlp_mult = int(os.environ.get("MLP_MULT", 2)) - tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) - rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) - logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) - - # Optimizer hyperparameters. - embed_lr = float(os.environ.get("EMBED_LR", 0.6)) - head_lr = float(os.environ.get("HEAD_LR", 0.008)) - tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) - tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) - matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) - scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) - muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) - muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) - muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) - muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) - beta1 = float(os.environ.get("BETA1", 0.9)) - beta2 = float(os.environ.get("BETA2", 0.95)) - adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) - grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) - -# ----------------------------- -# MUON OPTIMIZER -# ----------------------------- -# -# As borrowed from modded-nanogpt -# Background on Muon: https://kellerjordan.github.io/posts/muon/ - -def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: - # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. - # Muon uses this to normalize matrix-shaped gradients before applying them. - a, b, c = (3.4445, -4.7750, 2.0315) - X = G.bfloat16() - X /= X.norm() + eps - transposed = G.size(0) > G.size(1) - if transposed: - X = X.T - for _ in range(steps): - A = X @ X.T - B = b * A + c * A @ A - X = a * X + B @ X - return X.T if transposed else X - - -class Muon(torch.optim.Optimizer): - def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): - super().__init__( - params, - dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), - ) - - @torch.no_grad() - def step(self, closure=None): - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - distributed = dist.is_available() and dist.is_initialized() - world_size = dist.get_world_size() if distributed else 1 - rank = dist.get_rank() if distributed else 0 - - for group in self.param_groups: - params = group["params"] - if not params: - continue - lr = group["lr"] - momentum = group["momentum"] - backend_steps = group["backend_steps"] - nesterov = group["nesterov"] - - total_params = sum(int(p.numel()) for p in params) - updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) - - curr = 0 - for i, p in enumerate(params): - if i % world_size == rank and p.grad is not None: - g = p.grad - state = self.state[p] - if "momentum_buffer" not in state: - state["momentum_buffer"] = torch.zeros_like(g) - buf = state["momentum_buffer"] - buf.mul_(momentum).add_(g) - if nesterov: - g = g.add(buf, alpha=momentum) - g = zeropower_via_newtonschulz5(g, steps=backend_steps) - # Scale correction from Muon reference implementations. - g *= max(1, g.size(0) / g.size(1)) ** 0.5 - updates_flat[curr : curr + p.numel()] = g.reshape(-1) - curr += p.numel() - - if distributed: - dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) - - curr = 0 - for p in params: - g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) - p.add_(g, alpha=-lr) - curr += p.numel() - - return loss - - -# ----------------------------- -# TOKENIZER-AGNOSTIC EVALUATION SETUP -# ----------------------------- -# -# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. -# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. -# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. -# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. - -def build_sentencepiece_luts( - sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device -) -> tuple[Tensor, Tensor, Tensor]: - sp_vocab_size = int(sp.vocab_size()) - table_size = max(sp_vocab_size, vocab_size) - base_bytes_np = np.zeros((table_size,), dtype=np.int16) - has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) - is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) - for token_id in range(sp_vocab_size): - if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): - continue - is_boundary_token_np[token_id] = False - if sp.is_byte(token_id): - base_bytes_np[token_id] = 1 - continue - piece = sp.id_to_piece(token_id) - if piece.startswith("▁"): - has_leading_space_np[token_id] = True - piece = piece[1:] - base_bytes_np[token_id] = len(piece.encode("utf-8")) - return ( - torch.tensor(base_bytes_np, dtype=torch.int16, device=device), - torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), - torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), - ) - - -def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: - files = [Path(p) for p in sorted(glob.glob(pattern))] - if not files: - raise FileNotFoundError(f"No files found for pattern: {pattern}") - # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. - tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() - usable = ((tokens.numel() - 1) // seq_len) * seq_len - if usable <= 0: - raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") - return tokens[: usable + 1] - - -def eval_val( - args: Hyperparameters, - model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - grad_accum_steps: int, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, -) -> tuple[float, float]: - # Validation computes two metrics: - # - val_loss: token cross-entropy (natural log) - # - val_bpb: tokenizer-agnostic compression metric used by the challenge - local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) - if local_batch_tokens < args.train_seq_len: - raise ValueError( - "VAL_BATCH_SIZE must provide at least one sequence per rank; " - f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " - f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" - ) - local_batch_seqs = local_batch_tokens // args.train_seq_len - total_seqs = (val_tokens.numel() - 1) // args.train_seq_len - seq_start = (total_seqs * rank) // world_size - seq_end = (total_seqs * (rank + 1)) // world_size - val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) - val_token_count = torch.zeros((), device=device, dtype=torch.float64) - val_byte_count = torch.zeros((), device=device, dtype=torch.float64) - - model.eval() - with torch.inference_mode(): - for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): - batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) - raw_start = batch_seq_start * args.train_seq_len - raw_end = batch_seq_end * args.train_seq_len + 1 - local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) - x = local[:-1].reshape(-1, args.train_seq_len) - y = local[1:].reshape(-1, args.train_seq_len) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - batch_loss = model(x, y).detach() - batch_token_count = float(y.numel()) - val_loss_sum += batch_loss.to(torch.float64) * batch_token_count - val_token_count += batch_token_count - prev_ids = x.reshape(-1) - tgt_ids = y.reshape(-1) - token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) - token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) - val_byte_count += token_bytes.to(torch.float64).sum() - - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) - - val_loss = val_loss_sum / val_token_count - bits_per_token = val_loss.item() / math.log(2.0) - tokens_per_byte = val_token_count.item() / val_byte_count.item() - model.train() - return float(val_loss.item()), float(bits_per_token * tokens_per_byte) - -# ----------------------------- -# POST-TRAINING QUANTIZATION -# ----------------------------- -# -# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. -# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. -# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. - -CONTROL_TENSOR_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "CONTROL_TENSOR_NAME_PATTERNS", - "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", - ",".join(CONTROL_TENSOR_NAME_PATTERNS), - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 -INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 -INT8_PER_ROW_SCALE_DTYPE = torch.float16 -INT8_CLIP_PERCENTILE = 99.99984 -INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 - -def tensor_nbytes(t: Tensor) -> int: - return int(t.numel()) * int(t.element_size()) - -def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: - if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): - return t.float().contiguous() - if t.dtype in {torch.float32, torch.bfloat16}: - passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") - return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() - return t - -def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: - t32 = t.float() - if t32.ndim == 2: - # Matrices get one scale per row, which usually tracks output-channel - # ranges much better than a single tensor-wide scale. - clip_abs = ( - torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) - if t32.numel() - else torch.empty((t32.shape[0],), dtype=torch.float32) - ) - clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) - scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) - q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() - return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() - - # Vectors / scalars use a simpler per-tensor scale. - clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 - scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) - q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() - return q, scale - -def quantize_state_dict_int8(state_dict: dict[str, Tensor]): - # Single supported clean-script export format: - # - per-row int8 for 2D float tensors - # - per-tensor int8 for other float tensors - # - exact passthrough for non-floats - # - passthrough for small float tensors, stored as fp16 to save bytes - quantized: dict[str, Tensor] = {} - scales: dict[str, Tensor] = {} - dtypes: dict[str, str] = {} - passthrough: dict[str, Tensor] = {} - passthrough_orig_dtypes: dict[str, str] = {} - qmeta: dict[str, dict[str, object]] = {} - stats = dict.fromkeys( - ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), - 0, - ) - - for name, tensor in state_dict.items(): - t = tensor.detach().to("cpu").contiguous() - stats["param_count"] += int(t.numel()) - stats["num_tensors"] += 1 - stats["baseline_tensor_bytes"] += tensor_nbytes(t) - - if not t.is_floating_point(): - stats["num_nonfloat_tensors"] += 1 - passthrough[name] = t - stats["int8_payload_bytes"] += tensor_nbytes(t) - continue - - # Small float tensors are cheap enough to keep directly. We still downcast - # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. - if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: - kept = keep_float_tensor(name, t, passthrough_orig_dtypes) - passthrough[name] = kept - stats["int8_payload_bytes"] += tensor_nbytes(kept) - continue - - stats["num_float_tensors"] += 1 - q, s = quantize_float_tensor(t) - if s.ndim > 0: - qmeta[name] = {"scheme": "per_row", "axis": 0} - quantized[name] = q - scales[name] = s - dtypes[name] = str(t.dtype).removeprefix("torch.") - stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) - - obj: dict[str, object] = { - "__quant_format__": "int8_clean_per_row_v1", - "quantized": quantized, - "scales": scales, - "dtypes": dtypes, - "passthrough": passthrough, - } - if qmeta: - obj["qmeta"] = qmeta - if passthrough_orig_dtypes: - obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes - return obj, stats - -def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: - out: dict[str, Tensor] = {} - qmeta = obj.get("qmeta", {}) - passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) - for name, q in obj["quantized"].items(): - dtype = getattr(torch, obj["dtypes"][name]) - s = obj["scales"][name] - if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: - s = s.to(dtype=torch.float32) - # Broadcast the saved row scale back across trailing dimensions. - out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() - else: - scale = float(s.item()) - out[name] = (q.float() * scale).to(dtype=dtype).contiguous() - for name, t in obj["passthrough"].items(): - # Restore small tensors, undoing the temporary fp16 storage cast if needed. - out_t = t.detach().to("cpu").contiguous() - orig_dtype = passthrough_orig_dtypes.get(name) - if isinstance(orig_dtype, str): - out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() - out[name] = out_t - return out - - -# ----------------------------- -# DATA LOADING -# ----------------------------- - -def load_data_shard(file: Path) -> Tensor: - header_bytes = 256 * np.dtype(" None: - self.file_idx = (self.file_idx + 1) % len(self.files) - self.tokens = load_data_shard(self.files[self.file_idx]) - self.pos = 0 - - def take(self, n: int) -> Tensor: - chunks: list[Tensor] = [] - remaining = n - while remaining > 0: - avail = self.tokens.numel() - self.pos - if avail <= 0: - self._advance_file() - continue - k = min(remaining, avail) - chunks.append(self.tokens[self.pos : self.pos + k]) - self.pos += k - remaining -= k - return chunks[0] if len(chunks) == 1 else torch.cat(chunks) - - -class DistributedTokenLoader: - # Each call consumes a contiguous chunk from the shared token stream, then slices out - # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. - def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): - self.rank = rank - self.world_size = world_size - self.device = device - self.stream = TokenStream(pattern) - - def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: - local_tokens = global_tokens // (self.world_size * grad_accum_steps) - per_rank_span = local_tokens + 1 - chunk = self.stream.take(per_rank_span * self.world_size) - start = self.rank * per_rank_span - local = chunk[start : start + per_rank_span].to(dtype=torch.int64) - x = local[:-1].reshape(-1, seq_len) - y = local[1:].reshape(-1, seq_len) - return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) - -# ----------------------------- -# TRANSFORMER MODULES -# ----------------------------- - -class RMSNorm(nn.Module): - def __init__(self, eps: float | None = None): - super().__init__() - self.eps = eps - - def forward(self, x: Tensor) -> Tensor: - return F.rms_norm(x, (x.size(-1),), eps=self.eps) - - -class CastedLinear(nn.Linear): - # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. - def forward(self, x: Tensor) -> Tensor: - bias = self.bias.to(x.dtype) if self.bias is not None else None - return F.linear(x, self.weight.to(x.dtype), bias) - - -def restore_low_dim_params_to_fp32(module: nn.Module) -> None: - # Keep small/control parameters in fp32 even when the model body runs in bf16. - with torch.no_grad(): - for name, param in module.named_parameters(): - if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: - param.data = param.data.float() - - -class Rotary(nn.Module): - # Caches cos/sin tables per sequence length on the current device. - def __init__(self, dim: int, base: float = 10000.0): - super().__init__() - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self._seq_len_cached = 0 - self._cos_cached: Tensor | None = None - self._sin_cached: Tensor | None = None - - def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: - if ( - self._cos_cached is None - or self._sin_cached is None - or self._seq_len_cached != seq_len - or self._cos_cached.device != device - ): - t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) - freqs = torch.outer(t, self.inv_freq.to(device)) - self._cos_cached = freqs.cos()[None, None, :, :] - self._sin_cached = freqs.sin()[None, None, :, :] - self._seq_len_cached = seq_len - return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) - - -def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: - half = x.size(-1) // 2 - x1, x2 = x[..., :half], x[..., half:] - return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) - - -class CausalSelfAttention(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - rope_base: float, - qk_gain_init: float, - ): - super().__init__() - if dim % num_heads != 0: - raise ValueError("model_dim must be divisible by num_heads") - if num_heads % num_kv_heads != 0: - raise ValueError("num_heads must be divisible by num_kv_heads") - self.num_heads = num_heads - self.num_kv_heads = num_kv_heads - self.head_dim = dim // num_heads - if self.head_dim % 2 != 0: - raise ValueError("head_dim must be even for RoPE") - kv_dim = self.num_kv_heads * self.head_dim - self.c_q = CastedLinear(dim, dim, bias=False) - self.c_k = CastedLinear(dim, kv_dim, bias=False) - self.c_v = CastedLinear(dim, kv_dim, bias=False) - self.proj = CastedLinear(dim, dim, bias=False) - self.proj._zero_init = True - self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) - self.rotary = Rotary(self.head_dim, base=rope_base) - - def forward(self, x: Tensor) -> Tensor: - bsz, seqlen, dim = x.shape - q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) - k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) - v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) - q = F.rms_norm(q, (q.size(-1),)) - k = F.rms_norm(k, (k.size(-1),)) - cos, sin = self.rotary(seqlen, x.device, q.dtype) - q = apply_rotary_emb(q, cos, sin) - k = apply_rotary_emb(k, cos, sin) - q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] - y = F.scaled_dot_product_attention( - q, - k, - v, - attn_mask=None, - is_causal=True, - enable_gqa=(self.num_kv_heads != self.num_heads), - ) - y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) - return self.proj(y) - - -class MLP(nn.Module): - # relu^2 MLP from the original modded-nanogpt setup - def __init__(self, dim: int, mlp_mult: int): - super().__init__() - hidden = mlp_mult * dim - self.fc = CastedLinear(dim, hidden, bias=False) - self.proj = CastedLinear(hidden, dim, bias=False) - self.proj._zero_init = True - - def forward(self, x: Tensor) -> Tensor: - x = torch.relu(self.fc(x)) - return self.proj(x.square()) - - -class Block(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: int, - rope_base: float, - qk_gain_init: float, - ): - super().__init__() - self.attn_norm = RMSNorm() - self.mlp_norm = RMSNorm() - self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) - self.mlp = MLP(dim, mlp_mult) - self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) - - def forward(self, x: Tensor, x0: Tensor) -> Tensor: - mix = self.resid_mix.to(dtype=x.dtype) - x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 - attn_out = self.attn(self.attn_norm(x)) - x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out - x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) - return x - - -class GPT(nn.Module): - def __init__( - self, - vocab_size: int, - num_layers: int, - model_dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: int, - tie_embeddings: bool, - tied_embed_init_std: float, - logit_softcap: float, - rope_base: float, - qk_gain_init: float, - ): - super().__init__() - if logit_softcap <= 0.0: - raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") - self.tie_embeddings = tie_embeddings - self.tied_embed_init_std = tied_embed_init_std - self.logit_softcap = logit_softcap - self.tok_emb = nn.Embedding(vocab_size, model_dim) - self.num_encoder_layers = num_layers // 2 - self.num_decoder_layers = num_layers - self.num_encoder_layers - self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) - self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) - self.blocks = nn.ModuleList( - [ - Block( - model_dim, - num_heads, - num_kv_heads, - mlp_mult, - rope_base, - qk_gain_init, - ) - for i in range(num_layers) - ] - ) - self.final_norm = RMSNorm() - self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) - if self.lm_head is not None: - self.lm_head._zero_init = True - self._init_weights() - - def _init_weights(self) -> None: - if self.tie_embeddings: - nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) - for module in self.modules(): - if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): - nn.init.zeros_(module.weight) - - def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: - x = self.tok_emb(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - x0 = x - skips: list[Tensor] = [] - - # First half stores skips; second half reuses them in reverse order. - for i in range(self.num_encoder_layers): - x = self.blocks[i](x, x0) - skips.append(x) - for i in range(self.num_decoder_layers): - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - x = self.blocks[self.num_encoder_layers + i](x, x0) - - x = self.final_norm(x).reshape(-1, x.size(-1)) - targets = target_ids.reshape(-1) - if self.tie_embeddings: - logits_proj = F.linear(x, self.tok_emb.weight) - else: - if self.lm_head is None: - raise RuntimeError("lm_head is required when tie_embeddings=False") - logits_proj = self.lm_head(x) - logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - return F.cross_entropy(logits.float(), targets, reduction="mean") - - -# ----------------------------- -# TRAINING -# ----------------------------- - -def main() -> None: - global zeropower_via_newtonschulz5 - - code = Path(__file__).read_text(encoding="utf-8") - args = Hyperparameters() - zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) - - # ----------------------------- - # DISTRIBUTED + CUDA SETUP - # ----------------------------- - - distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ - rank = int(os.environ.get("RANK", "0")) - world_size = int(os.environ.get("WORLD_SIZE", "1")) - local_rank = int(os.environ.get("LOCAL_RANK", "0")) - if world_size <= 0: - raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") - if 8 % world_size != 0: - raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") - grad_accum_steps = 8 // world_size - grad_scale = 1.0 / grad_accum_steps - if not torch.cuda.is_available(): - raise RuntimeError("CUDA is required") - device = torch.device("cuda", local_rank) - torch.cuda.set_device(device) - if distributed: - dist.init_process_group(backend="nccl", device_id=device) - dist.barrier() - master_process = rank == 0 - - # Fast math knobs - torch.backends.cuda.matmul.allow_tf32 = True - torch.backends.cudnn.allow_tf32 = True - from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp - - enable_cudnn_sdp(False) - enable_flash_sdp(True) - enable_mem_efficient_sdp(False) - enable_math_sdp(False) - - logfile = None - if master_process: - os.makedirs("logs", exist_ok=True) - logfile = f"logs/{args.run_id}.txt" - print(logfile) - - def log0(msg: str, console: bool = True) -> None: - if not master_process: - return - if console: - print(msg) - if logfile is not None: - with open(logfile, "a", encoding="utf-8") as f: - print(msg, file=f) - - log0(code, console=False) - log0("=" * 100, console=False) - log0(f"Running Python {sys.version}", console=False) - log0(f"Running PyTorch {torch.__version__}", console=False) - log0( - subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, - console=False, - ) - log0("=" * 100, console=False) - - # ----------------------------- - # TOKENIZER + VALIDATION METRIC SETUP - # ----------------------------- - - random.seed(args.seed) - np.random.seed(args.seed) - torch.manual_seed(args.seed) - torch.cuda.manual_seed_all(args.seed) - - if not args.tokenizer_path.endswith(".model"): - raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") - sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) - if int(sp.vocab_size()) != args.vocab_size: - raise ValueError( - f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" - ) - dataset_dir = Path(args.data_path).resolve() - actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) - val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) - base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( - sp, args.vocab_size, device - ) - log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") - log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") - log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") - - # ----------------------------- - # MODEL + OPTIMIZER SETUP - # ----------------------------- - - base_model = GPT( - vocab_size=args.vocab_size, - num_layers=args.num_layers, - model_dim=args.model_dim, - num_heads=args.num_heads, - num_kv_heads=args.num_kv_heads, - mlp_mult=args.mlp_mult, - tie_embeddings=args.tie_embeddings, - tied_embed_init_std=args.tied_embed_init_std, - logit_softcap=args.logit_softcap, - rope_base=args.rope_base, - qk_gain_init=args.qk_gain_init, - ).to(device).bfloat16() - for module in base_model.modules(): - if isinstance(module, CastedLinear): - module.float() - restore_low_dim_params_to_fp32(base_model) - compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) - model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model - - # Optimizer split: - # - token embedding (Adam) uses EMBED_LR - # - untied lm_head (Adam) uses HEAD_LR - # - matrix params in transformer blocks use MATRIX_LR via Muon - # - vectors/scalars use SCALAR_LR via Adam - block_named_params = list(base_model.blocks.named_parameters()) - matrix_params = [ - p - for name, p in block_named_params - if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - scalar_params = [ - p - for name, p in block_named_params - if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - if base_model.skip_weights.numel() > 0: - scalar_params.append(base_model.skip_weights) - token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr - optimizer_tok = torch.optim.Adam( - [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizer_muon = Muon( - matrix_params, - lr=args.matrix_lr, - momentum=args.muon_momentum, - backend_steps=args.muon_backend_steps, - ) - for group in optimizer_muon.param_groups: - group["base_lr"] = args.matrix_lr - optimizer_scalar = torch.optim.Adam( - [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] - if base_model.lm_head is not None: - optimizer_head = torch.optim.Adam( - [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizers.insert(1, optimizer_head) - - n_params = sum(p.numel() for p in base_model.parameters()) - log0(f"model_params:{n_params}") - log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") - log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") - log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") - log0( - f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " - f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " - f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" - ) - log0( - f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " - f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " - f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" - ) - log0(f"seed:{args.seed}") - - # ----------------------------- - # DATA LOADER & MODEL WARMUP - # ----------------------------- - - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - def zero_grad_all() -> None: - for opt in optimizers: - opt.zero_grad(set_to_none=True) - - max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None - - def lr_mul(step: int, elapsed_ms: float) -> float: - if args.warmdown_iters <= 0: - return 1.0 - if max_wallclock_ms is None: - warmdown_start = max(args.iterations - args.warmdown_iters, 0) - return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 - step_ms = elapsed_ms / max(step, 1) - warmdown_ms = args.warmdown_iters * step_ms - remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) - return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 - - # Warmup primes the compiled forward/backward/optimizer paths, then we restore the - # initial weights/optimizer state so measured training starts from the true init. - if args.warmup_steps > 0: - initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} - initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] - model.train() - for warmup_step in range(args.warmup_steps): - zero_grad_all() - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - warmup_loss = model(x, y) - (warmup_loss * grad_scale).backward() - for opt in optimizers: - opt.step() - zero_grad_all() - if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: - log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") - base_model.load_state_dict(initial_model_state, strict=True) - for opt, state in zip(optimizers, initial_optimizer_states, strict=True): - opt.load_state_dict(state) - zero_grad_all() - if distributed: - model.require_backward_grad_sync = True - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - # ----------------------------- - # MAIN TRAINING LOOP - # ----------------------------- - - training_time_ms = 0.0 - stop_after_step: int | None = None - torch.cuda.synchronize() - t0 = time.perf_counter() - - step = 0 - while True: - last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) - - should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) - if should_validate: - torch.cuda.synchronize() - training_time_ms += 1000.0 * (time.perf_counter() - t0) - val_loss, val_bpb = eval_val( - args, - model, - rank, - world_size, - device, - grad_accum_steps, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, - ) - log0( - f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " - f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" - ) - torch.cuda.synchronize() - t0 = time.perf_counter() - - if last_step: - if stop_after_step is not None and step < args.iterations: - log0( - f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " - f"step:{step}/{args.iterations}" - ) - break - - elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - scale = lr_mul(step, elapsed_ms) - zero_grad_all() - train_loss = torch.zeros((), device=device) - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - loss = model(x, y) - train_loss += loss.detach() - (loss * grad_scale).backward() - train_loss /= grad_accum_steps - - frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 - muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum - for group in optimizer_muon.param_groups: - group["momentum"] = muon_momentum - - for opt in optimizers: - for group in opt.param_groups: - group["lr"] = group["base_lr"] * scale - - if args.grad_clip_norm > 0: - torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) - for opt in optimizers: - opt.step() - zero_grad_all() - - step += 1 - approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - should_log_train = ( - args.train_log_every > 0 - and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) - ) - if should_log_train: - log0( - f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " - f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" - ) - - # Needed to sync whether we've reached the wallclock cap. - reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms - if distributed and max_wallclock_ms is not None: - reached_cap_tensor = torch.tensor(int(reached_cap), device=device) - dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) - reached_cap = bool(reached_cap_tensor.item()) - if stop_after_step is None and reached_cap: - stop_after_step = step - - log0( - f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " - f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" - ) - - # ----------------------------- - # SERIALIZATION + ROUNDTRIP VALIDATION - # ----------------------------- - # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce - # the compressed int8+zlib artifact and validate the round-tripped weights. - - if master_process: - torch.save(base_model.state_dict(), "final_model.pt") - model_bytes = os.path.getsize("final_model.pt") - code_bytes = len(code.encode("utf-8")) - log0(f"Serialized model: {model_bytes} bytes") - log0(f"Code size: {code_bytes} bytes") - log0(f"Total submission size: {model_bytes + code_bytes} bytes") - - quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) - quant_buf = io.BytesIO() - torch.save(quant_obj, quant_buf) - quant_raw = quant_buf.getvalue() - quant_blob = zlib.compress(quant_raw, level=9) - quant_raw_bytes = len(quant_raw) - if master_process: - with open("final_model.int8.ptz", "wb") as f: - f.write(quant_blob) - quant_file_bytes = os.path.getsize("final_model.int8.ptz") - code_bytes = len(code.encode("utf-8")) - ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) - log0( - f"Serialized model int8+zlib: {quant_file_bytes} bytes " - f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" - ) - log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") - - if distributed: - dist.barrier() - with open("final_model.int8.ptz", "rb") as f: - quant_blob_disk = f.read() - quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") - base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) - torch.cuda.synchronize() - t_qeval = time.perf_counter() - q_val_loss, q_val_bpb = eval_val( - args, - model, - rank, - world_size, - device, - grad_accum_steps, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, - ) - torch.cuda.synchronize() - log0( - f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " - f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" - ) - log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") - - if distributed: - dist.destroy_process_group() - - -if __name__ == "__main__": - main() diff --git a/records/track_10min_16mb/2026-03-18_FP16Embed_WD3600/README.md b/records/track_10min_16mb/2026-03-18_FP16Embed_WD3600/README.md deleted file mode 100644 index 0e5004df77..0000000000 --- a/records/track_10min_16mb/2026-03-18_FP16Embed_WD3600/README.md +++ /dev/null @@ -1,59 +0,0 @@ -Kept the tied embedding in fp16 instead of quantizing it to int8, and tuned the LR schedule. Turns out the embedding is by far the most sensitive tensor to quantize — it's pulling double duty as the output head, so every bit of precision matters. - -## what changed - -**fp16 embedding passthrough**: one-line change in the quantization function. Instead of int8-quantizing `tok_emb.weight`, I pass it through as fp16. This drops the post-quant BPB degradation from ~0.007 to basically nothing (~0.0005). The tradeoff is ~500KB extra in the artifact, so I shrank the MLP hidden from 1024 to 992 to stay under 16MB. - -**warmdown + LR**: bumped `WARMDOWN_ITERS` from 1200 to 3600 and `MATRIX_LR` from 0.04 to 0.06. The default schedule assumes way more steps than you actually get in 10 minutes, so a longer warmdown and higher LR help the model converge properly. - -## config - -``` -VOCAB_SIZE=1024 NUM_LAYERS=9 MODEL_DIM=512 NUM_HEADS=8 NUM_KV_HEADS=4 -MLP_HIDDEN=992 TIE_EMBEDDINGS=1 WARMDOWN_ITERS=3600 MATRIX_LR=0.06 -``` - -## run command - -```bash -RUN_ID=fp16embed \ -DATA_PATH=./data/datasets/fineweb10B_sp1024/ \ -TOKENIZER_PATH=./data/tokenizers/fineweb_1024_bpe.model \ -VOCAB_SIZE=1024 \ -MLP_HIDDEN=992 \ -WARMDOWN_ITERS=3600 \ -MATRIX_LR=0.06 \ -torchrun --standalone --nproc_per_node=8 train_gpt.py -``` - -Note: don't set `NCCL_IB_DISABLE=1` — it tanks step throughput on pods with IB/NVLink (~60ms vs ~44ms per step). - -## results - -8xH100 SXM (RunPod secure cloud): - -| seed | steps | val_loss | val_bpb | artifact size | -|------|-------|----------|---------|---------------| -| 1337 | 13,692 | 2.0595 | 1.2197 | 15.90MB | -| 42 | 13,722 | 2.0600 | 1.2201 | 15.90MB | - -Pre-quant vs post-quant gap: ~0.0005 BPB (baseline gap is ~0.007). - -Improvement over baseline: ~0.013 nats. - -Also ran 3 seeds on 8xH200 SXM (all consistent, 1.2163-1.2179 BPB). - -## things I tried that didn't work - -- **SwiGLU**: better per-step quality but 45% slower on 8-GPU, so fewer total steps. Net negative. -- **depth recurrence** (looping layers): promising idea but needs way more steps than 10 min allows. -- **QAT**: tried both full-training and late-stage. The overhead per step wasn't worth the small quant gap reduction. -- **lzma compression**: actually compresses worse than zlib for int8 weight data. -- **higher embed LR** (0.08 vs 0.05): hurt convergence. - -## files - -- `train_gpt.py` — modified training script -- `train.log` — 8xH100 log (seed 1337) -- `train_seed42.log` — 8xH100 log (seed 42) -- `submission.json` diff --git a/records/track_10min_16mb/2026-03-18_FP16Embed_WD3600/submission.json b/records/track_10min_16mb/2026-03-18_FP16Embed_WD3600/submission.json deleted file mode 100644 index 6ff07bab72..0000000000 --- a/records/track_10min_16mb/2026-03-18_FP16Embed_WD3600/submission.json +++ /dev/null @@ -1,11 +0,0 @@ -{ - "author": "Renier Velazco", - "github_id": "chonchiog", - "name": "FP16 Tied Embedding + LR/Warmdown Tuning", - "blurb": "Keep tok_emb.weight in fp16 during int8 quantization to eliminate the output-head quantization gap (0.007 -> 0.0005 BPB). Slightly reduce MLP hidden (992 vs 1024) to fit within 16MB. Tune warmdown (3600 vs 1200) and matrix LR (0.06 vs 0.04) for better convergence under the 10-min wallclock cap.", - "date": "2026-03-18", - "val_loss": 2.05945460, - "val_bpb": 1.21972502, - "bytes_total": 15896222, - "bytes_code": 48125 -} diff --git a/records/track_10min_16mb/2026-03-18_FP16Embed_WD3600/train.log b/records/track_10min_16mb/2026-03-18_FP16Embed_WD3600/train.log deleted file mode 100644 index b4cf10e39f..0000000000 --- a/records/track_10min_16mb/2026-03-18_FP16Embed_WD3600/train.log +++ /dev/null @@ -1,1263 +0,0 @@ -""" -The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. - -Hard stop: `train_gpt.py` and `train_gpt_mlx.py` must never be longer than 1500 lines. -""" - -from __future__ import annotations - -import copy -import glob -import io -import math -import os -import random -import subprocess -import sys -import time -import uuid -import zlib -from pathlib import Path - -import numpy as np -import sentencepiece as spm -import torch -import torch.distributed as dist -import torch.nn.functional as F -from torch import Tensor, nn -from torch.nn.parallel import DistributedDataParallel as DDP - -# ----------------------------- -# HYPERPARAMETERS -# ----------------------------- -# Default Simple Baseline run: -# - 9 transformer blocks at width 512 -# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion -# - vocab size 1024, sequence length 1024, tied embeddings -# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap - -class Hyperparameters: - # Data paths are shard globs produced by the existing preprocessing pipeline. - data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") - train_files = os.path.join(data_path, "fineweb_train_*.bin") - val_files = os.path.join(data_path, "fineweb_val_*.bin") - tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") - run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) - seed = int(os.environ.get("SEED", 1337)) - - # Validation cadence and batch size. Validation always uses the full fineweb_val split. - val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) - val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) - train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) - - # Training length. - iterations = int(os.environ.get("ITERATIONS", 20000)) - warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) - warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) - train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) - train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) - max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) - qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) - - # Model shape. - vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) - num_layers = int(os.environ.get("NUM_LAYERS", 9)) - num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) - model_dim = int(os.environ.get("MODEL_DIM", 512)) - num_heads = int(os.environ.get("NUM_HEADS", 8)) - mlp_mult = int(os.environ.get("MLP_MULT", 2)) - mlp_hidden = int(os.environ.get("MLP_HIDDEN", 0)) # override mlp_mult if > 0 - tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) - rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) - logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) - - # Optimizer hyperparameters. - embed_lr = float(os.environ.get("EMBED_LR", 0.6)) - head_lr = float(os.environ.get("HEAD_LR", 0.008)) - tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) - tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) - matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) - scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) - muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) - muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) - muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) - muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) - beta1 = float(os.environ.get("BETA1", 0.9)) - beta2 = float(os.environ.get("BETA2", 0.95)) - adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) - grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) - -# ----------------------------- -# MUON OPTIMIZER -# ----------------------------- -# -# As borrowed from modded-nanogpt -# Background on Muon: https://kellerjordan.github.io/posts/muon/ - -def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: - # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. - # Muon uses this to normalize matrix-shaped gradients before applying them. - a, b, c = (3.4445, -4.7750, 2.0315) - X = G.bfloat16() - X /= X.norm() + eps - transposed = G.size(0) > G.size(1) - if transposed: - X = X.T - for _ in range(steps): - A = X @ X.T - B = b * A + c * A @ A - X = a * X + B @ X - return X.T if transposed else X - - -class Muon(torch.optim.Optimizer): - def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): - super().__init__( - params, - dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), - ) - - @torch.no_grad() - def step(self, closure=None): - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - distributed = dist.is_available() and dist.is_initialized() - world_size = dist.get_world_size() if distributed else 1 - rank = dist.get_rank() if distributed else 0 - - for group in self.param_groups: - params = group["params"] - if not params: - continue - lr = group["lr"] - momentum = group["momentum"] - backend_steps = group["backend_steps"] - nesterov = group["nesterov"] - - total_params = sum(int(p.numel()) for p in params) - updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) - - curr = 0 - for i, p in enumerate(params): - if i % world_size == rank and p.grad is not None: - g = p.grad - state = self.state[p] - if "momentum_buffer" not in state: - state["momentum_buffer"] = torch.zeros_like(g) - buf = state["momentum_buffer"] - buf.mul_(momentum).add_(g) - if nesterov: - g = g.add(buf, alpha=momentum) - g = zeropower_via_newtonschulz5(g, steps=backend_steps) - # Scale correction from Muon reference implementations. - g *= max(1, g.size(0) / g.size(1)) ** 0.5 - updates_flat[curr : curr + p.numel()] = g.reshape(-1) - curr += p.numel() - - if distributed: - dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) - - curr = 0 - for p in params: - g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) - p.add_(g, alpha=-lr) - curr += p.numel() - - return loss - - -# ----------------------------- -# TOKENIZER-AGNOSTIC EVALUATION SETUP -# ----------------------------- -# -# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. -# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. -# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. -# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. - -def build_sentencepiece_luts( - sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device -) -> tuple[Tensor, Tensor, Tensor]: - sp_vocab_size = int(sp.vocab_size()) - table_size = max(sp_vocab_size, vocab_size) - base_bytes_np = np.zeros((table_size,), dtype=np.int16) - has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) - is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) - for token_id in range(sp_vocab_size): - if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): - continue - is_boundary_token_np[token_id] = False - if sp.is_byte(token_id): - base_bytes_np[token_id] = 1 - continue - piece = sp.id_to_piece(token_id) - if piece.startswith("▁"): - has_leading_space_np[token_id] = True - piece = piece[1:] - base_bytes_np[token_id] = len(piece.encode("utf-8")) - return ( - torch.tensor(base_bytes_np, dtype=torch.int16, device=device), - torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), - torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), - ) - - -def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: - files = [Path(p) for p in sorted(glob.glob(pattern))] - if not files: - raise FileNotFoundError(f"No files found for pattern: {pattern}") - # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. - tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() - usable = ((tokens.numel() - 1) // seq_len) * seq_len - if usable <= 0: - raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") - return tokens[: usable + 1] - - -def eval_val( - args: Hyperparameters, - model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - grad_accum_steps: int, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, -) -> tuple[float, float]: - # Validation computes two metrics: - # - val_loss: token cross-entropy (natural log) - # - val_bpb: tokenizer-agnostic compression metric used by the challenge - local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) - if local_batch_tokens < args.train_seq_len: - raise ValueError( - "VAL_BATCH_SIZE must provide at least one sequence per rank; " - f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " - f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" - ) - local_batch_seqs = local_batch_tokens // args.train_seq_len - total_seqs = (val_tokens.numel() - 1) // args.train_seq_len - seq_start = (total_seqs * rank) // world_size - seq_end = (total_seqs * (rank + 1)) // world_size - val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) - val_token_count = torch.zeros((), device=device, dtype=torch.float64) - val_byte_count = torch.zeros((), device=device, dtype=torch.float64) - - model.eval() - with torch.inference_mode(): - for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): - batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) - raw_start = batch_seq_start * args.train_seq_len - raw_end = batch_seq_end * args.train_seq_len + 1 - local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) - x = local[:-1].reshape(-1, args.train_seq_len) - y = local[1:].reshape(-1, args.train_seq_len) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - batch_loss = model(x, y).detach() - batch_token_count = float(y.numel()) - val_loss_sum += batch_loss.to(torch.float64) * batch_token_count - val_token_count += batch_token_count - prev_ids = x.reshape(-1) - tgt_ids = y.reshape(-1) - token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) - token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) - val_byte_count += token_bytes.to(torch.float64).sum() - - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) - - val_loss = val_loss_sum / val_token_count - bits_per_token = val_loss.item() / math.log(2.0) - tokens_per_byte = val_token_count.item() / val_byte_count.item() - model.train() - return float(val_loss.item()), float(bits_per_token * tokens_per_byte) - -# ----------------------------- -# POST-TRAINING QUANTIZATION -# ----------------------------- -# -# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. -# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. -# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. - -CONTROL_TENSOR_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "CONTROL_TENSOR_NAME_PATTERNS", - "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", - ",".join(CONTROL_TENSOR_NAME_PATTERNS), - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 -INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 -INT8_PER_ROW_SCALE_DTYPE = torch.float16 -INT8_CLIP_PERCENTILE = 99.99984 -INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 - -def tensor_nbytes(t: Tensor) -> int: - return int(t.numel()) * int(t.element_size()) - -def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: - if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): - return t.float().contiguous() - if t.dtype in {torch.float32, torch.bfloat16}: - passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") - return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() - return t - -def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: - t32 = t.float() - if t32.ndim == 2: - # Matrices get one scale per row, which usually tracks output-channel - # ranges much better than a single tensor-wide scale. - clip_abs = ( - torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) - if t32.numel() - else torch.empty((t32.shape[0],), dtype=torch.float32) - ) - clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) - scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) - q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() - return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() - - # Vectors / scalars use a simpler per-tensor scale. - clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 - scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) - q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() - return q, scale - -def quantize_state_dict_int8(state_dict: dict[str, Tensor]): - # Single supported clean-script export format: - # - per-row int8 for 2D float tensors - # - per-tensor int8 for other float tensors - # - exact passthrough for non-floats - # - passthrough for small float tensors, stored as fp16 to save bytes - quantized: dict[str, Tensor] = {} - scales: dict[str, Tensor] = {} - dtypes: dict[str, str] = {} - passthrough: dict[str, Tensor] = {} - passthrough_orig_dtypes: dict[str, str] = {} - qmeta: dict[str, dict[str, object]] = {} - stats = dict.fromkeys( - ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), - 0, - ) - - for name, tensor in state_dict.items(): - t = tensor.detach().to("cpu").contiguous() - stats["param_count"] += int(t.numel()) - stats["num_tensors"] += 1 - stats["baseline_tensor_bytes"] += tensor_nbytes(t) - - if not t.is_floating_point(): - stats["num_nonfloat_tensors"] += 1 - passthrough[name] = t - stats["int8_payload_bytes"] += tensor_nbytes(t) - continue - - # Small float tensors are cheap enough to keep directly. We still downcast - # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. - # Also keep the tied embedding in fp16 to minimize quantization degradation - # on the output head, which is the most sensitive tensor for BPB. - if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL or name == "tok_emb.weight": - kept = keep_float_tensor(name, t, passthrough_orig_dtypes) - passthrough[name] = kept - stats["int8_payload_bytes"] += tensor_nbytes(kept) - continue - - stats["num_float_tensors"] += 1 - q, s = quantize_float_tensor(t) - if s.ndim > 0: - qmeta[name] = {"scheme": "per_row", "axis": 0} - quantized[name] = q - scales[name] = s - dtypes[name] = str(t.dtype).removeprefix("torch.") - stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) - - obj: dict[str, object] = { - "__quant_format__": "int8_clean_per_row_v1", - "quantized": quantized, - "scales": scales, - "dtypes": dtypes, - "passthrough": passthrough, - } - if qmeta: - obj["qmeta"] = qmeta - if passthrough_orig_dtypes: - obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes - return obj, stats - -def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: - out: dict[str, Tensor] = {} - qmeta = obj.get("qmeta", {}) - passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) - for name, q in obj["quantized"].items(): - dtype = getattr(torch, obj["dtypes"][name]) - s = obj["scales"][name] - if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: - s = s.to(dtype=torch.float32) - # Broadcast the saved row scale back across trailing dimensions. - out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() - else: - scale = float(s.item()) - out[name] = (q.float() * scale).to(dtype=dtype).contiguous() - for name, t in obj["passthrough"].items(): - # Restore small tensors, undoing the temporary fp16 storage cast if needed. - out_t = t.detach().to("cpu").contiguous() - orig_dtype = passthrough_orig_dtypes.get(name) - if isinstance(orig_dtype, str): - out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() - out[name] = out_t - return out - - -# ----------------------------- -# DATA LOADING -# ----------------------------- - -def load_data_shard(file: Path) -> Tensor: - header_bytes = 256 * np.dtype(" None: - self.file_idx = (self.file_idx + 1) % len(self.files) - self.tokens = load_data_shard(self.files[self.file_idx]) - self.pos = 0 - - def take(self, n: int) -> Tensor: - chunks: list[Tensor] = [] - remaining = n - while remaining > 0: - avail = self.tokens.numel() - self.pos - if avail <= 0: - self._advance_file() - continue - k = min(remaining, avail) - chunks.append(self.tokens[self.pos : self.pos + k]) - self.pos += k - remaining -= k - return chunks[0] if len(chunks) == 1 else torch.cat(chunks) - - -class DistributedTokenLoader: - # Each call consumes a contiguous chunk from the shared token stream, then slices out - # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. - def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): - self.rank = rank - self.world_size = world_size - self.device = device - self.stream = TokenStream(pattern) - - def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: - local_tokens = global_tokens // (self.world_size * grad_accum_steps) - per_rank_span = local_tokens + 1 - chunk = self.stream.take(per_rank_span * self.world_size) - start = self.rank * per_rank_span - local = chunk[start : start + per_rank_span].to(dtype=torch.int64) - x = local[:-1].reshape(-1, seq_len) - y = local[1:].reshape(-1, seq_len) - return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) - -# ----------------------------- -# TRANSFORMER MODULES -# ----------------------------- - -class RMSNorm(nn.Module): - def __init__(self, eps: float | None = None): - super().__init__() - self.eps = eps - - def forward(self, x: Tensor) -> Tensor: - return F.rms_norm(x, (x.size(-1),), eps=self.eps) - - -class CastedLinear(nn.Linear): - # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. - def forward(self, x: Tensor) -> Tensor: - bias = self.bias.to(x.dtype) if self.bias is not None else None - return F.linear(x, self.weight.to(x.dtype), bias) - - -def restore_low_dim_params_to_fp32(module: nn.Module) -> None: - # Keep small/control parameters in fp32 even when the model body runs in bf16. - with torch.no_grad(): - for name, param in module.named_parameters(): - if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: - param.data = param.data.float() - - -class Rotary(nn.Module): - # Caches cos/sin tables per sequence length on the current device. - def __init__(self, dim: int, base: float = 10000.0): - super().__init__() - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self._seq_len_cached = 0 - self._cos_cached: Tensor | None = None - self._sin_cached: Tensor | None = None - - def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: - if ( - self._cos_cached is None - or self._sin_cached is None - or self._seq_len_cached != seq_len - or self._cos_cached.device != device - ): - t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) - freqs = torch.outer(t, self.inv_freq.to(device)) - self._cos_cached = freqs.cos()[None, None, :, :] - self._sin_cached = freqs.sin()[None, None, :, :] - self._seq_len_cached = seq_len - return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) - - -def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: - half = x.size(-1) // 2 - x1, x2 = x[..., :half], x[..., half:] - return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) - - -class CausalSelfAttention(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - rope_base: float, - qk_gain_init: float, - ): - super().__init__() - if dim % num_heads != 0: - raise ValueError("model_dim must be divisible by num_heads") - if num_heads % num_kv_heads != 0: - raise ValueError("num_heads must be divisible by num_kv_heads") - self.num_heads = num_heads - self.num_kv_heads = num_kv_heads - self.head_dim = dim // num_heads - if self.head_dim % 2 != 0: - raise ValueError("head_dim must be even for RoPE") - kv_dim = self.num_kv_heads * self.head_dim - self.c_q = CastedLinear(dim, dim, bias=False) - self.c_k = CastedLinear(dim, kv_dim, bias=False) - self.c_v = CastedLinear(dim, kv_dim, bias=False) - self.proj = CastedLinear(dim, dim, bias=False) - self.proj._zero_init = True - self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) - self.rotary = Rotary(self.head_dim, base=rope_base) - - def forward(self, x: Tensor) -> Tensor: - bsz, seqlen, dim = x.shape - q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) - k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) - v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) - q = F.rms_norm(q, (q.size(-1),)) - k = F.rms_norm(k, (k.size(-1),)) - cos, sin = self.rotary(seqlen, x.device, q.dtype) - q = apply_rotary_emb(q, cos, sin) - k = apply_rotary_emb(k, cos, sin) - q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] - y = F.scaled_dot_product_attention( - q, - k, - v, - attn_mask=None, - is_causal=True, - enable_gqa=(self.num_kv_heads != self.num_heads), - ) - y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) - return self.proj(y) - - -class MLP(nn.Module): - # relu^2 MLP from the original modded-nanogpt setup - def __init__(self, dim: int, mlp_mult: int, mlp_hidden: int = 0): - super().__init__() - hidden = mlp_hidden if mlp_hidden > 0 else mlp_mult * dim - self.fc = CastedLinear(dim, hidden, bias=False) - self.proj = CastedLinear(hidden, dim, bias=False) - self.proj._zero_init = True - - def forward(self, x: Tensor) -> Tensor: - x = torch.relu(self.fc(x)) - return self.proj(x.square()) - - -class Block(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: int, - rope_base: float, - qk_gain_init: float, - mlp_hidden: int = 0, - ): - super().__init__() - self.attn_norm = RMSNorm() - self.mlp_norm = RMSNorm() - self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) - self.mlp = MLP(dim, mlp_mult, mlp_hidden=mlp_hidden) - self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) - - def forward(self, x: Tensor, x0: Tensor) -> Tensor: - mix = self.resid_mix.to(dtype=x.dtype) - x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 - attn_out = self.attn(self.attn_norm(x)) - x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out - x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) - return x - - -class GPT(nn.Module): - def __init__( - self, - vocab_size: int, - num_layers: int, - model_dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: int, - tie_embeddings: bool, - tied_embed_init_std: float, - logit_softcap: float, - rope_base: float, - qk_gain_init: float, - mlp_hidden: int = 0, - ): - super().__init__() - if logit_softcap <= 0.0: - raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") - self.tie_embeddings = tie_embeddings - self.tied_embed_init_std = tied_embed_init_std - self.logit_softcap = logit_softcap - self.tok_emb = nn.Embedding(vocab_size, model_dim) - self.num_encoder_layers = num_layers // 2 - self.num_decoder_layers = num_layers - self.num_encoder_layers - self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) - self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) - self.blocks = nn.ModuleList( - [ - Block( - model_dim, - num_heads, - num_kv_heads, - mlp_mult, - rope_base, - qk_gain_init, - mlp_hidden=mlp_hidden, - ) - for i in range(num_layers) - ] - ) - self.final_norm = RMSNorm() - self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) - if self.lm_head is not None: - self.lm_head._zero_init = True - self._init_weights() - - def _init_weights(self) -> None: - if self.tie_embeddings: - nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) - for module in self.modules(): - if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): - nn.init.zeros_(module.weight) - - def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: - x = self.tok_emb(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - x0 = x - skips: list[Tensor] = [] - - # First half stores skips; second half reuses them in reverse order. - for i in range(self.num_encoder_layers): - x = self.blocks[i](x, x0) - skips.append(x) - for i in range(self.num_decoder_layers): - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - x = self.blocks[self.num_encoder_layers + i](x, x0) - - x = self.final_norm(x).reshape(-1, x.size(-1)) - targets = target_ids.reshape(-1) - if self.tie_embeddings: - logits_proj = F.linear(x, self.tok_emb.weight) - else: - if self.lm_head is None: - raise RuntimeError("lm_head is required when tie_embeddings=False") - logits_proj = self.lm_head(x) - logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - return F.cross_entropy(logits.float(), targets, reduction="mean") - - -# ----------------------------- -# TRAINING -# ----------------------------- - -def main() -> None: - global zeropower_via_newtonschulz5 - - code = Path(__file__).read_text(encoding="utf-8") - args = Hyperparameters() - zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) - - # ----------------------------- - # DISTRIBUTED + CUDA SETUP - # ----------------------------- - - distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ - rank = int(os.environ.get("RANK", "0")) - world_size = int(os.environ.get("WORLD_SIZE", "1")) - local_rank = int(os.environ.get("LOCAL_RANK", "0")) - if world_size <= 0: - raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") - if 8 % world_size != 0: - raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") - grad_accum_steps = 8 // world_size - grad_scale = 1.0 / grad_accum_steps - if not torch.cuda.is_available(): - raise RuntimeError("CUDA is required") - device = torch.device("cuda", local_rank) - torch.cuda.set_device(device) - if distributed: - dist.init_process_group(backend="nccl", device_id=device) - dist.barrier() - master_process = rank == 0 - - # Fast math knobs - torch.backends.cuda.matmul.allow_tf32 = True - torch.backends.cudnn.allow_tf32 = True - from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp - - enable_cudnn_sdp(False) - enable_flash_sdp(True) - enable_mem_efficient_sdp(False) - enable_math_sdp(False) - - logfile = None - if master_process: - os.makedirs("logs", exist_ok=True) - logfile = f"logs/{args.run_id}.txt" - print(logfile) - - def log0(msg: str, console: bool = True) -> None: - if not master_process: - return - if console: - print(msg) - if logfile is not None: - with open(logfile, "a", encoding="utf-8") as f: - print(msg, file=f) - - log0(code, console=False) - log0("=" * 100, console=False) - log0(f"Running Python {sys.version}", console=False) - log0(f"Running PyTorch {torch.__version__}", console=False) - log0( - subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, - console=False, - ) - log0("=" * 100, console=False) - - # ----------------------------- - # TOKENIZER + VALIDATION METRIC SETUP - # ----------------------------- - - random.seed(args.seed) - np.random.seed(args.seed) - torch.manual_seed(args.seed) - torch.cuda.manual_seed_all(args.seed) - - if not args.tokenizer_path.endswith(".model"): - raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") - sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) - if int(sp.vocab_size()) != args.vocab_size: - raise ValueError( - f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" - ) - dataset_dir = Path(args.data_path).resolve() - actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) - val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) - base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( - sp, args.vocab_size, device - ) - log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") - log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") - log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") - - # ----------------------------- - # MODEL + OPTIMIZER SETUP - # ----------------------------- - - base_model = GPT( - vocab_size=args.vocab_size, - num_layers=args.num_layers, - model_dim=args.model_dim, - num_heads=args.num_heads, - num_kv_heads=args.num_kv_heads, - mlp_mult=args.mlp_mult, - tie_embeddings=args.tie_embeddings, - tied_embed_init_std=args.tied_embed_init_std, - logit_softcap=args.logit_softcap, - rope_base=args.rope_base, - qk_gain_init=args.qk_gain_init, - mlp_hidden=args.mlp_hidden, - ).to(device).bfloat16() - for module in base_model.modules(): - if isinstance(module, CastedLinear): - module.float() - restore_low_dim_params_to_fp32(base_model) - compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) - model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model - - # Optimizer split: - # - token embedding (Adam) uses EMBED_LR - # - untied lm_head (Adam) uses HEAD_LR - # - matrix params in transformer blocks use MATRIX_LR via Muon - # - vectors/scalars use SCALAR_LR via Adam - block_named_params = list(base_model.blocks.named_parameters()) - matrix_params = [ - p - for name, p in block_named_params - if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - scalar_params = [ - p - for name, p in block_named_params - if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - if base_model.skip_weights.numel() > 0: - scalar_params.append(base_model.skip_weights) - token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr - optimizer_tok = torch.optim.Adam( - [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizer_muon = Muon( - matrix_params, - lr=args.matrix_lr, - momentum=args.muon_momentum, - backend_steps=args.muon_backend_steps, - ) - for group in optimizer_muon.param_groups: - group["base_lr"] = args.matrix_lr - optimizer_scalar = torch.optim.Adam( - [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] - if base_model.lm_head is not None: - optimizer_head = torch.optim.Adam( - [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizers.insert(1, optimizer_head) - - n_params = sum(p.numel() for p in base_model.parameters()) - log0(f"model_params:{n_params}") - log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") - log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") - log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") - log0( - f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " - f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " - f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" - ) - log0( - f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " - f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " - f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" - ) - log0(f"seed:{args.seed}") - - # ----------------------------- - # DATA LOADER & MODEL WARMUP - # ----------------------------- - - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - def zero_grad_all() -> None: - for opt in optimizers: - opt.zero_grad(set_to_none=True) - - max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None - - def lr_mul(step: int, elapsed_ms: float) -> float: - if args.warmdown_iters <= 0: - return 1.0 - if max_wallclock_ms is None: - warmdown_start = max(args.iterations - args.warmdown_iters, 0) - return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 - step_ms = elapsed_ms / max(step, 1) - warmdown_ms = args.warmdown_iters * step_ms - remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) - return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 - - # Warmup primes the compiled forward/backward/optimizer paths, then we restore the - # initial weights/optimizer state so measured training starts from the true init. - if args.warmup_steps > 0: - initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} - initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] - model.train() - for warmup_step in range(args.warmup_steps): - zero_grad_all() - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - warmup_loss = model(x, y) - (warmup_loss * grad_scale).backward() - for opt in optimizers: - opt.step() - zero_grad_all() - if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: - log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") - base_model.load_state_dict(initial_model_state, strict=True) - for opt, state in zip(optimizers, initial_optimizer_states, strict=True): - opt.load_state_dict(state) - zero_grad_all() - if distributed: - model.require_backward_grad_sync = True - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - # ----------------------------- - # MAIN TRAINING LOOP - # ----------------------------- - - training_time_ms = 0.0 - stop_after_step: int | None = None - torch.cuda.synchronize() - t0 = time.perf_counter() - - step = 0 - while True: - last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) - - should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) - if should_validate: - torch.cuda.synchronize() - training_time_ms += 1000.0 * (time.perf_counter() - t0) - val_loss, val_bpb = eval_val( - args, - model, - rank, - world_size, - device, - grad_accum_steps, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, - ) - log0( - f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " - f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" - ) - torch.cuda.synchronize() - t0 = time.perf_counter() - - if last_step: - if stop_after_step is not None and step < args.iterations: - log0( - f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " - f"step:{step}/{args.iterations}" - ) - break - - elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - scale = lr_mul(step, elapsed_ms) - zero_grad_all() - train_loss = torch.zeros((), device=device) - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - loss = model(x, y) - train_loss += loss.detach() - (loss * grad_scale).backward() - train_loss /= grad_accum_steps - - frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 - muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum - for group in optimizer_muon.param_groups: - group["momentum"] = muon_momentum - - for opt in optimizers: - for group in opt.param_groups: - group["lr"] = group["base_lr"] * scale - - if args.grad_clip_norm > 0: - torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) - for opt in optimizers: - opt.step() - zero_grad_all() - - step += 1 - approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - should_log_train = ( - args.train_log_every > 0 - and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) - ) - if should_log_train: - log0( - f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " - f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" - ) - - # Needed to sync whether we've reached the wallclock cap. - reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms - if distributed and max_wallclock_ms is not None: - reached_cap_tensor = torch.tensor(int(reached_cap), device=device) - dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) - reached_cap = bool(reached_cap_tensor.item()) - if stop_after_step is None and reached_cap: - stop_after_step = step - - log0( - f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " - f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" - ) - - # ----------------------------- - # SERIALIZATION + ROUNDTRIP VALIDATION - # ----------------------------- - # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce - # the compressed int8+zlib artifact and validate the round-tripped weights. - - if master_process: - torch.save(base_model.state_dict(), "final_model.pt") - model_bytes = os.path.getsize("final_model.pt") - code_bytes = len(code.encode("utf-8")) - log0(f"Serialized model: {model_bytes} bytes") - log0(f"Code size: {code_bytes} bytes") - log0(f"Total submission size: {model_bytes + code_bytes} bytes") - - quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) - quant_buf = io.BytesIO() - torch.save(quant_obj, quant_buf) - quant_raw = quant_buf.getvalue() - quant_blob = zlib.compress(quant_raw, level=9) - quant_raw_bytes = len(quant_raw) - if master_process: - with open("final_model.int8.ptz", "wb") as f: - f.write(quant_blob) - quant_file_bytes = os.path.getsize("final_model.int8.ptz") - code_bytes = len(code.encode("utf-8")) - ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) - log0( - f"Serialized model int8+zlib: {quant_file_bytes} bytes " - f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" - ) - log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") - - if distributed: - dist.barrier() - with open("final_model.int8.ptz", "rb") as f: - quant_blob_disk = f.read() - quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") - base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) - torch.cuda.synchronize() - t_qeval = time.perf_counter() - q_val_loss, q_val_bpb = eval_val( - args, - model, - rank, - world_size, - device, - grad_accum_steps, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, - ) - torch.cuda.synchronize() - log0( - f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " - f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" - ) - log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") - - if distributed: - dist.destroy_process_group() - - -if __name__ == "__main__": - main() - -==================================================================================================== -Running Python 3.12.3 (main, Nov 6 2025, 13:44:16) [GCC 13.3.0] -Running PyTorch 2.9.1+cu128 -Thu Mar 19 00:32:54 2026 -+-----------------------------------------------------------------------------------------+ -| NVIDIA-SMI 565.57.01 Driver Version: 565.57.01 CUDA Version: 12.7 | -|-----------------------------------------+------------------------+----------------------+ -| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | -| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | -| | | MIG M. | -|=========================================+========================+======================| -| 0 NVIDIA H100 80GB HBM3 On | 00000000:18:00.0 Off | 0 | -| N/A 33C P0 147W / 700W | 1518MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 1 NVIDIA H100 80GB HBM3 On | 00000000:2A:00.0 Off | 0 | -| N/A 37C P0 152W / 700W | 1518MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 2 NVIDIA H100 80GB HBM3 On | 00000000:3A:00.0 Off | 0 | -| N/A 37C P0 153W / 700W | 1518MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 3 NVIDIA H100 80GB HBM3 On | 00000000:5D:00.0 Off | 0 | -| N/A 34C P0 151W / 700W | 1518MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 4 NVIDIA H100 80GB HBM3 On | 00000000:9A:00.0 Off | 0 | -| N/A 35C P0 151W / 700W | 1518MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 5 NVIDIA H100 80GB HBM3 On | 00000000:AB:00.0 Off | 0 | -| N/A 38C P0 150W / 700W | 1518MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 6 NVIDIA H100 80GB HBM3 On | 00000000:BA:00.0 Off | 0 | -| N/A 36C P0 150W / 700W | 1518MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 7 NVIDIA H100 80GB HBM3 On | 00000000:DB:00.0 Off | 0 | -| N/A 34C P0 150W / 700W | 1518MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ - -+-----------------------------------------------------------------------------------------+ -| Processes: | -| GPU GI CI PID Type Process name GPU Memory | -| ID ID Usage | -|=========================================================================================| -+-----------------------------------------------------------------------------------------+ - -==================================================================================================== -val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model -train_loader:dataset:fineweb10B_sp1024 train_shards:80 -val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 -model_params:16765000 -world_size:8 grad_accum_steps:1 -sdp_backends:cudnn=False flash=True mem_efficient=False math=False -attention_mode:gqa num_heads:8 num_kv_heads:4 -tie_embeddings:True embed_lr:0.05 head_lr:0.0 matrix_lr:0.06 scalar_lr:0.04 -train_batch_tokens:524288 train_seq_len:1024 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 -seed:1337 -warmup_step:1/20 -warmup_step:2/20 -warmup_step:3/20 -warmup_step:4/20 -warmup_step:5/20 -warmup_step:6/20 -warmup_step:7/20 -warmup_step:8/20 -warmup_step:9/20 -warmup_step:10/20 -warmup_step:11/20 -warmup_step:12/20 -warmup_step:13/20 -warmup_step:14/20 -warmup_step:15/20 -warmup_step:16/20 -warmup_step:17/20 -warmup_step:18/20 -warmup_step:19/20 -warmup_step:20/20 -step:1/20000 train_loss:6.9362 train_time:68ms step_avg:67.67ms -step:2/20000 train_loss:16.8730 train_time:108ms step_avg:54.09ms -step:3/20000 train_loss:8.5225 train_time:160ms step_avg:53.29ms -step:4/20000 train_loss:6.6152 train_time:206ms step_avg:51.59ms -step:5/20000 train_loss:6.8377 train_time:261ms step_avg:52.14ms -step:6/20000 train_loss:7.5396 train_time:310ms step_avg:51.68ms -step:7/20000 train_loss:6.4540 train_time:365ms step_avg:52.08ms -step:8/20000 train_loss:6.1992 train_time:416ms step_avg:52.03ms -step:9/20000 train_loss:6.0263 train_time:472ms step_avg:52.42ms -step:10/20000 train_loss:5.9363 train_time:527ms step_avg:52.68ms -step:500/20000 train_loss:2.5003 train_time:21847ms step_avg:43.69ms -step:1000/20000 train_loss:2.3731 train_time:43708ms step_avg:43.71ms -step:1500/20000 train_loss:2.2249 train_time:65435ms step_avg:43.62ms -step:2000/20000 train_loss:2.2537 train_time:87313ms step_avg:43.66ms -step:2500/20000 train_loss:2.1287 train_time:109175ms step_avg:43.67ms -step:3000/20000 train_loss:2.2325 train_time:130926ms step_avg:43.64ms -step:3500/20000 train_loss:2.2540 train_time:152759ms step_avg:43.65ms -step:4000/20000 train_loss:2.1636 train_time:174500ms step_avg:43.62ms -step:4500/20000 train_loss:2.2599 train_time:196331ms step_avg:43.63ms -step:5000/20000 train_loss:2.0318 train_time:218162ms step_avg:43.63ms -step:5500/20000 train_loss:2.1445 train_time:239884ms step_avg:43.62ms -step:6000/20000 train_loss:2.2222 train_time:261726ms step_avg:43.62ms -step:6500/20000 train_loss:2.0817 train_time:283581ms step_avg:43.63ms -step:7000/20000 train_loss:2.2303 train_time:305317ms step_avg:43.62ms -step:7500/20000 train_loss:2.1147 train_time:327164ms step_avg:43.62ms -step:8000/20000 train_loss:2.1190 train_time:348888ms step_avg:43.61ms -step:8500/20000 train_loss:2.1278 train_time:370717ms step_avg:43.61ms -step:9000/20000 train_loss:2.0327 train_time:392536ms step_avg:43.62ms -step:9500/20000 train_loss:2.2631 train_time:414362ms step_avg:43.62ms -step:10000/20000 train_loss:2.1185 train_time:436186ms step_avg:43.62ms -step:10500/20000 train_loss:1.9820 train_time:458034ms step_avg:43.62ms -step:11000/20000 train_loss:2.1087 train_time:479776ms step_avg:43.62ms -step:11500/20000 train_loss:2.0948 train_time:501635ms step_avg:43.62ms -step:12000/20000 train_loss:2.0525 train_time:523376ms step_avg:43.61ms -step:12500/20000 train_loss:2.3395 train_time:545215ms step_avg:43.62ms -step:13000/20000 train_loss:2.1638 train_time:567073ms step_avg:43.62ms -step:13500/20000 train_loss:2.0154 train_time:588838ms step_avg:43.62ms -step:13692/20000 val_loss:2.0586 val_bpb:1.2192 train_time:600008ms step_avg:43.82ms -stopping_early: wallclock_cap train_time:600008ms step:13692/20000 -peak memory allocated: 10111 MiB reserved: 10302 MiB -Serialized model: 66045335 bytes -Code size: 48125 bytes -Total submission size: 66093460 bytes -Serialized model int8+zlib: 15848097 bytes (payload:17405664 raw_torch:17450459 payload_ratio:3.79x) -Total submission size int8+zlib: 15896222 bytes -final_int8_zlib_roundtrip val_loss:2.0595 val_bpb:1.2197 eval_time:1399ms -final_int8_zlib_roundtrip_exact val_loss:2.05945460 val_bpb:1.21972502 diff --git a/records/track_10min_16mb/2026-03-18_FP16Embed_WD3600/train_gpt.py b/records/track_10min_16mb/2026-03-18_FP16Embed_WD3600/train_gpt.py deleted file mode 100644 index fd6c415533..0000000000 --- a/records/track_10min_16mb/2026-03-18_FP16Embed_WD3600/train_gpt.py +++ /dev/null @@ -1,1133 +0,0 @@ -""" -The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. - -Hard stop: `train_gpt.py` and `train_gpt_mlx.py` must never be longer than 1500 lines. -""" - -from __future__ import annotations - -import copy -import glob -import io -import math -import os -import random -import subprocess -import sys -import time -import uuid -import zlib -from pathlib import Path - -import numpy as np -import sentencepiece as spm -import torch -import torch.distributed as dist -import torch.nn.functional as F -from torch import Tensor, nn -from torch.nn.parallel import DistributedDataParallel as DDP - -# ----------------------------- -# HYPERPARAMETERS -# ----------------------------- -# Default Simple Baseline run: -# - 9 transformer blocks at width 512 -# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion -# - vocab size 1024, sequence length 1024, tied embeddings -# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap - -class Hyperparameters: - # Data paths are shard globs produced by the existing preprocessing pipeline. - data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") - train_files = os.path.join(data_path, "fineweb_train_*.bin") - val_files = os.path.join(data_path, "fineweb_val_*.bin") - tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") - run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) - seed = int(os.environ.get("SEED", 1337)) - - # Validation cadence and batch size. Validation always uses the full fineweb_val split. - val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) - val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) - train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) - - # Training length. - iterations = int(os.environ.get("ITERATIONS", 20000)) - warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) - warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) - train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) - train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) - max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) - qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) - - # Model shape. - vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) - num_layers = int(os.environ.get("NUM_LAYERS", 9)) - num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) - model_dim = int(os.environ.get("MODEL_DIM", 512)) - num_heads = int(os.environ.get("NUM_HEADS", 8)) - mlp_mult = int(os.environ.get("MLP_MULT", 2)) - mlp_hidden = int(os.environ.get("MLP_HIDDEN", 0)) # override mlp_mult if > 0 - tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) - rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) - logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) - - # Optimizer hyperparameters. - embed_lr = float(os.environ.get("EMBED_LR", 0.6)) - head_lr = float(os.environ.get("HEAD_LR", 0.008)) - tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) - tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) - matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) - scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) - muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) - muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) - muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) - muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) - beta1 = float(os.environ.get("BETA1", 0.9)) - beta2 = float(os.environ.get("BETA2", 0.95)) - adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) - grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) - -# ----------------------------- -# MUON OPTIMIZER -# ----------------------------- -# -# As borrowed from modded-nanogpt -# Background on Muon: https://kellerjordan.github.io/posts/muon/ - -def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: - # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. - # Muon uses this to normalize matrix-shaped gradients before applying them. - a, b, c = (3.4445, -4.7750, 2.0315) - X = G.bfloat16() - X /= X.norm() + eps - transposed = G.size(0) > G.size(1) - if transposed: - X = X.T - for _ in range(steps): - A = X @ X.T - B = b * A + c * A @ A - X = a * X + B @ X - return X.T if transposed else X - - -class Muon(torch.optim.Optimizer): - def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): - super().__init__( - params, - dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), - ) - - @torch.no_grad() - def step(self, closure=None): - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - distributed = dist.is_available() and dist.is_initialized() - world_size = dist.get_world_size() if distributed else 1 - rank = dist.get_rank() if distributed else 0 - - for group in self.param_groups: - params = group["params"] - if not params: - continue - lr = group["lr"] - momentum = group["momentum"] - backend_steps = group["backend_steps"] - nesterov = group["nesterov"] - - total_params = sum(int(p.numel()) for p in params) - updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) - - curr = 0 - for i, p in enumerate(params): - if i % world_size == rank and p.grad is not None: - g = p.grad - state = self.state[p] - if "momentum_buffer" not in state: - state["momentum_buffer"] = torch.zeros_like(g) - buf = state["momentum_buffer"] - buf.mul_(momentum).add_(g) - if nesterov: - g = g.add(buf, alpha=momentum) - g = zeropower_via_newtonschulz5(g, steps=backend_steps) - # Scale correction from Muon reference implementations. - g *= max(1, g.size(0) / g.size(1)) ** 0.5 - updates_flat[curr : curr + p.numel()] = g.reshape(-1) - curr += p.numel() - - if distributed: - dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) - - curr = 0 - for p in params: - g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) - p.add_(g, alpha=-lr) - curr += p.numel() - - return loss - - -# ----------------------------- -# TOKENIZER-AGNOSTIC EVALUATION SETUP -# ----------------------------- -# -# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. -# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. -# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. -# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. - -def build_sentencepiece_luts( - sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device -) -> tuple[Tensor, Tensor, Tensor]: - sp_vocab_size = int(sp.vocab_size()) - table_size = max(sp_vocab_size, vocab_size) - base_bytes_np = np.zeros((table_size,), dtype=np.int16) - has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) - is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) - for token_id in range(sp_vocab_size): - if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): - continue - is_boundary_token_np[token_id] = False - if sp.is_byte(token_id): - base_bytes_np[token_id] = 1 - continue - piece = sp.id_to_piece(token_id) - if piece.startswith("▁"): - has_leading_space_np[token_id] = True - piece = piece[1:] - base_bytes_np[token_id] = len(piece.encode("utf-8")) - return ( - torch.tensor(base_bytes_np, dtype=torch.int16, device=device), - torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), - torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), - ) - - -def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: - files = [Path(p) for p in sorted(glob.glob(pattern))] - if not files: - raise FileNotFoundError(f"No files found for pattern: {pattern}") - # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. - tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() - usable = ((tokens.numel() - 1) // seq_len) * seq_len - if usable <= 0: - raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") - return tokens[: usable + 1] - - -def eval_val( - args: Hyperparameters, - model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - grad_accum_steps: int, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, -) -> tuple[float, float]: - # Validation computes two metrics: - # - val_loss: token cross-entropy (natural log) - # - val_bpb: tokenizer-agnostic compression metric used by the challenge - local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) - if local_batch_tokens < args.train_seq_len: - raise ValueError( - "VAL_BATCH_SIZE must provide at least one sequence per rank; " - f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " - f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" - ) - local_batch_seqs = local_batch_tokens // args.train_seq_len - total_seqs = (val_tokens.numel() - 1) // args.train_seq_len - seq_start = (total_seqs * rank) // world_size - seq_end = (total_seqs * (rank + 1)) // world_size - val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) - val_token_count = torch.zeros((), device=device, dtype=torch.float64) - val_byte_count = torch.zeros((), device=device, dtype=torch.float64) - - model.eval() - with torch.inference_mode(): - for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): - batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) - raw_start = batch_seq_start * args.train_seq_len - raw_end = batch_seq_end * args.train_seq_len + 1 - local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) - x = local[:-1].reshape(-1, args.train_seq_len) - y = local[1:].reshape(-1, args.train_seq_len) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - batch_loss = model(x, y).detach() - batch_token_count = float(y.numel()) - val_loss_sum += batch_loss.to(torch.float64) * batch_token_count - val_token_count += batch_token_count - prev_ids = x.reshape(-1) - tgt_ids = y.reshape(-1) - token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) - token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) - val_byte_count += token_bytes.to(torch.float64).sum() - - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) - - val_loss = val_loss_sum / val_token_count - bits_per_token = val_loss.item() / math.log(2.0) - tokens_per_byte = val_token_count.item() / val_byte_count.item() - model.train() - return float(val_loss.item()), float(bits_per_token * tokens_per_byte) - -# ----------------------------- -# POST-TRAINING QUANTIZATION -# ----------------------------- -# -# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. -# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. -# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. - -CONTROL_TENSOR_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "CONTROL_TENSOR_NAME_PATTERNS", - "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", - ",".join(CONTROL_TENSOR_NAME_PATTERNS), - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 -INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 -INT8_PER_ROW_SCALE_DTYPE = torch.float16 -INT8_CLIP_PERCENTILE = 99.99984 -INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 - -def tensor_nbytes(t: Tensor) -> int: - return int(t.numel()) * int(t.element_size()) - -def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: - if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): - return t.float().contiguous() - if t.dtype in {torch.float32, torch.bfloat16}: - passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") - return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() - return t - -def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: - t32 = t.float() - if t32.ndim == 2: - # Matrices get one scale per row, which usually tracks output-channel - # ranges much better than a single tensor-wide scale. - clip_abs = ( - torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) - if t32.numel() - else torch.empty((t32.shape[0],), dtype=torch.float32) - ) - clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) - scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) - q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() - return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() - - # Vectors / scalars use a simpler per-tensor scale. - clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 - scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) - q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() - return q, scale - -def quantize_state_dict_int8(state_dict: dict[str, Tensor]): - # Single supported clean-script export format: - # - per-row int8 for 2D float tensors - # - per-tensor int8 for other float tensors - # - exact passthrough for non-floats - # - passthrough for small float tensors, stored as fp16 to save bytes - quantized: dict[str, Tensor] = {} - scales: dict[str, Tensor] = {} - dtypes: dict[str, str] = {} - passthrough: dict[str, Tensor] = {} - passthrough_orig_dtypes: dict[str, str] = {} - qmeta: dict[str, dict[str, object]] = {} - stats = dict.fromkeys( - ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), - 0, - ) - - for name, tensor in state_dict.items(): - t = tensor.detach().to("cpu").contiguous() - stats["param_count"] += int(t.numel()) - stats["num_tensors"] += 1 - stats["baseline_tensor_bytes"] += tensor_nbytes(t) - - if not t.is_floating_point(): - stats["num_nonfloat_tensors"] += 1 - passthrough[name] = t - stats["int8_payload_bytes"] += tensor_nbytes(t) - continue - - # Small float tensors are cheap enough to keep directly. We still downcast - # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. - # Also keep the tied embedding in fp16 to minimize quantization degradation - # on the output head, which is the most sensitive tensor for BPB. - if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL or name == "tok_emb.weight": - kept = keep_float_tensor(name, t, passthrough_orig_dtypes) - passthrough[name] = kept - stats["int8_payload_bytes"] += tensor_nbytes(kept) - continue - - stats["num_float_tensors"] += 1 - q, s = quantize_float_tensor(t) - if s.ndim > 0: - qmeta[name] = {"scheme": "per_row", "axis": 0} - quantized[name] = q - scales[name] = s - dtypes[name] = str(t.dtype).removeprefix("torch.") - stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) - - obj: dict[str, object] = { - "__quant_format__": "int8_clean_per_row_v1", - "quantized": quantized, - "scales": scales, - "dtypes": dtypes, - "passthrough": passthrough, - } - if qmeta: - obj["qmeta"] = qmeta - if passthrough_orig_dtypes: - obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes - return obj, stats - -def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: - out: dict[str, Tensor] = {} - qmeta = obj.get("qmeta", {}) - passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) - for name, q in obj["quantized"].items(): - dtype = getattr(torch, obj["dtypes"][name]) - s = obj["scales"][name] - if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: - s = s.to(dtype=torch.float32) - # Broadcast the saved row scale back across trailing dimensions. - out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() - else: - scale = float(s.item()) - out[name] = (q.float() * scale).to(dtype=dtype).contiguous() - for name, t in obj["passthrough"].items(): - # Restore small tensors, undoing the temporary fp16 storage cast if needed. - out_t = t.detach().to("cpu").contiguous() - orig_dtype = passthrough_orig_dtypes.get(name) - if isinstance(orig_dtype, str): - out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() - out[name] = out_t - return out - - -# ----------------------------- -# DATA LOADING -# ----------------------------- - -def load_data_shard(file: Path) -> Tensor: - header_bytes = 256 * np.dtype(" None: - self.file_idx = (self.file_idx + 1) % len(self.files) - self.tokens = load_data_shard(self.files[self.file_idx]) - self.pos = 0 - - def take(self, n: int) -> Tensor: - chunks: list[Tensor] = [] - remaining = n - while remaining > 0: - avail = self.tokens.numel() - self.pos - if avail <= 0: - self._advance_file() - continue - k = min(remaining, avail) - chunks.append(self.tokens[self.pos : self.pos + k]) - self.pos += k - remaining -= k - return chunks[0] if len(chunks) == 1 else torch.cat(chunks) - - -class DistributedTokenLoader: - # Each call consumes a contiguous chunk from the shared token stream, then slices out - # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. - def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): - self.rank = rank - self.world_size = world_size - self.device = device - self.stream = TokenStream(pattern) - - def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: - local_tokens = global_tokens // (self.world_size * grad_accum_steps) - per_rank_span = local_tokens + 1 - chunk = self.stream.take(per_rank_span * self.world_size) - start = self.rank * per_rank_span - local = chunk[start : start + per_rank_span].to(dtype=torch.int64) - x = local[:-1].reshape(-1, seq_len) - y = local[1:].reshape(-1, seq_len) - return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) - -# ----------------------------- -# TRANSFORMER MODULES -# ----------------------------- - -class RMSNorm(nn.Module): - def __init__(self, eps: float | None = None): - super().__init__() - self.eps = eps - - def forward(self, x: Tensor) -> Tensor: - return F.rms_norm(x, (x.size(-1),), eps=self.eps) - - -class CastedLinear(nn.Linear): - # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. - def forward(self, x: Tensor) -> Tensor: - bias = self.bias.to(x.dtype) if self.bias is not None else None - return F.linear(x, self.weight.to(x.dtype), bias) - - -def restore_low_dim_params_to_fp32(module: nn.Module) -> None: - # Keep small/control parameters in fp32 even when the model body runs in bf16. - with torch.no_grad(): - for name, param in module.named_parameters(): - if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: - param.data = param.data.float() - - -class Rotary(nn.Module): - # Caches cos/sin tables per sequence length on the current device. - def __init__(self, dim: int, base: float = 10000.0): - super().__init__() - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self._seq_len_cached = 0 - self._cos_cached: Tensor | None = None - self._sin_cached: Tensor | None = None - - def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: - if ( - self._cos_cached is None - or self._sin_cached is None - or self._seq_len_cached != seq_len - or self._cos_cached.device != device - ): - t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) - freqs = torch.outer(t, self.inv_freq.to(device)) - self._cos_cached = freqs.cos()[None, None, :, :] - self._sin_cached = freqs.sin()[None, None, :, :] - self._seq_len_cached = seq_len - return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) - - -def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: - half = x.size(-1) // 2 - x1, x2 = x[..., :half], x[..., half:] - return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) - - -class CausalSelfAttention(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - rope_base: float, - qk_gain_init: float, - ): - super().__init__() - if dim % num_heads != 0: - raise ValueError("model_dim must be divisible by num_heads") - if num_heads % num_kv_heads != 0: - raise ValueError("num_heads must be divisible by num_kv_heads") - self.num_heads = num_heads - self.num_kv_heads = num_kv_heads - self.head_dim = dim // num_heads - if self.head_dim % 2 != 0: - raise ValueError("head_dim must be even for RoPE") - kv_dim = self.num_kv_heads * self.head_dim - self.c_q = CastedLinear(dim, dim, bias=False) - self.c_k = CastedLinear(dim, kv_dim, bias=False) - self.c_v = CastedLinear(dim, kv_dim, bias=False) - self.proj = CastedLinear(dim, dim, bias=False) - self.proj._zero_init = True - self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) - self.rotary = Rotary(self.head_dim, base=rope_base) - - def forward(self, x: Tensor) -> Tensor: - bsz, seqlen, dim = x.shape - q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) - k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) - v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) - q = F.rms_norm(q, (q.size(-1),)) - k = F.rms_norm(k, (k.size(-1),)) - cos, sin = self.rotary(seqlen, x.device, q.dtype) - q = apply_rotary_emb(q, cos, sin) - k = apply_rotary_emb(k, cos, sin) - q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] - y = F.scaled_dot_product_attention( - q, - k, - v, - attn_mask=None, - is_causal=True, - enable_gqa=(self.num_kv_heads != self.num_heads), - ) - y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) - return self.proj(y) - - -class MLP(nn.Module): - # relu^2 MLP from the original modded-nanogpt setup - def __init__(self, dim: int, mlp_mult: int, mlp_hidden: int = 0): - super().__init__() - hidden = mlp_hidden if mlp_hidden > 0 else mlp_mult * dim - self.fc = CastedLinear(dim, hidden, bias=False) - self.proj = CastedLinear(hidden, dim, bias=False) - self.proj._zero_init = True - - def forward(self, x: Tensor) -> Tensor: - x = torch.relu(self.fc(x)) - return self.proj(x.square()) - - -class Block(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: int, - rope_base: float, - qk_gain_init: float, - mlp_hidden: int = 0, - ): - super().__init__() - self.attn_norm = RMSNorm() - self.mlp_norm = RMSNorm() - self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) - self.mlp = MLP(dim, mlp_mult, mlp_hidden=mlp_hidden) - self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) - - def forward(self, x: Tensor, x0: Tensor) -> Tensor: - mix = self.resid_mix.to(dtype=x.dtype) - x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 - attn_out = self.attn(self.attn_norm(x)) - x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out - x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) - return x - - -class GPT(nn.Module): - def __init__( - self, - vocab_size: int, - num_layers: int, - model_dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: int, - tie_embeddings: bool, - tied_embed_init_std: float, - logit_softcap: float, - rope_base: float, - qk_gain_init: float, - mlp_hidden: int = 0, - ): - super().__init__() - if logit_softcap <= 0.0: - raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") - self.tie_embeddings = tie_embeddings - self.tied_embed_init_std = tied_embed_init_std - self.logit_softcap = logit_softcap - self.tok_emb = nn.Embedding(vocab_size, model_dim) - self.num_encoder_layers = num_layers // 2 - self.num_decoder_layers = num_layers - self.num_encoder_layers - self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) - self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) - self.blocks = nn.ModuleList( - [ - Block( - model_dim, - num_heads, - num_kv_heads, - mlp_mult, - rope_base, - qk_gain_init, - mlp_hidden=mlp_hidden, - ) - for i in range(num_layers) - ] - ) - self.final_norm = RMSNorm() - self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) - if self.lm_head is not None: - self.lm_head._zero_init = True - self._init_weights() - - def _init_weights(self) -> None: - if self.tie_embeddings: - nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) - for module in self.modules(): - if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): - nn.init.zeros_(module.weight) - - def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: - x = self.tok_emb(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - x0 = x - skips: list[Tensor] = [] - - # First half stores skips; second half reuses them in reverse order. - for i in range(self.num_encoder_layers): - x = self.blocks[i](x, x0) - skips.append(x) - for i in range(self.num_decoder_layers): - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - x = self.blocks[self.num_encoder_layers + i](x, x0) - - x = self.final_norm(x).reshape(-1, x.size(-1)) - targets = target_ids.reshape(-1) - if self.tie_embeddings: - logits_proj = F.linear(x, self.tok_emb.weight) - else: - if self.lm_head is None: - raise RuntimeError("lm_head is required when tie_embeddings=False") - logits_proj = self.lm_head(x) - logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - return F.cross_entropy(logits.float(), targets, reduction="mean") - - -# ----------------------------- -# TRAINING -# ----------------------------- - -def main() -> None: - global zeropower_via_newtonschulz5 - - code = Path(__file__).read_text(encoding="utf-8") - args = Hyperparameters() - zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) - - # ----------------------------- - # DISTRIBUTED + CUDA SETUP - # ----------------------------- - - distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ - rank = int(os.environ.get("RANK", "0")) - world_size = int(os.environ.get("WORLD_SIZE", "1")) - local_rank = int(os.environ.get("LOCAL_RANK", "0")) - if world_size <= 0: - raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") - if 8 % world_size != 0: - raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") - grad_accum_steps = 8 // world_size - grad_scale = 1.0 / grad_accum_steps - if not torch.cuda.is_available(): - raise RuntimeError("CUDA is required") - device = torch.device("cuda", local_rank) - torch.cuda.set_device(device) - if distributed: - dist.init_process_group(backend="nccl", device_id=device) - dist.barrier() - master_process = rank == 0 - - # Fast math knobs - torch.backends.cuda.matmul.allow_tf32 = True - torch.backends.cudnn.allow_tf32 = True - from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp - - enable_cudnn_sdp(False) - enable_flash_sdp(True) - enable_mem_efficient_sdp(False) - enable_math_sdp(False) - - logfile = None - if master_process: - os.makedirs("logs", exist_ok=True) - logfile = f"logs/{args.run_id}.txt" - print(logfile) - - def log0(msg: str, console: bool = True) -> None: - if not master_process: - return - if console: - print(msg) - if logfile is not None: - with open(logfile, "a", encoding="utf-8") as f: - print(msg, file=f) - - log0(code, console=False) - log0("=" * 100, console=False) - log0(f"Running Python {sys.version}", console=False) - log0(f"Running PyTorch {torch.__version__}", console=False) - log0( - subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, - console=False, - ) - log0("=" * 100, console=False) - - # ----------------------------- - # TOKENIZER + VALIDATION METRIC SETUP - # ----------------------------- - - random.seed(args.seed) - np.random.seed(args.seed) - torch.manual_seed(args.seed) - torch.cuda.manual_seed_all(args.seed) - - if not args.tokenizer_path.endswith(".model"): - raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") - sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) - if int(sp.vocab_size()) != args.vocab_size: - raise ValueError( - f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" - ) - dataset_dir = Path(args.data_path).resolve() - actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) - val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) - base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( - sp, args.vocab_size, device - ) - log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") - log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") - log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") - - # ----------------------------- - # MODEL + OPTIMIZER SETUP - # ----------------------------- - - base_model = GPT( - vocab_size=args.vocab_size, - num_layers=args.num_layers, - model_dim=args.model_dim, - num_heads=args.num_heads, - num_kv_heads=args.num_kv_heads, - mlp_mult=args.mlp_mult, - tie_embeddings=args.tie_embeddings, - tied_embed_init_std=args.tied_embed_init_std, - logit_softcap=args.logit_softcap, - rope_base=args.rope_base, - qk_gain_init=args.qk_gain_init, - mlp_hidden=args.mlp_hidden, - ).to(device).bfloat16() - for module in base_model.modules(): - if isinstance(module, CastedLinear): - module.float() - restore_low_dim_params_to_fp32(base_model) - compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) - model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model - - # Optimizer split: - # - token embedding (Adam) uses EMBED_LR - # - untied lm_head (Adam) uses HEAD_LR - # - matrix params in transformer blocks use MATRIX_LR via Muon - # - vectors/scalars use SCALAR_LR via Adam - block_named_params = list(base_model.blocks.named_parameters()) - matrix_params = [ - p - for name, p in block_named_params - if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - scalar_params = [ - p - for name, p in block_named_params - if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - if base_model.skip_weights.numel() > 0: - scalar_params.append(base_model.skip_weights) - token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr - optimizer_tok = torch.optim.Adam( - [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizer_muon = Muon( - matrix_params, - lr=args.matrix_lr, - momentum=args.muon_momentum, - backend_steps=args.muon_backend_steps, - ) - for group in optimizer_muon.param_groups: - group["base_lr"] = args.matrix_lr - optimizer_scalar = torch.optim.Adam( - [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] - if base_model.lm_head is not None: - optimizer_head = torch.optim.Adam( - [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizers.insert(1, optimizer_head) - - n_params = sum(p.numel() for p in base_model.parameters()) - log0(f"model_params:{n_params}") - log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") - log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") - log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") - log0( - f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " - f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " - f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" - ) - log0( - f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " - f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " - f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" - ) - log0(f"seed:{args.seed}") - - # ----------------------------- - # DATA LOADER & MODEL WARMUP - # ----------------------------- - - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - def zero_grad_all() -> None: - for opt in optimizers: - opt.zero_grad(set_to_none=True) - - max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None - - def lr_mul(step: int, elapsed_ms: float) -> float: - if args.warmdown_iters <= 0: - return 1.0 - if max_wallclock_ms is None: - warmdown_start = max(args.iterations - args.warmdown_iters, 0) - return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 - step_ms = elapsed_ms / max(step, 1) - warmdown_ms = args.warmdown_iters * step_ms - remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) - return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 - - # Warmup primes the compiled forward/backward/optimizer paths, then we restore the - # initial weights/optimizer state so measured training starts from the true init. - if args.warmup_steps > 0: - initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} - initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] - model.train() - for warmup_step in range(args.warmup_steps): - zero_grad_all() - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - warmup_loss = model(x, y) - (warmup_loss * grad_scale).backward() - for opt in optimizers: - opt.step() - zero_grad_all() - if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: - log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") - base_model.load_state_dict(initial_model_state, strict=True) - for opt, state in zip(optimizers, initial_optimizer_states, strict=True): - opt.load_state_dict(state) - zero_grad_all() - if distributed: - model.require_backward_grad_sync = True - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - # ----------------------------- - # MAIN TRAINING LOOP - # ----------------------------- - - training_time_ms = 0.0 - stop_after_step: int | None = None - torch.cuda.synchronize() - t0 = time.perf_counter() - - step = 0 - while True: - last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) - - should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) - if should_validate: - torch.cuda.synchronize() - training_time_ms += 1000.0 * (time.perf_counter() - t0) - val_loss, val_bpb = eval_val( - args, - model, - rank, - world_size, - device, - grad_accum_steps, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, - ) - log0( - f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " - f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" - ) - torch.cuda.synchronize() - t0 = time.perf_counter() - - if last_step: - if stop_after_step is not None and step < args.iterations: - log0( - f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " - f"step:{step}/{args.iterations}" - ) - break - - elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - scale = lr_mul(step, elapsed_ms) - zero_grad_all() - train_loss = torch.zeros((), device=device) - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - loss = model(x, y) - train_loss += loss.detach() - (loss * grad_scale).backward() - train_loss /= grad_accum_steps - - frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 - muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum - for group in optimizer_muon.param_groups: - group["momentum"] = muon_momentum - - for opt in optimizers: - for group in opt.param_groups: - group["lr"] = group["base_lr"] * scale - - if args.grad_clip_norm > 0: - torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) - for opt in optimizers: - opt.step() - zero_grad_all() - - step += 1 - approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - should_log_train = ( - args.train_log_every > 0 - and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) - ) - if should_log_train: - log0( - f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " - f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" - ) - - # Needed to sync whether we've reached the wallclock cap. - reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms - if distributed and max_wallclock_ms is not None: - reached_cap_tensor = torch.tensor(int(reached_cap), device=device) - dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) - reached_cap = bool(reached_cap_tensor.item()) - if stop_after_step is None and reached_cap: - stop_after_step = step - - log0( - f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " - f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" - ) - - # ----------------------------- - # SERIALIZATION + ROUNDTRIP VALIDATION - # ----------------------------- - # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce - # the compressed int8+zlib artifact and validate the round-tripped weights. - - if master_process: - torch.save(base_model.state_dict(), "final_model.pt") - model_bytes = os.path.getsize("final_model.pt") - code_bytes = len(code.encode("utf-8")) - log0(f"Serialized model: {model_bytes} bytes") - log0(f"Code size: {code_bytes} bytes") - log0(f"Total submission size: {model_bytes + code_bytes} bytes") - - quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) - quant_buf = io.BytesIO() - torch.save(quant_obj, quant_buf) - quant_raw = quant_buf.getvalue() - quant_blob = zlib.compress(quant_raw, level=9) - quant_raw_bytes = len(quant_raw) - if master_process: - with open("final_model.int8.ptz", "wb") as f: - f.write(quant_blob) - quant_file_bytes = os.path.getsize("final_model.int8.ptz") - code_bytes = len(code.encode("utf-8")) - ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) - log0( - f"Serialized model int8+zlib: {quant_file_bytes} bytes " - f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" - ) - log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") - - if distributed: - dist.barrier() - with open("final_model.int8.ptz", "rb") as f: - quant_blob_disk = f.read() - quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") - base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) - torch.cuda.synchronize() - t_qeval = time.perf_counter() - q_val_loss, q_val_bpb = eval_val( - args, - model, - rank, - world_size, - device, - grad_accum_steps, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, - ) - torch.cuda.synchronize() - log0( - f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " - f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" - ) - log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") - - if distributed: - dist.destroy_process_group() - - -if __name__ == "__main__": - main() diff --git a/records/track_10min_16mb/2026-03-18_FP16Embed_WD3600/train_seed42.log b/records/track_10min_16mb/2026-03-18_FP16Embed_WD3600/train_seed42.log deleted file mode 100644 index 6b74aefaef..0000000000 --- a/records/track_10min_16mb/2026-03-18_FP16Embed_WD3600/train_seed42.log +++ /dev/null @@ -1,1311 +0,0 @@ -""" -The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. - -Hard stop: `train_gpt.py` and `train_gpt_mlx.py` must never be longer than 1500 lines. -""" - -from __future__ import annotations - -import copy -import glob -import io -import math -import os -import random -import subprocess -import sys -import time -import uuid -import zlib -from pathlib import Path - -import numpy as np -import sentencepiece as spm -import torch -import torch.distributed as dist -import torch.nn.functional as F -from torch import Tensor, nn -from torch.nn.parallel import DistributedDataParallel as DDP - -# ----------------------------- -# HYPERPARAMETERS -# ----------------------------- -# Default Simple Baseline run: -# - 9 transformer blocks at width 512 -# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion -# - vocab size 1024, sequence length 1024, tied embeddings -# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap - -class Hyperparameters: - # Data paths are shard globs produced by the existing preprocessing pipeline. - data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") - train_files = os.path.join(data_path, "fineweb_train_*.bin") - val_files = os.path.join(data_path, "fineweb_val_*.bin") - tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") - run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) - seed = int(os.environ.get("SEED", 1337)) - - # Validation cadence and batch size. Validation always uses the full fineweb_val split. - val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) - val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) - train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) - - # Training length. - iterations = int(os.environ.get("ITERATIONS", 20000)) - warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) - warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) - train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) - train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) - max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) - qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) - - # Model shape. - vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) - num_layers = int(os.environ.get("NUM_LAYERS", 9)) - num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) - model_dim = int(os.environ.get("MODEL_DIM", 512)) - num_heads = int(os.environ.get("NUM_HEADS", 8)) - mlp_mult = int(os.environ.get("MLP_MULT", 2)) - mlp_hidden = int(os.environ.get("MLP_HIDDEN", 0)) # override mlp_mult if > 0 - tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) - rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) - logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) - - # Optimizer hyperparameters. - embed_lr = float(os.environ.get("EMBED_LR", 0.6)) - head_lr = float(os.environ.get("HEAD_LR", 0.008)) - tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) - tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) - matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) - scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) - muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) - muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) - muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) - muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) - beta1 = float(os.environ.get("BETA1", 0.9)) - beta2 = float(os.environ.get("BETA2", 0.95)) - adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) - grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) - -# ----------------------------- -# MUON OPTIMIZER -# ----------------------------- -# -# As borrowed from modded-nanogpt -# Background on Muon: https://kellerjordan.github.io/posts/muon/ - -def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: - # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. - # Muon uses this to normalize matrix-shaped gradients before applying them. - a, b, c = (3.4445, -4.7750, 2.0315) - X = G.bfloat16() - X /= X.norm() + eps - transposed = G.size(0) > G.size(1) - if transposed: - X = X.T - for _ in range(steps): - A = X @ X.T - B = b * A + c * A @ A - X = a * X + B @ X - return X.T if transposed else X - - -class Muon(torch.optim.Optimizer): - def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): - super().__init__( - params, - dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), - ) - - @torch.no_grad() - def step(self, closure=None): - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - distributed = dist.is_available() and dist.is_initialized() - world_size = dist.get_world_size() if distributed else 1 - rank = dist.get_rank() if distributed else 0 - - for group in self.param_groups: - params = group["params"] - if not params: - continue - lr = group["lr"] - momentum = group["momentum"] - backend_steps = group["backend_steps"] - nesterov = group["nesterov"] - - total_params = sum(int(p.numel()) for p in params) - updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) - - curr = 0 - for i, p in enumerate(params): - if i % world_size == rank and p.grad is not None: - g = p.grad - state = self.state[p] - if "momentum_buffer" not in state: - state["momentum_buffer"] = torch.zeros_like(g) - buf = state["momentum_buffer"] - buf.mul_(momentum).add_(g) - if nesterov: - g = g.add(buf, alpha=momentum) - g = zeropower_via_newtonschulz5(g, steps=backend_steps) - # Scale correction from Muon reference implementations. - g *= max(1, g.size(0) / g.size(1)) ** 0.5 - updates_flat[curr : curr + p.numel()] = g.reshape(-1) - curr += p.numel() - - if distributed: - dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) - - curr = 0 - for p in params: - g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) - p.add_(g, alpha=-lr) - curr += p.numel() - - return loss - - -# ----------------------------- -# TOKENIZER-AGNOSTIC EVALUATION SETUP -# ----------------------------- -# -# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. -# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. -# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. -# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. - -def build_sentencepiece_luts( - sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device -) -> tuple[Tensor, Tensor, Tensor]: - sp_vocab_size = int(sp.vocab_size()) - table_size = max(sp_vocab_size, vocab_size) - base_bytes_np = np.zeros((table_size,), dtype=np.int16) - has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) - is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) - for token_id in range(sp_vocab_size): - if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): - continue - is_boundary_token_np[token_id] = False - if sp.is_byte(token_id): - base_bytes_np[token_id] = 1 - continue - piece = sp.id_to_piece(token_id) - if piece.startswith("▁"): - has_leading_space_np[token_id] = True - piece = piece[1:] - base_bytes_np[token_id] = len(piece.encode("utf-8")) - return ( - torch.tensor(base_bytes_np, dtype=torch.int16, device=device), - torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), - torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), - ) - - -def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: - files = [Path(p) for p in sorted(glob.glob(pattern))] - if not files: - raise FileNotFoundError(f"No files found for pattern: {pattern}") - # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. - tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() - usable = ((tokens.numel() - 1) // seq_len) * seq_len - if usable <= 0: - raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") - return tokens[: usable + 1] - - -def eval_val( - args: Hyperparameters, - model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - grad_accum_steps: int, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, -) -> tuple[float, float]: - # Validation computes two metrics: - # - val_loss: token cross-entropy (natural log) - # - val_bpb: tokenizer-agnostic compression metric used by the challenge - local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) - if local_batch_tokens < args.train_seq_len: - raise ValueError( - "VAL_BATCH_SIZE must provide at least one sequence per rank; " - f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " - f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" - ) - local_batch_seqs = local_batch_tokens // args.train_seq_len - total_seqs = (val_tokens.numel() - 1) // args.train_seq_len - seq_start = (total_seqs * rank) // world_size - seq_end = (total_seqs * (rank + 1)) // world_size - val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) - val_token_count = torch.zeros((), device=device, dtype=torch.float64) - val_byte_count = torch.zeros((), device=device, dtype=torch.float64) - - model.eval() - with torch.inference_mode(): - for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): - batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) - raw_start = batch_seq_start * args.train_seq_len - raw_end = batch_seq_end * args.train_seq_len + 1 - local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) - x = local[:-1].reshape(-1, args.train_seq_len) - y = local[1:].reshape(-1, args.train_seq_len) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - batch_loss = model(x, y).detach() - batch_token_count = float(y.numel()) - val_loss_sum += batch_loss.to(torch.float64) * batch_token_count - val_token_count += batch_token_count - prev_ids = x.reshape(-1) - tgt_ids = y.reshape(-1) - token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) - token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) - val_byte_count += token_bytes.to(torch.float64).sum() - - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) - - val_loss = val_loss_sum / val_token_count - bits_per_token = val_loss.item() / math.log(2.0) - tokens_per_byte = val_token_count.item() / val_byte_count.item() - model.train() - return float(val_loss.item()), float(bits_per_token * tokens_per_byte) - -# ----------------------------- -# POST-TRAINING QUANTIZATION -# ----------------------------- -# -# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. -# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. -# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. - -CONTROL_TENSOR_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "CONTROL_TENSOR_NAME_PATTERNS", - "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", - ",".join(CONTROL_TENSOR_NAME_PATTERNS), - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 -INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 -INT8_PER_ROW_SCALE_DTYPE = torch.float16 -INT8_CLIP_PERCENTILE = 99.99984 -INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 - -def tensor_nbytes(t: Tensor) -> int: - return int(t.numel()) * int(t.element_size()) - -def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: - if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): - return t.float().contiguous() - if t.dtype in {torch.float32, torch.bfloat16}: - passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") - return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() - return t - -def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: - t32 = t.float() - if t32.ndim == 2: - # Matrices get one scale per row, which usually tracks output-channel - # ranges much better than a single tensor-wide scale. - clip_abs = ( - torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) - if t32.numel() - else torch.empty((t32.shape[0],), dtype=torch.float32) - ) - clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) - scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) - q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() - return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() - - # Vectors / scalars use a simpler per-tensor scale. - clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 - scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) - q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() - return q, scale - -def quantize_state_dict_int8(state_dict: dict[str, Tensor]): - # Single supported clean-script export format: - # - per-row int8 for 2D float tensors - # - per-tensor int8 for other float tensors - # - exact passthrough for non-floats - # - passthrough for small float tensors, stored as fp16 to save bytes - quantized: dict[str, Tensor] = {} - scales: dict[str, Tensor] = {} - dtypes: dict[str, str] = {} - passthrough: dict[str, Tensor] = {} - passthrough_orig_dtypes: dict[str, str] = {} - qmeta: dict[str, dict[str, object]] = {} - stats = dict.fromkeys( - ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), - 0, - ) - - for name, tensor in state_dict.items(): - t = tensor.detach().to("cpu").contiguous() - stats["param_count"] += int(t.numel()) - stats["num_tensors"] += 1 - stats["baseline_tensor_bytes"] += tensor_nbytes(t) - - if not t.is_floating_point(): - stats["num_nonfloat_tensors"] += 1 - passthrough[name] = t - stats["int8_payload_bytes"] += tensor_nbytes(t) - continue - - # Small float tensors are cheap enough to keep directly. We still downcast - # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. - # Also keep the tied embedding in fp16 to minimize quantization degradation - # on the output head, which is the most sensitive tensor for BPB. - if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL or name == "tok_emb.weight": - kept = keep_float_tensor(name, t, passthrough_orig_dtypes) - passthrough[name] = kept - stats["int8_payload_bytes"] += tensor_nbytes(kept) - continue - - stats["num_float_tensors"] += 1 - q, s = quantize_float_tensor(t) - if s.ndim > 0: - qmeta[name] = {"scheme": "per_row", "axis": 0} - quantized[name] = q - scales[name] = s - dtypes[name] = str(t.dtype).removeprefix("torch.") - stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) - - obj: dict[str, object] = { - "__quant_format__": "int8_clean_per_row_v1", - "quantized": quantized, - "scales": scales, - "dtypes": dtypes, - "passthrough": passthrough, - } - if qmeta: - obj["qmeta"] = qmeta - if passthrough_orig_dtypes: - obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes - return obj, stats - -def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: - out: dict[str, Tensor] = {} - qmeta = obj.get("qmeta", {}) - passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) - for name, q in obj["quantized"].items(): - dtype = getattr(torch, obj["dtypes"][name]) - s = obj["scales"][name] - if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: - s = s.to(dtype=torch.float32) - # Broadcast the saved row scale back across trailing dimensions. - out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() - else: - scale = float(s.item()) - out[name] = (q.float() * scale).to(dtype=dtype).contiguous() - for name, t in obj["passthrough"].items(): - # Restore small tensors, undoing the temporary fp16 storage cast if needed. - out_t = t.detach().to("cpu").contiguous() - orig_dtype = passthrough_orig_dtypes.get(name) - if isinstance(orig_dtype, str): - out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() - out[name] = out_t - return out - - -# ----------------------------- -# DATA LOADING -# ----------------------------- - -def load_data_shard(file: Path) -> Tensor: - header_bytes = 256 * np.dtype(" None: - self.file_idx = (self.file_idx + 1) % len(self.files) - self.tokens = load_data_shard(self.files[self.file_idx]) - self.pos = 0 - - def take(self, n: int) -> Tensor: - chunks: list[Tensor] = [] - remaining = n - while remaining > 0: - avail = self.tokens.numel() - self.pos - if avail <= 0: - self._advance_file() - continue - k = min(remaining, avail) - chunks.append(self.tokens[self.pos : self.pos + k]) - self.pos += k - remaining -= k - return chunks[0] if len(chunks) == 1 else torch.cat(chunks) - - -class DistributedTokenLoader: - # Each call consumes a contiguous chunk from the shared token stream, then slices out - # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. - def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): - self.rank = rank - self.world_size = world_size - self.device = device - self.stream = TokenStream(pattern) - - def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: - local_tokens = global_tokens // (self.world_size * grad_accum_steps) - per_rank_span = local_tokens + 1 - chunk = self.stream.take(per_rank_span * self.world_size) - start = self.rank * per_rank_span - local = chunk[start : start + per_rank_span].to(dtype=torch.int64) - x = local[:-1].reshape(-1, seq_len) - y = local[1:].reshape(-1, seq_len) - return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) - -# ----------------------------- -# TRANSFORMER MODULES -# ----------------------------- - -class RMSNorm(nn.Module): - def __init__(self, eps: float | None = None): - super().__init__() - self.eps = eps - - def forward(self, x: Tensor) -> Tensor: - return F.rms_norm(x, (x.size(-1),), eps=self.eps) - - -class CastedLinear(nn.Linear): - # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. - def forward(self, x: Tensor) -> Tensor: - bias = self.bias.to(x.dtype) if self.bias is not None else None - return F.linear(x, self.weight.to(x.dtype), bias) - - -def restore_low_dim_params_to_fp32(module: nn.Module) -> None: - # Keep small/control parameters in fp32 even when the model body runs in bf16. - with torch.no_grad(): - for name, param in module.named_parameters(): - if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: - param.data = param.data.float() - - -class Rotary(nn.Module): - # Caches cos/sin tables per sequence length on the current device. - def __init__(self, dim: int, base: float = 10000.0): - super().__init__() - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self._seq_len_cached = 0 - self._cos_cached: Tensor | None = None - self._sin_cached: Tensor | None = None - - def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: - if ( - self._cos_cached is None - or self._sin_cached is None - or self._seq_len_cached != seq_len - or self._cos_cached.device != device - ): - t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) - freqs = torch.outer(t, self.inv_freq.to(device)) - self._cos_cached = freqs.cos()[None, None, :, :] - self._sin_cached = freqs.sin()[None, None, :, :] - self._seq_len_cached = seq_len - return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) - - -def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: - half = x.size(-1) // 2 - x1, x2 = x[..., :half], x[..., half:] - return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) - - -class CausalSelfAttention(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - rope_base: float, - qk_gain_init: float, - ): - super().__init__() - if dim % num_heads != 0: - raise ValueError("model_dim must be divisible by num_heads") - if num_heads % num_kv_heads != 0: - raise ValueError("num_heads must be divisible by num_kv_heads") - self.num_heads = num_heads - self.num_kv_heads = num_kv_heads - self.head_dim = dim // num_heads - if self.head_dim % 2 != 0: - raise ValueError("head_dim must be even for RoPE") - kv_dim = self.num_kv_heads * self.head_dim - self.c_q = CastedLinear(dim, dim, bias=False) - self.c_k = CastedLinear(dim, kv_dim, bias=False) - self.c_v = CastedLinear(dim, kv_dim, bias=False) - self.proj = CastedLinear(dim, dim, bias=False) - self.proj._zero_init = True - self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) - self.rotary = Rotary(self.head_dim, base=rope_base) - - def forward(self, x: Tensor) -> Tensor: - bsz, seqlen, dim = x.shape - q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) - k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) - v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) - q = F.rms_norm(q, (q.size(-1),)) - k = F.rms_norm(k, (k.size(-1),)) - cos, sin = self.rotary(seqlen, x.device, q.dtype) - q = apply_rotary_emb(q, cos, sin) - k = apply_rotary_emb(k, cos, sin) - q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] - y = F.scaled_dot_product_attention( - q, - k, - v, - attn_mask=None, - is_causal=True, - enable_gqa=(self.num_kv_heads != self.num_heads), - ) - y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) - return self.proj(y) - - -class MLP(nn.Module): - # relu^2 MLP from the original modded-nanogpt setup - def __init__(self, dim: int, mlp_mult: int, mlp_hidden: int = 0): - super().__init__() - hidden = mlp_hidden if mlp_hidden > 0 else mlp_mult * dim - self.fc = CastedLinear(dim, hidden, bias=False) - self.proj = CastedLinear(hidden, dim, bias=False) - self.proj._zero_init = True - - def forward(self, x: Tensor) -> Tensor: - x = torch.relu(self.fc(x)) - return self.proj(x.square()) - - -class Block(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: int, - rope_base: float, - qk_gain_init: float, - mlp_hidden: int = 0, - ): - super().__init__() - self.attn_norm = RMSNorm() - self.mlp_norm = RMSNorm() - self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) - self.mlp = MLP(dim, mlp_mult, mlp_hidden=mlp_hidden) - self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) - - def forward(self, x: Tensor, x0: Tensor) -> Tensor: - mix = self.resid_mix.to(dtype=x.dtype) - x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 - attn_out = self.attn(self.attn_norm(x)) - x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out - x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) - return x - - -class GPT(nn.Module): - def __init__( - self, - vocab_size: int, - num_layers: int, - model_dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: int, - tie_embeddings: bool, - tied_embed_init_std: float, - logit_softcap: float, - rope_base: float, - qk_gain_init: float, - mlp_hidden: int = 0, - ): - super().__init__() - if logit_softcap <= 0.0: - raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") - self.tie_embeddings = tie_embeddings - self.tied_embed_init_std = tied_embed_init_std - self.logit_softcap = logit_softcap - self.tok_emb = nn.Embedding(vocab_size, model_dim) - self.num_encoder_layers = num_layers // 2 - self.num_decoder_layers = num_layers - self.num_encoder_layers - self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) - self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) - self.blocks = nn.ModuleList( - [ - Block( - model_dim, - num_heads, - num_kv_heads, - mlp_mult, - rope_base, - qk_gain_init, - mlp_hidden=mlp_hidden, - ) - for i in range(num_layers) - ] - ) - self.final_norm = RMSNorm() - self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) - if self.lm_head is not None: - self.lm_head._zero_init = True - self._init_weights() - - def _init_weights(self) -> None: - if self.tie_embeddings: - nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) - for module in self.modules(): - if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): - nn.init.zeros_(module.weight) - - def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: - x = self.tok_emb(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - x0 = x - skips: list[Tensor] = [] - - # First half stores skips; second half reuses them in reverse order. - for i in range(self.num_encoder_layers): - x = self.blocks[i](x, x0) - skips.append(x) - for i in range(self.num_decoder_layers): - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - x = self.blocks[self.num_encoder_layers + i](x, x0) - - x = self.final_norm(x).reshape(-1, x.size(-1)) - targets = target_ids.reshape(-1) - if self.tie_embeddings: - logits_proj = F.linear(x, self.tok_emb.weight) - else: - if self.lm_head is None: - raise RuntimeError("lm_head is required when tie_embeddings=False") - logits_proj = self.lm_head(x) - logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - return F.cross_entropy(logits.float(), targets, reduction="mean") - - -# ----------------------------- -# TRAINING -# ----------------------------- - -def main() -> None: - global zeropower_via_newtonschulz5 - - code = Path(__file__).read_text(encoding="utf-8") - args = Hyperparameters() - zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) - - # ----------------------------- - # DISTRIBUTED + CUDA SETUP - # ----------------------------- - - distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ - rank = int(os.environ.get("RANK", "0")) - world_size = int(os.environ.get("WORLD_SIZE", "1")) - local_rank = int(os.environ.get("LOCAL_RANK", "0")) - if world_size <= 0: - raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") - if 8 % world_size != 0: - raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") - grad_accum_steps = 8 // world_size - grad_scale = 1.0 / grad_accum_steps - if not torch.cuda.is_available(): - raise RuntimeError("CUDA is required") - device = torch.device("cuda", local_rank) - torch.cuda.set_device(device) - if distributed: - dist.init_process_group(backend="nccl", device_id=device) - dist.barrier() - master_process = rank == 0 - - # Fast math knobs - torch.backends.cuda.matmul.allow_tf32 = True - torch.backends.cudnn.allow_tf32 = True - from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp - - enable_cudnn_sdp(False) - enable_flash_sdp(True) - enable_mem_efficient_sdp(False) - enable_math_sdp(False) - - logfile = None - if master_process: - os.makedirs("logs", exist_ok=True) - logfile = f"logs/{args.run_id}.txt" - print(logfile) - - def log0(msg: str, console: bool = True) -> None: - if not master_process: - return - if console: - print(msg) - if logfile is not None: - with open(logfile, "a", encoding="utf-8") as f: - print(msg, file=f) - - log0(code, console=False) - log0("=" * 100, console=False) - log0(f"Running Python {sys.version}", console=False) - log0(f"Running PyTorch {torch.__version__}", console=False) - log0( - subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, - console=False, - ) - log0("=" * 100, console=False) - - # ----------------------------- - # TOKENIZER + VALIDATION METRIC SETUP - # ----------------------------- - - random.seed(args.seed) - np.random.seed(args.seed) - torch.manual_seed(args.seed) - torch.cuda.manual_seed_all(args.seed) - - if not args.tokenizer_path.endswith(".model"): - raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") - sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) - if int(sp.vocab_size()) != args.vocab_size: - raise ValueError( - f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" - ) - dataset_dir = Path(args.data_path).resolve() - actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) - val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) - base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( - sp, args.vocab_size, device - ) - log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") - log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") - log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") - - # ----------------------------- - # MODEL + OPTIMIZER SETUP - # ----------------------------- - - base_model = GPT( - vocab_size=args.vocab_size, - num_layers=args.num_layers, - model_dim=args.model_dim, - num_heads=args.num_heads, - num_kv_heads=args.num_kv_heads, - mlp_mult=args.mlp_mult, - tie_embeddings=args.tie_embeddings, - tied_embed_init_std=args.tied_embed_init_std, - logit_softcap=args.logit_softcap, - rope_base=args.rope_base, - qk_gain_init=args.qk_gain_init, - mlp_hidden=args.mlp_hidden, - ).to(device).bfloat16() - for module in base_model.modules(): - if isinstance(module, CastedLinear): - module.float() - restore_low_dim_params_to_fp32(base_model) - compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) - model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model - - # Optimizer split: - # - token embedding (Adam) uses EMBED_LR - # - untied lm_head (Adam) uses HEAD_LR - # - matrix params in transformer blocks use MATRIX_LR via Muon - # - vectors/scalars use SCALAR_LR via Adam - block_named_params = list(base_model.blocks.named_parameters()) - matrix_params = [ - p - for name, p in block_named_params - if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - scalar_params = [ - p - for name, p in block_named_params - if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - if base_model.skip_weights.numel() > 0: - scalar_params.append(base_model.skip_weights) - token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr - optimizer_tok = torch.optim.Adam( - [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizer_muon = Muon( - matrix_params, - lr=args.matrix_lr, - momentum=args.muon_momentum, - backend_steps=args.muon_backend_steps, - ) - for group in optimizer_muon.param_groups: - group["base_lr"] = args.matrix_lr - optimizer_scalar = torch.optim.Adam( - [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] - if base_model.lm_head is not None: - optimizer_head = torch.optim.Adam( - [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizers.insert(1, optimizer_head) - - n_params = sum(p.numel() for p in base_model.parameters()) - log0(f"model_params:{n_params}") - log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") - log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") - log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") - log0( - f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " - f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " - f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" - ) - log0( - f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " - f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " - f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" - ) - log0(f"seed:{args.seed}") - - # ----------------------------- - # DATA LOADER & MODEL WARMUP - # ----------------------------- - - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - def zero_grad_all() -> None: - for opt in optimizers: - opt.zero_grad(set_to_none=True) - - max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None - - def lr_mul(step: int, elapsed_ms: float) -> float: - if args.warmdown_iters <= 0: - return 1.0 - if max_wallclock_ms is None: - warmdown_start = max(args.iterations - args.warmdown_iters, 0) - return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 - step_ms = elapsed_ms / max(step, 1) - warmdown_ms = args.warmdown_iters * step_ms - remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) - return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 - - # Warmup primes the compiled forward/backward/optimizer paths, then we restore the - # initial weights/optimizer state so measured training starts from the true init. - if args.warmup_steps > 0: - initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} - initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] - model.train() - for warmup_step in range(args.warmup_steps): - zero_grad_all() - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - warmup_loss = model(x, y) - (warmup_loss * grad_scale).backward() - for opt in optimizers: - opt.step() - zero_grad_all() - if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: - log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") - base_model.load_state_dict(initial_model_state, strict=True) - for opt, state in zip(optimizers, initial_optimizer_states, strict=True): - opt.load_state_dict(state) - zero_grad_all() - if distributed: - model.require_backward_grad_sync = True - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - # ----------------------------- - # MAIN TRAINING LOOP - # ----------------------------- - - training_time_ms = 0.0 - stop_after_step: int | None = None - torch.cuda.synchronize() - t0 = time.perf_counter() - - step = 0 - while True: - last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) - - should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) - if should_validate: - torch.cuda.synchronize() - training_time_ms += 1000.0 * (time.perf_counter() - t0) - val_loss, val_bpb = eval_val( - args, - model, - rank, - world_size, - device, - grad_accum_steps, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, - ) - log0( - f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " - f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" - ) - torch.cuda.synchronize() - t0 = time.perf_counter() - - if last_step: - if stop_after_step is not None and step < args.iterations: - log0( - f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " - f"step:{step}/{args.iterations}" - ) - break - - elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - scale = lr_mul(step, elapsed_ms) - zero_grad_all() - train_loss = torch.zeros((), device=device) - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - loss = model(x, y) - train_loss += loss.detach() - (loss * grad_scale).backward() - train_loss /= grad_accum_steps - - frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 - muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum - for group in optimizer_muon.param_groups: - group["momentum"] = muon_momentum - - for opt in optimizers: - for group in opt.param_groups: - group["lr"] = group["base_lr"] * scale - - if args.grad_clip_norm > 0: - torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) - for opt in optimizers: - opt.step() - zero_grad_all() - - step += 1 - approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - should_log_train = ( - args.train_log_every > 0 - and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) - ) - if should_log_train: - log0( - f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " - f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" - ) - - # Needed to sync whether we've reached the wallclock cap. - reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms - if distributed and max_wallclock_ms is not None: - reached_cap_tensor = torch.tensor(int(reached_cap), device=device) - dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) - reached_cap = bool(reached_cap_tensor.item()) - if stop_after_step is None and reached_cap: - stop_after_step = step - - log0( - f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " - f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" - ) - - # ----------------------------- - # SERIALIZATION + ROUNDTRIP VALIDATION - # ----------------------------- - # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce - # the compressed int8+zlib artifact and validate the round-tripped weights. - - if master_process: - torch.save(base_model.state_dict(), "final_model.pt") - model_bytes = os.path.getsize("final_model.pt") - code_bytes = len(code.encode("utf-8")) - log0(f"Serialized model: {model_bytes} bytes") - log0(f"Code size: {code_bytes} bytes") - log0(f"Total submission size: {model_bytes + code_bytes} bytes") - - quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) - quant_buf = io.BytesIO() - torch.save(quant_obj, quant_buf) - quant_raw = quant_buf.getvalue() - quant_blob = zlib.compress(quant_raw, level=9) - quant_raw_bytes = len(quant_raw) - if master_process: - with open("final_model.int8.ptz", "wb") as f: - f.write(quant_blob) - quant_file_bytes = os.path.getsize("final_model.int8.ptz") - code_bytes = len(code.encode("utf-8")) - ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) - log0( - f"Serialized model int8+zlib: {quant_file_bytes} bytes " - f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" - ) - log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") - - if distributed: - dist.barrier() - with open("final_model.int8.ptz", "rb") as f: - quant_blob_disk = f.read() - quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") - base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) - torch.cuda.synchronize() - t_qeval = time.perf_counter() - q_val_loss, q_val_bpb = eval_val( - args, - model, - rank, - world_size, - device, - grad_accum_steps, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, - ) - torch.cuda.synchronize() - log0( - f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " - f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" - ) - log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") - - if distributed: - dist.destroy_process_group() - - -if __name__ == "__main__": - main() - -==================================================================================================== -Running Python 3.12.3 (main, Nov 6 2025, 13:44:16) [GCC 13.3.0] -Running PyTorch 2.9.1+cu128 -Thu Mar 19 00:57:34 2026 -+-----------------------------------------------------------------------------------------+ -| NVIDIA-SMI 565.57.01 Driver Version: 565.57.01 CUDA Version: 12.7 | -|-----------------------------------------+------------------------+----------------------+ -| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | -| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | -| | | MIG M. | -|=========================================+========================+======================| -| 0 NVIDIA H100 80GB HBM3 On | 00000000:18:00.0 Off | 0 | -| N/A 33C P0 147W / 700W | 1518MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 1 NVIDIA H100 80GB HBM3 On | 00000000:2A:00.0 Off | 0 | -| N/A 36C P0 151W / 700W | 1518MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 2 NVIDIA H100 80GB HBM3 On | 00000000:3A:00.0 Off | 0 | -| N/A 35C P0 149W / 700W | 1518MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 3 NVIDIA H100 80GB HBM3 On | 00000000:5D:00.0 Off | 0 | -| N/A 33C P0 152W / 700W | 1518MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 4 NVIDIA H100 80GB HBM3 On | 00000000:9A:00.0 Off | 0 | -| N/A 35C P0 151W / 700W | 1518MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 5 NVIDIA H100 80GB HBM3 On | 00000000:AB:00.0 Off | 0 | -| N/A 37C P0 148W / 700W | 1518MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 6 NVIDIA H100 80GB HBM3 On | 00000000:BA:00.0 Off | 0 | -| N/A 35C P0 149W / 700W | 1518MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 7 NVIDIA H100 80GB HBM3 On | 00000000:DB:00.0 Off | 0 | -| N/A 33C P0 147W / 700W | 1518MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ - -+-----------------------------------------------------------------------------------------+ -| Processes: | -| GPU GI CI PID Type Process name GPU Memory | -| ID ID Usage | -|=========================================================================================| -+-----------------------------------------------------------------------------------------+ - -==================================================================================================== -val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model -train_loader:dataset:fineweb10B_sp1024 train_shards:80 -val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 -model_params:16765000 -world_size:8 grad_accum_steps:1 -sdp_backends:cudnn=False flash=True mem_efficient=False math=False -attention_mode:gqa num_heads:8 num_kv_heads:4 -tie_embeddings:True embed_lr:0.05 head_lr:0.0 matrix_lr:0.06 scalar_lr:0.04 -train_batch_tokens:524288 train_seq_len:1024 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 -seed:42 -warmup_step:1/20 -warmup_step:2/20 -warmup_step:3/20 -warmup_step:4/20 -warmup_step:5/20 -warmup_step:6/20 -warmup_step:7/20 -warmup_step:8/20 -warmup_step:9/20 -warmup_step:10/20 -warmup_step:11/20 -warmup_step:12/20 -warmup_step:13/20 -warmup_step:14/20 -warmup_step:15/20 -warmup_step:16/20 -warmup_step:17/20 -warmup_step:18/20 -warmup_step:19/20 -warmup_step:20/20 -step:0/20000 val_loss:6.9386 val_bpb:4.1094 train_time:0ms step_avg:0.02ms -step:1/20000 train_loss:6.9374 train_time:32ms step_avg:31.73ms -step:2/20000 train_loss:16.7845 train_time:82ms step_avg:41.03ms -step:3/20000 train_loss:8.3491 train_time:133ms step_avg:44.45ms -step:4/20000 train_loss:6.7040 train_time:185ms step_avg:46.20ms -step:5/20000 train_loss:6.8139 train_time:235ms step_avg:47.07ms -step:6/20000 train_loss:7.4557 train_time:291ms step_avg:48.43ms -step:7/20000 train_loss:6.4028 train_time:343ms step_avg:49.03ms -step:8/20000 train_loss:6.1309 train_time:398ms step_avg:49.75ms -step:9/20000 train_loss:6.0002 train_time:455ms step_avg:50.51ms -step:10/20000 train_loss:5.8936 train_time:513ms step_avg:51.26ms -step:200/20000 train_loss:2.8638 train_time:8830ms step_avg:44.15ms -step:400/20000 train_loss:2.3624 train_time:17564ms step_avg:43.91ms -step:600/20000 train_loss:2.5469 train_time:26310ms step_avg:43.85ms -step:800/20000 train_loss:2.2999 train_time:35076ms step_avg:43.85ms -step:1000/20000 train_loss:2.3767 train_time:43826ms step_avg:43.83ms -step:1200/20000 train_loss:2.3883 train_time:52566ms step_avg:43.81ms -step:1400/20000 train_loss:2.4343 train_time:61328ms step_avg:43.81ms -step:1600/20000 train_loss:2.1019 train_time:70068ms step_avg:43.79ms -step:1800/20000 train_loss:2.2065 train_time:78798ms step_avg:43.78ms -step:2000/20000 train_loss:2.2557 train_time:87543ms step_avg:43.77ms -step:2000/20000 val_loss:2.2388 val_bpb:1.3259 train_time:87562ms step_avg:43.78ms -step:2200/20000 train_loss:2.0773 train_time:96276ms step_avg:43.76ms -step:2400/20000 train_loss:2.2039 train_time:105005ms step_avg:43.75ms -step:2600/20000 train_loss:2.4142 train_time:113743ms step_avg:43.75ms -step:2800/20000 train_loss:2.2384 train_time:122552ms step_avg:43.77ms -step:3000/20000 train_loss:2.2291 train_time:131285ms step_avg:43.76ms -step:3200/20000 train_loss:2.1880 train_time:140018ms step_avg:43.76ms -step:3400/20000 train_loss:2.1604 train_time:148771ms step_avg:43.76ms -step:3600/20000 train_loss:2.1172 train_time:157501ms step_avg:43.75ms -step:3800/20000 train_loss:2.2261 train_time:166232ms step_avg:43.75ms -step:4000/20000 train_loss:2.1641 train_time:174959ms step_avg:43.74ms -step:4000/20000 val_loss:2.1722 val_bpb:1.2865 train_time:174975ms step_avg:43.74ms -step:4200/20000 train_loss:2.1771 train_time:183777ms step_avg:43.76ms -step:4400/20000 train_loss:2.1164 train_time:192514ms step_avg:43.75ms -step:4600/20000 train_loss:1.9734 train_time:201249ms step_avg:43.75ms -step:4800/20000 train_loss:2.2695 train_time:209978ms step_avg:43.75ms -step:5000/20000 train_loss:2.0314 train_time:218701ms step_avg:43.74ms -step:5200/20000 train_loss:2.1773 train_time:227437ms step_avg:43.74ms -step:5400/20000 train_loss:2.1858 train_time:236176ms step_avg:43.74ms -step:5600/20000 train_loss:2.1886 train_time:244909ms step_avg:43.73ms -step:5800/20000 train_loss:2.1449 train_time:253638ms step_avg:43.73ms -step:6000/20000 train_loss:2.2212 train_time:262378ms step_avg:43.73ms -step:6000/20000 val_loss:2.1458 val_bpb:1.2709 train_time:262394ms step_avg:43.73ms -step:6200/20000 train_loss:2.0900 train_time:271098ms step_avg:43.73ms -step:6400/20000 train_loss:2.1637 train_time:279831ms step_avg:43.72ms -step:6600/20000 train_loss:2.1237 train_time:288565ms step_avg:43.72ms -step:6800/20000 train_loss:2.1928 train_time:297295ms step_avg:43.72ms -step:7000/20000 train_loss:2.2285 train_time:306032ms step_avg:43.72ms -step:7200/20000 train_loss:2.1982 train_time:314770ms step_avg:43.72ms -step:7400/20000 train_loss:2.1229 train_time:323525ms step_avg:43.72ms -step:7600/20000 train_loss:1.9980 train_time:332277ms step_avg:43.72ms -step:7800/20000 train_loss:2.1451 train_time:341024ms step_avg:43.72ms -step:8000/20000 train_loss:2.1144 train_time:349774ms step_avg:43.72ms -step:8000/20000 val_loss:2.1252 val_bpb:1.2587 train_time:349793ms step_avg:43.72ms -step:8200/20000 train_loss:2.1890 train_time:358502ms step_avg:43.72ms -step:8400/20000 train_loss:2.1378 train_time:367330ms step_avg:43.73ms -step:8600/20000 train_loss:2.1407 train_time:376094ms step_avg:43.73ms -step:8800/20000 train_loss:2.0997 train_time:384842ms step_avg:43.73ms -step:9000/20000 train_loss:2.0303 train_time:393596ms step_avg:43.73ms -step:9200/20000 train_loss:2.0892 train_time:402333ms step_avg:43.73ms -step:9400/20000 train_loss:2.1323 train_time:411086ms step_avg:43.73ms -step:9600/20000 train_loss:2.1527 train_time:419832ms step_avg:43.73ms -step:9800/20000 train_loss:2.0751 train_time:428571ms step_avg:43.73ms -step:10000/20000 train_loss:2.1173 train_time:437330ms step_avg:43.73ms -step:10000/20000 val_loss:2.1148 val_bpb:1.2525 train_time:437347ms step_avg:43.73ms -step:10200/20000 train_loss:2.0678 train_time:446065ms step_avg:43.73ms -step:10400/20000 train_loss:2.1034 train_time:454794ms step_avg:43.73ms -step:10600/20000 train_loss:1.9767 train_time:463528ms step_avg:43.73ms -step:10800/20000 train_loss:2.1827 train_time:472270ms step_avg:43.73ms -step:11000/20000 train_loss:2.1106 train_time:481007ms step_avg:43.73ms -step:11200/20000 train_loss:2.0630 train_time:489740ms step_avg:43.73ms -step:11400/20000 train_loss:2.0444 train_time:498486ms step_avg:43.73ms -step:11600/20000 train_loss:2.0506 train_time:507218ms step_avg:43.73ms -step:11800/20000 train_loss:2.0765 train_time:515969ms step_avg:43.73ms -step:12000/20000 train_loss:2.0517 train_time:524706ms step_avg:43.73ms -step:12000/20000 val_loss:2.0813 val_bpb:1.2326 train_time:524722ms step_avg:43.73ms -step:12200/20000 train_loss:2.1971 train_time:533452ms step_avg:43.73ms -step:12400/20000 train_loss:1.8441 train_time:542306ms step_avg:43.73ms -step:12600/20000 train_loss:2.0663 train_time:551034ms step_avg:43.73ms -step:12800/20000 train_loss:2.0870 train_time:559771ms step_avg:43.73ms -step:13000/20000 train_loss:2.1597 train_time:568510ms step_avg:43.73ms -step:13200/20000 train_loss:2.1700 train_time:577237ms step_avg:43.73ms -step:13400/20000 train_loss:2.0473 train_time:585968ms step_avg:43.73ms -step:13600/20000 train_loss:1.9135 train_time:594715ms step_avg:43.73ms -step:13722/20000 val_loss:2.0576 val_bpb:1.2186 train_time:599992ms step_avg:43.72ms -stopping_early: wallclock_cap train_time:599992ms step:13722/20000 -peak memory allocated: 10111 MiB reserved: 10302 MiB -Serialized model: 66045335 bytes -Code size: 48125 bytes -Total submission size: 66093460 bytes -Serialized model int8+zlib: 15849048 bytes (payload:17405664 raw_torch:17450459 payload_ratio:3.79x) -Total submission size int8+zlib: 15897173 bytes -final_int8_zlib_roundtrip val_loss:2.0600 val_bpb:1.2201 eval_time:1408ms -final_int8_zlib_roundtrip_exact val_loss:2.06002794 val_bpb:1.22006458 diff --git a/records/track_10min_16mb/2026-03-18_LongContextSeq2048/README.md b/records/track_10min_16mb/2026-03-18_LongContextSeq2048/README.md deleted file mode 100644 index 735cc0ccc6..0000000000 --- a/records/track_10min_16mb/2026-03-18_LongContextSeq2048/README.md +++ /dev/null @@ -1,65 +0,0 @@ -This record submission is called `Long Context Seq2048 v2`. - -Configuration: -- Layout: `VOCAB_SIZE=1024 NUM_LAYERS=9 MODEL_DIM=512 NUM_HEADS=8 NUM_KV_HEADS=4 MLP_MULT=2` -- Tied output/input embeddings: `TIE_EMBEDDINGS=1` -- Sequence length: `TRAIN_SEQ_LEN=2048` -- Batching: `TRAIN_BATCH_TOKENS=524288` -- Learning rates: `TIED_EMBED_LR=0.04 MATRIX_LR=0.032 SCALAR_LR=0.032` - -Command: -```bash -NCCL_IB_DISABLE=1 \ -RUN_ID=seq2048_sxm28_full_20260319a \ -DATA_PATH=./data/datasets/fineweb10B_sp1024 \ -TOKENIZER_PATH=./data/tokenizers/fineweb_1024_bpe.model \ -VOCAB_SIZE=1024 \ -MAX_WALLCLOCK_SECONDS=600 \ -torchrun --standalone --nproc_per_node=8 \ - records/track_10min_16mb/2026-03-18_LongContextSeq2048/train_gpt.py -``` - -Verification environment: -- `8x H100 80GB HBM3` -- all-to-all `NV18` topology -- `torch 2.8.0+cu128` - -Key metrics (from `train.log` in this folder, rerun on the target SXM-class box): -- Timed training stopped at `11564/20000` steps due to the wallclock cap. -- Pre-quant eval at stop: `val_loss:2.0269`, `val_bpb:1.2005` -- Post-quant roundtrip eval: `val_loss:2.0359`, `val_bpb:1.2058` -- Exact printed metric: `final_int8_zlib_roundtrip_exact val_bpb:1.20576485` -- Train time: `600038ms` (`step_avg:51.89ms`) -- Peak memory: `10247 MiB allocated`, `10488 MiB reserved` -- Serialized model int8+zlib: `15819554 bytes` -- Code size for this standalone record script: `47716 bytes` -- Total submission size int8+zlib: `15867270 bytes` - -Additional full-run reproducibility logs included in this folder: -- `train.log`: canonical SXM rerun, `SEED=1337`, `val_bpb=1.20576485` -- `train_seed1338.log`: SXM rerun, `SEED=1338`, `val_bpb=1.20617460` -- `train_seed1339.log`: SXM rerun, `SEED=1339`, `val_bpb=1.20715923` - -Record-track significance note: -- The public repo state for this submission has `Naive Baseline` at `1.2243657`. -- The challenge therefore requires beating `1.2193657` to claim a new record. -- All three included SXM full runs clear that threshold: - - `SEED=1337`: `1.20576485` - - `SEED=1338`: `1.20617460` - - `SEED=1339`: `1.20715923` -- Sample mean across the three runs: `1.20636623` -- Sample standard deviation: `0.00071667` -- One-sided one-sample t-test against `1.2193657`: `t=31.42` with `df=2`, which gives `p=0.00051` - -Why this folder is standalone: -- `train_gpt.py` compiles from inside this record folder and was used for the canonical rerun whose output is saved as `train.log`. -- No extra Python source files are required for the training path. -- The only inputs expected at runtime are the cached dataset and tokenizer paths described in the main repo README. - -Included files: -- `train_gpt.py` (standalone winning recipe with defaults baked in) -- `README.md` (this file) -- `submission.json` (leaderboard metadata) -- `train.log` (canonical full log from the standalone record script) -- `train_seed1338.log`, `train_seed1339.log` (extra full reruns for reproducibility) -- `logs/seq2048_sxm28_*` (raw per-run tee output and trainer text logs from the SXM verification box) diff --git a/records/track_10min_16mb/2026-03-18_LongContextSeq2048/submission.json b/records/track_10min_16mb/2026-03-18_LongContextSeq2048/submission.json deleted file mode 100644 index 049a8b74ad..0000000000 --- a/records/track_10min_16mb/2026-03-18_LongContextSeq2048/submission.json +++ /dev/null @@ -1,11 +0,0 @@ -{ - "author": "Spokane Way", - "github_id": "spokane-way", - "name": "Long Context Seq2048 v2", - "blurb": "SP-1024 9x512 KV4 run at TRAIN_SEQ_LEN=2048 with tuned seq2048 learning rates (0.040/0.032/0.032). This standalone record script reproduces the SXM-verified 10-minute artifact under the 16,000,000-byte cap.", - "date": "2026-03-19T04:50:00Z", - "val_loss": 2.03588345, - "val_bpb": 1.20576485, - "bytes_total": 15867270, - "bytes_code": 47716 -} diff --git a/records/track_10min_16mb/2026-03-18_LongContextSeq2048/train.log b/records/track_10min_16mb/2026-03-18_LongContextSeq2048/train.log deleted file mode 100644 index 15a6f65b12..0000000000 --- a/records/track_10min_16mb/2026-03-18_LongContextSeq2048/train.log +++ /dev/null @@ -1,124 +0,0 @@ - -***************************************** -Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. -***************************************** -logs/seq2048_sxm28_full_20260319a.txt -val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=/root/parameter-golf-sxm28/data/tokenizers/fineweb_1024_bpe.model -train_loader:dataset:fineweb10B_sp1024 train_shards:80 -val_loader:shards pattern=/root/parameter-golf-sxm28/data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 -model_params:17059912 -world_size:8 grad_accum_steps:1 -sdp_backends:cudnn=False flash=True mem_efficient=False math=False -attention_mode:gqa num_heads:8 num_kv_heads:4 -tie_embeddings:True embed_lr:0.04 head_lr:0.0 matrix_lr:0.032 scalar_lr:0.032 -train_batch_tokens:524288 train_seq_len:2048 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 -seed:1337 -warmup_step:1/20 -warmup_step:2/20 -warmup_step:3/20 -warmup_step:4/20 -warmup_step:5/20 -warmup_step:6/20 -warmup_step:7/20 -warmup_step:8/20 -warmup_step:9/20 -warmup_step:10/20 -warmup_step:11/20 -warmup_step:12/20 -warmup_step:13/20 -warmup_step:14/20 -warmup_step:15/20 -warmup_step:16/20 -warmup_step:17/20 -warmup_step:18/20 -warmup_step:19/20 -warmup_step:20/20 -step:0/20000 val_loss:6.9357 val_bpb:4.1077 train_time:0ms step_avg:0.02ms -step:1/20000 train_loss:6.9370 train_time:27ms step_avg:27.23ms -step:2/20000 train_loss:14.7712 train_time:74ms step_avg:36.88ms -step:3/20000 train_loss:8.1324 train_time:125ms step_avg:41.59ms -step:4/20000 train_loss:6.6083 train_time:176ms step_avg:44.01ms -step:5/20000 train_loss:6.9060 train_time:227ms step_avg:45.47ms -step:6/20000 train_loss:7.6667 train_time:279ms step_avg:46.44ms -step:7/20000 train_loss:6.6546 train_time:330ms step_avg:47.13ms -step:8/20000 train_loss:6.3864 train_time:381ms step_avg:47.66ms -step:9/20000 train_loss:6.2280 train_time:433ms step_avg:48.07ms -step:10/20000 train_loss:6.1411 train_time:484ms step_avg:48.42ms -step:200/20000 train_loss:2.7753 train_time:10282ms step_avg:51.41ms -step:400/20000 train_loss:2.2990 train_time:20615ms step_avg:51.54ms -step:600/20000 train_loss:2.5004 train_time:30958ms step_avg:51.60ms -step:800/20000 train_loss:2.2435 train_time:41311ms step_avg:51.64ms -step:1000/20000 train_loss:2.3383 train_time:51684ms step_avg:51.68ms -step:1000/20000 val_loss:2.2909 val_bpb:1.3568 train_time:51717ms step_avg:51.72ms -step:1200/20000 train_loss:2.3520 train_time:62063ms step_avg:51.72ms -step:1400/20000 train_loss:2.3778 train_time:72454ms step_avg:51.75ms -step:1600/20000 train_loss:2.0422 train_time:82841ms step_avg:51.78ms -step:1800/20000 train_loss:2.1630 train_time:93248ms step_avg:51.80ms -step:2000/20000 train_loss:2.2122 train_time:103654ms step_avg:51.83ms -step:2000/20000 val_loss:2.1924 val_bpb:1.2984 train_time:103687ms step_avg:51.84ms -step:2200/20000 train_loss:2.0339 train_time:114067ms step_avg:51.85ms -step:2400/20000 train_loss:2.1666 train_time:124488ms step_avg:51.87ms -step:2600/20000 train_loss:2.3803 train_time:134904ms step_avg:51.89ms -step:2800/20000 train_loss:2.1944 train_time:145315ms step_avg:51.90ms -step:3000/20000 train_loss:2.1889 train_time:155728ms step_avg:51.91ms -step:3000/20000 val_loss:2.1524 val_bpb:1.2748 train_time:155761ms step_avg:51.92ms -step:3200/20000 train_loss:2.1507 train_time:166139ms step_avg:51.92ms -step:3400/20000 train_loss:2.1186 train_time:176537ms step_avg:51.92ms -step:3600/20000 train_loss:2.0636 train_time:186950ms step_avg:51.93ms -step:3800/20000 train_loss:2.1715 train_time:197346ms step_avg:51.93ms -step:4000/20000 train_loss:2.1326 train_time:207738ms step_avg:51.93ms -step:4000/20000 val_loss:2.1285 val_bpb:1.2606 train_time:207770ms step_avg:51.94ms -step:4200/20000 train_loss:2.1300 train_time:218180ms step_avg:51.95ms -step:4400/20000 train_loss:2.0635 train_time:228563ms step_avg:51.95ms -step:4600/20000 train_loss:1.9340 train_time:238947ms step_avg:51.95ms -step:4800/20000 train_loss:2.2169 train_time:249326ms step_avg:51.94ms -step:5000/20000 train_loss:1.9728 train_time:259712ms step_avg:51.94ms -step:5000/20000 val_loss:2.1118 val_bpb:1.2507 train_time:259745ms step_avg:51.95ms -step:5200/20000 train_loss:2.1346 train_time:270102ms step_avg:51.94ms -step:5400/20000 train_loss:2.1480 train_time:280489ms step_avg:51.94ms -step:5600/20000 train_loss:2.1403 train_time:290858ms step_avg:51.94ms -step:5800/20000 train_loss:2.0939 train_time:301230ms step_avg:51.94ms -step:6000/20000 train_loss:2.1745 train_time:311608ms step_avg:51.93ms -step:6000/20000 val_loss:2.1015 val_bpb:1.2446 train_time:311642ms step_avg:51.94ms -step:6200/20000 train_loss:2.0438 train_time:321983ms step_avg:51.93ms -step:6400/20000 train_loss:2.1272 train_time:332352ms step_avg:51.93ms -step:6600/20000 train_loss:2.0825 train_time:342718ms step_avg:51.93ms -step:6800/20000 train_loss:2.1436 train_time:353087ms step_avg:51.92ms -step:7000/20000 train_loss:2.1914 train_time:363453ms step_avg:51.92ms -step:7000/20000 val_loss:2.0907 val_bpb:1.2382 train_time:363485ms step_avg:51.93ms -step:7200/20000 train_loss:2.1618 train_time:373813ms step_avg:51.92ms -step:7400/20000 train_loss:2.0806 train_time:384181ms step_avg:51.92ms -step:7600/20000 train_loss:1.9643 train_time:394550ms step_avg:51.91ms -step:7800/20000 train_loss:2.1069 train_time:404903ms step_avg:51.91ms -step:8000/20000 train_loss:2.0808 train_time:415270ms step_avg:51.91ms -step:8000/20000 val_loss:2.0816 val_bpb:1.2328 train_time:415302ms step_avg:51.91ms -step:8200/20000 train_loss:2.1517 train_time:425628ms step_avg:51.91ms -step:8400/20000 train_loss:2.0958 train_time:436033ms step_avg:51.91ms -step:8600/20000 train_loss:2.1052 train_time:446388ms step_avg:51.91ms -step:8800/20000 train_loss:2.0699 train_time:456752ms step_avg:51.90ms -step:9000/20000 train_loss:1.9858 train_time:467109ms step_avg:51.90ms -step:9000/20000 val_loss:2.0765 val_bpb:1.2298 train_time:467142ms step_avg:51.90ms -step:9200/20000 train_loss:2.0473 train_time:477468ms step_avg:51.90ms -step:9400/20000 train_loss:2.0934 train_time:487824ms step_avg:51.90ms -step:9600/20000 train_loss:2.1151 train_time:498188ms step_avg:51.89ms -step:9800/20000 train_loss:2.0174 train_time:508551ms step_avg:51.89ms -step:10000/20000 train_loss:2.0742 train_time:518903ms step_avg:51.89ms -step:10000/20000 val_loss:2.0715 val_bpb:1.2268 train_time:518936ms step_avg:51.89ms -step:10200/20000 train_loss:2.0357 train_time:529265ms step_avg:51.89ms -step:10400/20000 train_loss:2.0548 train_time:539622ms step_avg:51.89ms -step:10600/20000 train_loss:1.9345 train_time:549977ms step_avg:51.88ms -step:10800/20000 train_loss:2.1369 train_time:560331ms step_avg:51.88ms -step:11000/20000 train_loss:2.0578 train_time:570691ms step_avg:51.88ms -step:11000/20000 val_loss:2.0447 val_bpb:1.2110 train_time:570724ms step_avg:51.88ms -step:11200/20000 train_loss:2.0111 train_time:581136ms step_avg:51.89ms -step:11400/20000 train_loss:1.9882 train_time:591500ms step_avg:51.89ms -step:11564/20000 val_loss:2.0269 val_bpb:1.2005 train_time:600038ms step_avg:51.89ms -stopping_early: wallclock_cap train_time:600038ms step:11564/20000 -peak memory allocated: 10247 MiB reserved: 10488 MiB -Serialized model: 67224983 bytes -Code size: 47716 bytes -Total submission size: 67272699 bytes -Serialized model int8+zlib: 15819554 bytes (payload:17178912 raw_torch:17224025 payload_ratio:3.91x) -Total submission size int8+zlib: 15867270 bytes -final_int8_zlib_roundtrip val_loss:2.0359 val_bpb:1.2058 eval_time:1639ms -final_int8_zlib_roundtrip_exact val_loss:2.03588345 val_bpb:1.20576485 diff --git a/records/track_10min_16mb/2026-03-18_LongContextSeq2048/train_gpt.py b/records/track_10min_16mb/2026-03-18_LongContextSeq2048/train_gpt.py deleted file mode 100644 index a9ceb044e5..0000000000 --- a/records/track_10min_16mb/2026-03-18_LongContextSeq2048/train_gpt.py +++ /dev/null @@ -1,1127 +0,0 @@ -""" -The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. - -Hard stop: `train_gpt.py` and `train_gpt_mlx.py` must never be longer than 1500 lines. -""" - -from __future__ import annotations - -import copy -import glob -import io -import math -import os -import random -import subprocess -import sys -import time -import uuid -import zlib -from pathlib import Path - -import numpy as np -import sentencepiece as spm -import torch -import torch.distributed as dist -import torch.nn.functional as F -from torch import Tensor, nn -from torch.nn.parallel import DistributedDataParallel as DDP - -# ----------------------------- -# HYPERPARAMETERS -# ----------------------------- -# Default Long Context Seq2048 v2 run: -# - 9 transformer blocks at width 512 -# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion -# - vocab size 1024, sequence length 2048, tied embeddings -# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap -# - tuned seq2048 learning rates: 0.040 / 0.032 / 0.032 - -class Hyperparameters: - # Data paths are shard globs produced by the existing preprocessing pipeline. - data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") - train_files = os.path.join(data_path, "fineweb_train_*.bin") - val_files = os.path.join(data_path, "fineweb_val_*.bin") - tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") - run_id = os.environ.get("RUN_ID", "long_context_seq2048_v2") - seed = int(os.environ.get("SEED", 1337)) - - # Validation cadence and batch size. Validation always uses the full fineweb_val split. - val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) - val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) - train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) - - # Training length. - iterations = int(os.environ.get("ITERATIONS", 20000)) - warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) - warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) - train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) - train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) - max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) - qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) - - # Model shape. - vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) - num_layers = int(os.environ.get("NUM_LAYERS", 9)) - num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) - model_dim = int(os.environ.get("MODEL_DIM", 512)) - num_heads = int(os.environ.get("NUM_HEADS", 8)) - mlp_mult = int(os.environ.get("MLP_MULT", 2)) - tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) - rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) - logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) - - # Optimizer hyperparameters. - embed_lr = float(os.environ.get("EMBED_LR", 0.6)) - head_lr = float(os.environ.get("HEAD_LR", 0.008)) - tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.04)) - tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) - matrix_lr = float(os.environ.get("MATRIX_LR", 0.032)) - scalar_lr = float(os.environ.get("SCALAR_LR", 0.032)) - muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) - muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) - muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) - muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) - beta1 = float(os.environ.get("BETA1", 0.9)) - beta2 = float(os.environ.get("BETA2", 0.95)) - adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) - grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) - -# ----------------------------- -# MUON OPTIMIZER -# ----------------------------- -# -# As borrowed from modded-nanogpt -# Background on Muon: https://kellerjordan.github.io/posts/muon/ - -def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: - # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. - # Muon uses this to normalize matrix-shaped gradients before applying them. - a, b, c = (3.4445, -4.7750, 2.0315) - X = G.bfloat16() - X /= X.norm() + eps - transposed = G.size(0) > G.size(1) - if transposed: - X = X.T - for _ in range(steps): - A = X @ X.T - B = b * A + c * A @ A - X = a * X + B @ X - return X.T if transposed else X - - -class Muon(torch.optim.Optimizer): - def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): - super().__init__( - params, - dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), - ) - - @torch.no_grad() - def step(self, closure=None): - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - distributed = dist.is_available() and dist.is_initialized() - world_size = dist.get_world_size() if distributed else 1 - rank = dist.get_rank() if distributed else 0 - - for group in self.param_groups: - params = group["params"] - if not params: - continue - lr = group["lr"] - momentum = group["momentum"] - backend_steps = group["backend_steps"] - nesterov = group["nesterov"] - - total_params = sum(int(p.numel()) for p in params) - updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) - - curr = 0 - for i, p in enumerate(params): - if i % world_size == rank and p.grad is not None: - g = p.grad - state = self.state[p] - if "momentum_buffer" not in state: - state["momentum_buffer"] = torch.zeros_like(g) - buf = state["momentum_buffer"] - buf.mul_(momentum).add_(g) - if nesterov: - g = g.add(buf, alpha=momentum) - g = zeropower_via_newtonschulz5(g, steps=backend_steps) - # Scale correction from Muon reference implementations. - g *= max(1, g.size(0) / g.size(1)) ** 0.5 - updates_flat[curr : curr + p.numel()] = g.reshape(-1) - curr += p.numel() - - if distributed: - dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) - - curr = 0 - for p in params: - g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) - p.add_(g, alpha=-lr) - curr += p.numel() - - return loss - - -# ----------------------------- -# TOKENIZER-AGNOSTIC EVALUATION SETUP -# ----------------------------- -# -# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. -# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. -# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. -# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. - -def build_sentencepiece_luts( - sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device -) -> tuple[Tensor, Tensor, Tensor]: - sp_vocab_size = int(sp.vocab_size()) - table_size = max(sp_vocab_size, vocab_size) - base_bytes_np = np.zeros((table_size,), dtype=np.int16) - has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) - is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) - for token_id in range(sp_vocab_size): - if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): - continue - is_boundary_token_np[token_id] = False - if sp.is_byte(token_id): - base_bytes_np[token_id] = 1 - continue - piece = sp.id_to_piece(token_id) - if piece.startswith("▁"): - has_leading_space_np[token_id] = True - piece = piece[1:] - base_bytes_np[token_id] = len(piece.encode("utf-8")) - return ( - torch.tensor(base_bytes_np, dtype=torch.int16, device=device), - torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), - torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), - ) - - -def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: - files = [Path(p) for p in sorted(glob.glob(pattern))] - if not files: - raise FileNotFoundError(f"No files found for pattern: {pattern}") - # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. - tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() - usable = ((tokens.numel() - 1) // seq_len) * seq_len - if usable <= 0: - raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") - return tokens[: usable + 1] - - -def eval_val( - args: Hyperparameters, - model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - grad_accum_steps: int, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, -) -> tuple[float, float]: - # Validation computes two metrics: - # - val_loss: token cross-entropy (natural log) - # - val_bpb: tokenizer-agnostic compression metric used by the challenge - local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) - if local_batch_tokens < args.train_seq_len: - raise ValueError( - "VAL_BATCH_SIZE must provide at least one sequence per rank; " - f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " - f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" - ) - local_batch_seqs = local_batch_tokens // args.train_seq_len - total_seqs = (val_tokens.numel() - 1) // args.train_seq_len - seq_start = (total_seqs * rank) // world_size - seq_end = (total_seqs * (rank + 1)) // world_size - val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) - val_token_count = torch.zeros((), device=device, dtype=torch.float64) - val_byte_count = torch.zeros((), device=device, dtype=torch.float64) - - model.eval() - with torch.inference_mode(): - for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): - batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) - raw_start = batch_seq_start * args.train_seq_len - raw_end = batch_seq_end * args.train_seq_len + 1 - local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) - x = local[:-1].reshape(-1, args.train_seq_len) - y = local[1:].reshape(-1, args.train_seq_len) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - batch_loss = model(x, y).detach() - batch_token_count = float(y.numel()) - val_loss_sum += batch_loss.to(torch.float64) * batch_token_count - val_token_count += batch_token_count - prev_ids = x.reshape(-1) - tgt_ids = y.reshape(-1) - token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) - token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) - val_byte_count += token_bytes.to(torch.float64).sum() - - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) - - val_loss = val_loss_sum / val_token_count - bits_per_token = val_loss.item() / math.log(2.0) - tokens_per_byte = val_token_count.item() / val_byte_count.item() - model.train() - return float(val_loss.item()), float(bits_per_token * tokens_per_byte) - -# ----------------------------- -# POST-TRAINING QUANTIZATION -# ----------------------------- -# -# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. -# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. -# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. - -CONTROL_TENSOR_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "CONTROL_TENSOR_NAME_PATTERNS", - "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", - ",".join(CONTROL_TENSOR_NAME_PATTERNS), - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 -INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 -INT8_PER_ROW_SCALE_DTYPE = torch.float16 -INT8_CLIP_PERCENTILE = 99.99984 -INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 - -def tensor_nbytes(t: Tensor) -> int: - return int(t.numel()) * int(t.element_size()) - -def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: - if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): - return t.float().contiguous() - if t.dtype in {torch.float32, torch.bfloat16}: - passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") - return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() - return t - -def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: - t32 = t.float() - if t32.ndim == 2: - # Matrices get one scale per row, which usually tracks output-channel - # ranges much better than a single tensor-wide scale. - clip_abs = ( - torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) - if t32.numel() - else torch.empty((t32.shape[0],), dtype=torch.float32) - ) - clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) - scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) - q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() - return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() - - # Vectors / scalars use a simpler per-tensor scale. - clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 - scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) - q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() - return q, scale - -def quantize_state_dict_int8(state_dict: dict[str, Tensor]): - # Single supported clean-script export format: - # - per-row int8 for 2D float tensors - # - per-tensor int8 for other float tensors - # - exact passthrough for non-floats - # - passthrough for small float tensors, stored as fp16 to save bytes - quantized: dict[str, Tensor] = {} - scales: dict[str, Tensor] = {} - dtypes: dict[str, str] = {} - passthrough: dict[str, Tensor] = {} - passthrough_orig_dtypes: dict[str, str] = {} - qmeta: dict[str, dict[str, object]] = {} - stats = dict.fromkeys( - ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), - 0, - ) - - for name, tensor in state_dict.items(): - t = tensor.detach().to("cpu").contiguous() - stats["param_count"] += int(t.numel()) - stats["num_tensors"] += 1 - stats["baseline_tensor_bytes"] += tensor_nbytes(t) - - if not t.is_floating_point(): - stats["num_nonfloat_tensors"] += 1 - passthrough[name] = t - stats["int8_payload_bytes"] += tensor_nbytes(t) - continue - - # Small float tensors are cheap enough to keep directly. We still downcast - # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. - if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: - kept = keep_float_tensor(name, t, passthrough_orig_dtypes) - passthrough[name] = kept - stats["int8_payload_bytes"] += tensor_nbytes(kept) - continue - - stats["num_float_tensors"] += 1 - q, s = quantize_float_tensor(t) - if s.ndim > 0: - qmeta[name] = {"scheme": "per_row", "axis": 0} - quantized[name] = q - scales[name] = s - dtypes[name] = str(t.dtype).removeprefix("torch.") - stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) - - obj: dict[str, object] = { - "__quant_format__": "int8_clean_per_row_v1", - "quantized": quantized, - "scales": scales, - "dtypes": dtypes, - "passthrough": passthrough, - } - if qmeta: - obj["qmeta"] = qmeta - if passthrough_orig_dtypes: - obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes - return obj, stats - -def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: - out: dict[str, Tensor] = {} - qmeta = obj.get("qmeta", {}) - passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) - for name, q in obj["quantized"].items(): - dtype = getattr(torch, obj["dtypes"][name]) - s = obj["scales"][name] - if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: - s = s.to(dtype=torch.float32) - # Broadcast the saved row scale back across trailing dimensions. - out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() - else: - scale = float(s.item()) - out[name] = (q.float() * scale).to(dtype=dtype).contiguous() - for name, t in obj["passthrough"].items(): - # Restore small tensors, undoing the temporary fp16 storage cast if needed. - out_t = t.detach().to("cpu").contiguous() - orig_dtype = passthrough_orig_dtypes.get(name) - if isinstance(orig_dtype, str): - out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() - out[name] = out_t - return out - - -# ----------------------------- -# DATA LOADING -# ----------------------------- - -def load_data_shard(file: Path) -> Tensor: - header_bytes = 256 * np.dtype(" None: - self.file_idx = (self.file_idx + 1) % len(self.files) - self.tokens = load_data_shard(self.files[self.file_idx]) - self.pos = 0 - - def take(self, n: int) -> Tensor: - chunks: list[Tensor] = [] - remaining = n - while remaining > 0: - avail = self.tokens.numel() - self.pos - if avail <= 0: - self._advance_file() - continue - k = min(remaining, avail) - chunks.append(self.tokens[self.pos : self.pos + k]) - self.pos += k - remaining -= k - return chunks[0] if len(chunks) == 1 else torch.cat(chunks) - - -class DistributedTokenLoader: - # Each call consumes a contiguous chunk from the shared token stream, then slices out - # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. - def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): - self.rank = rank - self.world_size = world_size - self.device = device - self.stream = TokenStream(pattern) - - def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: - local_tokens = global_tokens // (self.world_size * grad_accum_steps) - per_rank_span = local_tokens + 1 - chunk = self.stream.take(per_rank_span * self.world_size) - start = self.rank * per_rank_span - local = chunk[start : start + per_rank_span].to(dtype=torch.int64) - x = local[:-1].reshape(-1, seq_len) - y = local[1:].reshape(-1, seq_len) - return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) - -# ----------------------------- -# TRANSFORMER MODULES -# ----------------------------- - -class RMSNorm(nn.Module): - def __init__(self, eps: float | None = None): - super().__init__() - self.eps = eps - - def forward(self, x: Tensor) -> Tensor: - return F.rms_norm(x, (x.size(-1),), eps=self.eps) - - -class CastedLinear(nn.Linear): - # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. - def forward(self, x: Tensor) -> Tensor: - bias = self.bias.to(x.dtype) if self.bias is not None else None - return F.linear(x, self.weight.to(x.dtype), bias) - - -def restore_low_dim_params_to_fp32(module: nn.Module) -> None: - # Keep small/control parameters in fp32 even when the model body runs in bf16. - with torch.no_grad(): - for name, param in module.named_parameters(): - if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: - param.data = param.data.float() - - -class Rotary(nn.Module): - # Caches cos/sin tables per sequence length on the current device. - def __init__(self, dim: int, base: float = 10000.0): - super().__init__() - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self._seq_len_cached = 0 - self._cos_cached: Tensor | None = None - self._sin_cached: Tensor | None = None - - def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: - if ( - self._cos_cached is None - or self._sin_cached is None - or self._seq_len_cached != seq_len - or self._cos_cached.device != device - ): - t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) - freqs = torch.outer(t, self.inv_freq.to(device)) - self._cos_cached = freqs.cos()[None, None, :, :] - self._sin_cached = freqs.sin()[None, None, :, :] - self._seq_len_cached = seq_len - return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) - - -def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: - half = x.size(-1) // 2 - x1, x2 = x[..., :half], x[..., half:] - return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) - - -class CausalSelfAttention(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - rope_base: float, - qk_gain_init: float, - ): - super().__init__() - if dim % num_heads != 0: - raise ValueError("model_dim must be divisible by num_heads") - if num_heads % num_kv_heads != 0: - raise ValueError("num_heads must be divisible by num_kv_heads") - self.num_heads = num_heads - self.num_kv_heads = num_kv_heads - self.head_dim = dim // num_heads - if self.head_dim % 2 != 0: - raise ValueError("head_dim must be even for RoPE") - kv_dim = self.num_kv_heads * self.head_dim - self.c_q = CastedLinear(dim, dim, bias=False) - self.c_k = CastedLinear(dim, kv_dim, bias=False) - self.c_v = CastedLinear(dim, kv_dim, bias=False) - self.proj = CastedLinear(dim, dim, bias=False) - self.proj._zero_init = True - self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) - self.rotary = Rotary(self.head_dim, base=rope_base) - - def forward(self, x: Tensor) -> Tensor: - bsz, seqlen, dim = x.shape - q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) - k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) - v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) - q = F.rms_norm(q, (q.size(-1),)) - k = F.rms_norm(k, (k.size(-1),)) - cos, sin = self.rotary(seqlen, x.device, q.dtype) - q = apply_rotary_emb(q, cos, sin) - k = apply_rotary_emb(k, cos, sin) - q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] - y = F.scaled_dot_product_attention( - q, - k, - v, - attn_mask=None, - is_causal=True, - enable_gqa=(self.num_kv_heads != self.num_heads), - ) - y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) - return self.proj(y) - - -class MLP(nn.Module): - # relu^2 MLP from the original modded-nanogpt setup - def __init__(self, dim: int, mlp_mult: int): - super().__init__() - hidden = mlp_mult * dim - self.fc = CastedLinear(dim, hidden, bias=False) - self.proj = CastedLinear(hidden, dim, bias=False) - self.proj._zero_init = True - - def forward(self, x: Tensor) -> Tensor: - x = torch.relu(self.fc(x)) - return self.proj(x.square()) - - -class Block(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: int, - rope_base: float, - qk_gain_init: float, - ): - super().__init__() - self.attn_norm = RMSNorm() - self.mlp_norm = RMSNorm() - self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) - self.mlp = MLP(dim, mlp_mult) - self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) - - def forward(self, x: Tensor, x0: Tensor) -> Tensor: - mix = self.resid_mix.to(dtype=x.dtype) - x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 - attn_out = self.attn(self.attn_norm(x)) - x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out - x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) - return x - - -class GPT(nn.Module): - def __init__( - self, - vocab_size: int, - num_layers: int, - model_dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: int, - tie_embeddings: bool, - tied_embed_init_std: float, - logit_softcap: float, - rope_base: float, - qk_gain_init: float, - ): - super().__init__() - if logit_softcap <= 0.0: - raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") - self.tie_embeddings = tie_embeddings - self.tied_embed_init_std = tied_embed_init_std - self.logit_softcap = logit_softcap - self.tok_emb = nn.Embedding(vocab_size, model_dim) - self.num_encoder_layers = num_layers // 2 - self.num_decoder_layers = num_layers - self.num_encoder_layers - self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) - self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) - self.blocks = nn.ModuleList( - [ - Block( - model_dim, - num_heads, - num_kv_heads, - mlp_mult, - rope_base, - qk_gain_init, - ) - for i in range(num_layers) - ] - ) - self.final_norm = RMSNorm() - self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) - if self.lm_head is not None: - self.lm_head._zero_init = True - self._init_weights() - - def _init_weights(self) -> None: - if self.tie_embeddings: - nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) - for module in self.modules(): - if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): - nn.init.zeros_(module.weight) - - def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: - x = self.tok_emb(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - x0 = x - skips: list[Tensor] = [] - - # First half stores skips; second half reuses them in reverse order. - for i in range(self.num_encoder_layers): - x = self.blocks[i](x, x0) - skips.append(x) - for i in range(self.num_decoder_layers): - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - x = self.blocks[self.num_encoder_layers + i](x, x0) - - x = self.final_norm(x).reshape(-1, x.size(-1)) - targets = target_ids.reshape(-1) - if self.tie_embeddings: - logits_proj = F.linear(x, self.tok_emb.weight) - else: - if self.lm_head is None: - raise RuntimeError("lm_head is required when tie_embeddings=False") - logits_proj = self.lm_head(x) - logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - return F.cross_entropy(logits.float(), targets, reduction="mean") - - -# ----------------------------- -# TRAINING -# ----------------------------- - -def main() -> None: - global zeropower_via_newtonschulz5 - - code = Path(__file__).read_text(encoding="utf-8") - args = Hyperparameters() - zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) - - # ----------------------------- - # DISTRIBUTED + CUDA SETUP - # ----------------------------- - - distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ - rank = int(os.environ.get("RANK", "0")) - world_size = int(os.environ.get("WORLD_SIZE", "1")) - local_rank = int(os.environ.get("LOCAL_RANK", "0")) - if world_size <= 0: - raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") - if 8 % world_size != 0: - raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") - grad_accum_steps = 8 // world_size - grad_scale = 1.0 / grad_accum_steps - if not torch.cuda.is_available(): - raise RuntimeError("CUDA is required") - device = torch.device("cuda", local_rank) - torch.cuda.set_device(device) - if distributed: - dist.init_process_group(backend="nccl", device_id=device) - dist.barrier() - master_process = rank == 0 - - # Fast math knobs - torch.backends.cuda.matmul.allow_tf32 = True - torch.backends.cudnn.allow_tf32 = True - from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp - - enable_cudnn_sdp(False) - enable_flash_sdp(True) - enable_mem_efficient_sdp(False) - enable_math_sdp(False) - - logfile = None - if master_process: - os.makedirs("logs", exist_ok=True) - logfile = f"logs/{args.run_id}.txt" - print(logfile) - - def log0(msg: str, console: bool = True) -> None: - if not master_process: - return - if console: - print(msg) - if logfile is not None: - with open(logfile, "a", encoding="utf-8") as f: - print(msg, file=f) - - log0(code, console=False) - log0("=" * 100, console=False) - log0(f"Running Python {sys.version}", console=False) - log0(f"Running PyTorch {torch.__version__}", console=False) - log0( - subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, - console=False, - ) - log0("=" * 100, console=False) - - # ----------------------------- - # TOKENIZER + VALIDATION METRIC SETUP - # ----------------------------- - - random.seed(args.seed) - np.random.seed(args.seed) - torch.manual_seed(args.seed) - torch.cuda.manual_seed_all(args.seed) - - if not args.tokenizer_path.endswith(".model"): - raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") - sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) - if int(sp.vocab_size()) != args.vocab_size: - raise ValueError( - f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" - ) - dataset_dir = Path(args.data_path).resolve() - actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) - val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) - base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( - sp, args.vocab_size, device - ) - log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") - log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") - log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") - - # ----------------------------- - # MODEL + OPTIMIZER SETUP - # ----------------------------- - - base_model = GPT( - vocab_size=args.vocab_size, - num_layers=args.num_layers, - model_dim=args.model_dim, - num_heads=args.num_heads, - num_kv_heads=args.num_kv_heads, - mlp_mult=args.mlp_mult, - tie_embeddings=args.tie_embeddings, - tied_embed_init_std=args.tied_embed_init_std, - logit_softcap=args.logit_softcap, - rope_base=args.rope_base, - qk_gain_init=args.qk_gain_init, - ).to(device).bfloat16() - for module in base_model.modules(): - if isinstance(module, CastedLinear): - module.float() - restore_low_dim_params_to_fp32(base_model) - compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) - model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model - - # Optimizer split: - # - token embedding (Adam) uses EMBED_LR - # - untied lm_head (Adam) uses HEAD_LR - # - matrix params in transformer blocks use MATRIX_LR via Muon - # - vectors/scalars use SCALAR_LR via Adam - block_named_params = list(base_model.blocks.named_parameters()) - matrix_params = [ - p - for name, p in block_named_params - if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - scalar_params = [ - p - for name, p in block_named_params - if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - if base_model.skip_weights.numel() > 0: - scalar_params.append(base_model.skip_weights) - token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr - optimizer_tok = torch.optim.Adam( - [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizer_muon = Muon( - matrix_params, - lr=args.matrix_lr, - momentum=args.muon_momentum, - backend_steps=args.muon_backend_steps, - ) - for group in optimizer_muon.param_groups: - group["base_lr"] = args.matrix_lr - optimizer_scalar = torch.optim.Adam( - [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] - if base_model.lm_head is not None: - optimizer_head = torch.optim.Adam( - [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizers.insert(1, optimizer_head) - - n_params = sum(p.numel() for p in base_model.parameters()) - log0(f"model_params:{n_params}") - log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") - log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") - log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") - log0( - f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " - f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " - f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" - ) - log0( - f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " - f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " - f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" - ) - log0(f"seed:{args.seed}") - - # ----------------------------- - # DATA LOADER & MODEL WARMUP - # ----------------------------- - - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - def zero_grad_all() -> None: - for opt in optimizers: - opt.zero_grad(set_to_none=True) - - max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None - - def lr_mul(step: int, elapsed_ms: float) -> float: - if args.warmdown_iters <= 0: - return 1.0 - if max_wallclock_ms is None: - warmdown_start = max(args.iterations - args.warmdown_iters, 0) - return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 - step_ms = elapsed_ms / max(step, 1) - warmdown_ms = args.warmdown_iters * step_ms - remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) - return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 - - # Warmup primes the compiled forward/backward/optimizer paths, then we restore the - # initial weights/optimizer state so measured training starts from the true init. - if args.warmup_steps > 0: - initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} - initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] - model.train() - for warmup_step in range(args.warmup_steps): - zero_grad_all() - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - warmup_loss = model(x, y) - (warmup_loss * grad_scale).backward() - for opt in optimizers: - opt.step() - zero_grad_all() - if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: - log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") - base_model.load_state_dict(initial_model_state, strict=True) - for opt, state in zip(optimizers, initial_optimizer_states, strict=True): - opt.load_state_dict(state) - zero_grad_all() - if distributed: - model.require_backward_grad_sync = True - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - # ----------------------------- - # MAIN TRAINING LOOP - # ----------------------------- - - training_time_ms = 0.0 - stop_after_step: int | None = None - torch.cuda.synchronize() - t0 = time.perf_counter() - - step = 0 - while True: - last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) - - should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) - if should_validate: - torch.cuda.synchronize() - training_time_ms += 1000.0 * (time.perf_counter() - t0) - val_loss, val_bpb = eval_val( - args, - model, - rank, - world_size, - device, - grad_accum_steps, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, - ) - log0( - f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " - f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" - ) - torch.cuda.synchronize() - t0 = time.perf_counter() - - if last_step: - if stop_after_step is not None and step < args.iterations: - log0( - f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " - f"step:{step}/{args.iterations}" - ) - break - - elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - scale = lr_mul(step, elapsed_ms) - zero_grad_all() - train_loss = torch.zeros((), device=device) - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - loss = model(x, y) - train_loss += loss.detach() - (loss * grad_scale).backward() - train_loss /= grad_accum_steps - - frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 - muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum - for group in optimizer_muon.param_groups: - group["momentum"] = muon_momentum - - for opt in optimizers: - for group in opt.param_groups: - group["lr"] = group["base_lr"] * scale - - if args.grad_clip_norm > 0: - torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) - for opt in optimizers: - opt.step() - zero_grad_all() - - step += 1 - approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - should_log_train = ( - args.train_log_every > 0 - and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) - ) - if should_log_train: - log0( - f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " - f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" - ) - - # Needed to sync whether we've reached the wallclock cap. - reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms - if distributed and max_wallclock_ms is not None: - reached_cap_tensor = torch.tensor(int(reached_cap), device=device) - dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) - reached_cap = bool(reached_cap_tensor.item()) - if stop_after_step is None and reached_cap: - stop_after_step = step - - log0( - f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " - f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" - ) - - # ----------------------------- - # SERIALIZATION + ROUNDTRIP VALIDATION - # ----------------------------- - # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce - # the compressed int8+zlib artifact and validate the round-tripped weights. - - if master_process: - torch.save(base_model.state_dict(), "final_model.pt") - model_bytes = os.path.getsize("final_model.pt") - code_bytes = len(code.encode("utf-8")) - log0(f"Serialized model: {model_bytes} bytes") - log0(f"Code size: {code_bytes} bytes") - log0(f"Total submission size: {model_bytes + code_bytes} bytes") - - quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) - quant_buf = io.BytesIO() - torch.save(quant_obj, quant_buf) - quant_raw = quant_buf.getvalue() - quant_blob = zlib.compress(quant_raw, level=9) - quant_raw_bytes = len(quant_raw) - if master_process: - with open("final_model.int8.ptz", "wb") as f: - f.write(quant_blob) - quant_file_bytes = os.path.getsize("final_model.int8.ptz") - code_bytes = len(code.encode("utf-8")) - ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) - log0( - f"Serialized model int8+zlib: {quant_file_bytes} bytes " - f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" - ) - log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") - - if distributed: - dist.barrier() - with open("final_model.int8.ptz", "rb") as f: - quant_blob_disk = f.read() - quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") - base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) - torch.cuda.synchronize() - t_qeval = time.perf_counter() - q_val_loss, q_val_bpb = eval_val( - args, - model, - rank, - world_size, - device, - grad_accum_steps, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, - ) - torch.cuda.synchronize() - log0( - f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " - f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" - ) - log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") - - if distributed: - dist.destroy_process_group() - - -if __name__ == "__main__": - main() diff --git a/records/track_10min_16mb/2026-03-18_LongContextSeq2048/train_seed1338.log b/records/track_10min_16mb/2026-03-18_LongContextSeq2048/train_seed1338.log deleted file mode 100644 index 5851fe0a3c..0000000000 --- a/records/track_10min_16mb/2026-03-18_LongContextSeq2048/train_seed1338.log +++ /dev/null @@ -1,124 +0,0 @@ - -***************************************** -Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. -***************************************** -logs/seq2048_sxm28_seed1338_20260319a.txt -val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=/root/parameter-golf-sxm28/data/tokenizers/fineweb_1024_bpe.model -train_loader:dataset:fineweb10B_sp1024 train_shards:80 -val_loader:shards pattern=/root/parameter-golf-sxm28/data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 -model_params:17059912 -world_size:8 grad_accum_steps:1 -sdp_backends:cudnn=False flash=True mem_efficient=False math=False -attention_mode:gqa num_heads:8 num_kv_heads:4 -tie_embeddings:True embed_lr:0.04 head_lr:0.0 matrix_lr:0.032 scalar_lr:0.032 -train_batch_tokens:524288 train_seq_len:2048 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 -seed:1338 -warmup_step:1/20 -warmup_step:2/20 -warmup_step:3/20 -warmup_step:4/20 -warmup_step:5/20 -warmup_step:6/20 -warmup_step:7/20 -warmup_step:8/20 -warmup_step:9/20 -warmup_step:10/20 -warmup_step:11/20 -warmup_step:12/20 -warmup_step:13/20 -warmup_step:14/20 -warmup_step:15/20 -warmup_step:16/20 -warmup_step:17/20 -warmup_step:18/20 -warmup_step:19/20 -warmup_step:20/20 -step:0/20000 val_loss:6.9373 val_bpb:4.1086 train_time:0ms step_avg:0.02ms -step:1/20000 train_loss:6.9363 train_time:27ms step_avg:27.47ms -step:2/20000 train_loss:14.7813 train_time:73ms step_avg:36.26ms -step:3/20000 train_loss:8.1303 train_time:124ms step_avg:41.23ms -step:4/20000 train_loss:6.6459 train_time:175ms step_avg:43.65ms -step:5/20000 train_loss:6.9413 train_time:227ms step_avg:45.34ms -step:6/20000 train_loss:7.8068 train_time:278ms step_avg:46.34ms -step:7/20000 train_loss:6.8490 train_time:330ms step_avg:47.09ms -step:8/20000 train_loss:6.4809 train_time:381ms step_avg:47.61ms -step:9/20000 train_loss:6.2006 train_time:434ms step_avg:48.24ms -step:10/20000 train_loss:6.0502 train_time:483ms step_avg:48.35ms -step:200/20000 train_loss:2.7815 train_time:10297ms step_avg:51.48ms -step:400/20000 train_loss:2.3049 train_time:20646ms step_avg:51.62ms -step:600/20000 train_loss:2.4980 train_time:30990ms step_avg:51.65ms -step:800/20000 train_loss:2.2491 train_time:41345ms step_avg:51.68ms -step:1000/20000 train_loss:2.3361 train_time:51721ms step_avg:51.72ms -step:1000/20000 val_loss:2.2898 val_bpb:1.3562 train_time:51754ms step_avg:51.75ms -step:1200/20000 train_loss:2.3575 train_time:62108ms step_avg:51.76ms -step:1400/20000 train_loss:2.3813 train_time:72504ms step_avg:51.79ms -step:1600/20000 train_loss:2.0496 train_time:82894ms step_avg:51.81ms -step:1800/20000 train_loss:2.1676 train_time:93289ms step_avg:51.83ms -step:2000/20000 train_loss:2.2119 train_time:103684ms step_avg:51.84ms -step:2000/20000 val_loss:2.1937 val_bpb:1.2993 train_time:103717ms step_avg:51.86ms -step:2200/20000 train_loss:2.0320 train_time:114081ms step_avg:51.86ms -step:2400/20000 train_loss:2.1652 train_time:124471ms step_avg:51.86ms -step:2600/20000 train_loss:2.3821 train_time:134872ms step_avg:51.87ms -step:2800/20000 train_loss:2.1987 train_time:145264ms step_avg:51.88ms -step:3000/20000 train_loss:2.1912 train_time:155647ms step_avg:51.88ms -step:3000/20000 val_loss:2.1534 val_bpb:1.2754 train_time:155680ms step_avg:51.89ms -step:3200/20000 train_loss:2.1516 train_time:166037ms step_avg:51.89ms -step:3400/20000 train_loss:2.1199 train_time:176417ms step_avg:51.89ms -step:3600/20000 train_loss:2.0651 train_time:186795ms step_avg:51.89ms -step:3800/20000 train_loss:2.1694 train_time:197173ms step_avg:51.89ms -step:4000/20000 train_loss:2.1330 train_time:207554ms step_avg:51.89ms -step:4000/20000 val_loss:2.1292 val_bpb:1.2610 train_time:207587ms step_avg:51.90ms -step:4200/20000 train_loss:2.1284 train_time:217976ms step_avg:51.90ms -step:4400/20000 train_loss:2.0686 train_time:228351ms step_avg:51.90ms -step:4600/20000 train_loss:1.9371 train_time:238738ms step_avg:51.90ms -step:4800/20000 train_loss:2.2171 train_time:249109ms step_avg:51.90ms -step:5000/20000 train_loss:1.9744 train_time:259476ms step_avg:51.90ms -step:5000/20000 val_loss:2.1127 val_bpb:1.2512 train_time:259509ms step_avg:51.90ms -step:5200/20000 train_loss:2.1356 train_time:269848ms step_avg:51.89ms -step:5400/20000 train_loss:2.1527 train_time:280217ms step_avg:51.89ms -step:5600/20000 train_loss:2.1390 train_time:290578ms step_avg:51.89ms -step:5800/20000 train_loss:2.0944 train_time:300933ms step_avg:51.89ms -step:6000/20000 train_loss:2.1752 train_time:311294ms step_avg:51.88ms -step:6000/20000 val_loss:2.1026 val_bpb:1.2453 train_time:311327ms step_avg:51.89ms -step:6200/20000 train_loss:2.0458 train_time:321653ms step_avg:51.88ms -step:6400/20000 train_loss:2.1240 train_time:332019ms step_avg:51.88ms -step:6600/20000 train_loss:2.0830 train_time:342381ms step_avg:51.88ms -step:6800/20000 train_loss:2.1434 train_time:352738ms step_avg:51.87ms -step:7000/20000 train_loss:2.1907 train_time:363096ms step_avg:51.87ms -step:7000/20000 val_loss:2.0916 val_bpb:1.2388 train_time:363129ms step_avg:51.88ms -step:7200/20000 train_loss:2.1672 train_time:373450ms step_avg:51.87ms -step:7400/20000 train_loss:2.0842 train_time:383806ms step_avg:51.87ms -step:7600/20000 train_loss:1.9615 train_time:394163ms step_avg:51.86ms -step:7800/20000 train_loss:2.1113 train_time:404518ms step_avg:51.86ms -step:8000/20000 train_loss:2.0788 train_time:414870ms step_avg:51.86ms -step:8000/20000 val_loss:2.0826 val_bpb:1.2334 train_time:414903ms step_avg:51.86ms -step:8200/20000 train_loss:2.1505 train_time:425230ms step_avg:51.86ms -step:8400/20000 train_loss:2.0933 train_time:435626ms step_avg:51.86ms -step:8600/20000 train_loss:2.1070 train_time:445977ms step_avg:51.86ms -step:8800/20000 train_loss:2.0708 train_time:456329ms step_avg:51.86ms -step:9000/20000 train_loss:1.9882 train_time:466685ms step_avg:51.85ms -step:9000/20000 val_loss:2.0772 val_bpb:1.2302 train_time:466718ms step_avg:51.86ms -step:9200/20000 train_loss:2.0470 train_time:477091ms step_avg:51.86ms -step:9400/20000 train_loss:2.0941 train_time:487469ms step_avg:51.86ms -step:9600/20000 train_loss:2.1116 train_time:497817ms step_avg:51.86ms -step:9800/20000 train_loss:2.0202 train_time:508164ms step_avg:51.85ms -step:10000/20000 train_loss:2.0783 train_time:518510ms step_avg:51.85ms -step:10000/20000 val_loss:2.0723 val_bpb:1.2274 train_time:518543ms step_avg:51.85ms -step:10200/20000 train_loss:2.0342 train_time:528861ms step_avg:51.85ms -step:10400/20000 train_loss:2.0587 train_time:539208ms step_avg:51.85ms -step:10600/20000 train_loss:1.9323 train_time:549552ms step_avg:51.84ms -step:10800/20000 train_loss:2.1371 train_time:559906ms step_avg:51.84ms -step:11000/20000 train_loss:2.0567 train_time:570254ms step_avg:51.84ms -step:11000/20000 val_loss:2.0458 val_bpb:1.2116 train_time:570286ms step_avg:51.84ms -step:11200/20000 train_loss:2.0119 train_time:580608ms step_avg:51.84ms -step:11400/20000 train_loss:1.9920 train_time:590954ms step_avg:51.84ms -step:11575/20000 val_loss:2.0274 val_bpb:1.2007 train_time:600043ms step_avg:51.84ms -stopping_early: wallclock_cap train_time:600043ms step:11575/20000 -peak memory allocated: 10247 MiB reserved: 10312 MiB -Serialized model: 67224983 bytes -Code size: 47716 bytes -Total submission size: 67272699 bytes -Serialized model int8+zlib: 15813523 bytes (payload:17178912 raw_torch:17224025 payload_ratio:3.91x) -Total submission size int8+zlib: 15861239 bytes -final_int8_zlib_roundtrip val_loss:2.0366 val_bpb:1.2062 eval_time:1638ms -final_int8_zlib_roundtrip_exact val_loss:2.03657529 val_bpb:1.20617460 diff --git a/records/track_10min_16mb/2026-03-18_LongContextSeq2048/train_seed1339.log b/records/track_10min_16mb/2026-03-18_LongContextSeq2048/train_seed1339.log deleted file mode 100644 index 527d535a46..0000000000 --- a/records/track_10min_16mb/2026-03-18_LongContextSeq2048/train_seed1339.log +++ /dev/null @@ -1,124 +0,0 @@ - -***************************************** -Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. -***************************************** -logs/seq2048_sxm28_seed1339_20260319a.txt -val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=/root/parameter-golf-sxm28/data/tokenizers/fineweb_1024_bpe.model -train_loader:dataset:fineweb10B_sp1024 train_shards:80 -val_loader:shards pattern=/root/parameter-golf-sxm28/data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 -model_params:17059912 -world_size:8 grad_accum_steps:1 -sdp_backends:cudnn=False flash=True mem_efficient=False math=False -attention_mode:gqa num_heads:8 num_kv_heads:4 -tie_embeddings:True embed_lr:0.04 head_lr:0.0 matrix_lr:0.032 scalar_lr:0.032 -train_batch_tokens:524288 train_seq_len:2048 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 -seed:1339 -warmup_step:1/20 -warmup_step:2/20 -warmup_step:3/20 -warmup_step:4/20 -warmup_step:5/20 -warmup_step:6/20 -warmup_step:7/20 -warmup_step:8/20 -warmup_step:9/20 -warmup_step:10/20 -warmup_step:11/20 -warmup_step:12/20 -warmup_step:13/20 -warmup_step:14/20 -warmup_step:15/20 -warmup_step:16/20 -warmup_step:17/20 -warmup_step:18/20 -warmup_step:19/20 -warmup_step:20/20 -step:0/20000 val_loss:6.9372 val_bpb:4.1086 train_time:0ms step_avg:0.02ms -step:1/20000 train_loss:6.9375 train_time:27ms step_avg:27.41ms -step:2/20000 train_loss:14.8195 train_time:75ms step_avg:37.30ms -step:3/20000 train_loss:8.0864 train_time:126ms step_avg:41.96ms -step:4/20000 train_loss:6.5534 train_time:177ms step_avg:44.18ms -step:5/20000 train_loss:6.8329 train_time:228ms step_avg:45.62ms -step:6/20000 train_loss:7.6678 train_time:279ms step_avg:46.55ms -step:7/20000 train_loss:6.6834 train_time:330ms step_avg:47.21ms -step:8/20000 train_loss:6.3318 train_time:382ms step_avg:47.72ms -step:9/20000 train_loss:6.1684 train_time:433ms step_avg:48.12ms -step:10/20000 train_loss:6.0971 train_time:484ms step_avg:48.42ms -step:200/20000 train_loss:2.7828 train_time:10300ms step_avg:51.50ms -step:400/20000 train_loss:2.2900 train_time:20640ms step_avg:51.60ms -step:600/20000 train_loss:2.5016 train_time:30985ms step_avg:51.64ms -step:800/20000 train_loss:2.2388 train_time:41342ms step_avg:51.68ms -step:1000/20000 train_loss:2.3326 train_time:51710ms step_avg:51.71ms -step:1000/20000 val_loss:2.2851 val_bpb:1.3534 train_time:51743ms step_avg:51.74ms -step:1200/20000 train_loss:2.3498 train_time:62082ms step_avg:51.74ms -step:1400/20000 train_loss:2.3834 train_time:72461ms step_avg:51.76ms -step:1600/20000 train_loss:2.0438 train_time:82831ms step_avg:51.77ms -step:1800/20000 train_loss:2.1610 train_time:93204ms step_avg:51.78ms -step:2000/20000 train_loss:2.2097 train_time:103580ms step_avg:51.79ms -step:2000/20000 val_loss:2.1921 val_bpb:1.2983 train_time:103613ms step_avg:51.81ms -step:2200/20000 train_loss:2.0283 train_time:113950ms step_avg:51.80ms -step:2400/20000 train_loss:2.1558 train_time:124322ms step_avg:51.80ms -step:2600/20000 train_loss:2.3815 train_time:134697ms step_avg:51.81ms -step:2800/20000 train_loss:2.1964 train_time:145067ms step_avg:51.81ms -step:3000/20000 train_loss:2.1879 train_time:155433ms step_avg:51.81ms -step:3000/20000 val_loss:2.1535 val_bpb:1.2754 train_time:155467ms step_avg:51.82ms -step:3200/20000 train_loss:2.1499 train_time:165801ms step_avg:51.81ms -step:3400/20000 train_loss:2.1196 train_time:176179ms step_avg:51.82ms -step:3600/20000 train_loss:2.0663 train_time:186549ms step_avg:51.82ms -step:3800/20000 train_loss:2.1720 train_time:196916ms step_avg:51.82ms -step:4000/20000 train_loss:2.1340 train_time:207282ms step_avg:51.82ms -step:4000/20000 val_loss:2.1294 val_bpb:1.2611 train_time:207315ms step_avg:51.83ms -step:4200/20000 train_loss:2.1284 train_time:217691ms step_avg:51.83ms -step:4400/20000 train_loss:2.0666 train_time:228047ms step_avg:51.83ms -step:4600/20000 train_loss:1.9387 train_time:238412ms step_avg:51.83ms -step:4800/20000 train_loss:2.2201 train_time:248762ms step_avg:51.83ms -step:5000/20000 train_loss:1.9748 train_time:259127ms step_avg:51.83ms -step:5000/20000 val_loss:2.1132 val_bpb:1.2516 train_time:259160ms step_avg:51.83ms -step:5200/20000 train_loss:2.1342 train_time:269493ms step_avg:51.83ms -step:5400/20000 train_loss:2.1527 train_time:279859ms step_avg:51.83ms -step:5600/20000 train_loss:2.1413 train_time:290221ms step_avg:51.83ms -step:5800/20000 train_loss:2.0992 train_time:300579ms step_avg:51.82ms -step:6000/20000 train_loss:2.1790 train_time:310939ms step_avg:51.82ms -step:6000/20000 val_loss:2.1039 val_bpb:1.2460 train_time:310972ms step_avg:51.83ms -step:6200/20000 train_loss:2.0485 train_time:321291ms step_avg:51.82ms -step:6400/20000 train_loss:2.1251 train_time:331655ms step_avg:51.82ms -step:6600/20000 train_loss:2.0805 train_time:342011ms step_avg:51.82ms -step:6800/20000 train_loss:2.1480 train_time:352372ms step_avg:51.82ms -step:7000/20000 train_loss:2.1942 train_time:362734ms step_avg:51.82ms -step:7000/20000 val_loss:2.0925 val_bpb:1.2393 train_time:362767ms step_avg:51.82ms -step:7200/20000 train_loss:2.1633 train_time:373093ms step_avg:51.82ms -step:7400/20000 train_loss:2.0834 train_time:383453ms step_avg:51.82ms -step:7600/20000 train_loss:1.9632 train_time:393812ms step_avg:51.82ms -step:7800/20000 train_loss:2.1106 train_time:404251ms step_avg:51.83ms -step:8000/20000 train_loss:2.0791 train_time:414608ms step_avg:51.83ms -step:8000/20000 val_loss:2.0836 val_bpb:1.2340 train_time:414641ms step_avg:51.83ms -step:8200/20000 train_loss:2.1540 train_time:424958ms step_avg:51.82ms -step:8400/20000 train_loss:2.0970 train_time:435353ms step_avg:51.83ms -step:8600/20000 train_loss:2.1104 train_time:445701ms step_avg:51.83ms -step:8800/20000 train_loss:2.0694 train_time:456052ms step_avg:51.82ms -step:9000/20000 train_loss:1.9874 train_time:466402ms step_avg:51.82ms -step:9000/20000 val_loss:2.0785 val_bpb:1.2310 train_time:466435ms step_avg:51.83ms -step:9200/20000 train_loss:2.0479 train_time:476785ms step_avg:51.82ms -step:9400/20000 train_loss:2.0944 train_time:487193ms step_avg:51.83ms -step:9600/20000 train_loss:2.1114 train_time:497576ms step_avg:51.83ms -step:9800/20000 train_loss:2.0204 train_time:507942ms step_avg:51.83ms -step:10000/20000 train_loss:2.0788 train_time:518312ms step_avg:51.83ms -step:10000/20000 val_loss:2.0729 val_bpb:1.2277 train_time:518345ms step_avg:51.83ms -step:10200/20000 train_loss:2.0372 train_time:528681ms step_avg:51.83ms -step:10400/20000 train_loss:2.0577 train_time:539056ms step_avg:51.83ms -step:10600/20000 train_loss:1.9348 train_time:549410ms step_avg:51.83ms -step:10800/20000 train_loss:2.1410 train_time:559767ms step_avg:51.83ms -step:11000/20000 train_loss:2.0562 train_time:570117ms step_avg:51.83ms -step:11000/20000 val_loss:2.0472 val_bpb:1.2125 train_time:570149ms step_avg:51.83ms -step:11200/20000 train_loss:2.0166 train_time:580461ms step_avg:51.83ms -step:11400/20000 train_loss:1.9939 train_time:590808ms step_avg:51.83ms -step:11578/20000 val_loss:2.0286 val_bpb:1.2015 train_time:600051ms step_avg:51.83ms -stopping_early: wallclock_cap train_time:600051ms step:11578/20000 -peak memory allocated: 10247 MiB reserved: 10312 MiB -Serialized model: 67224983 bytes -Code size: 47716 bytes -Total submission size: 67272699 bytes -Serialized model int8+zlib: 15814036 bytes (payload:17178912 raw_torch:17224025 payload_ratio:3.91x) -Total submission size int8+zlib: 15861752 bytes -final_int8_zlib_roundtrip val_loss:2.0382 val_bpb:1.2072 eval_time:1640ms -final_int8_zlib_roundtrip_exact val_loss:2.03823779 val_bpb:1.20715923 diff --git a/records/track_10min_16mb/2026-03-18_LowerLR/README.md b/records/track_10min_16mb/2026-03-18_LowerLR/README.md deleted file mode 100644 index de55954ab1..0000000000 --- a/records/track_10min_16mb/2026-03-18_LowerLR/README.md +++ /dev/null @@ -1,66 +0,0 @@ -This record captures the `Lower LR` submission. - -## Summary - -Same baseline architecture (9x512, SP-1024, 4 KV heads, tied embeddings, relu^2 MLP) with lower learning rates. A systematic LR sweep over 8 experiments showed the default Muon/Adam learning rates (MATRIX_LR=0.04, SCALAR_LR=0.04, TIED_EMBED_LR=0.05) were too high. Optimal is approximately half the default. - -## Changes from baseline -- `MATRIX_LR=0.02` (default: 0.04) -- `SCALAR_LR=0.02` (default: 0.04) -- `TIED_EMBED_LR=0.03` (default: 0.05) - -No architecture, schedule, or other hyperparameter changes. - -## LR sweep results (8-GPU H200, 600s) - -| MATRIX_LR | val_bpb (post-quant) | Delta vs baseline | -|---|---|---| -| 0.06 | 1.2445 | +0.0159 (much worse) | -| 0.04 (default) | 1.2286 | — | -| 0.03 | 1.2279 | -0.0007 | -| 0.025 | 1.2250 | -0.0036 | -| **0.02** | **1.2230** | **-0.0056** | -| 0.015 | 1.2234 | -0.0052 | - -Configuration: -- Layout: `VOCAB_SIZE=1024 NUM_LAYERS=9 MODEL_DIM=512 NUM_HEADS=8 NUM_KV_HEADS=4 MLP_MULT=2` -- Tied output/input embeddings: `TIE_EMBEDDINGS=1` -- Learning rates: `MATRIX_LR=0.02 SCALAR_LR=0.02 TIED_EMBED_LR=0.03` -- Batching: `TRAIN_BATCH_TOKENS=524288 TRAIN_SEQ_LEN=1024` - -Command: -```bash -RUN_ID=exp25_lr_0.02 \ -DATA_PATH=./data/datasets/fineweb10B_sp1024/ \ -TOKENIZER_PATH=./data/tokenizers/fineweb_1024_bpe.model \ -VOCAB_SIZE=1024 \ -MAX_WALLCLOCK_SECONDS=600 \ -VAL_LOSS_EVERY=200 \ -TRAIN_LOG_EVERY=50 \ -MATRIX_LR=0.02 \ -SCALAR_LR=0.02 \ -TIED_EMBED_LR=0.03 \ -torchrun --standalone --nproc_per_node=8 train_gpt.py -``` - -Key metrics (from `train.log`): -- Timed training stopped at `14421/20000` steps due to the wallclock cap. -- Pre-quant eval at stop: `val_loss:2.0571`, `val_bpb:1.2183` -- Post-quant roundtrip eval: `val_loss:2.0649`, `val_bpb:1.2230` -- Exact printed metric: `final_int8_zlib_roundtrip_exact val_bpb:1.22296644` -- Train time: `599847ms` (`step_avg:41.60ms`) -- Peak memory: `10246 MiB allocated`, `10310 MiB reserved` -- Serialized model int8+zlib: `15803327 bytes` -- Code size: `50919 bytes` -- Total submission size int8+zlib: `15854246 bytes` - -Training volume: -- Global batch: `524288` tokens/step -- Total train tokens seen: `7560609792` - -Note: Run performed on 8xH200 (141GB HBM3e). Step time (41.60ms) is comparable to 8xH100 baseline (43.54ms), and memory usage (10.2 GiB) is well within H100's 80GB limit. The ~5% faster step time on H200 yields ~400 extra steps, which may account for a small portion of the improvement. - -Included files: -- `train_gpt.py` (code snapshot used for the run) -- `train.log` (exact training log) -- `submission.json` (leaderboard metadata) diff --git a/records/track_10min_16mb/2026-03-18_LowerLR/submission.json b/records/track_10min_16mb/2026-03-18_LowerLR/submission.json deleted file mode 100644 index 42ec3279b3..0000000000 --- a/records/track_10min_16mb/2026-03-18_LowerLR/submission.json +++ /dev/null @@ -1,11 +0,0 @@ -{ - "author": "Nan Liu", - "github_id": "nanlliu", - "name": "Lower LR", - "blurb": "Same 9x512 SP-1024 KV4 tied-embedding baseline architecture with lower Muon/Adam learning rates (MATRIX_LR=0.02, SCALAR_LR=0.02, TIED_EMBED_LR=0.03). Systematic LR sweep showed default 0.04 was too high; optimal is ~0.02.", - "date": "2026-03-18T22:30:00Z", - "val_loss": 2.06492760, - "val_bpb": 1.22296644, - "bytes_total": 15854246, - "bytes_code": 50919 -} diff --git a/records/track_10min_16mb/2026-03-18_LowerLR/train.log b/records/track_10min_16mb/2026-03-18_LowerLR/train.log deleted file mode 100644 index f13b904b81..0000000000 --- a/records/track_10min_16mb/2026-03-18_LowerLR/train.log +++ /dev/null @@ -1,1667 +0,0 @@ -""" -The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. - -Hard stop: `train_gpt.py` and `train_gpt_mlx.py` must never be longer than 1500 lines. -""" - -from __future__ import annotations - -import copy -import glob -import io -import math -import os -import random -import subprocess -import sys -import time -import uuid -import zlib -from pathlib import Path - -import numpy as np -import sentencepiece as spm -import torch -import torch.distributed as dist -import torch.nn.functional as F -from torch import Tensor, nn -from torch.nn.parallel import DistributedDataParallel as DDP - -# ----------------------------- -# HYPERPARAMETERS -# ----------------------------- -# Default Simple Baseline run: -# - 9 transformer blocks at width 512 -# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion -# - vocab size 1024, sequence length 1024, tied embeddings -# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap - -class Hyperparameters: - # Data paths are shard globs produced by the existing preprocessing pipeline. - data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") - train_files = os.path.join(data_path, "fineweb_train_*.bin") - val_files = os.path.join(data_path, "fineweb_val_*.bin") - tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") - run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) - seed = int(os.environ.get("SEED", 1337)) - - # Validation cadence and batch size. Validation always uses the full fineweb_val split. - val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) - val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) - train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) - - # Training length. - iterations = int(os.environ.get("ITERATIONS", 20000)) - warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) - warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) - lr_warmup_steps = int(os.environ.get("LR_WARMUP_STEPS", 200)) - lr_schedule = os.environ.get("LR_SCHEDULE", "linear") # "linear" (original) or "cosine" - train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) - train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) - max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) - qat_start_frac = float(os.environ.get("QAT_START_FRAC", 0.0)) # fraction of training to start QAT (0=disabled, 0.8=last 20%) - qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) - - # Model shape. - vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) - num_layers = int(os.environ.get("NUM_LAYERS", 9)) - num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) - model_dim = int(os.environ.get("MODEL_DIM", 512)) - num_heads = int(os.environ.get("NUM_HEADS", 8)) - mlp_mult = int(os.environ.get("MLP_MULT", 2)) - mlp_type = os.environ.get("MLP_TYPE", "relu2") # "relu2" (original) or "swiglu" - tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) - rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) - logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) - - # Optimizer hyperparameters. - embed_lr = float(os.environ.get("EMBED_LR", 0.6)) - head_lr = float(os.environ.get("HEAD_LR", 0.008)) - tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) - tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) - matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) - scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) - muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) - muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) - muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) - muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) - beta1 = float(os.environ.get("BETA1", 0.9)) - beta2 = float(os.environ.get("BETA2", 0.95)) - adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) - grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) - -# ----------------------------- -# MUON OPTIMIZER -# ----------------------------- -# -# As borrowed from modded-nanogpt -# Background on Muon: https://kellerjordan.github.io/posts/muon/ - -def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: - # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. - # Muon uses this to normalize matrix-shaped gradients before applying them. - a, b, c = (3.4445, -4.7750, 2.0315) - X = G.bfloat16() - X /= X.norm() + eps - transposed = G.size(0) > G.size(1) - if transposed: - X = X.T - for _ in range(steps): - A = X @ X.T - B = b * A + c * A @ A - X = a * X + B @ X - return X.T if transposed else X - - -class Muon(torch.optim.Optimizer): - def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): - super().__init__( - params, - dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), - ) - - @torch.no_grad() - def step(self, closure=None): - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - distributed = dist.is_available() and dist.is_initialized() - world_size = dist.get_world_size() if distributed else 1 - rank = dist.get_rank() if distributed else 0 - - for group in self.param_groups: - params = group["params"] - if not params: - continue - lr = group["lr"] - momentum = group["momentum"] - backend_steps = group["backend_steps"] - nesterov = group["nesterov"] - - total_params = sum(int(p.numel()) for p in params) - updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) - - curr = 0 - for i, p in enumerate(params): - if i % world_size == rank and p.grad is not None: - g = p.grad - state = self.state[p] - if "momentum_buffer" not in state: - state["momentum_buffer"] = torch.zeros_like(g) - buf = state["momentum_buffer"] - buf.mul_(momentum).add_(g) - if nesterov: - g = g.add(buf, alpha=momentum) - g = zeropower_via_newtonschulz5(g, steps=backend_steps) - # Scale correction from Muon reference implementations. - g *= max(1, g.size(0) / g.size(1)) ** 0.5 - updates_flat[curr : curr + p.numel()] = g.reshape(-1) - curr += p.numel() - - if distributed: - dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) - - curr = 0 - for p in params: - g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) - p.add_(g, alpha=-lr) - curr += p.numel() - - return loss - - -# ----------------------------- -# TOKENIZER-AGNOSTIC EVALUATION SETUP -# ----------------------------- -# -# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. -# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. -# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. -# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. - -def build_sentencepiece_luts( - sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device -) -> tuple[Tensor, Tensor, Tensor]: - sp_vocab_size = int(sp.vocab_size()) - table_size = max(sp_vocab_size, vocab_size) - base_bytes_np = np.zeros((table_size,), dtype=np.int16) - has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) - is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) - for token_id in range(sp_vocab_size): - if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): - continue - is_boundary_token_np[token_id] = False - if sp.is_byte(token_id): - base_bytes_np[token_id] = 1 - continue - piece = sp.id_to_piece(token_id) - if piece.startswith("▁"): - has_leading_space_np[token_id] = True - piece = piece[1:] - base_bytes_np[token_id] = len(piece.encode("utf-8")) - return ( - torch.tensor(base_bytes_np, dtype=torch.int16, device=device), - torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), - torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), - ) - - -def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: - files = [Path(p) for p in sorted(glob.glob(pattern))] - if not files: - raise FileNotFoundError(f"No files found for pattern: {pattern}") - # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. - tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() - usable = ((tokens.numel() - 1) // seq_len) * seq_len - if usable <= 0: - raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") - return tokens[: usable + 1] - - -def eval_val( - args: Hyperparameters, - model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - grad_accum_steps: int, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, -) -> tuple[float, float]: - # Validation computes two metrics: - # - val_loss: token cross-entropy (natural log) - # - val_bpb: tokenizer-agnostic compression metric used by the challenge - local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) - if local_batch_tokens < args.train_seq_len: - raise ValueError( - "VAL_BATCH_SIZE must provide at least one sequence per rank; " - f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " - f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" - ) - local_batch_seqs = local_batch_tokens // args.train_seq_len - total_seqs = (val_tokens.numel() - 1) // args.train_seq_len - seq_start = (total_seqs * rank) // world_size - seq_end = (total_seqs * (rank + 1)) // world_size - val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) - val_token_count = torch.zeros((), device=device, dtype=torch.float64) - val_byte_count = torch.zeros((), device=device, dtype=torch.float64) - - model.eval() - with torch.inference_mode(): - for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): - batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) - raw_start = batch_seq_start * args.train_seq_len - raw_end = batch_seq_end * args.train_seq_len + 1 - local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) - x = local[:-1].reshape(-1, args.train_seq_len) - y = local[1:].reshape(-1, args.train_seq_len) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - batch_loss = model(x, y).detach() - batch_token_count = float(y.numel()) - val_loss_sum += batch_loss.to(torch.float64) * batch_token_count - val_token_count += batch_token_count - prev_ids = x.reshape(-1) - tgt_ids = y.reshape(-1) - token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) - token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) - val_byte_count += token_bytes.to(torch.float64).sum() - - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) - - val_loss = val_loss_sum / val_token_count - bits_per_token = val_loss.item() / math.log(2.0) - tokens_per_byte = val_token_count.item() / val_byte_count.item() - model.train() - return float(val_loss.item()), float(bits_per_token * tokens_per_byte) - -# ----------------------------- -# POST-TRAINING QUANTIZATION -# ----------------------------- -# -# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. -# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. -# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. - -CONTROL_TENSOR_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "CONTROL_TENSOR_NAME_PATTERNS", - "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", - ",".join(CONTROL_TENSOR_NAME_PATTERNS), - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 -INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 -INT8_PER_ROW_SCALE_DTYPE = torch.float16 -INT8_CLIP_PERCENTILE = 99.99984 -INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 - -def tensor_nbytes(t: Tensor) -> int: - return int(t.numel()) * int(t.element_size()) - -def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: - if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): - return t.float().contiguous() - if t.dtype in {torch.float32, torch.bfloat16}: - passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") - return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() - return t - -def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: - t32 = t.float() - if t32.ndim == 2: - # Matrices get one scale per row, which usually tracks output-channel - # ranges much better than a single tensor-wide scale. - clip_abs = ( - torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) - if t32.numel() - else torch.empty((t32.shape[0],), dtype=torch.float32) - ) - clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) - scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) - q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() - return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() - - # Vectors / scalars use a simpler per-tensor scale. - clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 - scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) - q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() - return q, scale - -def quantize_state_dict_int8(state_dict: dict[str, Tensor]): - # Single supported clean-script export format: - # - per-row int8 for 2D float tensors - # - per-tensor int8 for other float tensors - # - exact passthrough for non-floats - # - passthrough for small float tensors, stored as fp16 to save bytes - quantized: dict[str, Tensor] = {} - scales: dict[str, Tensor] = {} - dtypes: dict[str, str] = {} - passthrough: dict[str, Tensor] = {} - passthrough_orig_dtypes: dict[str, str] = {} - qmeta: dict[str, dict[str, object]] = {} - stats = dict.fromkeys( - ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), - 0, - ) - - for name, tensor in state_dict.items(): - t = tensor.detach().to("cpu").contiguous() - stats["param_count"] += int(t.numel()) - stats["num_tensors"] += 1 - stats["baseline_tensor_bytes"] += tensor_nbytes(t) - - if not t.is_floating_point(): - stats["num_nonfloat_tensors"] += 1 - passthrough[name] = t - stats["int8_payload_bytes"] += tensor_nbytes(t) - continue - - # Small float tensors are cheap enough to keep directly. We still downcast - # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. - if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: - kept = keep_float_tensor(name, t, passthrough_orig_dtypes) - passthrough[name] = kept - stats["int8_payload_bytes"] += tensor_nbytes(kept) - continue - - stats["num_float_tensors"] += 1 - q, s = quantize_float_tensor(t) - if s.ndim > 0: - qmeta[name] = {"scheme": "per_row", "axis": 0} - quantized[name] = q - scales[name] = s - dtypes[name] = str(t.dtype).removeprefix("torch.") - stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) - - obj: dict[str, object] = { - "__quant_format__": "int8_clean_per_row_v1", - "quantized": quantized, - "scales": scales, - "dtypes": dtypes, - "passthrough": passthrough, - } - if qmeta: - obj["qmeta"] = qmeta - if passthrough_orig_dtypes: - obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes - return obj, stats - -def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: - out: dict[str, Tensor] = {} - qmeta = obj.get("qmeta", {}) - passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) - for name, q in obj["quantized"].items(): - dtype = getattr(torch, obj["dtypes"][name]) - s = obj["scales"][name] - if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: - s = s.to(dtype=torch.float32) - # Broadcast the saved row scale back across trailing dimensions. - out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() - else: - scale = float(s.item()) - out[name] = (q.float() * scale).to(dtype=dtype).contiguous() - for name, t in obj["passthrough"].items(): - # Restore small tensors, undoing the temporary fp16 storage cast if needed. - out_t = t.detach().to("cpu").contiguous() - orig_dtype = passthrough_orig_dtypes.get(name) - if isinstance(orig_dtype, str): - out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() - out[name] = out_t - return out - - -# ----------------------------- -# DATA LOADING -# ----------------------------- - -def load_data_shard(file: Path) -> Tensor: - header_bytes = 256 * np.dtype(" None: - self.file_idx = (self.file_idx + 1) % len(self.files) - self.tokens = load_data_shard(self.files[self.file_idx]) - self.pos = 0 - - def take(self, n: int) -> Tensor: - chunks: list[Tensor] = [] - remaining = n - while remaining > 0: - avail = self.tokens.numel() - self.pos - if avail <= 0: - self._advance_file() - continue - k = min(remaining, avail) - chunks.append(self.tokens[self.pos : self.pos + k]) - self.pos += k - remaining -= k - return chunks[0] if len(chunks) == 1 else torch.cat(chunks) - - -class DistributedTokenLoader: - # Each call consumes a contiguous chunk from the shared token stream, then slices out - # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. - def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): - self.rank = rank - self.world_size = world_size - self.device = device - self.stream = TokenStream(pattern) - - def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: - local_tokens = global_tokens // (self.world_size * grad_accum_steps) - per_rank_span = local_tokens + 1 - chunk = self.stream.take(per_rank_span * self.world_size) - start = self.rank * per_rank_span - local = chunk[start : start + per_rank_span].to(dtype=torch.int64) - x = local[:-1].reshape(-1, seq_len) - y = local[1:].reshape(-1, seq_len) - return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) - -# ----------------------------- -# TRANSFORMER MODULES -# ----------------------------- - -class RMSNorm(nn.Module): - def __init__(self, eps: float | None = None): - super().__init__() - self.eps = eps - - def forward(self, x: Tensor) -> Tensor: - return F.rms_norm(x, (x.size(-1),), eps=self.eps) - - -_QAT_ENABLED = False # Set by training loop during QAT phase - -def _fake_quantize_int8_ste(w: Tensor) -> Tensor: - """Simulate per-row int8 quantization with straight-through estimator.""" - with torch.no_grad(): - w32 = w.float() - if w32.ndim == 2: - amax = w32.abs().amax(dim=1, keepdim=True).clamp_min(1e-7) - scale = amax / 127.0 - q = torch.clamp(torch.round(w32 / scale), -127, 127) - w_q = (q * scale).to(w.dtype) - else: - amax = w32.abs().amax().clamp_min(1e-7) - scale = amax / 127.0 - q = torch.clamp(torch.round(w32 / scale), -127, 127) - w_q = (q * scale).to(w.dtype) - # STE: forward uses quantized weights, backward uses original gradients - return w + (w_q - w).detach() - - -class CastedLinear(nn.Linear): - # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. - def forward(self, x: Tensor) -> Tensor: - w = self.weight.to(x.dtype) - if _QAT_ENABLED and self.weight.numel() > INT8_KEEP_FLOAT_MAX_NUMEL: - w = _fake_quantize_int8_ste(w) - bias = self.bias.to(x.dtype) if self.bias is not None else None - return F.linear(x, w, bias) - - -def restore_low_dim_params_to_fp32(module: nn.Module) -> None: - # Keep small/control parameters in fp32 even when the model body runs in bf16. - with torch.no_grad(): - for name, param in module.named_parameters(): - if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: - param.data = param.data.float() - - -class Rotary(nn.Module): - # Caches cos/sin tables per sequence length on the current device. - def __init__(self, dim: int, base: float = 10000.0): - super().__init__() - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self._seq_len_cached = 0 - self._cos_cached: Tensor | None = None - self._sin_cached: Tensor | None = None - - def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: - if ( - self._cos_cached is None - or self._sin_cached is None - or self._seq_len_cached != seq_len - or self._cos_cached.device != device - ): - t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) - freqs = torch.outer(t, self.inv_freq.to(device)) - self._cos_cached = freqs.cos()[None, None, :, :] - self._sin_cached = freqs.sin()[None, None, :, :] - self._seq_len_cached = seq_len - return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) - - -def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: - half = x.size(-1) // 2 - x1, x2 = x[..., :half], x[..., half:] - return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) - - -class CausalSelfAttention(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - rope_base: float, - qk_gain_init: float, - ): - super().__init__() - if dim % num_heads != 0: - raise ValueError("model_dim must be divisible by num_heads") - if num_heads % num_kv_heads != 0: - raise ValueError("num_heads must be divisible by num_kv_heads") - self.num_heads = num_heads - self.num_kv_heads = num_kv_heads - self.head_dim = dim // num_heads - if self.head_dim % 2 != 0: - raise ValueError("head_dim must be even for RoPE") - kv_dim = self.num_kv_heads * self.head_dim - self.c_q = CastedLinear(dim, dim, bias=False) - self.c_k = CastedLinear(dim, kv_dim, bias=False) - self.c_v = CastedLinear(dim, kv_dim, bias=False) - self.proj = CastedLinear(dim, dim, bias=False) - self.proj._zero_init = True - self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) - self.rotary = Rotary(self.head_dim, base=rope_base) - - def forward(self, x: Tensor) -> Tensor: - bsz, seqlen, dim = x.shape - q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) - k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) - v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) - q = F.rms_norm(q, (q.size(-1),)) - k = F.rms_norm(k, (k.size(-1),)) - cos, sin = self.rotary(seqlen, x.device, q.dtype) - q = apply_rotary_emb(q, cos, sin) - k = apply_rotary_emb(k, cos, sin) - q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] - y = F.scaled_dot_product_attention( - q, - k, - v, - attn_mask=None, - is_causal=True, - enable_gqa=(self.num_kv_heads != self.num_heads), - ) - y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) - return self.proj(y) - - -class MLP(nn.Module): - # relu^2 MLP from the original modded-nanogpt setup - def __init__(self, dim: int, mlp_mult: int): - super().__init__() - hidden = mlp_mult * dim - self.fc = CastedLinear(dim, hidden, bias=False) - self.proj = CastedLinear(hidden, dim, bias=False) - self.proj._zero_init = True - - def forward(self, x: Tensor) -> Tensor: - x = torch.relu(self.fc(x)) - return self.proj(x.square()) - - -class SwiGLUMLP(nn.Module): - # SwiGLU MLP (LLaMA-style). Parameter-matched to relu^2 MLP: - # relu2 has 2*dim*hidden params; swiglu has 3*dim*hidden, so hidden = 2/3 * mlp_mult * dim - def __init__(self, dim: int, mlp_mult: int): - super().__init__() - hidden = int(2 * mlp_mult * dim / 3) - # Round to nearest multiple of 8 for efficiency - hidden = ((hidden + 7) // 8) * 8 - self.gate = CastedLinear(dim, hidden, bias=False) - self.up = CastedLinear(dim, hidden, bias=False) - self.proj = CastedLinear(hidden, dim, bias=False) - self.proj._zero_init = True - - def forward(self, x: Tensor) -> Tensor: - return self.proj(F.silu(self.gate(x)) * self.up(x)) - - -class Block(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: int, - mlp_type: str, - rope_base: float, - qk_gain_init: float, - ): - super().__init__() - self.attn_norm = RMSNorm() - self.mlp_norm = RMSNorm() - self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) - self.mlp = SwiGLUMLP(dim, mlp_mult) if mlp_type == "swiglu" else MLP(dim, mlp_mult) - self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) - - def forward(self, x: Tensor, x0: Tensor) -> Tensor: - mix = self.resid_mix.to(dtype=x.dtype) - x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 - attn_out = self.attn(self.attn_norm(x)) - x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out - x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) - return x - - -class GPT(nn.Module): - def __init__( - self, - vocab_size: int, - num_layers: int, - model_dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: int, - mlp_type: str, - tie_embeddings: bool, - tied_embed_init_std: float, - logit_softcap: float, - rope_base: float, - qk_gain_init: float, - ): - super().__init__() - if logit_softcap <= 0.0: - raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") - self.tie_embeddings = tie_embeddings - self.tied_embed_init_std = tied_embed_init_std - self.logit_softcap = logit_softcap - self.tok_emb = nn.Embedding(vocab_size, model_dim) - self.num_encoder_layers = num_layers // 2 - self.num_decoder_layers = num_layers - self.num_encoder_layers - self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) - self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) - self.blocks = nn.ModuleList( - [ - Block( - model_dim, - num_heads, - num_kv_heads, - mlp_mult, - mlp_type, - rope_base, - qk_gain_init, - ) - for i in range(num_layers) - ] - ) - self.final_norm = RMSNorm() - self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) - if self.lm_head is not None: - self.lm_head._zero_init = True - self._init_weights() - - def _init_weights(self) -> None: - if self.tie_embeddings: - nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) - for module in self.modules(): - if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): - nn.init.zeros_(module.weight) - - def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: - x = self.tok_emb(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - x0 = x - skips: list[Tensor] = [] - - # First half stores skips; second half reuses them in reverse order. - for i in range(self.num_encoder_layers): - x = self.blocks[i](x, x0) - skips.append(x) - for i in range(self.num_decoder_layers): - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - x = self.blocks[self.num_encoder_layers + i](x, x0) - - x = self.final_norm(x).reshape(-1, x.size(-1)) - targets = target_ids.reshape(-1) - if self.tie_embeddings: - logits_proj = F.linear(x, self.tok_emb.weight) - else: - if self.lm_head is None: - raise RuntimeError("lm_head is required when tie_embeddings=False") - logits_proj = self.lm_head(x) - logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - return F.cross_entropy(logits.float(), targets, reduction="mean") - - -# ----------------------------- -# TRAINING -# ----------------------------- - -def main() -> None: - global zeropower_via_newtonschulz5 - - code = Path(__file__).read_text(encoding="utf-8") - args = Hyperparameters() - zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) - - # ----------------------------- - # DISTRIBUTED + CUDA SETUP - # ----------------------------- - - distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ - rank = int(os.environ.get("RANK", "0")) - world_size = int(os.environ.get("WORLD_SIZE", "1")) - local_rank = int(os.environ.get("LOCAL_RANK", "0")) - if world_size <= 0: - raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") - if 8 % world_size != 0: - raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") - grad_accum_steps = 8 // world_size - grad_scale = 1.0 / grad_accum_steps - if not torch.cuda.is_available(): - raise RuntimeError("CUDA is required") - device = torch.device("cuda", local_rank) - torch.cuda.set_device(device) - if distributed: - dist.init_process_group(backend="nccl", device_id=device) - dist.barrier() - master_process = rank == 0 - - # Fast math knobs - torch.backends.cuda.matmul.allow_tf32 = True - torch.backends.cudnn.allow_tf32 = True - from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp - - enable_cudnn_sdp(False) - enable_flash_sdp(True) - enable_mem_efficient_sdp(False) - enable_math_sdp(False) - - logfile = None - if master_process: - os.makedirs("logs", exist_ok=True) - logfile = f"logs/{args.run_id}.txt" - print(logfile) - - def log0(msg: str, console: bool = True) -> None: - if not master_process: - return - if console: - print(msg) - if logfile is not None: - with open(logfile, "a", encoding="utf-8") as f: - print(msg, file=f) - - log0(code, console=False) - log0("=" * 100, console=False) - log0(f"Running Python {sys.version}", console=False) - log0(f"Running PyTorch {torch.__version__}", console=False) - log0( - subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, - console=False, - ) - log0("=" * 100, console=False) - - # ----------------------------- - # TOKENIZER + VALIDATION METRIC SETUP - # ----------------------------- - - random.seed(args.seed) - np.random.seed(args.seed) - torch.manual_seed(args.seed) - torch.cuda.manual_seed_all(args.seed) - - if not args.tokenizer_path.endswith(".model"): - raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") - sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) - if int(sp.vocab_size()) != args.vocab_size: - raise ValueError( - f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" - ) - dataset_dir = Path(args.data_path).resolve() - actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) - val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) - base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( - sp, args.vocab_size, device - ) - log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") - log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") - log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") - - # ----------------------------- - # MODEL + OPTIMIZER SETUP - # ----------------------------- - - base_model = GPT( - vocab_size=args.vocab_size, - num_layers=args.num_layers, - model_dim=args.model_dim, - num_heads=args.num_heads, - num_kv_heads=args.num_kv_heads, - mlp_mult=args.mlp_mult, - mlp_type=args.mlp_type, - tie_embeddings=args.tie_embeddings, - tied_embed_init_std=args.tied_embed_init_std, - logit_softcap=args.logit_softcap, - rope_base=args.rope_base, - qk_gain_init=args.qk_gain_init, - ).to(device).bfloat16() - for module in base_model.modules(): - if isinstance(module, CastedLinear): - module.float() - restore_low_dim_params_to_fp32(base_model) - compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) - model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model - - # Optimizer split: - # - token embedding (Adam) uses EMBED_LR - # - untied lm_head (Adam) uses HEAD_LR - # - matrix params in transformer blocks use MATRIX_LR via Muon - # - vectors/scalars use SCALAR_LR via Adam - block_named_params = list(base_model.blocks.named_parameters()) - matrix_params = [ - p - for name, p in block_named_params - if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - scalar_params = [ - p - for name, p in block_named_params - if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - if base_model.skip_weights.numel() > 0: - scalar_params.append(base_model.skip_weights) - token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr - optimizer_tok = torch.optim.Adam( - [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizer_muon = Muon( - matrix_params, - lr=args.matrix_lr, - momentum=args.muon_momentum, - backend_steps=args.muon_backend_steps, - ) - for group in optimizer_muon.param_groups: - group["base_lr"] = args.matrix_lr - optimizer_scalar = torch.optim.Adam( - [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] - if base_model.lm_head is not None: - optimizer_head = torch.optim.Adam( - [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizers.insert(1, optimizer_head) - - n_params = sum(p.numel() for p in base_model.parameters()) - log0(f"model_params:{n_params}") - log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") - log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") - log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") - log0( - f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " - f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " - f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" - ) - log0( - f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " - f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " - f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" - ) - log0(f"seed:{args.seed}") - - # ----------------------------- - # DATA LOADER & MODEL WARMUP - # ----------------------------- - - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - def zero_grad_all() -> None: - for opt in optimizers: - opt.zero_grad(set_to_none=True) - - max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None - - def lr_mul(step: int, elapsed_ms: float) -> float: - # LR warmup: linear ramp from 0 to 1 over lr_warmup_steps - warmup_mul = min(step / max(args.lr_warmup_steps, 1), 1.0) if args.lr_warmup_steps > 0 else 1.0 - - if args.lr_schedule == "cosine": - # Cosine decay: estimate total steps from wallclock, decay over full run - if max_wallclock_ms is not None and step > 0: - step_ms = elapsed_ms / step - estimated_total = int(max_wallclock_ms / step_ms) - else: - estimated_total = args.iterations - progress = min(step / max(estimated_total, 1), 1.0) - decay_mul = 0.5 * (1.0 + math.cos(math.pi * progress)) - return warmup_mul * max(decay_mul, 0.0) - - # Original linear warmdown schedule - if args.warmdown_iters <= 0: - return warmup_mul - if max_wallclock_ms is None: - warmdown_start = max(args.iterations - args.warmdown_iters, 0) - decay_mul = max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 - return warmup_mul * decay_mul - step_ms = elapsed_ms / max(step, 1) - warmdown_ms = args.warmdown_iters * step_ms - remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) - decay_mul = remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 - return warmup_mul * decay_mul - - # Warmup primes the compiled forward/backward/optimizer paths, then we restore the - # initial weights/optimizer state so measured training starts from the true init. - if args.warmup_steps > 0: - initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} - initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] - model.train() - for warmup_step in range(args.warmup_steps): - zero_grad_all() - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - warmup_loss = model(x, y) - (warmup_loss * grad_scale).backward() - for opt in optimizers: - opt.step() - zero_grad_all() - if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: - log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") - base_model.load_state_dict(initial_model_state, strict=True) - for opt, state in zip(optimizers, initial_optimizer_states, strict=True): - opt.load_state_dict(state) - zero_grad_all() - if distributed: - model.require_backward_grad_sync = True - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - # ----------------------------- - # MAIN TRAINING LOOP - # ----------------------------- - - training_time_ms = 0.0 - stop_after_step: int | None = None - torch.cuda.synchronize() - t0 = time.perf_counter() - - step = 0 - while True: - last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) - - should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) - if should_validate: - torch.cuda.synchronize() - training_time_ms += 1000.0 * (time.perf_counter() - t0) - val_loss, val_bpb = eval_val( - args, - model, - rank, - world_size, - device, - grad_accum_steps, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, - ) - log0( - f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " - f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" - ) - torch.cuda.synchronize() - t0 = time.perf_counter() - - if last_step: - if stop_after_step is not None and step < args.iterations: - log0( - f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " - f"step:{step}/{args.iterations}" - ) - break - - elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - # Enable QAT for the last fraction of training - global _QAT_ENABLED - if args.qat_start_frac > 0 and max_wallclock_ms is not None: - _QAT_ENABLED = elapsed_ms >= args.qat_start_frac * max_wallclock_ms - scale = lr_mul(step, elapsed_ms) - zero_grad_all() - train_loss = torch.zeros((), device=device) - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - loss = model(x, y) - train_loss += loss.detach() - (loss * grad_scale).backward() - train_loss /= grad_accum_steps - - frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 - muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum - for group in optimizer_muon.param_groups: - group["momentum"] = muon_momentum - - for opt in optimizers: - for group in opt.param_groups: - group["lr"] = group["base_lr"] * scale - - if args.grad_clip_norm > 0: - torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) - for opt in optimizers: - opt.step() - zero_grad_all() - - step += 1 - approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - should_log_train = ( - args.train_log_every > 0 - and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) - ) - if should_log_train: - log0( - f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " - f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" - ) - - # Needed to sync whether we've reached the wallclock cap. - reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms - if distributed and max_wallclock_ms is not None: - reached_cap_tensor = torch.tensor(int(reached_cap), device=device) - dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) - reached_cap = bool(reached_cap_tensor.item()) - if stop_after_step is None and reached_cap: - stop_after_step = step - - log0( - f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " - f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" - ) - - # ----------------------------- - # SERIALIZATION + ROUNDTRIP VALIDATION - # ----------------------------- - # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce - # the compressed int8+zlib artifact and validate the round-tripped weights. - - if master_process: - torch.save(base_model.state_dict(), "final_model.pt") - model_bytes = os.path.getsize("final_model.pt") - code_bytes = len(code.encode("utf-8")) - log0(f"Serialized model: {model_bytes} bytes") - log0(f"Code size: {code_bytes} bytes") - log0(f"Total submission size: {model_bytes + code_bytes} bytes") - - quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) - quant_buf = io.BytesIO() - torch.save(quant_obj, quant_buf) - quant_raw = quant_buf.getvalue() - quant_blob = zlib.compress(quant_raw, level=9) - quant_raw_bytes = len(quant_raw) - if master_process: - with open("final_model.int8.ptz", "wb") as f: - f.write(quant_blob) - quant_file_bytes = os.path.getsize("final_model.int8.ptz") - code_bytes = len(code.encode("utf-8")) - ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) - log0( - f"Serialized model int8+zlib: {quant_file_bytes} bytes " - f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" - ) - log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") - - if distributed: - dist.barrier() - with open("final_model.int8.ptz", "rb") as f: - quant_blob_disk = f.read() - quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") - base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) - torch.cuda.synchronize() - t_qeval = time.perf_counter() - q_val_loss, q_val_bpb = eval_val( - args, - model, - rank, - world_size, - device, - grad_accum_steps, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, - ) - torch.cuda.synchronize() - log0( - f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " - f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" - ) - log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") - - if distributed: - dist.destroy_process_group() - - -if __name__ == "__main__": - main() - -==================================================================================================== -Running Python 3.11.9 (main, Nov 10 2025, 02:08:09) [GCC 11.4.0] -Running PyTorch 2.8.0+cu128 -Wed Mar 18 21:58:25 2026 -+-----------------------------------------------------------------------------------------+ -| NVIDIA-SMI 580.95.05 Driver Version: 580.95.05 CUDA Version: 13.0 | -+-----------------------------------------+------------------------+----------------------+ -| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | -| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | -| | | MIG M. | -|=========================================+========================+======================| -| 0 NVIDIA H200 Off | 00000002:00:01.0 Off | 0 | -| N/A 34C P0 123W / 700W | 1516MiB / 143771MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 1 NVIDIA H200 Off | 00000002:00:02.0 Off | 0 | -| N/A 38C P0 121W / 700W | 1516MiB / 143771MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 2 NVIDIA H200 Off | 00000002:00:03.0 Off | 0 | -| N/A 39C P0 124W / 700W | 1516MiB / 143771MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 3 NVIDIA H200 Off | 00000002:00:04.0 Off | 0 | -| N/A 34C P0 122W / 700W | 1516MiB / 143771MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 4 NVIDIA H200 Off | 00000003:00:01.0 Off | 0 | -| N/A 34C P0 122W / 700W | 1516MiB / 143771MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 5 NVIDIA H200 Off | 00000003:00:02.0 Off | 0 | -| N/A 41C P0 123W / 700W | 1516MiB / 143771MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 6 NVIDIA H200 Off | 00000003:00:03.0 Off | 0 | -| N/A 38C P0 119W / 700W | 1516MiB / 143771MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 7 NVIDIA H200 Off | 00000003:00:04.0 Off | 0 | -| N/A 33C P0 120W / 700W | 1516MiB / 143771MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ - -+-----------------------------------------------------------------------------------------+ -| Processes: | -| GPU GI CI PID Type Process name GPU Memory | -| ID ID Usage | -|=========================================================================================| -| 0 N/A N/A 3930892 C ...nv/versions/3.11.9/bin/python 1506MiB | -| 1 N/A N/A 3930893 C ...nv/versions/3.11.9/bin/python 1506MiB | -| 2 N/A N/A 3930894 C ...nv/versions/3.11.9/bin/python 1506MiB | -| 3 N/A N/A 3930895 C ...nv/versions/3.11.9/bin/python 1506MiB | -| 4 N/A N/A 3930896 C ...nv/versions/3.11.9/bin/python 1506MiB | -| 5 N/A N/A 3930897 C ...nv/versions/3.11.9/bin/python 1506MiB | -| 6 N/A N/A 3930898 C ...nv/versions/3.11.9/bin/python 1506MiB | -| 7 N/A N/A 3930899 C ...nv/versions/3.11.9/bin/python 1506MiB | -+-----------------------------------------------------------------------------------------+ - -==================================================================================================== -val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model -train_loader:dataset:fineweb10B_sp1024 train_shards:180 -val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 -model_params:17059912 -world_size:8 grad_accum_steps:1 -sdp_backends:cudnn=False flash=True mem_efficient=False math=False -attention_mode:gqa num_heads:8 num_kv_heads:4 -tie_embeddings:True embed_lr:0.03 head_lr:0.0 matrix_lr:0.02 scalar_lr:0.02 -train_batch_tokens:524288 train_seq_len:1024 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 -seed:1337 -warmup_step:1/20 -warmup_step:2/20 -warmup_step:3/20 -warmup_step:4/20 -warmup_step:5/20 -warmup_step:6/20 -warmup_step:7/20 -warmup_step:8/20 -warmup_step:9/20 -warmup_step:10/20 -warmup_step:11/20 -warmup_step:12/20 -warmup_step:13/20 -warmup_step:14/20 -warmup_step:15/20 -warmup_step:16/20 -warmup_step:17/20 -warmup_step:18/20 -warmup_step:19/20 -warmup_step:20/20 -step:0/20000 val_loss:6.9357 val_bpb:4.1077 train_time:0ms step_avg:0.02ms -step:1/20000 train_loss:6.9370 train_time:27ms step_avg:26.75ms -step:2/20000 train_loss:6.9395 train_time:67ms step_avg:33.75ms -step:3/20000 train_loss:6.9196 train_time:108ms step_avg:36.00ms -step:4/20000 train_loss:6.8807 train_time:158ms step_avg:39.57ms -step:5/20000 train_loss:6.8231 train_time:200ms step_avg:40.03ms -step:6/20000 train_loss:6.7488 train_time:242ms step_avg:40.37ms -step:7/20000 train_loss:6.6228 train_time:285ms step_avg:40.68ms -step:8/20000 train_loss:6.4791 train_time:317ms step_avg:39.57ms -step:9/20000 train_loss:6.3144 train_time:358ms step_avg:39.73ms -step:10/20000 train_loss:6.1475 train_time:399ms step_avg:39.88ms -step:50/20000 train_loss:5.1576 train_time:2023ms step_avg:40.47ms -step:100/20000 train_loss:4.0009 train_time:4063ms step_avg:40.63ms -step:150/20000 train_loss:3.3581 train_time:6099ms step_avg:40.66ms -step:200/20000 train_loss:2.9814 train_time:8294ms step_avg:41.47ms -step:200/20000 val_loss:2.9791 val_bpb:1.7644 train_time:8306ms step_avg:41.53ms -step:250/20000 train_loss:2.8186 train_time:10326ms step_avg:41.30ms -step:300/20000 train_loss:2.5205 train_time:12366ms step_avg:41.22ms -step:350/20000 train_loss:2.7077 train_time:14401ms step_avg:41.15ms -step:400/20000 train_loss:2.3745 train_time:16576ms step_avg:41.44ms -step:400/20000 val_loss:2.5855 val_bpb:1.5313 train_time:16595ms step_avg:41.49ms -step:450/20000 train_loss:2.5282 train_time:18609ms step_avg:41.35ms -step:500/20000 train_loss:2.5144 train_time:20646ms step_avg:41.29ms -step:550/20000 train_loss:2.4087 train_time:22687ms step_avg:41.25ms -step:600/20000 train_loss:2.5538 train_time:24921ms step_avg:41.53ms -step:600/20000 val_loss:2.4592 val_bpb:1.4565 train_time:24937ms step_avg:41.56ms -step:650/20000 train_loss:2.3962 train_time:26951ms step_avg:41.46ms -step:700/20000 train_loss:2.4482 train_time:28984ms step_avg:41.41ms -step:750/20000 train_loss:2.2825 train_time:31030ms step_avg:41.37ms -step:800/20000 train_loss:2.3037 train_time:33204ms step_avg:41.50ms -step:800/20000 val_loss:2.3866 val_bpb:1.4135 train_time:33225ms step_avg:41.53ms -step:850/20000 train_loss:2.7243 train_time:35241ms step_avg:41.46ms -step:900/20000 train_loss:2.3437 train_time:37284ms step_avg:41.43ms -step:950/20000 train_loss:2.4026 train_time:39326ms step_avg:41.40ms -step:1000/20000 train_loss:2.3816 train_time:41503ms step_avg:41.50ms -step:1000/20000 val_loss:2.3386 val_bpb:1.3851 train_time:41516ms step_avg:41.52ms -step:1050/20000 train_loss:2.4882 train_time:43527ms step_avg:41.45ms -step:1100/20000 train_loss:2.2661 train_time:45563ms step_avg:41.42ms -step:1150/20000 train_loss:2.2548 train_time:47747ms step_avg:41.52ms -step:1200/20000 train_loss:2.3863 train_time:49794ms step_avg:41.50ms -step:1200/20000 val_loss:2.3043 val_bpb:1.3648 train_time:49797ms step_avg:41.50ms -step:1250/20000 train_loss:2.2116 train_time:51819ms step_avg:41.46ms -step:1300/20000 train_loss:2.3616 train_time:53857ms step_avg:41.43ms -step:1350/20000 train_loss:2.2752 train_time:56018ms step_avg:41.49ms -step:1400/20000 train_loss:2.4304 train_time:58072ms step_avg:41.48ms -step:1400/20000 val_loss:2.2823 val_bpb:1.3517 train_time:58082ms step_avg:41.49ms -step:1450/20000 train_loss:2.2339 train_time:60097ms step_avg:41.45ms -step:1500/20000 train_loss:2.2223 train_time:62137ms step_avg:41.42ms -step:1550/20000 train_loss:2.1548 train_time:64312ms step_avg:41.49ms -step:1600/20000 train_loss:2.0963 train_time:66353ms step_avg:41.47ms -step:1600/20000 val_loss:2.2670 val_bpb:1.3426 train_time:66361ms step_avg:41.48ms -step:1650/20000 train_loss:2.2300 train_time:68378ms step_avg:41.44ms -step:1700/20000 train_loss:2.1717 train_time:70418ms step_avg:41.42ms -step:1750/20000 train_loss:2.2492 train_time:72603ms step_avg:41.49ms -step:1800/20000 train_loss:2.1983 train_time:74635ms step_avg:41.46ms -step:1800/20000 val_loss:2.2502 val_bpb:1.3327 train_time:74657ms step_avg:41.48ms -step:1850/20000 train_loss:2.3093 train_time:76681ms step_avg:41.45ms -step:1900/20000 train_loss:2.1858 train_time:78723ms step_avg:41.43ms -step:1950/20000 train_loss:2.2117 train_time:80921ms step_avg:41.50ms -step:2000/20000 train_loss:2.2537 train_time:82947ms step_avg:41.47ms -step:2000/20000 val_loss:2.2354 val_bpb:1.3239 train_time:82970ms step_avg:41.48ms -step:2050/20000 train_loss:2.2511 train_time:84982ms step_avg:41.45ms -step:2100/20000 train_loss:2.2674 train_time:87184ms step_avg:41.52ms -step:2150/20000 train_loss:2.1881 train_time:89227ms step_avg:41.50ms -step:2200/20000 train_loss:2.0740 train_time:91265ms step_avg:41.48ms -step:2200/20000 val_loss:2.2266 val_bpb:1.3187 train_time:91285ms step_avg:41.49ms -step:2250/20000 train_loss:2.1619 train_time:93305ms step_avg:41.47ms -step:2300/20000 train_loss:2.3779 train_time:95499ms step_avg:41.52ms -step:2350/20000 train_loss:2.2010 train_time:97535ms step_avg:41.50ms -step:2400/20000 train_loss:2.2015 train_time:99573ms step_avg:41.49ms -step:2400/20000 val_loss:2.2165 val_bpb:1.3127 train_time:99594ms step_avg:41.50ms -step:2450/20000 train_loss:2.2036 train_time:101613ms step_avg:41.47ms -step:2500/20000 train_loss:2.1252 train_time:103791ms step_avg:41.52ms -step:2550/20000 train_loss:2.1366 train_time:105846ms step_avg:41.51ms -step:2600/20000 train_loss:2.4066 train_time:107889ms step_avg:41.50ms -step:2600/20000 val_loss:2.2161 val_bpb:1.3125 train_time:107909ms step_avg:41.50ms -step:2650/20000 train_loss:2.2435 train_time:109929ms step_avg:41.48ms -step:2700/20000 train_loss:2.1554 train_time:112113ms step_avg:41.52ms -step:2750/20000 train_loss:2.3645 train_time:114133ms step_avg:41.50ms -step:2800/20000 train_loss:2.2346 train_time:116174ms step_avg:41.49ms -step:2800/20000 val_loss:2.2012 val_bpb:1.3037 train_time:116194ms step_avg:41.50ms -step:2850/20000 train_loss:2.1841 train_time:118209ms step_avg:41.48ms -step:2900/20000 train_loss:2.1789 train_time:120386ms step_avg:41.51ms -step:2950/20000 train_loss:2.2378 train_time:122425ms step_avg:41.50ms -step:3000/20000 train_loss:2.2245 train_time:124465ms step_avg:41.49ms -step:3000/20000 val_loss:2.1939 val_bpb:1.2994 train_time:124485ms step_avg:41.49ms -step:3050/20000 train_loss:2.1679 train_time:126506ms step_avg:41.48ms -step:3100/20000 train_loss:2.2084 train_time:128674ms step_avg:41.51ms -step:3150/20000 train_loss:2.1561 train_time:130715ms step_avg:41.50ms -step:3200/20000 train_loss:2.1857 train_time:132756ms step_avg:41.49ms -step:3200/20000 val_loss:2.1893 val_bpb:1.2966 train_time:132776ms step_avg:41.49ms -step:3250/20000 train_loss:2.0908 train_time:134931ms step_avg:41.52ms -step:3300/20000 train_loss:2.2346 train_time:136978ms step_avg:41.51ms -step:3350/20000 train_loss:2.0918 train_time:139011ms step_avg:41.50ms -step:3400/20000 train_loss:2.1562 train_time:141051ms step_avg:41.49ms -step:3400/20000 val_loss:2.1852 val_bpb:1.2942 train_time:141073ms step_avg:41.49ms -step:3450/20000 train_loss:2.1047 train_time:143351ms step_avg:41.55ms -step:3500/20000 train_loss:2.2490 train_time:145398ms step_avg:41.54ms -step:3550/20000 train_loss:2.3891 train_time:147432ms step_avg:41.53ms -step:3600/20000 train_loss:2.1173 train_time:149471ms step_avg:41.52ms -step:3600/20000 val_loss:2.1779 val_bpb:1.2899 train_time:149492ms step_avg:41.53ms -step:3650/20000 train_loss:2.2191 train_time:151652ms step_avg:41.55ms -step:3700/20000 train_loss:2.1520 train_time:153697ms step_avg:41.54ms -step:3750/20000 train_loss:2.1464 train_time:155743ms step_avg:41.53ms -step:3800/20000 train_loss:2.2206 train_time:157776ms step_avg:41.52ms -step:3800/20000 val_loss:2.1740 val_bpb:1.2876 train_time:157798ms step_avg:41.53ms -step:3850/20000 train_loss:2.1749 train_time:159987ms step_avg:41.56ms -step:3900/20000 train_loss:1.9899 train_time:162037ms step_avg:41.55ms -step:3950/20000 train_loss:2.1256 train_time:164070ms step_avg:41.54ms -step:4000/20000 train_loss:2.1611 train_time:166109ms step_avg:41.53ms -step:4000/20000 val_loss:2.1691 val_bpb:1.2847 train_time:166123ms step_avg:41.53ms -step:4050/20000 train_loss:2.0968 train_time:168299ms step_avg:41.56ms -step:4100/20000 train_loss:2.1860 train_time:170355ms step_avg:41.55ms -step:4150/20000 train_loss:2.3258 train_time:172391ms step_avg:41.54ms -step:4200/20000 train_loss:2.1748 train_time:174581ms step_avg:41.57ms -step:4200/20000 val_loss:2.1651 val_bpb:1.2823 train_time:174599ms step_avg:41.57ms -step:4250/20000 train_loss:2.1285 train_time:176628ms step_avg:41.56ms -step:4300/20000 train_loss:2.0212 train_time:178664ms step_avg:41.55ms -step:4350/20000 train_loss:2.2053 train_time:180796ms step_avg:41.56ms -step:4400/20000 train_loss:2.1137 train_time:182978ms step_avg:41.59ms -step:4400/20000 val_loss:2.1651 val_bpb:1.2823 train_time:183003ms step_avg:41.59ms -step:4450/20000 train_loss:2.0669 train_time:185022ms step_avg:41.58ms -step:4500/20000 train_loss:2.2530 train_time:187065ms step_avg:41.57ms -step:4550/20000 train_loss:2.0511 train_time:189106ms step_avg:41.56ms -step:4600/20000 train_loss:1.9698 train_time:191295ms step_avg:41.59ms -step:4600/20000 val_loss:2.1604 val_bpb:1.2795 train_time:191314ms step_avg:41.59ms -step:4650/20000 train_loss:2.0762 train_time:193332ms step_avg:41.58ms -step:4700/20000 train_loss:2.2670 train_time:195381ms step_avg:41.57ms -step:4750/20000 train_loss:1.9802 train_time:197412ms step_avg:41.56ms -step:4800/20000 train_loss:2.2632 train_time:199621ms step_avg:41.59ms -step:4800/20000 val_loss:2.1566 val_bpb:1.2773 train_time:199631ms step_avg:41.59ms -step:4850/20000 train_loss:2.1477 train_time:201656ms step_avg:41.58ms -step:4900/20000 train_loss:2.1638 train_time:203691ms step_avg:41.57ms -step:4950/20000 train_loss:2.3388 train_time:205733ms step_avg:41.56ms -step:5000/20000 train_loss:2.0292 train_time:207937ms step_avg:41.59ms -step:5000/20000 val_loss:2.1522 val_bpb:1.2746 train_time:207949ms step_avg:41.59ms -step:5050/20000 train_loss:2.2033 train_time:209966ms step_avg:41.58ms -step:5100/20000 train_loss:2.0231 train_time:212016ms step_avg:41.57ms -step:5150/20000 train_loss:2.2743 train_time:214202ms step_avg:41.59ms -step:5200/20000 train_loss:2.1705 train_time:216243ms step_avg:41.59ms -step:5200/20000 val_loss:2.1518 val_bpb:1.2744 train_time:216266ms step_avg:41.59ms -step:5250/20000 train_loss:2.1210 train_time:218283ms step_avg:41.58ms -step:5300/20000 train_loss:2.2122 train_time:220327ms step_avg:41.57ms -step:5350/20000 train_loss:2.1381 train_time:222499ms step_avg:41.59ms -step:5400/20000 train_loss:2.1869 train_time:224549ms step_avg:41.58ms -step:5400/20000 val_loss:2.1478 val_bpb:1.2721 train_time:224570ms step_avg:41.59ms -step:5450/20000 train_loss:2.1970 train_time:226597ms step_avg:41.58ms -step:5500/20000 train_loss:2.1369 train_time:228630ms step_avg:41.57ms -step:5550/20000 train_loss:2.1072 train_time:230827ms step_avg:41.59ms -step:5600/20000 train_loss:2.1806 train_time:232867ms step_avg:41.58ms -step:5600/20000 val_loss:2.1472 val_bpb:1.2717 train_time:232888ms step_avg:41.59ms -step:5650/20000 train_loss:2.0527 train_time:234911ms step_avg:41.58ms -step:5700/20000 train_loss:2.1755 train_time:236963ms step_avg:41.57ms -step:5750/20000 train_loss:2.2142 train_time:239137ms step_avg:41.59ms -step:5800/20000 train_loss:2.1436 train_time:241182ms step_avg:41.58ms -step:5800/20000 val_loss:2.1447 val_bpb:1.2702 train_time:241200ms step_avg:41.59ms -step:5850/20000 train_loss:2.1794 train_time:243224ms step_avg:41.58ms -step:5900/20000 train_loss:2.0950 train_time:245271ms step_avg:41.57ms -step:5950/20000 train_loss:2.1320 train_time:247462ms step_avg:41.59ms -step:6000/20000 train_loss:2.2187 train_time:249501ms step_avg:41.58ms -step:6000/20000 val_loss:2.1419 val_bpb:1.2686 train_time:249521ms step_avg:41.59ms -step:6050/20000 train_loss:2.1237 train_time:251549ms step_avg:41.58ms -step:6100/20000 train_loss:2.1189 train_time:253580ms step_avg:41.57ms -step:6150/20000 train_loss:2.1001 train_time:255763ms step_avg:41.59ms -step:6200/20000 train_loss:2.0846 train_time:257796ms step_avg:41.58ms -step:6200/20000 val_loss:2.1402 val_bpb:1.2676 train_time:257819ms step_avg:41.58ms -step:6250/20000 train_loss:2.1530 train_time:259841ms step_avg:41.57ms -step:6300/20000 train_loss:2.0335 train_time:262015ms step_avg:41.59ms -step:6350/20000 train_loss:2.0226 train_time:264055ms step_avg:41.58ms -step:6400/20000 train_loss:2.1600 train_time:266098ms step_avg:41.58ms -step:6400/20000 val_loss:2.1375 val_bpb:1.2660 train_time:266113ms step_avg:41.58ms -step:6450/20000 train_loss:2.0737 train_time:268131ms step_avg:41.57ms -step:6500/20000 train_loss:2.0776 train_time:270315ms step_avg:41.59ms -step:6550/20000 train_loss:2.2121 train_time:272357ms step_avg:41.58ms -step:6600/20000 train_loss:2.1218 train_time:274402ms step_avg:41.58ms -step:6600/20000 val_loss:2.1340 val_bpb:1.2639 train_time:274418ms step_avg:41.58ms -step:6650/20000 train_loss:2.2922 train_time:276435ms step_avg:41.57ms -step:6700/20000 train_loss:2.1544 train_time:278630ms step_avg:41.59ms -step:6750/20000 train_loss:2.3298 train_time:280669ms step_avg:41.58ms -step:6800/20000 train_loss:2.1900 train_time:282710ms step_avg:41.58ms -step:6800/20000 val_loss:2.1321 val_bpb:1.2627 train_time:282732ms step_avg:41.58ms -step:6850/20000 train_loss:2.0224 train_time:284745ms step_avg:41.57ms -step:6900/20000 train_loss:2.0966 train_time:286927ms step_avg:41.58ms -step:6950/20000 train_loss:2.1716 train_time:288970ms step_avg:41.58ms -step:7000/20000 train_loss:2.2194 train_time:291023ms step_avg:41.57ms -step:7000/20000 val_loss:2.1303 val_bpb:1.2617 train_time:291032ms step_avg:41.58ms -step:7050/20000 train_loss:2.2478 train_time:293055ms step_avg:41.57ms -step:7100/20000 train_loss:2.0619 train_time:295238ms step_avg:41.58ms -step:7150/20000 train_loss:2.1470 train_time:297279ms step_avg:41.58ms -step:7200/20000 train_loss:2.1975 train_time:299325ms step_avg:41.57ms -step:7200/20000 val_loss:2.1302 val_bpb:1.2616 train_time:299345ms step_avg:41.58ms -step:7250/20000 train_loss:2.0975 train_time:301520ms step_avg:41.59ms -step:7300/20000 train_loss:2.0807 train_time:303560ms step_avg:41.58ms -step:7350/20000 train_loss:2.1775 train_time:305601ms step_avg:41.58ms -step:7400/20000 train_loss:2.1169 train_time:307642ms step_avg:41.57ms -step:7400/20000 val_loss:2.1272 val_bpb:1.2599 train_time:307663ms step_avg:41.58ms -step:7450/20000 train_loss:2.1137 train_time:309830ms step_avg:41.59ms -step:7500/20000 train_loss:2.1093 train_time:311864ms step_avg:41.58ms -step:7550/20000 train_loss:2.1671 train_time:313901ms step_avg:41.58ms -step:7600/20000 train_loss:1.9977 train_time:315946ms step_avg:41.57ms -step:7600/20000 val_loss:2.1262 val_bpb:1.2593 train_time:315967ms step_avg:41.57ms -step:7650/20000 train_loss:2.2816 train_time:318120ms step_avg:41.58ms -step:7700/20000 train_loss:2.0852 train_time:320179ms step_avg:41.58ms -step:7750/20000 train_loss:2.1093 train_time:322212ms step_avg:41.58ms -step:7800/20000 train_loss:2.1463 train_time:324257ms step_avg:41.57ms -step:7800/20000 val_loss:2.1232 val_bpb:1.2575 train_time:324277ms step_avg:41.57ms -step:7850/20000 train_loss:1.9975 train_time:326453ms step_avg:41.59ms -step:7900/20000 train_loss:2.1335 train_time:328489ms step_avg:41.58ms -step:7950/20000 train_loss:2.0959 train_time:330530ms step_avg:41.58ms -step:8000/20000 train_loss:2.1129 train_time:332566ms step_avg:41.57ms -step:8000/20000 val_loss:2.1209 val_bpb:1.2561 train_time:332587ms step_avg:41.57ms -step:8050/20000 train_loss:2.0770 train_time:334758ms step_avg:41.58ms -step:8100/20000 train_loss:2.1438 train_time:336797ms step_avg:41.58ms -step:8150/20000 train_loss:2.2537 train_time:338838ms step_avg:41.58ms -step:8200/20000 train_loss:2.1884 train_time:340881ms step_avg:41.57ms -step:8200/20000 val_loss:2.1198 val_bpb:1.2555 train_time:340896ms step_avg:41.57ms -step:8250/20000 train_loss:2.1417 train_time:343054ms step_avg:41.58ms -step:8300/20000 train_loss:2.1192 train_time:345096ms step_avg:41.58ms -step:8350/20000 train_loss:2.2272 train_time:347133ms step_avg:41.57ms -step:8400/20000 train_loss:2.1348 train_time:349312ms step_avg:41.58ms -step:8400/20000 val_loss:2.1193 val_bpb:1.2552 train_time:349332ms step_avg:41.59ms -step:8450/20000 train_loss:2.2258 train_time:351351ms step_avg:41.58ms -step:8500/20000 train_loss:2.1283 train_time:353392ms step_avg:41.58ms -step:8550/20000 train_loss:2.1952 train_time:355438ms step_avg:41.57ms -step:8600/20000 train_loss:2.1338 train_time:357628ms step_avg:41.58ms -step:8600/20000 val_loss:2.1167 val_bpb:1.2536 train_time:357649ms step_avg:41.59ms -step:8650/20000 train_loss:2.0988 train_time:359663ms step_avg:41.58ms -step:8700/20000 train_loss:2.0309 train_time:361710ms step_avg:41.58ms -step:8750/20000 train_loss:2.1958 train_time:363752ms step_avg:41.57ms -step:8800/20000 train_loss:2.1006 train_time:365952ms step_avg:41.59ms -step:8800/20000 val_loss:2.1161 val_bpb:1.2532 train_time:365973ms step_avg:41.59ms -step:8850/20000 train_loss:2.3085 train_time:367999ms step_avg:41.58ms -step:8900/20000 train_loss:2.1983 train_time:370035ms step_avg:41.58ms -step:8950/20000 train_loss:2.1584 train_time:372078ms step_avg:41.57ms -step:9000/20000 train_loss:2.0222 train_time:374260ms step_avg:41.58ms -step:9000/20000 val_loss:2.1159 val_bpb:1.2532 train_time:374284ms step_avg:41.59ms -step:9050/20000 train_loss:2.0589 train_time:376310ms step_avg:41.58ms -step:9100/20000 train_loss:2.3027 train_time:378344ms step_avg:41.58ms -step:9150/20000 train_loss:1.9951 train_time:380383ms step_avg:41.57ms -step:9200/20000 train_loss:2.0824 train_time:382571ms step_avg:41.58ms -step:9200/20000 val_loss:2.1141 val_bpb:1.2521 train_time:382592ms step_avg:41.59ms -step:9250/20000 train_loss:2.1997 train_time:384620ms step_avg:41.58ms -step:9300/20000 train_loss:2.1245 train_time:386656ms step_avg:41.58ms -step:9350/20000 train_loss:2.2276 train_time:388840ms step_avg:41.59ms -step:9400/20000 train_loss:2.1315 train_time:390881ms step_avg:41.58ms -step:9400/20000 val_loss:2.1122 val_bpb:1.2510 train_time:390901ms step_avg:41.59ms -step:9450/20000 train_loss:2.1611 train_time:392928ms step_avg:41.58ms -step:9500/20000 train_loss:2.2559 train_time:394969ms step_avg:41.58ms -step:9550/20000 train_loss:2.1978 train_time:397134ms step_avg:41.58ms -step:9600/20000 train_loss:2.1446 train_time:399182ms step_avg:41.58ms -step:9600/20000 val_loss:2.1113 val_bpb:1.2504 train_time:399195ms step_avg:41.58ms -step:9650/20000 train_loss:2.0912 train_time:401213ms step_avg:41.58ms -step:9700/20000 train_loss:2.1043 train_time:403258ms step_avg:41.57ms -step:9750/20000 train_loss:2.0616 train_time:405438ms step_avg:41.58ms -step:9800/20000 train_loss:2.0748 train_time:407485ms step_avg:41.58ms -step:9800/20000 val_loss:2.1112 val_bpb:1.2503 train_time:407507ms step_avg:41.58ms -step:9850/20000 train_loss:2.0331 train_time:409526ms step_avg:41.58ms -step:9900/20000 train_loss:2.1455 train_time:411576ms step_avg:41.57ms -step:9950/20000 train_loss:2.0212 train_time:413748ms step_avg:41.58ms -step:10000/20000 train_loss:2.1121 train_time:415797ms step_avg:41.58ms -step:10000/20000 val_loss:2.1106 val_bpb:1.2500 train_time:415810ms step_avg:41.58ms -step:10050/20000 train_loss:2.1000 train_time:417829ms step_avg:41.58ms -step:10100/20000 train_loss:2.0911 train_time:419866ms step_avg:41.57ms -step:10150/20000 train_loss:2.0633 train_time:422044ms step_avg:41.58ms -step:10200/20000 train_loss:2.0664 train_time:424080ms step_avg:41.58ms -step:10200/20000 val_loss:2.1073 val_bpb:1.2481 train_time:424100ms step_avg:41.58ms -step:10250/20000 train_loss:2.0637 train_time:426124ms step_avg:41.57ms -step:10300/20000 train_loss:2.1898 train_time:428338ms step_avg:41.59ms -step:10350/20000 train_loss:2.1276 train_time:430363ms step_avg:41.58ms -step:10400/20000 train_loss:2.0965 train_time:432406ms step_avg:41.58ms -step:10400/20000 val_loss:2.1071 val_bpb:1.2480 train_time:432426ms step_avg:41.58ms -step:10450/20000 train_loss:2.0837 train_time:434443ms step_avg:41.57ms -step:10500/20000 train_loss:1.9827 train_time:436639ms step_avg:41.58ms -step:10550/20000 train_loss:2.0099 train_time:438689ms step_avg:41.58ms -step:10600/20000 train_loss:1.9746 train_time:440730ms step_avg:41.58ms -step:10600/20000 val_loss:2.1071 val_bpb:1.2479 train_time:440744ms step_avg:41.58ms -step:10650/20000 train_loss:2.1979 train_time:442761ms step_avg:41.57ms -step:10700/20000 train_loss:2.0719 train_time:444932ms step_avg:41.58ms -step:10750/20000 train_loss:2.1335 train_time:446980ms step_avg:41.58ms -step:10800/20000 train_loss:2.1824 train_time:449018ms step_avg:41.58ms -step:10800/20000 val_loss:2.1050 val_bpb:1.2467 train_time:449043ms step_avg:41.58ms -step:10850/20000 train_loss:2.1403 train_time:451055ms step_avg:41.57ms -step:10900/20000 train_loss:2.1497 train_time:453278ms step_avg:41.59ms -step:10950/20000 train_loss:2.1130 train_time:455306ms step_avg:41.58ms -step:11000/20000 train_loss:2.1196 train_time:457343ms step_avg:41.58ms -step:11000/20000 val_loss:2.1038 val_bpb:1.2460 train_time:457364ms step_avg:41.58ms -step:11050/20000 train_loss:2.0790 train_time:459384ms step_avg:41.57ms -step:11100/20000 train_loss:2.0557 train_time:461577ms step_avg:41.58ms -step:11150/20000 train_loss:2.1606 train_time:463610ms step_avg:41.58ms -step:11200/20000 train_loss:2.0688 train_time:465649ms step_avg:41.58ms -step:11200/20000 val_loss:2.1029 val_bpb:1.2454 train_time:465671ms step_avg:41.58ms -step:11250/20000 train_loss:1.9445 train_time:467691ms step_avg:41.57ms -step:11300/20000 train_loss:1.9937 train_time:469892ms step_avg:41.58ms -step:11350/20000 train_loss:1.9798 train_time:471919ms step_avg:41.58ms -step:11400/20000 train_loss:2.0541 train_time:473961ms step_avg:41.58ms -step:11400/20000 val_loss:2.1033 val_bpb:1.2457 train_time:473981ms step_avg:41.58ms -step:11450/20000 train_loss:2.0449 train_time:476142ms step_avg:41.58ms -step:11500/20000 train_loss:2.1067 train_time:478185ms step_avg:41.58ms -step:11550/20000 train_loss:2.1091 train_time:480227ms step_avg:41.58ms -step:11600/20000 train_loss:2.0575 train_time:482265ms step_avg:41.57ms -step:11600/20000 val_loss:2.1010 val_bpb:1.2444 train_time:482285ms step_avg:41.58ms -step:11650/20000 train_loss:2.1794 train_time:484478ms step_avg:41.59ms -step:11700/20000 train_loss:2.2088 train_time:486515ms step_avg:41.58ms -step:11750/20000 train_loss:2.1199 train_time:488561ms step_avg:41.58ms -step:11800/20000 train_loss:2.0924 train_time:490596ms step_avg:41.58ms -step:11800/20000 val_loss:2.1000 val_bpb:1.2437 train_time:490617ms step_avg:41.58ms -step:11850/20000 train_loss:2.1360 train_time:492795ms step_avg:41.59ms -step:11900/20000 train_loss:2.0570 train_time:494848ms step_avg:41.58ms -step:11950/20000 train_loss:2.0831 train_time:496888ms step_avg:41.58ms -step:12000/20000 train_loss:2.0670 train_time:498923ms step_avg:41.58ms -step:12000/20000 val_loss:2.0985 val_bpb:1.2428 train_time:498946ms step_avg:41.58ms -step:12050/20000 train_loss:2.0843 train_time:501106ms step_avg:41.59ms -step:12100/20000 train_loss:2.1068 train_time:503144ms step_avg:41.58ms -step:12150/20000 train_loss:2.2622 train_time:505197ms step_avg:41.58ms -step:12200/20000 train_loss:2.2202 train_time:507228ms step_avg:41.58ms -step:12200/20000 val_loss:2.0984 val_bpb:1.2428 train_time:507250ms step_avg:41.58ms -step:12250/20000 train_loss:1.9104 train_time:509409ms step_avg:41.58ms -step:12300/20000 train_loss:2.1068 train_time:511453ms step_avg:41.58ms -step:12350/20000 train_loss:2.1745 train_time:513496ms step_avg:41.58ms -step:12400/20000 train_loss:1.8565 train_time:515722ms step_avg:41.59ms -step:12400/20000 val_loss:2.0982 val_bpb:1.2427 train_time:515743ms step_avg:41.59ms -step:12450/20000 train_loss:2.0249 train_time:517767ms step_avg:41.59ms -step:12500/20000 train_loss:2.3661 train_time:519800ms step_avg:41.58ms -step:12550/20000 train_loss:2.1391 train_time:521843ms step_avg:41.58ms -step:12600/20000 train_loss:2.0869 train_time:524014ms step_avg:41.59ms -step:12600/20000 val_loss:2.0987 val_bpb:1.2430 train_time:524036ms step_avg:41.59ms -step:12650/20000 train_loss:2.0534 train_time:526069ms step_avg:41.59ms -step:12700/20000 train_loss:2.0893 train_time:528110ms step_avg:41.58ms -step:12750/20000 train_loss:2.1019 train_time:530150ms step_avg:41.58ms -step:12800/20000 train_loss:2.1097 train_time:532313ms step_avg:41.59ms -step:12800/20000 val_loss:2.0972 val_bpb:1.2421 train_time:532334ms step_avg:41.59ms -step:12850/20000 train_loss:2.0167 train_time:534356ms step_avg:41.58ms -step:12900/20000 train_loss:2.1447 train_time:536388ms step_avg:41.58ms -step:12950/20000 train_loss:2.0178 train_time:538439ms step_avg:41.58ms -step:13000/20000 train_loss:2.1921 train_time:540633ms step_avg:41.59ms -step:13000/20000 val_loss:2.0971 val_bpb:1.2420 train_time:540655ms step_avg:41.59ms -step:13050/20000 train_loss:2.1280 train_time:542674ms step_avg:41.58ms -step:13100/20000 train_loss:2.0189 train_time:544714ms step_avg:41.58ms -step:13150/20000 train_loss:2.0806 train_time:546755ms step_avg:41.58ms -step:13200/20000 train_loss:2.2015 train_time:548947ms step_avg:41.59ms -step:13200/20000 val_loss:2.0942 val_bpb:1.2403 train_time:548959ms step_avg:41.59ms -step:13250/20000 train_loss:2.2855 train_time:550972ms step_avg:41.58ms -step:13300/20000 train_loss:2.0595 train_time:553013ms step_avg:41.58ms -step:13350/20000 train_loss:2.0810 train_time:555056ms step_avg:41.58ms -step:13400/20000 train_loss:2.0734 train_time:557234ms step_avg:41.58ms -step:13400/20000 val_loss:2.0899 val_bpb:1.2378 train_time:557254ms step_avg:41.59ms -step:13450/20000 train_loss:2.1027 train_time:559274ms step_avg:41.58ms -step:13500/20000 train_loss:2.0397 train_time:561311ms step_avg:41.58ms -step:13550/20000 train_loss:2.1138 train_time:563503ms step_avg:41.59ms -step:13600/20000 train_loss:1.9325 train_time:565549ms step_avg:41.58ms -step:13600/20000 val_loss:2.0847 val_bpb:1.2347 train_time:565564ms step_avg:41.59ms -step:13650/20000 train_loss:2.1098 train_time:567585ms step_avg:41.58ms -step:13700/20000 train_loss:2.1089 train_time:569632ms step_avg:41.58ms -step:13750/20000 train_loss:2.1835 train_time:571799ms step_avg:41.59ms -step:13800/20000 train_loss:2.0195 train_time:573851ms step_avg:41.58ms -step:13800/20000 val_loss:2.0755 val_bpb:1.2292 train_time:573871ms step_avg:41.58ms -step:13850/20000 train_loss:2.0397 train_time:575897ms step_avg:41.58ms -step:13900/20000 train_loss:2.1022 train_time:577930ms step_avg:41.58ms -step:13950/20000 train_loss:1.9046 train_time:580197ms step_avg:41.59ms -step:14000/20000 train_loss:2.0752 train_time:582232ms step_avg:41.59ms -step:14000/20000 val_loss:2.0691 val_bpb:1.2254 train_time:582252ms step_avg:41.59ms -step:14050/20000 train_loss:1.9633 train_time:584270ms step_avg:41.59ms -step:14100/20000 train_loss:2.0442 train_time:586464ms step_avg:41.59ms -step:14150/20000 train_loss:2.1136 train_time:588646ms step_avg:41.60ms -step:14200/20000 train_loss:2.1544 train_time:590693ms step_avg:41.60ms -step:14200/20000 val_loss:2.0628 val_bpb:1.2217 train_time:590705ms step_avg:41.60ms -step:14250/20000 train_loss:2.2636 train_time:592724ms step_avg:41.59ms -step:14300/20000 train_loss:2.1059 train_time:594763ms step_avg:41.59ms -step:14350/20000 train_loss:2.0329 train_time:596936ms step_avg:41.60ms -step:14400/20000 train_loss:2.0489 train_time:598976ms step_avg:41.60ms -step:14400/20000 val_loss:2.0573 val_bpb:1.2184 train_time:598996ms step_avg:41.60ms -step:14421/20000 val_loss:2.0571 val_bpb:1.2183 train_time:599847ms step_avg:41.60ms -stopping_early: wallclock_cap train_time:599847ms step:14421/20000 -peak memory allocated: 10246 MiB reserved: 10310 MiB -Serialized model: 67224983 bytes -Code size: 50919 bytes -Total submission size: 67275902 bytes -Serialized model int8+zlib: 15803327 bytes (payload:17178912 raw_torch:17224025 payload_ratio:3.91x) -Total submission size int8+zlib: 15854246 bytes -final_int8_zlib_roundtrip val_loss:2.0649 val_bpb:1.2230 eval_time:1284ms -final_int8_zlib_roundtrip_exact val_loss:2.06492760 val_bpb:1.22296644 diff --git a/records/track_10min_16mb/2026-03-18_LowerLR/train_gpt.py b/records/track_10min_16mb/2026-03-18_LowerLR/train_gpt.py deleted file mode 100644 index 0deb0565f5..0000000000 --- a/records/track_10min_16mb/2026-03-18_LowerLR/train_gpt.py +++ /dev/null @@ -1,1126 +0,0 @@ -""" -The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. - -Hard stop: `train_gpt.py` and `train_gpt_mlx.py` must never be longer than 1500 lines. -""" - -from __future__ import annotations - -import copy -import glob -import io -import math -import os -import random -import subprocess -import sys -import time -import uuid -import zlib -from pathlib import Path - -import numpy as np -import sentencepiece as spm -import torch -import torch.distributed as dist -import torch.nn.functional as F -from torch import Tensor, nn -from torch.nn.parallel import DistributedDataParallel as DDP - -# ----------------------------- -# HYPERPARAMETERS -# ----------------------------- -# Default Simple Baseline run: -# - 9 transformer blocks at width 512 -# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion -# - vocab size 1024, sequence length 1024, tied embeddings -# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap - -class Hyperparameters: - # Data paths are shard globs produced by the existing preprocessing pipeline. - data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") - train_files = os.path.join(data_path, "fineweb_train_*.bin") - val_files = os.path.join(data_path, "fineweb_val_*.bin") - tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") - run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) - seed = int(os.environ.get("SEED", 1337)) - - # Validation cadence and batch size. Validation always uses the full fineweb_val split. - val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) - val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) - train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) - - # Training length. - iterations = int(os.environ.get("ITERATIONS", 20000)) - warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) - warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) - train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) - train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) - max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) - qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) - - # Model shape. - vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) - num_layers = int(os.environ.get("NUM_LAYERS", 9)) - num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) - model_dim = int(os.environ.get("MODEL_DIM", 512)) - num_heads = int(os.environ.get("NUM_HEADS", 8)) - mlp_mult = int(os.environ.get("MLP_MULT", 2)) - tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) - rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) - logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) - - # Optimizer hyperparameters. - embed_lr = float(os.environ.get("EMBED_LR", 0.6)) - head_lr = float(os.environ.get("HEAD_LR", 0.008)) - tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) - tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) - matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) - scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) - muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) - muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) - muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) - muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) - beta1 = float(os.environ.get("BETA1", 0.9)) - beta2 = float(os.environ.get("BETA2", 0.95)) - adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) - grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) - -# ----------------------------- -# MUON OPTIMIZER -# ----------------------------- -# -# As borrowed from modded-nanogpt -# Background on Muon: https://kellerjordan.github.io/posts/muon/ - -def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: - # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. - # Muon uses this to normalize matrix-shaped gradients before applying them. - a, b, c = (3.4445, -4.7750, 2.0315) - X = G.bfloat16() - X /= X.norm() + eps - transposed = G.size(0) > G.size(1) - if transposed: - X = X.T - for _ in range(steps): - A = X @ X.T - B = b * A + c * A @ A - X = a * X + B @ X - return X.T if transposed else X - - -class Muon(torch.optim.Optimizer): - def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): - super().__init__( - params, - dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), - ) - - @torch.no_grad() - def step(self, closure=None): - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - distributed = dist.is_available() and dist.is_initialized() - world_size = dist.get_world_size() if distributed else 1 - rank = dist.get_rank() if distributed else 0 - - for group in self.param_groups: - params = group["params"] - if not params: - continue - lr = group["lr"] - momentum = group["momentum"] - backend_steps = group["backend_steps"] - nesterov = group["nesterov"] - - total_params = sum(int(p.numel()) for p in params) - updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) - - curr = 0 - for i, p in enumerate(params): - if i % world_size == rank and p.grad is not None: - g = p.grad - state = self.state[p] - if "momentum_buffer" not in state: - state["momentum_buffer"] = torch.zeros_like(g) - buf = state["momentum_buffer"] - buf.mul_(momentum).add_(g) - if nesterov: - g = g.add(buf, alpha=momentum) - g = zeropower_via_newtonschulz5(g, steps=backend_steps) - # Scale correction from Muon reference implementations. - g *= max(1, g.size(0) / g.size(1)) ** 0.5 - updates_flat[curr : curr + p.numel()] = g.reshape(-1) - curr += p.numel() - - if distributed: - dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) - - curr = 0 - for p in params: - g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) - p.add_(g, alpha=-lr) - curr += p.numel() - - return loss - - -# ----------------------------- -# TOKENIZER-AGNOSTIC EVALUATION SETUP -# ----------------------------- -# -# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. -# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. -# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. -# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. - -def build_sentencepiece_luts( - sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device -) -> tuple[Tensor, Tensor, Tensor]: - sp_vocab_size = int(sp.vocab_size()) - table_size = max(sp_vocab_size, vocab_size) - base_bytes_np = np.zeros((table_size,), dtype=np.int16) - has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) - is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) - for token_id in range(sp_vocab_size): - if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): - continue - is_boundary_token_np[token_id] = False - if sp.is_byte(token_id): - base_bytes_np[token_id] = 1 - continue - piece = sp.id_to_piece(token_id) - if piece.startswith("▁"): - has_leading_space_np[token_id] = True - piece = piece[1:] - base_bytes_np[token_id] = len(piece.encode("utf-8")) - return ( - torch.tensor(base_bytes_np, dtype=torch.int16, device=device), - torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), - torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), - ) - - -def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: - files = [Path(p) for p in sorted(glob.glob(pattern))] - if not files: - raise FileNotFoundError(f"No files found for pattern: {pattern}") - # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. - tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() - usable = ((tokens.numel() - 1) // seq_len) * seq_len - if usable <= 0: - raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") - return tokens[: usable + 1] - - -def eval_val( - args: Hyperparameters, - model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - grad_accum_steps: int, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, -) -> tuple[float, float]: - # Validation computes two metrics: - # - val_loss: token cross-entropy (natural log) - # - val_bpb: tokenizer-agnostic compression metric used by the challenge - local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) - if local_batch_tokens < args.train_seq_len: - raise ValueError( - "VAL_BATCH_SIZE must provide at least one sequence per rank; " - f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " - f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" - ) - local_batch_seqs = local_batch_tokens // args.train_seq_len - total_seqs = (val_tokens.numel() - 1) // args.train_seq_len - seq_start = (total_seqs * rank) // world_size - seq_end = (total_seqs * (rank + 1)) // world_size - val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) - val_token_count = torch.zeros((), device=device, dtype=torch.float64) - val_byte_count = torch.zeros((), device=device, dtype=torch.float64) - - model.eval() - with torch.inference_mode(): - for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): - batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) - raw_start = batch_seq_start * args.train_seq_len - raw_end = batch_seq_end * args.train_seq_len + 1 - local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) - x = local[:-1].reshape(-1, args.train_seq_len) - y = local[1:].reshape(-1, args.train_seq_len) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - batch_loss = model(x, y).detach() - batch_token_count = float(y.numel()) - val_loss_sum += batch_loss.to(torch.float64) * batch_token_count - val_token_count += batch_token_count - prev_ids = x.reshape(-1) - tgt_ids = y.reshape(-1) - token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) - token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) - val_byte_count += token_bytes.to(torch.float64).sum() - - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) - - val_loss = val_loss_sum / val_token_count - bits_per_token = val_loss.item() / math.log(2.0) - tokens_per_byte = val_token_count.item() / val_byte_count.item() - model.train() - return float(val_loss.item()), float(bits_per_token * tokens_per_byte) - -# ----------------------------- -# POST-TRAINING QUANTIZATION -# ----------------------------- -# -# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. -# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. -# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. - -CONTROL_TENSOR_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "CONTROL_TENSOR_NAME_PATTERNS", - "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", - ",".join(CONTROL_TENSOR_NAME_PATTERNS), - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 -INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 -INT8_PER_ROW_SCALE_DTYPE = torch.float16 -INT8_CLIP_PERCENTILE = 99.99984 -INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 - -def tensor_nbytes(t: Tensor) -> int: - return int(t.numel()) * int(t.element_size()) - -def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: - if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): - return t.float().contiguous() - if t.dtype in {torch.float32, torch.bfloat16}: - passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") - return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() - return t - -def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: - t32 = t.float() - if t32.ndim == 2: - # Matrices get one scale per row, which usually tracks output-channel - # ranges much better than a single tensor-wide scale. - clip_abs = ( - torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) - if t32.numel() - else torch.empty((t32.shape[0],), dtype=torch.float32) - ) - clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) - scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) - q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() - return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() - - # Vectors / scalars use a simpler per-tensor scale. - clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 - scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) - q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() - return q, scale - -def quantize_state_dict_int8(state_dict: dict[str, Tensor]): - # Single supported clean-script export format: - # - per-row int8 for 2D float tensors - # - per-tensor int8 for other float tensors - # - exact passthrough for non-floats - # - passthrough for small float tensors, stored as fp16 to save bytes - quantized: dict[str, Tensor] = {} - scales: dict[str, Tensor] = {} - dtypes: dict[str, str] = {} - passthrough: dict[str, Tensor] = {} - passthrough_orig_dtypes: dict[str, str] = {} - qmeta: dict[str, dict[str, object]] = {} - stats = dict.fromkeys( - ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), - 0, - ) - - for name, tensor in state_dict.items(): - t = tensor.detach().to("cpu").contiguous() - stats["param_count"] += int(t.numel()) - stats["num_tensors"] += 1 - stats["baseline_tensor_bytes"] += tensor_nbytes(t) - - if not t.is_floating_point(): - stats["num_nonfloat_tensors"] += 1 - passthrough[name] = t - stats["int8_payload_bytes"] += tensor_nbytes(t) - continue - - # Small float tensors are cheap enough to keep directly. We still downcast - # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. - if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: - kept = keep_float_tensor(name, t, passthrough_orig_dtypes) - passthrough[name] = kept - stats["int8_payload_bytes"] += tensor_nbytes(kept) - continue - - stats["num_float_tensors"] += 1 - q, s = quantize_float_tensor(t) - if s.ndim > 0: - qmeta[name] = {"scheme": "per_row", "axis": 0} - quantized[name] = q - scales[name] = s - dtypes[name] = str(t.dtype).removeprefix("torch.") - stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) - - obj: dict[str, object] = { - "__quant_format__": "int8_clean_per_row_v1", - "quantized": quantized, - "scales": scales, - "dtypes": dtypes, - "passthrough": passthrough, - } - if qmeta: - obj["qmeta"] = qmeta - if passthrough_orig_dtypes: - obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes - return obj, stats - -def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: - out: dict[str, Tensor] = {} - qmeta = obj.get("qmeta", {}) - passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) - for name, q in obj["quantized"].items(): - dtype = getattr(torch, obj["dtypes"][name]) - s = obj["scales"][name] - if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: - s = s.to(dtype=torch.float32) - # Broadcast the saved row scale back across trailing dimensions. - out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() - else: - scale = float(s.item()) - out[name] = (q.float() * scale).to(dtype=dtype).contiguous() - for name, t in obj["passthrough"].items(): - # Restore small tensors, undoing the temporary fp16 storage cast if needed. - out_t = t.detach().to("cpu").contiguous() - orig_dtype = passthrough_orig_dtypes.get(name) - if isinstance(orig_dtype, str): - out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() - out[name] = out_t - return out - - -# ----------------------------- -# DATA LOADING -# ----------------------------- - -def load_data_shard(file: Path) -> Tensor: - header_bytes = 256 * np.dtype(" None: - self.file_idx = (self.file_idx + 1) % len(self.files) - self.tokens = load_data_shard(self.files[self.file_idx]) - self.pos = 0 - - def take(self, n: int) -> Tensor: - chunks: list[Tensor] = [] - remaining = n - while remaining > 0: - avail = self.tokens.numel() - self.pos - if avail <= 0: - self._advance_file() - continue - k = min(remaining, avail) - chunks.append(self.tokens[self.pos : self.pos + k]) - self.pos += k - remaining -= k - return chunks[0] if len(chunks) == 1 else torch.cat(chunks) - - -class DistributedTokenLoader: - # Each call consumes a contiguous chunk from the shared token stream, then slices out - # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. - def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): - self.rank = rank - self.world_size = world_size - self.device = device - self.stream = TokenStream(pattern) - - def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: - local_tokens = global_tokens // (self.world_size * grad_accum_steps) - per_rank_span = local_tokens + 1 - chunk = self.stream.take(per_rank_span * self.world_size) - start = self.rank * per_rank_span - local = chunk[start : start + per_rank_span].to(dtype=torch.int64) - x = local[:-1].reshape(-1, seq_len) - y = local[1:].reshape(-1, seq_len) - return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) - -# ----------------------------- -# TRANSFORMER MODULES -# ----------------------------- - -class RMSNorm(nn.Module): - def __init__(self, eps: float | None = None): - super().__init__() - self.eps = eps - - def forward(self, x: Tensor) -> Tensor: - return F.rms_norm(x, (x.size(-1),), eps=self.eps) - - -class CastedLinear(nn.Linear): - # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. - def forward(self, x: Tensor) -> Tensor: - bias = self.bias.to(x.dtype) if self.bias is not None else None - return F.linear(x, self.weight.to(x.dtype), bias) - - -def restore_low_dim_params_to_fp32(module: nn.Module) -> None: - # Keep small/control parameters in fp32 even when the model body runs in bf16. - with torch.no_grad(): - for name, param in module.named_parameters(): - if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: - param.data = param.data.float() - - -class Rotary(nn.Module): - # Caches cos/sin tables per sequence length on the current device. - def __init__(self, dim: int, base: float = 10000.0): - super().__init__() - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self._seq_len_cached = 0 - self._cos_cached: Tensor | None = None - self._sin_cached: Tensor | None = None - - def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: - if ( - self._cos_cached is None - or self._sin_cached is None - or self._seq_len_cached != seq_len - or self._cos_cached.device != device - ): - t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) - freqs = torch.outer(t, self.inv_freq.to(device)) - self._cos_cached = freqs.cos()[None, None, :, :] - self._sin_cached = freqs.sin()[None, None, :, :] - self._seq_len_cached = seq_len - return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) - - -def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: - half = x.size(-1) // 2 - x1, x2 = x[..., :half], x[..., half:] - return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) - - -class CausalSelfAttention(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - rope_base: float, - qk_gain_init: float, - ): - super().__init__() - if dim % num_heads != 0: - raise ValueError("model_dim must be divisible by num_heads") - if num_heads % num_kv_heads != 0: - raise ValueError("num_heads must be divisible by num_kv_heads") - self.num_heads = num_heads - self.num_kv_heads = num_kv_heads - self.head_dim = dim // num_heads - if self.head_dim % 2 != 0: - raise ValueError("head_dim must be even for RoPE") - kv_dim = self.num_kv_heads * self.head_dim - self.c_q = CastedLinear(dim, dim, bias=False) - self.c_k = CastedLinear(dim, kv_dim, bias=False) - self.c_v = CastedLinear(dim, kv_dim, bias=False) - self.proj = CastedLinear(dim, dim, bias=False) - self.proj._zero_init = True - self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) - self.rotary = Rotary(self.head_dim, base=rope_base) - - def forward(self, x: Tensor) -> Tensor: - bsz, seqlen, dim = x.shape - q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) - k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) - v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) - q = F.rms_norm(q, (q.size(-1),)) - k = F.rms_norm(k, (k.size(-1),)) - cos, sin = self.rotary(seqlen, x.device, q.dtype) - q = apply_rotary_emb(q, cos, sin) - k = apply_rotary_emb(k, cos, sin) - q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] - y = F.scaled_dot_product_attention( - q, - k, - v, - attn_mask=None, - is_causal=True, - enable_gqa=(self.num_kv_heads != self.num_heads), - ) - y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) - return self.proj(y) - - -class MLP(nn.Module): - # relu^2 MLP from the original modded-nanogpt setup - def __init__(self, dim: int, mlp_mult: int): - super().__init__() - hidden = mlp_mult * dim - self.fc = CastedLinear(dim, hidden, bias=False) - self.proj = CastedLinear(hidden, dim, bias=False) - self.proj._zero_init = True - - def forward(self, x: Tensor) -> Tensor: - x = torch.relu(self.fc(x)) - return self.proj(x.square()) - - -class Block(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: int, - rope_base: float, - qk_gain_init: float, - ): - super().__init__() - self.attn_norm = RMSNorm() - self.mlp_norm = RMSNorm() - self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) - self.mlp = MLP(dim, mlp_mult) - self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) - - def forward(self, x: Tensor, x0: Tensor) -> Tensor: - mix = self.resid_mix.to(dtype=x.dtype) - x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 - attn_out = self.attn(self.attn_norm(x)) - x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out - x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) - return x - - -class GPT(nn.Module): - def __init__( - self, - vocab_size: int, - num_layers: int, - model_dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: int, - tie_embeddings: bool, - tied_embed_init_std: float, - logit_softcap: float, - rope_base: float, - qk_gain_init: float, - ): - super().__init__() - if logit_softcap <= 0.0: - raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") - self.tie_embeddings = tie_embeddings - self.tied_embed_init_std = tied_embed_init_std - self.logit_softcap = logit_softcap - self.tok_emb = nn.Embedding(vocab_size, model_dim) - self.num_encoder_layers = num_layers // 2 - self.num_decoder_layers = num_layers - self.num_encoder_layers - self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) - self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) - self.blocks = nn.ModuleList( - [ - Block( - model_dim, - num_heads, - num_kv_heads, - mlp_mult, - rope_base, - qk_gain_init, - ) - for i in range(num_layers) - ] - ) - self.final_norm = RMSNorm() - self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) - if self.lm_head is not None: - self.lm_head._zero_init = True - self._init_weights() - - def _init_weights(self) -> None: - if self.tie_embeddings: - nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) - for module in self.modules(): - if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): - nn.init.zeros_(module.weight) - - def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: - x = self.tok_emb(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - x0 = x - skips: list[Tensor] = [] - - # First half stores skips; second half reuses them in reverse order. - for i in range(self.num_encoder_layers): - x = self.blocks[i](x, x0) - skips.append(x) - for i in range(self.num_decoder_layers): - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - x = self.blocks[self.num_encoder_layers + i](x, x0) - - x = self.final_norm(x).reshape(-1, x.size(-1)) - targets = target_ids.reshape(-1) - if self.tie_embeddings: - logits_proj = F.linear(x, self.tok_emb.weight) - else: - if self.lm_head is None: - raise RuntimeError("lm_head is required when tie_embeddings=False") - logits_proj = self.lm_head(x) - logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - return F.cross_entropy(logits.float(), targets, reduction="mean") - - -# ----------------------------- -# TRAINING -# ----------------------------- - -def main() -> None: - global zeropower_via_newtonschulz5 - - code = Path(__file__).read_text(encoding="utf-8") - args = Hyperparameters() - zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) - - # ----------------------------- - # DISTRIBUTED + CUDA SETUP - # ----------------------------- - - distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ - rank = int(os.environ.get("RANK", "0")) - world_size = int(os.environ.get("WORLD_SIZE", "1")) - local_rank = int(os.environ.get("LOCAL_RANK", "0")) - if world_size <= 0: - raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") - if 8 % world_size != 0: - raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") - grad_accum_steps = 8 // world_size - grad_scale = 1.0 / grad_accum_steps - if not torch.cuda.is_available(): - raise RuntimeError("CUDA is required") - device = torch.device("cuda", local_rank) - torch.cuda.set_device(device) - if distributed: - dist.init_process_group(backend="nccl", device_id=device) - dist.barrier() - master_process = rank == 0 - - # Fast math knobs - torch.backends.cuda.matmul.allow_tf32 = True - torch.backends.cudnn.allow_tf32 = True - from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp - - enable_cudnn_sdp(False) - enable_flash_sdp(True) - enable_mem_efficient_sdp(False) - enable_math_sdp(False) - - logfile = None - if master_process: - os.makedirs("logs", exist_ok=True) - logfile = f"logs/{args.run_id}.txt" - print(logfile) - - def log0(msg: str, console: bool = True) -> None: - if not master_process: - return - if console: - print(msg) - if logfile is not None: - with open(logfile, "a", encoding="utf-8") as f: - print(msg, file=f) - - log0(code, console=False) - log0("=" * 100, console=False) - log0(f"Running Python {sys.version}", console=False) - log0(f"Running PyTorch {torch.__version__}", console=False) - log0( - subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, - console=False, - ) - log0("=" * 100, console=False) - - # ----------------------------- - # TOKENIZER + VALIDATION METRIC SETUP - # ----------------------------- - - random.seed(args.seed) - np.random.seed(args.seed) - torch.manual_seed(args.seed) - torch.cuda.manual_seed_all(args.seed) - - if not args.tokenizer_path.endswith(".model"): - raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") - sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) - if int(sp.vocab_size()) != args.vocab_size: - raise ValueError( - f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" - ) - dataset_dir = Path(args.data_path).resolve() - actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) - val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) - base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( - sp, args.vocab_size, device - ) - log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") - log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") - log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") - - # ----------------------------- - # MODEL + OPTIMIZER SETUP - # ----------------------------- - - base_model = GPT( - vocab_size=args.vocab_size, - num_layers=args.num_layers, - model_dim=args.model_dim, - num_heads=args.num_heads, - num_kv_heads=args.num_kv_heads, - mlp_mult=args.mlp_mult, - tie_embeddings=args.tie_embeddings, - tied_embed_init_std=args.tied_embed_init_std, - logit_softcap=args.logit_softcap, - rope_base=args.rope_base, - qk_gain_init=args.qk_gain_init, - ).to(device).bfloat16() - for module in base_model.modules(): - if isinstance(module, CastedLinear): - module.float() - restore_low_dim_params_to_fp32(base_model) - compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) - model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model - - # Optimizer split: - # - token embedding (Adam) uses EMBED_LR - # - untied lm_head (Adam) uses HEAD_LR - # - matrix params in transformer blocks use MATRIX_LR via Muon - # - vectors/scalars use SCALAR_LR via Adam - block_named_params = list(base_model.blocks.named_parameters()) - matrix_params = [ - p - for name, p in block_named_params - if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - scalar_params = [ - p - for name, p in block_named_params - if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - if base_model.skip_weights.numel() > 0: - scalar_params.append(base_model.skip_weights) - token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr - optimizer_tok = torch.optim.Adam( - [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizer_muon = Muon( - matrix_params, - lr=args.matrix_lr, - momentum=args.muon_momentum, - backend_steps=args.muon_backend_steps, - ) - for group in optimizer_muon.param_groups: - group["base_lr"] = args.matrix_lr - optimizer_scalar = torch.optim.Adam( - [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] - if base_model.lm_head is not None: - optimizer_head = torch.optim.Adam( - [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizers.insert(1, optimizer_head) - - n_params = sum(p.numel() for p in base_model.parameters()) - log0(f"model_params:{n_params}") - log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") - log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") - log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") - log0( - f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " - f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " - f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" - ) - log0( - f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " - f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " - f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" - ) - log0(f"seed:{args.seed}") - - # ----------------------------- - # DATA LOADER & MODEL WARMUP - # ----------------------------- - - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - def zero_grad_all() -> None: - for opt in optimizers: - opt.zero_grad(set_to_none=True) - - max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None - - def lr_mul(step: int, elapsed_ms: float) -> float: - if args.warmdown_iters <= 0: - return 1.0 - if max_wallclock_ms is None: - warmdown_start = max(args.iterations - args.warmdown_iters, 0) - return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 - step_ms = elapsed_ms / max(step, 1) - warmdown_ms = args.warmdown_iters * step_ms - remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) - return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 - - # Warmup primes the compiled forward/backward/optimizer paths, then we restore the - # initial weights/optimizer state so measured training starts from the true init. - if args.warmup_steps > 0: - initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} - initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] - model.train() - for warmup_step in range(args.warmup_steps): - zero_grad_all() - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - warmup_loss = model(x, y) - (warmup_loss * grad_scale).backward() - for opt in optimizers: - opt.step() - zero_grad_all() - if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: - log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") - base_model.load_state_dict(initial_model_state, strict=True) - for opt, state in zip(optimizers, initial_optimizer_states, strict=True): - opt.load_state_dict(state) - zero_grad_all() - if distributed: - model.require_backward_grad_sync = True - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - # ----------------------------- - # MAIN TRAINING LOOP - # ----------------------------- - - training_time_ms = 0.0 - stop_after_step: int | None = None - torch.cuda.synchronize() - t0 = time.perf_counter() - - step = 0 - while True: - last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) - - should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) - if should_validate: - torch.cuda.synchronize() - training_time_ms += 1000.0 * (time.perf_counter() - t0) - val_loss, val_bpb = eval_val( - args, - model, - rank, - world_size, - device, - grad_accum_steps, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, - ) - log0( - f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " - f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" - ) - torch.cuda.synchronize() - t0 = time.perf_counter() - - if last_step: - if stop_after_step is not None and step < args.iterations: - log0( - f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " - f"step:{step}/{args.iterations}" - ) - break - - elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - scale = lr_mul(step, elapsed_ms) - zero_grad_all() - train_loss = torch.zeros((), device=device) - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - loss = model(x, y) - train_loss += loss.detach() - (loss * grad_scale).backward() - train_loss /= grad_accum_steps - - frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 - muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum - for group in optimizer_muon.param_groups: - group["momentum"] = muon_momentum - - for opt in optimizers: - for group in opt.param_groups: - group["lr"] = group["base_lr"] * scale - - if args.grad_clip_norm > 0: - torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) - for opt in optimizers: - opt.step() - zero_grad_all() - - step += 1 - approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - should_log_train = ( - args.train_log_every > 0 - and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) - ) - if should_log_train: - log0( - f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " - f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" - ) - - # Needed to sync whether we've reached the wallclock cap. - reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms - if distributed and max_wallclock_ms is not None: - reached_cap_tensor = torch.tensor(int(reached_cap), device=device) - dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) - reached_cap = bool(reached_cap_tensor.item()) - if stop_after_step is None and reached_cap: - stop_after_step = step - - log0( - f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " - f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" - ) - - # ----------------------------- - # SERIALIZATION + ROUNDTRIP VALIDATION - # ----------------------------- - # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce - # the compressed int8+zlib artifact and validate the round-tripped weights. - - if master_process: - torch.save(base_model.state_dict(), "final_model.pt") - model_bytes = os.path.getsize("final_model.pt") - code_bytes = len(code.encode("utf-8")) - log0(f"Serialized model: {model_bytes} bytes") - log0(f"Code size: {code_bytes} bytes") - log0(f"Total submission size: {model_bytes + code_bytes} bytes") - - quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) - quant_buf = io.BytesIO() - torch.save(quant_obj, quant_buf) - quant_raw = quant_buf.getvalue() - quant_blob = zlib.compress(quant_raw, level=9) - quant_raw_bytes = len(quant_raw) - if master_process: - with open("final_model.int8.ptz", "wb") as f: - f.write(quant_blob) - quant_file_bytes = os.path.getsize("final_model.int8.ptz") - code_bytes = len(code.encode("utf-8")) - ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) - log0( - f"Serialized model int8+zlib: {quant_file_bytes} bytes " - f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" - ) - log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") - - if distributed: - dist.barrier() - with open("final_model.int8.ptz", "rb") as f: - quant_blob_disk = f.read() - quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") - base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) - torch.cuda.synchronize() - t_qeval = time.perf_counter() - q_val_loss, q_val_bpb = eval_val( - args, - model, - rank, - world_size, - device, - grad_accum_steps, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, - ) - torch.cuda.synchronize() - log0( - f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " - f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" - ) - log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") - - if distributed: - dist.destroy_process_group() - - -if __name__ == "__main__": - main() diff --git a/records/track_10min_16mb/2026-03-19_10L_MixedPrecision/README.md b/records/track_10min_16mb/2026-03-19_10L_MixedPrecision/README.md deleted file mode 100644 index ce9815d0d6..0000000000 --- a/records/track_10min_16mb/2026-03-19_10L_MixedPrecision/README.md +++ /dev/null @@ -1,75 +0,0 @@ -This record captures the `10L Mixed Precision` submission. - -## Summary - -Two key improvements over the baseline: - -1. **10 transformer layers** instead of 9 — adds depth for better language modeling -2. **Lower learning rates** — MATRIX_LR=0.02, SCALAR_LR=0.02, TIED_EMBED_LR=0.03 (vs default 0.04/0.04/0.05) -3. **Mixed int8/int6 compression** — middle layers (3,4,5,6) use int6 precision (round int8 to nearest 4) for better zlib compression, while first/last layers keep full int8 - -The 10-layer model at dim=512 has 18.9M params which compresses to 17.6MB with standard int8+zlib — 1.6MB over the 16MB cap. By reducing precision on the 4 middle layers to int6 (64 quantization levels instead of 256), the compressed size drops to 15.9MB with only 0.0018 bpb quality loss from quantization. - -## Configuration - -- Layout: `VOCAB_SIZE=1024 NUM_LAYERS=10 MODEL_DIM=512 NUM_HEADS=8 NUM_KV_HEADS=4 MLP_MULT=2` -- Tied output/input embeddings: `TIE_EMBEDDINGS=1` -- Learning rates: `MATRIX_LR=0.02 SCALAR_LR=0.02 TIED_EMBED_LR=0.03` -- Mixed precision: `INT4_LAYERS=3,4,5,6 INT4_STEP=4` (int6 for middle layers) -- Batching: `TRAIN_BATCH_TOKENS=524288 TRAIN_SEQ_LEN=1024` - -## Command - -```bash -RUN_ID=exp45_10L_int6_mid \ -DATA_PATH=./data/datasets/fineweb10B_sp1024/ \ -TOKENIZER_PATH=./data/tokenizers/fineweb_1024_bpe.model \ -VOCAB_SIZE=1024 \ -MAX_WALLCLOCK_SECONDS=600 \ -VAL_LOSS_EVERY=200 \ -TRAIN_LOG_EVERY=50 \ -NUM_LAYERS=10 \ -MATRIX_LR=0.02 \ -SCALAR_LR=0.02 \ -TIED_EMBED_LR=0.03 \ -INT4_LAYERS=3,4,5,6 \ -INT4_STEP=4 \ -torchrun --standalone --nproc_per_node=8 train_gpt.py -``` - -## Key Metrics - -- Pre-quant eval: `val_loss:2.0480`, `val_bpb:1.2129` -- Post-quant (int8/int6 mixed + zlib): `val_loss:2.0510`, `val_bpb:1.2147` -- Exact: `final_int8_zlib_roundtrip_exact val_bpb:1.21474500` -- Quantization gap: 0.0018 bpb (vs baseline's 0.0093) -- Train time: `599732ms` (`step_avg:45.78ms`) -- Steps: 13,100/20,000 (wallclock limited) -- Peak memory: 11,389 MiB allocated -- Artifact: 15,928,974 bytes (code: 48,917 + model: 15,880,057) - -## Compression Analysis - -| Layer Group | Precision | Reason | -|---|---|---| -| Layers 0-2 (early) | int8 (256 levels) | Critical for input processing | -| Layers 3-6 (middle) | int6 (64 levels) | Less sensitive, saves ~1.6MB | -| Layers 7-9 (late) | int8 (256 levels) | Critical for output quality | - -## LR Sweep Results - -Systematic sweep showed default LR (0.04) was too high: -| MATRIX_LR | val_bpb (9L baseline) | -|---|---| -| 0.04 (default) | 1.2286 | -| 0.02 (optimal) | 1.2230 | - -## Note on Hardware - -Run performed on 8xH200 (step_avg: 45.78ms). H100 baseline was 43.54ms/step for 9 layers; 10 layers would be ~47-48ms on H100, yielding ~12,500-12,700 steps. Results should be comparable. - -## Included Files - -- `train_gpt.py` (code snapshot) -- `train.log` (training log) -- `submission.json` (metadata) diff --git a/records/track_10min_16mb/2026-03-19_10L_MixedPrecision/submission.json b/records/track_10min_16mb/2026-03-19_10L_MixedPrecision/submission.json deleted file mode 100644 index 0f3c94741f..0000000000 --- a/records/track_10min_16mb/2026-03-19_10L_MixedPrecision/submission.json +++ /dev/null @@ -1,11 +0,0 @@ -{ - "author": "Nan Liu", - "github_id": "nanlliu", - "name": "10L Mixed Precision", - "blurb": "10-layer 512-dim model with lower LR (MATRIX_LR=0.02) and mixed int8/int6 compression: full int8 for first/last 3 layers, int6 (step=4 rounding) for middle layers 3-6. Fits 16MB via better compression while gaining an extra transformer layer over baseline.", - "date": "2026-03-19T03:30:00Z", - "val_loss": 2.05104604, - "val_bpb": 1.21474500, - "bytes_total": 15928974, - "bytes_code": 48917 -} diff --git a/records/track_10min_16mb/2026-03-19_10L_MixedPrecision/train.log b/records/track_10min_16mb/2026-03-19_10L_MixedPrecision/train.log deleted file mode 100644 index c40e941144..0000000000 --- a/records/track_10min_16mb/2026-03-19_10L_MixedPrecision/train.log +++ /dev/null @@ -1,1591 +0,0 @@ -""" -The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. - -Hard stop: `train_gpt.py` and `train_gpt_mlx.py` must never be longer than 1500 lines. -""" - -from __future__ import annotations - -import copy -import glob -import io -import math -import os -import random -import subprocess -import sys -import time -import uuid -import zlib -from pathlib import Path - -import numpy as np -import sentencepiece as spm -import torch -import torch.distributed as dist -import torch.nn.functional as F -from torch import Tensor, nn -from torch.nn.parallel import DistributedDataParallel as DDP - -# ----------------------------- -# HYPERPARAMETERS -# ----------------------------- -# Default Simple Baseline run: -# - 9 transformer blocks at width 512 -# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion -# - vocab size 1024, sequence length 1024, tied embeddings -# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap - -class Hyperparameters: - # Data paths are shard globs produced by the existing preprocessing pipeline. - data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") - train_files = os.path.join(data_path, "fineweb_train_*.bin") - val_files = os.path.join(data_path, "fineweb_val_*.bin") - tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") - run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) - seed = int(os.environ.get("SEED", 1337)) - - # Validation cadence and batch size. Validation always uses the full fineweb_val split. - val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) - val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) - train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) - - # Training length. - iterations = int(os.environ.get("ITERATIONS", 20000)) - warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) - warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) - train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) - train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) - max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) - qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) - prune_ratio = float(os.environ.get("PRUNE_RATIO", 0.0)) # fraction of int8 range to prune (e.g. 0.1 = zero out |val| <= 12) - int4_layers = os.environ.get("INT4_LAYERS", "") # comma-separated layer indices for reduced precision (e.g. "3,4,5,6") - int4_step = int(os.environ.get("INT4_STEP", 16)) # rounding step: 2=int7, 4=int6, 8=int5, 16=int4 - - # Model shape. - vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) - num_layers = int(os.environ.get("NUM_LAYERS", 9)) - num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) - model_dim = int(os.environ.get("MODEL_DIM", 512)) - num_heads = int(os.environ.get("NUM_HEADS", 8)) - mlp_mult = int(os.environ.get("MLP_MULT", 2)) - tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) - rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) - logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) - - # Optimizer hyperparameters. - embed_lr = float(os.environ.get("EMBED_LR", 0.6)) - head_lr = float(os.environ.get("HEAD_LR", 0.008)) - tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) - tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) - matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) - scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) - muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) - muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) - muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) - muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) - beta1 = float(os.environ.get("BETA1", 0.9)) - beta2 = float(os.environ.get("BETA2", 0.95)) - adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) - grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) - -# ----------------------------- -# MUON OPTIMIZER -# ----------------------------- -# -# As borrowed from modded-nanogpt -# Background on Muon: https://kellerjordan.github.io/posts/muon/ - -def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: - # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. - # Muon uses this to normalize matrix-shaped gradients before applying them. - a, b, c = (3.4445, -4.7750, 2.0315) - X = G.bfloat16() - X /= X.norm() + eps - transposed = G.size(0) > G.size(1) - if transposed: - X = X.T - for _ in range(steps): - A = X @ X.T - B = b * A + c * A @ A - X = a * X + B @ X - return X.T if transposed else X - - -class Muon(torch.optim.Optimizer): - def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): - super().__init__( - params, - dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), - ) - - @torch.no_grad() - def step(self, closure=None): - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - distributed = dist.is_available() and dist.is_initialized() - world_size = dist.get_world_size() if distributed else 1 - rank = dist.get_rank() if distributed else 0 - - for group in self.param_groups: - params = group["params"] - if not params: - continue - lr = group["lr"] - momentum = group["momentum"] - backend_steps = group["backend_steps"] - nesterov = group["nesterov"] - - total_params = sum(int(p.numel()) for p in params) - updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) - - curr = 0 - for i, p in enumerate(params): - if i % world_size == rank and p.grad is not None: - g = p.grad - state = self.state[p] - if "momentum_buffer" not in state: - state["momentum_buffer"] = torch.zeros_like(g) - buf = state["momentum_buffer"] - buf.mul_(momentum).add_(g) - if nesterov: - g = g.add(buf, alpha=momentum) - g = zeropower_via_newtonschulz5(g, steps=backend_steps) - # Scale correction from Muon reference implementations. - g *= max(1, g.size(0) / g.size(1)) ** 0.5 - updates_flat[curr : curr + p.numel()] = g.reshape(-1) - curr += p.numel() - - if distributed: - dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) - - curr = 0 - for p in params: - g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) - p.add_(g, alpha=-lr) - curr += p.numel() - - return loss - - -# ----------------------------- -# TOKENIZER-AGNOSTIC EVALUATION SETUP -# ----------------------------- -# -# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. -# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. -# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. -# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. - -def build_sentencepiece_luts( - sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device -) -> tuple[Tensor, Tensor, Tensor]: - sp_vocab_size = int(sp.vocab_size()) - table_size = max(sp_vocab_size, vocab_size) - base_bytes_np = np.zeros((table_size,), dtype=np.int16) - has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) - is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) - for token_id in range(sp_vocab_size): - if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): - continue - is_boundary_token_np[token_id] = False - if sp.is_byte(token_id): - base_bytes_np[token_id] = 1 - continue - piece = sp.id_to_piece(token_id) - if piece.startswith("▁"): - has_leading_space_np[token_id] = True - piece = piece[1:] - base_bytes_np[token_id] = len(piece.encode("utf-8")) - return ( - torch.tensor(base_bytes_np, dtype=torch.int16, device=device), - torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), - torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), - ) - - -def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: - files = [Path(p) for p in sorted(glob.glob(pattern))] - if not files: - raise FileNotFoundError(f"No files found for pattern: {pattern}") - # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. - tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() - usable = ((tokens.numel() - 1) // seq_len) * seq_len - if usable <= 0: - raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") - return tokens[: usable + 1] - - -def eval_val( - args: Hyperparameters, - model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - grad_accum_steps: int, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, -) -> tuple[float, float]: - # Validation computes two metrics: - # - val_loss: token cross-entropy (natural log) - # - val_bpb: tokenizer-agnostic compression metric used by the challenge - local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) - if local_batch_tokens < args.train_seq_len: - raise ValueError( - "VAL_BATCH_SIZE must provide at least one sequence per rank; " - f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " - f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" - ) - local_batch_seqs = local_batch_tokens // args.train_seq_len - total_seqs = (val_tokens.numel() - 1) // args.train_seq_len - seq_start = (total_seqs * rank) // world_size - seq_end = (total_seqs * (rank + 1)) // world_size - val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) - val_token_count = torch.zeros((), device=device, dtype=torch.float64) - val_byte_count = torch.zeros((), device=device, dtype=torch.float64) - - model.eval() - with torch.inference_mode(): - for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): - batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) - raw_start = batch_seq_start * args.train_seq_len - raw_end = batch_seq_end * args.train_seq_len + 1 - local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) - x = local[:-1].reshape(-1, args.train_seq_len) - y = local[1:].reshape(-1, args.train_seq_len) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - batch_loss = model(x, y).detach() - batch_token_count = float(y.numel()) - val_loss_sum += batch_loss.to(torch.float64) * batch_token_count - val_token_count += batch_token_count - prev_ids = x.reshape(-1) - tgt_ids = y.reshape(-1) - token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) - token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) - val_byte_count += token_bytes.to(torch.float64).sum() - - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) - - val_loss = val_loss_sum / val_token_count - bits_per_token = val_loss.item() / math.log(2.0) - tokens_per_byte = val_token_count.item() / val_byte_count.item() - model.train() - return float(val_loss.item()), float(bits_per_token * tokens_per_byte) - - - -# ----------------------------- -# POST-TRAINING QUANTIZATION -# ----------------------------- -# -# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. -# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. -# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. - -CONTROL_TENSOR_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "CONTROL_TENSOR_NAME_PATTERNS", - "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", - ",".join(CONTROL_TENSOR_NAME_PATTERNS), - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 -INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 -INT8_PER_ROW_SCALE_DTYPE = torch.float16 -INT8_CLIP_PERCENTILE = 99.99984 -INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 - -def tensor_nbytes(t: Tensor) -> int: - return int(t.numel()) * int(t.element_size()) - -def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: - if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): - return t.float().contiguous() - if t.dtype in {torch.float32, torch.bfloat16}: - passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") - return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() - return t - -def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: - t32 = t.float() - if t32.ndim == 2: - # Matrices get one scale per row, which usually tracks output-channel - # ranges much better than a single tensor-wide scale. - clip_abs = ( - torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) - if t32.numel() - else torch.empty((t32.shape[0],), dtype=torch.float32) - ) - clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) - scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) - q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() - return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() - - # Vectors / scalars use a simpler per-tensor scale. - clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 - scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) - q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() - return q, scale - -def quantize_state_dict_int8(state_dict: dict[str, Tensor]): - # Single supported clean-script export format: - # - per-row int8 for 2D float tensors - # - per-tensor int8 for other float tensors - # - exact passthrough for non-floats - # - passthrough for small float tensors, stored as fp16 to save bytes - quantized: dict[str, Tensor] = {} - scales: dict[str, Tensor] = {} - dtypes: dict[str, str] = {} - passthrough: dict[str, Tensor] = {} - passthrough_orig_dtypes: dict[str, str] = {} - qmeta: dict[str, dict[str, object]] = {} - stats = dict.fromkeys( - ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), - 0, - ) - - for name, tensor in state_dict.items(): - t = tensor.detach().to("cpu").contiguous() - stats["param_count"] += int(t.numel()) - stats["num_tensors"] += 1 - stats["baseline_tensor_bytes"] += tensor_nbytes(t) - - if not t.is_floating_point(): - stats["num_nonfloat_tensors"] += 1 - passthrough[name] = t - stats["int8_payload_bytes"] += tensor_nbytes(t) - continue - - # Small float tensors are cheap enough to keep directly. We still downcast - # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. - if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: - kept = keep_float_tensor(name, t, passthrough_orig_dtypes) - passthrough[name] = kept - stats["int8_payload_bytes"] += tensor_nbytes(kept) - continue - - stats["num_float_tensors"] += 1 - q, s = quantize_float_tensor(t) - if s.ndim > 0: - qmeta[name] = {"scheme": "per_row", "axis": 0} - quantized[name] = q - scales[name] = s - dtypes[name] = str(t.dtype).removeprefix("torch.") - stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) - - obj: dict[str, object] = { - "__quant_format__": "int8_clean_per_row_v1", - "quantized": quantized, - "scales": scales, - "dtypes": dtypes, - "passthrough": passthrough, - } - if qmeta: - obj["qmeta"] = qmeta - if passthrough_orig_dtypes: - obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes - return obj, stats - -def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: - out: dict[str, Tensor] = {} - qmeta = obj.get("qmeta", {}) - passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) - for name, q in obj["quantized"].items(): - dtype = getattr(torch, obj["dtypes"][name]) - s = obj["scales"][name] - if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: - s = s.to(dtype=torch.float32) - # Broadcast the saved row scale back across trailing dimensions. - out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() - else: - scale = float(s.item()) - out[name] = (q.float() * scale).to(dtype=dtype).contiguous() - for name, t in obj["passthrough"].items(): - # Restore small tensors, undoing the temporary fp16 storage cast if needed. - out_t = t.detach().to("cpu").contiguous() - orig_dtype = passthrough_orig_dtypes.get(name) - if isinstance(orig_dtype, str): - out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() - out[name] = out_t - return out - - -# ----------------------------- -# DATA LOADING -# ----------------------------- - -def load_data_shard(file: Path) -> Tensor: - header_bytes = 256 * np.dtype(" None: - self.file_idx = (self.file_idx + 1) % len(self.files) - self.tokens = load_data_shard(self.files[self.file_idx]) - self.pos = 0 - - def take(self, n: int) -> Tensor: - chunks: list[Tensor] = [] - remaining = n - while remaining > 0: - avail = self.tokens.numel() - self.pos - if avail <= 0: - self._advance_file() - continue - k = min(remaining, avail) - chunks.append(self.tokens[self.pos : self.pos + k]) - self.pos += k - remaining -= k - return chunks[0] if len(chunks) == 1 else torch.cat(chunks) - - -class DistributedTokenLoader: - # Each call consumes a contiguous chunk from the shared token stream, then slices out - # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. - def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): - self.rank = rank - self.world_size = world_size - self.device = device - self.stream = TokenStream(pattern) - - def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: - local_tokens = global_tokens // (self.world_size * grad_accum_steps) - per_rank_span = local_tokens + 1 - chunk = self.stream.take(per_rank_span * self.world_size) - start = self.rank * per_rank_span - local = chunk[start : start + per_rank_span].to(dtype=torch.int64) - x = local[:-1].reshape(-1, seq_len) - y = local[1:].reshape(-1, seq_len) - return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) - -# ----------------------------- -# TRANSFORMER MODULES -# ----------------------------- - -class RMSNorm(nn.Module): - def __init__(self, eps: float | None = None): - super().__init__() - self.eps = eps - - def forward(self, x: Tensor) -> Tensor: - return F.rms_norm(x, (x.size(-1),), eps=self.eps) - - -class CastedLinear(nn.Linear): - # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. - def forward(self, x: Tensor) -> Tensor: - bias = self.bias.to(x.dtype) if self.bias is not None else None - return F.linear(x, self.weight.to(x.dtype), bias) - - -def restore_low_dim_params_to_fp32(module: nn.Module) -> None: - # Keep small/control parameters in fp32 even when the model body runs in bf16. - with torch.no_grad(): - for name, param in module.named_parameters(): - if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: - param.data = param.data.float() - - -class Rotary(nn.Module): - # Caches cos/sin tables per sequence length on the current device. - def __init__(self, dim: int, base: float = 10000.0): - super().__init__() - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self._seq_len_cached = 0 - self._cos_cached: Tensor | None = None - self._sin_cached: Tensor | None = None - - def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: - if ( - self._cos_cached is None - or self._sin_cached is None - or self._seq_len_cached != seq_len - or self._cos_cached.device != device - ): - t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) - freqs = torch.outer(t, self.inv_freq.to(device)) - self._cos_cached = freqs.cos()[None, None, :, :] - self._sin_cached = freqs.sin()[None, None, :, :] - self._seq_len_cached = seq_len - return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) - - -def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: - half = x.size(-1) // 2 - x1, x2 = x[..., :half], x[..., half:] - return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) - - -class CausalSelfAttention(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - rope_base: float, - qk_gain_init: float, - ): - super().__init__() - if dim % num_heads != 0: - raise ValueError("model_dim must be divisible by num_heads") - if num_heads % num_kv_heads != 0: - raise ValueError("num_heads must be divisible by num_kv_heads") - self.num_heads = num_heads - self.num_kv_heads = num_kv_heads - self.head_dim = dim // num_heads - if self.head_dim % 2 != 0: - raise ValueError("head_dim must be even for RoPE") - kv_dim = self.num_kv_heads * self.head_dim - self.c_q = CastedLinear(dim, dim, bias=False) - self.c_k = CastedLinear(dim, kv_dim, bias=False) - self.c_v = CastedLinear(dim, kv_dim, bias=False) - self.proj = CastedLinear(dim, dim, bias=False) - self.proj._zero_init = True - self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) - self.rotary = Rotary(self.head_dim, base=rope_base) - - def forward(self, x: Tensor) -> Tensor: - bsz, seqlen, dim = x.shape - q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) - k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) - v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) - q = F.rms_norm(q, (q.size(-1),)) - k = F.rms_norm(k, (k.size(-1),)) - cos, sin = self.rotary(seqlen, x.device, q.dtype) - q = apply_rotary_emb(q, cos, sin) - k = apply_rotary_emb(k, cos, sin) - q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] - y = F.scaled_dot_product_attention( - q, - k, - v, - attn_mask=None, - is_causal=True, - enable_gqa=(self.num_kv_heads != self.num_heads), - ) - y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) - return self.proj(y) - - - -class MLP(nn.Module): - # relu^2 MLP from the original modded-nanogpt setup - def __init__(self, dim: int, mlp_mult: int): - super().__init__() - hidden = mlp_mult * dim - self.fc = CastedLinear(dim, hidden, bias=False) - self.proj = CastedLinear(hidden, dim, bias=False) - self.proj._zero_init = True - - def forward(self, x: Tensor) -> Tensor: - x = torch.relu(self.fc(x)) - return self.proj(x.square()) - - -class Block(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: int, - rope_base: float, - qk_gain_init: float, - ): - super().__init__() - self.attn_norm = RMSNorm() - self.mlp_norm = RMSNorm() - self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) - self.mlp = MLP(dim, mlp_mult) - self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) - - def forward(self, x: Tensor, x0: Tensor) -> Tensor: - mix = self.resid_mix.to(dtype=x.dtype) - x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 - attn_out = self.attn(self.attn_norm(x)) - x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out - x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) - return x - - -class GPT(nn.Module): - def __init__( - self, - vocab_size: int, - num_layers: int, - model_dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: int, - tie_embeddings: bool, - tied_embed_init_std: float, - logit_softcap: float, - rope_base: float, - qk_gain_init: float, - ): - super().__init__() - if logit_softcap <= 0.0: - raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") - self.tie_embeddings = tie_embeddings - self.tied_embed_init_std = tied_embed_init_std - self.logit_softcap = logit_softcap - self.tok_emb = nn.Embedding(vocab_size, model_dim) - self.num_encoder_layers = num_layers // 2 - self.num_decoder_layers = num_layers - self.num_encoder_layers - self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) - self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) - self.blocks = nn.ModuleList( - [ - Block( - model_dim, - num_heads, - num_kv_heads, - mlp_mult, - rope_base, - qk_gain_init, - ) - for i in range(num_layers) - ] - ) - self.final_norm = RMSNorm() - self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) - if self.lm_head is not None: - self.lm_head._zero_init = True - self._init_weights() - - def _init_weights(self) -> None: - if self.tie_embeddings: - nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) - for module in self.modules(): - if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): - nn.init.zeros_(module.weight) - - def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: - x = self.tok_emb(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - x0 = x - skips: list[Tensor] = [] - - # First half stores skips; second half reuses them in reverse order. - for i in range(self.num_encoder_layers): - x = self.blocks[i](x, x0) - skips.append(x) - for i in range(self.num_decoder_layers): - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - x = self.blocks[self.num_encoder_layers + i](x, x0) - - x = self.final_norm(x).reshape(-1, x.size(-1)) - targets = target_ids.reshape(-1) - if self.tie_embeddings: - logits_proj = F.linear(x, self.tok_emb.weight) - else: - if self.lm_head is None: - raise RuntimeError("lm_head is required when tie_embeddings=False") - logits_proj = self.lm_head(x) - logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - return F.cross_entropy(logits.float(), targets, reduction="mean") - - -# ----------------------------- -# TRAINING -# ----------------------------- - -def main() -> None: - global zeropower_via_newtonschulz5 - - code = Path(__file__).read_text(encoding="utf-8") - args = Hyperparameters() - zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) - - # ----------------------------- - # DISTRIBUTED + CUDA SETUP - # ----------------------------- - - distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ - rank = int(os.environ.get("RANK", "0")) - world_size = int(os.environ.get("WORLD_SIZE", "1")) - local_rank = int(os.environ.get("LOCAL_RANK", "0")) - if world_size <= 0: - raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") - if 8 % world_size != 0: - raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") - grad_accum_steps = 8 // world_size - grad_scale = 1.0 / grad_accum_steps - if not torch.cuda.is_available(): - raise RuntimeError("CUDA is required") - device = torch.device("cuda", local_rank) - torch.cuda.set_device(device) - if distributed: - dist.init_process_group(backend="nccl", device_id=device) - dist.barrier() - master_process = rank == 0 - - # Fast math knobs - torch.backends.cuda.matmul.allow_tf32 = True - torch.backends.cudnn.allow_tf32 = True - from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp - - enable_cudnn_sdp(False) - enable_flash_sdp(True) - enable_mem_efficient_sdp(False) - enable_math_sdp(False) - - logfile = None - if master_process: - os.makedirs("logs", exist_ok=True) - logfile = f"logs/{args.run_id}.txt" - print(logfile) - - def log0(msg: str, console: bool = True) -> None: - if not master_process: - return - if console: - print(msg) - if logfile is not None: - with open(logfile, "a", encoding="utf-8") as f: - print(msg, file=f) - - log0(code, console=False) - log0("=" * 100, console=False) - log0(f"Running Python {sys.version}", console=False) - log0(f"Running PyTorch {torch.__version__}", console=False) - log0( - subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, - console=False, - ) - log0("=" * 100, console=False) - - # ----------------------------- - # TOKENIZER + VALIDATION METRIC SETUP - # ----------------------------- - - random.seed(args.seed) - np.random.seed(args.seed) - torch.manual_seed(args.seed) - torch.cuda.manual_seed_all(args.seed) - - if not args.tokenizer_path.endswith(".model"): - raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") - sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) - if int(sp.vocab_size()) != args.vocab_size: - raise ValueError( - f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" - ) - dataset_dir = Path(args.data_path).resolve() - actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) - val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) - base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( - sp, args.vocab_size, device - ) - log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") - log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") - log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") - - # ----------------------------- - # MODEL + OPTIMIZER SETUP - # ----------------------------- - - base_model = GPT( - vocab_size=args.vocab_size, - num_layers=args.num_layers, - model_dim=args.model_dim, - num_heads=args.num_heads, - num_kv_heads=args.num_kv_heads, - mlp_mult=args.mlp_mult, - tie_embeddings=args.tie_embeddings, - tied_embed_init_std=args.tied_embed_init_std, - logit_softcap=args.logit_softcap, - rope_base=args.rope_base, - qk_gain_init=args.qk_gain_init, - ).to(device).bfloat16() - for module in base_model.modules(): - if isinstance(module, CastedLinear): - module.float() - restore_low_dim_params_to_fp32(base_model) - compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) - model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model - - # Optimizer split: - # - token embedding (Adam) uses EMBED_LR - # - untied lm_head (Adam) uses HEAD_LR - # - matrix params in transformer blocks use MATRIX_LR via Muon - # - vectors/scalars use SCALAR_LR via Adam - block_named_params = list(base_model.blocks.named_parameters()) - matrix_params = [ - p - for name, p in block_named_params - if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - scalar_params = [ - p - for name, p in block_named_params - if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - if base_model.skip_weights.numel() > 0: - scalar_params.append(base_model.skip_weights) - token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr - optimizer_tok = torch.optim.Adam( - [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizer_muon = Muon( - matrix_params, - lr=args.matrix_lr, - momentum=args.muon_momentum, - backend_steps=args.muon_backend_steps, - ) - for group in optimizer_muon.param_groups: - group["base_lr"] = args.matrix_lr - optimizer_scalar = torch.optim.Adam( - [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] - if base_model.lm_head is not None: - optimizer_head = torch.optim.Adam( - [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizers.insert(1, optimizer_head) - - n_params = sum(p.numel() for p in base_model.parameters()) - log0(f"model_params:{n_params}") - log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") - log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") - log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") - log0( - f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " - f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " - f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" - ) - log0( - f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " - f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " - f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" - ) - log0(f"seed:{args.seed}") - - # ----------------------------- - # DATA LOADER & MODEL WARMUP - # ----------------------------- - - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - def zero_grad_all() -> None: - for opt in optimizers: - opt.zero_grad(set_to_none=True) - - max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None - - def lr_mul(step: int, elapsed_ms: float) -> float: - if args.warmdown_iters <= 0: - return 1.0 - if max_wallclock_ms is None: - warmdown_start = max(args.iterations - args.warmdown_iters, 0) - return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 - step_ms = elapsed_ms / max(step, 1) - warmdown_ms = args.warmdown_iters * step_ms - remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) - return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 - - # Warmup primes the compiled forward/backward/optimizer paths, then we restore the - # initial weights/optimizer state so measured training starts from the true init. - if args.warmup_steps > 0: - initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} - initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] - model.train() - for warmup_step in range(args.warmup_steps): - zero_grad_all() - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - warmup_loss = model(x, y) - (warmup_loss * grad_scale).backward() - for opt in optimizers: - opt.step() - zero_grad_all() - if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: - log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") - base_model.load_state_dict(initial_model_state, strict=True) - for opt, state in zip(optimizers, initial_optimizer_states, strict=True): - opt.load_state_dict(state) - zero_grad_all() - if distributed: - model.require_backward_grad_sync = True - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - # ----------------------------- - # MAIN TRAINING LOOP - # ----------------------------- - - training_time_ms = 0.0 - stop_after_step: int | None = None - torch.cuda.synchronize() - t0 = time.perf_counter() - - step = 0 - while True: - last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) - - should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) - if should_validate: - torch.cuda.synchronize() - training_time_ms += 1000.0 * (time.perf_counter() - t0) - val_loss, val_bpb = eval_val( - args, - model, - rank, - world_size, - device, - grad_accum_steps, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, - ) - log0( - f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " - f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" - ) - torch.cuda.synchronize() - t0 = time.perf_counter() - - if last_step: - if stop_after_step is not None and step < args.iterations: - log0( - f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " - f"step:{step}/{args.iterations}" - ) - break - - elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - scale = lr_mul(step, elapsed_ms) - zero_grad_all() - train_loss = torch.zeros((), device=device) - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - loss = model(x, y) - train_loss += loss.detach() - (loss * grad_scale).backward() - train_loss /= grad_accum_steps - - frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 - muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum - for group in optimizer_muon.param_groups: - group["momentum"] = muon_momentum - - for opt in optimizers: - for group in opt.param_groups: - group["lr"] = group["base_lr"] * scale - - if args.grad_clip_norm > 0: - torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) - for opt in optimizers: - opt.step() - zero_grad_all() - - step += 1 - approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - should_log_train = ( - args.train_log_every > 0 - and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) - ) - if should_log_train: - log0( - f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " - f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" - ) - - # Needed to sync whether we've reached the wallclock cap. - reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms - if distributed and max_wallclock_ms is not None: - reached_cap_tensor = torch.tensor(int(reached_cap), device=device) - dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) - reached_cap = bool(reached_cap_tensor.item()) - if stop_after_step is None and reached_cap: - stop_after_step = step - - log0( - f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " - f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" - ) - - # ----------------------------- - # SERIALIZATION + ROUNDTRIP VALIDATION - # ----------------------------- - # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce - # the compressed int8+zlib artifact and validate the round-tripped weights. - - if master_process: - torch.save(base_model.state_dict(), "final_model.pt") - model_bytes = os.path.getsize("final_model.pt") - code_bytes = len(code.encode("utf-8")) - log0(f"Serialized model: {model_bytes} bytes") - log0(f"Code size: {code_bytes} bytes") - log0(f"Total submission size: {model_bytes + code_bytes} bytes") - - quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) - # Optional post-quantization pruning: zero out small int8 values for better compression - if args.prune_ratio > 0: - threshold = int(127 * args.prune_ratio) - for name in list(quant_obj.get("quantized", {}).keys()): - t = quant_obj["quantized"][name] - t[t.abs() <= threshold] = 0 - # Optional mixed-precision: round middle layers to int4 (16 levels) for better compression - if args.int4_layers: - int4_set = set(int(x) for x in args.int4_layers.split(",") if x.strip()) - for name in list(quant_obj.get("quantized", {}).keys()): - layer_num = -1 - if "blocks." in name: - try: - layer_num = int(name.split("blocks.")[1].split(".")[0]) - except (ValueError, IndexError): - pass - if layer_num in int4_set: - t = quant_obj["quantized"][name] - step = args.int4_step - quant_obj["quantized"][name] = ((t.float() / step).round() * step).clamp(-127, 127).to(torch.int8) - quant_buf = io.BytesIO() - torch.save(quant_obj, quant_buf) - quant_raw = quant_buf.getvalue() - quant_blob = zlib.compress(quant_raw, level=9) - quant_raw_bytes = len(quant_raw) - if master_process: - with open("final_model.int8.ptz", "wb") as f: - f.write(quant_blob) - quant_file_bytes = os.path.getsize("final_model.int8.ptz") - code_bytes = len(code.encode("utf-8")) - ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) - log0( - f"Serialized model int8+zlib: {quant_file_bytes} bytes " - f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" - ) - log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") - - if distributed: - dist.barrier() - with open("final_model.int8.ptz", "rb") as f: - quant_blob_disk = f.read() - quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") - base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) - torch.cuda.synchronize() - t_qeval = time.perf_counter() - q_val_loss, q_val_bpb = eval_val( - args, - model, - rank, - world_size, - device, - grad_accum_steps, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, - ) - torch.cuda.synchronize() - log0( - f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " - f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" - ) - log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") - - if distributed: - dist.destroy_process_group() - - -if __name__ == "__main__": - main() - -==================================================================================================== -Running Python 3.11.9 (main, Nov 10 2025, 02:08:09) [GCC 11.4.0] -Running PyTorch 2.8.0+cu128 -Thu Mar 19 02:57:29 2026 -+-----------------------------------------------------------------------------------------+ -| NVIDIA-SMI 580.95.05 Driver Version: 580.95.05 CUDA Version: 13.0 | -+-----------------------------------------+------------------------+----------------------+ -| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | -| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | -| | | MIG M. | -|=========================================+========================+======================| -| 0 NVIDIA H200 Off | 00000002:00:01.0 Off | 0 | -| N/A 32C P0 122W / 700W | 1516MiB / 143771MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 1 NVIDIA H200 Off | 00000002:00:02.0 Off | 0 | -| N/A 34C P0 118W / 700W | 1516MiB / 143771MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 2 NVIDIA H200 Off | 00000002:00:03.0 Off | 0 | -| N/A 35C P0 121W / 700W | 1516MiB / 143771MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 3 NVIDIA H200 Off | 00000002:00:04.0 Off | 0 | -| N/A 32C P0 121W / 700W | 1516MiB / 143771MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 4 NVIDIA H200 Off | 00000003:00:01.0 Off | 0 | -| N/A 32C P0 120W / 700W | 1516MiB / 143771MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 5 NVIDIA H200 Off | 00000003:00:02.0 Off | 0 | -| N/A 37C P0 121W / 700W | 1516MiB / 143771MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 6 NVIDIA H200 Off | 00000003:00:03.0 Off | 0 | -| N/A 35C P0 117W / 700W | 1516MiB / 143771MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 7 NVIDIA H200 Off | 00000003:00:04.0 Off | 0 | -| N/A 31C P0 119W / 700W | 1516MiB / 143771MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ - -+-----------------------------------------------------------------------------------------+ -| Processes: | -| GPU GI CI PID Type Process name GPU Memory | -| ID ID Usage | -|=========================================================================================| -| 0 N/A N/A 301282 C ...nv/versions/3.11.9/bin/python 1506MiB | -| 1 N/A N/A 301283 C ...nv/versions/3.11.9/bin/python 1506MiB | -| 2 N/A N/A 301284 C ...nv/versions/3.11.9/bin/python 1506MiB | -| 3 N/A N/A 301285 C ...nv/versions/3.11.9/bin/python 1530MiB | -| 4 N/A N/A 301286 C ...nv/versions/3.11.9/bin/python 1534MiB | -| 5 N/A N/A 301287 C ...nv/versions/3.11.9/bin/python 1506MiB | -| 6 N/A N/A 301288 C ...nv/versions/3.11.9/bin/python 1506MiB | -| 7 N/A N/A 301289 C ...nv/versions/3.11.9/bin/python 1506MiB | -+-----------------------------------------------------------------------------------------+ - -==================================================================================================== -val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model -train_loader:dataset:fineweb10B_sp1024 train_shards:180 -val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 -model_params:18897488 -world_size:8 grad_accum_steps:1 -sdp_backends:cudnn=False flash=True mem_efficient=False math=False -attention_mode:gqa num_heads:8 num_kv_heads:4 -tie_embeddings:True embed_lr:0.03 head_lr:0.0 matrix_lr:0.02 scalar_lr:0.02 -train_batch_tokens:524288 train_seq_len:1024 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 -seed:1337 -warmup_step:1/20 -warmup_step:2/20 -warmup_step:3/20 -warmup_step:4/20 -warmup_step:5/20 -warmup_step:6/20 -warmup_step:7/20 -warmup_step:8/20 -warmup_step:9/20 -warmup_step:10/20 -warmup_step:11/20 -warmup_step:12/20 -warmup_step:13/20 -warmup_step:14/20 -warmup_step:15/20 -warmup_step:16/20 -warmup_step:17/20 -warmup_step:18/20 -warmup_step:19/20 -warmup_step:20/20 -step:0/20000 val_loss:6.9363 val_bpb:4.1080 train_time:0ms step_avg:0.02ms -step:1/20000 train_loss:6.9355 train_time:31ms step_avg:31.28ms -step:2/20000 train_loss:12.1231 train_time:73ms step_avg:36.68ms -step:3/20000 train_loss:7.2165 train_time:125ms step_avg:41.81ms -step:4/20000 train_loss:6.4176 train_time:165ms step_avg:41.19ms -step:5/20000 train_loss:6.8785 train_time:213ms step_avg:42.51ms -step:6/20000 train_loss:7.6248 train_time:256ms step_avg:42.66ms -step:7/20000 train_loss:6.8402 train_time:303ms step_avg:43.24ms -step:8/20000 train_loss:6.4526 train_time:348ms step_avg:43.54ms -step:9/20000 train_loss:6.2574 train_time:394ms step_avg:43.79ms -step:10/20000 train_loss:6.1774 train_time:441ms step_avg:44.06ms -step:50/20000 train_loss:4.0508 train_time:2226ms step_avg:44.51ms -step:100/20000 train_loss:3.2461 train_time:4465ms step_avg:44.65ms -step:150/20000 train_loss:2.9349 train_time:6704ms step_avg:44.69ms -step:200/20000 train_loss:2.7576 train_time:9077ms step_avg:45.38ms -step:200/20000 val_loss:2.7303 val_bpb:1.6170 train_time:9103ms step_avg:45.52ms -step:250/20000 train_loss:2.6695 train_time:11324ms step_avg:45.30ms -step:300/20000 train_loss:2.4316 train_time:13571ms step_avg:45.24ms -step:350/20000 train_loss:2.6215 train_time:15818ms step_avg:45.19ms -step:400/20000 train_loss:2.3094 train_time:18198ms step_avg:45.49ms -step:400/20000 val_loss:2.5186 val_bpb:1.4917 train_time:18221ms step_avg:45.55ms -step:450/20000 train_loss:2.4622 train_time:20441ms step_avg:45.42ms -step:500/20000 train_loss:2.4606 train_time:22696ms step_avg:45.39ms -step:550/20000 train_loss:2.3666 train_time:24931ms step_avg:45.33ms -step:600/20000 train_loss:2.5197 train_time:27345ms step_avg:45.57ms -step:600/20000 val_loss:2.4169 val_bpb:1.4315 train_time:27361ms step_avg:45.60ms -step:650/20000 train_loss:2.3557 train_time:29586ms step_avg:45.52ms -step:700/20000 train_loss:2.4120 train_time:31830ms step_avg:45.47ms -step:750/20000 train_loss:2.2413 train_time:34073ms step_avg:45.43ms -step:800/20000 train_loss:2.2639 train_time:36456ms step_avg:45.57ms -step:800/20000 val_loss:2.3534 val_bpb:1.3938 train_time:36482ms step_avg:45.60ms -step:850/20000 train_loss:2.6951 train_time:38715ms step_avg:45.55ms -step:900/20000 train_loss:2.3127 train_time:40946ms step_avg:45.50ms -step:950/20000 train_loss:2.3720 train_time:43193ms step_avg:45.47ms -step:1000/20000 train_loss:2.3507 train_time:45593ms step_avg:45.59ms -step:1000/20000 val_loss:2.3092 val_bpb:1.3677 train_time:45617ms step_avg:45.62ms -step:1050/20000 train_loss:2.4618 train_time:47845ms step_avg:45.57ms -step:1100/20000 train_loss:2.2360 train_time:50109ms step_avg:45.55ms -step:1150/20000 train_loss:2.2286 train_time:52507ms step_avg:45.66ms -step:1200/20000 train_loss:2.3629 train_time:54759ms step_avg:45.63ms -step:1200/20000 val_loss:2.2795 val_bpb:1.3501 train_time:54781ms step_avg:45.65ms -step:1250/20000 train_loss:2.1861 train_time:57017ms step_avg:45.61ms -step:1300/20000 train_loss:2.3350 train_time:59267ms step_avg:45.59ms -step:1350/20000 train_loss:2.2513 train_time:61658ms step_avg:45.67ms -step:1400/20000 train_loss:2.4046 train_time:63916ms step_avg:45.65ms -step:1400/20000 val_loss:2.2574 val_bpb:1.3370 train_time:63932ms step_avg:45.67ms -step:1450/20000 train_loss:2.2156 train_time:66161ms step_avg:45.63ms -step:1500/20000 train_loss:2.1999 train_time:68424ms step_avg:45.62ms -step:1550/20000 train_loss:2.1331 train_time:70797ms step_avg:45.68ms -step:1600/20000 train_loss:2.0744 train_time:73040ms step_avg:45.65ms -step:1600/20000 val_loss:2.2430 val_bpb:1.3284 train_time:73065ms step_avg:45.67ms -step:1650/20000 train_loss:2.2082 train_time:75294ms step_avg:45.63ms -step:1700/20000 train_loss:2.1512 train_time:77538ms step_avg:45.61ms -step:1750/20000 train_loss:2.2288 train_time:79928ms step_avg:45.67ms -step:1800/20000 train_loss:2.1758 train_time:82184ms step_avg:45.66ms -step:1800/20000 val_loss:2.2278 val_bpb:1.3194 train_time:82201ms step_avg:45.67ms -step:1850/20000 train_loss:2.2862 train_time:84429ms step_avg:45.64ms -step:1900/20000 train_loss:2.1669 train_time:86684ms step_avg:45.62ms -step:1950/20000 train_loss:2.1909 train_time:89069ms step_avg:45.68ms -step:2000/20000 train_loss:2.2319 train_time:91324ms step_avg:45.66ms -step:2000/20000 val_loss:2.2129 val_bpb:1.3106 train_time:91351ms step_avg:45.68ms -step:2050/20000 train_loss:2.2279 train_time:93580ms step_avg:45.65ms -step:2100/20000 train_loss:2.2434 train_time:95987ms step_avg:45.71ms -step:2150/20000 train_loss:2.1649 train_time:98238ms step_avg:45.69ms -step:2200/20000 train_loss:2.0522 train_time:100494ms step_avg:45.68ms -step:2200/20000 val_loss:2.2044 val_bpb:1.3055 train_time:100513ms step_avg:45.69ms -step:2250/20000 train_loss:2.1400 train_time:102739ms step_avg:45.66ms -step:2300/20000 train_loss:2.3533 train_time:105126ms step_avg:45.71ms -step:2350/20000 train_loss:2.1788 train_time:107379ms step_avg:45.69ms -step:2400/20000 train_loss:2.1812 train_time:109631ms step_avg:45.68ms -step:2400/20000 val_loss:2.1944 val_bpb:1.2997 train_time:109653ms step_avg:45.69ms -step:2450/20000 train_loss:2.1842 train_time:111889ms step_avg:45.67ms -step:2500/20000 train_loss:2.0987 train_time:114285ms step_avg:45.71ms -step:2550/20000 train_loss:2.1142 train_time:116542ms step_avg:45.70ms -step:2600/20000 train_loss:2.3890 train_time:118794ms step_avg:45.69ms -step:2600/20000 val_loss:2.1926 val_bpb:1.2986 train_time:118817ms step_avg:45.70ms -step:2650/20000 train_loss:2.2227 train_time:121046ms step_avg:45.68ms -step:2700/20000 train_loss:2.1354 train_time:123439ms step_avg:45.72ms -step:2750/20000 train_loss:2.3391 train_time:125689ms step_avg:45.71ms -step:2800/20000 train_loss:2.2132 train_time:127942ms step_avg:45.69ms -step:2800/20000 val_loss:2.1801 val_bpb:1.2912 train_time:127965ms step_avg:45.70ms -step:2850/20000 train_loss:2.1652 train_time:130191ms step_avg:45.68ms -step:2900/20000 train_loss:2.1553 train_time:132603ms step_avg:45.73ms -step:2950/20000 train_loss:2.2221 train_time:134847ms step_avg:45.71ms -step:3000/20000 train_loss:2.2030 train_time:137103ms step_avg:45.70ms -step:3000/20000 val_loss:2.1722 val_bpb:1.2865 train_time:137127ms step_avg:45.71ms -step:3050/20000 train_loss:2.1472 train_time:139356ms step_avg:45.69ms -step:3100/20000 train_loss:2.1888 train_time:141742ms step_avg:45.72ms -step:3150/20000 train_loss:2.1400 train_time:143993ms step_avg:45.71ms -step:3200/20000 train_loss:2.1645 train_time:146242ms step_avg:45.70ms -step:3200/20000 val_loss:2.1671 val_bpb:1.2835 train_time:146268ms step_avg:45.71ms -step:3250/20000 train_loss:2.0671 train_time:148637ms step_avg:45.73ms -step:3300/20000 train_loss:2.2137 train_time:150885ms step_avg:45.72ms -step:3350/20000 train_loss:2.0731 train_time:153128ms step_avg:45.71ms -step:3400/20000 train_loss:2.1360 train_time:155380ms step_avg:45.70ms -step:3400/20000 val_loss:2.1642 val_bpb:1.2818 train_time:155406ms step_avg:45.71ms -step:3450/20000 train_loss:2.0890 train_time:157762ms step_avg:45.73ms -step:3500/20000 train_loss:2.2285 train_time:160022ms step_avg:45.72ms -step:3550/20000 train_loss:2.3672 train_time:162275ms step_avg:45.71ms -step:3600/20000 train_loss:2.0920 train_time:164524ms step_avg:45.70ms -step:3600/20000 val_loss:2.1558 val_bpb:1.2768 train_time:164550ms step_avg:45.71ms -step:3650/20000 train_loss:2.2009 train_time:166952ms step_avg:45.74ms -step:3700/20000 train_loss:2.1262 train_time:169204ms step_avg:45.73ms -step:3750/20000 train_loss:2.1242 train_time:171454ms step_avg:45.72ms -step:3800/20000 train_loss:2.1979 train_time:173690ms step_avg:45.71ms -step:3800/20000 val_loss:2.1522 val_bpb:1.2747 train_time:173716ms step_avg:45.71ms -step:3850/20000 train_loss:2.1522 train_time:176098ms step_avg:45.74ms -step:3900/20000 train_loss:1.9648 train_time:178363ms step_avg:45.73ms -step:3950/20000 train_loss:2.1063 train_time:180605ms step_avg:45.72ms -step:4000/20000 train_loss:2.1404 train_time:182855ms step_avg:45.71ms -step:4000/20000 val_loss:2.1483 val_bpb:1.2723 train_time:182880ms step_avg:45.72ms -step:4050/20000 train_loss:2.0791 train_time:185248ms step_avg:45.74ms -step:4100/20000 train_loss:2.1671 train_time:187504ms step_avg:45.73ms -step:4150/20000 train_loss:2.3022 train_time:189754ms step_avg:45.72ms -step:4200/20000 train_loss:2.1544 train_time:192160ms step_avg:45.75ms -step:4200/20000 val_loss:2.1438 val_bpb:1.2697 train_time:192179ms step_avg:45.76ms -step:4250/20000 train_loss:2.1045 train_time:194411ms step_avg:45.74ms -step:4300/20000 train_loss:2.0035 train_time:196665ms step_avg:45.74ms -step:4350/20000 train_loss:2.1856 train_time:198915ms step_avg:45.73ms -step:4400/20000 train_loss:2.0900 train_time:201320ms step_avg:45.75ms -step:4400/20000 val_loss:2.1441 val_bpb:1.2698 train_time:201344ms step_avg:45.76ms -step:4450/20000 train_loss:2.0453 train_time:203578ms step_avg:45.75ms -step:4500/20000 train_loss:2.2388 train_time:205817ms step_avg:45.74ms -step:4550/20000 train_loss:2.0374 train_time:208069ms step_avg:45.73ms -step:4600/20000 train_loss:1.9508 train_time:210467ms step_avg:45.75ms -step:4600/20000 val_loss:2.1397 val_bpb:1.2672 train_time:210490ms step_avg:45.76ms -step:4650/20000 train_loss:2.0559 train_time:212730ms step_avg:45.75ms -step:4700/20000 train_loss:2.2390 train_time:214973ms step_avg:45.74ms -step:4750/20000 train_loss:1.9568 train_time:217228ms step_avg:45.73ms -step:4800/20000 train_loss:2.2412 train_time:219613ms step_avg:45.75ms -step:4800/20000 val_loss:2.1358 val_bpb:1.2649 train_time:219633ms step_avg:45.76ms -step:4850/20000 train_loss:2.1316 train_time:221864ms step_avg:45.75ms -step:4900/20000 train_loss:2.1431 train_time:224117ms step_avg:45.74ms -step:4950/20000 train_loss:2.3209 train_time:226367ms step_avg:45.73ms -step:5000/20000 train_loss:2.0024 train_time:228775ms step_avg:45.75ms -step:5000/20000 val_loss:2.1310 val_bpb:1.2621 train_time:228799ms step_avg:45.76ms -step:5050/20000 train_loss:2.1805 train_time:231025ms step_avg:45.75ms -step:5100/20000 train_loss:2.0028 train_time:233279ms step_avg:45.74ms -step:5150/20000 train_loss:2.2555 train_time:235682ms step_avg:45.76ms -step:5200/20000 train_loss:2.1495 train_time:237949ms step_avg:45.76ms -step:5200/20000 val_loss:2.1311 val_bpb:1.2621 train_time:237972ms step_avg:45.76ms -step:5250/20000 train_loss:2.1005 train_time:240194ms step_avg:45.75ms -step:5300/20000 train_loss:2.1919 train_time:242450ms step_avg:45.75ms -step:5350/20000 train_loss:2.1187 train_time:244835ms step_avg:45.76ms -step:5400/20000 train_loss:2.1622 train_time:247089ms step_avg:45.76ms -step:5400/20000 val_loss:2.1266 val_bpb:1.2595 train_time:247105ms step_avg:45.76ms -step:5450/20000 train_loss:2.1771 train_time:249340ms step_avg:45.75ms -step:5500/20000 train_loss:2.1191 train_time:251580ms step_avg:45.74ms -step:5550/20000 train_loss:2.0819 train_time:253964ms step_avg:45.76ms -step:5600/20000 train_loss:2.1589 train_time:256223ms step_avg:45.75ms -step:5600/20000 val_loss:2.1262 val_bpb:1.2593 train_time:256248ms step_avg:45.76ms -step:5650/20000 train_loss:2.0342 train_time:258475ms step_avg:45.75ms -step:5700/20000 train_loss:2.1558 train_time:260721ms step_avg:45.74ms -step:5750/20000 train_loss:2.1967 train_time:263132ms step_avg:45.76ms -step:5800/20000 train_loss:2.1182 train_time:265385ms step_avg:45.76ms -step:5800/20000 val_loss:2.1233 val_bpb:1.2575 train_time:265407ms step_avg:45.76ms -step:5850/20000 train_loss:2.1585 train_time:267633ms step_avg:45.75ms -step:5900/20000 train_loss:2.0729 train_time:269888ms step_avg:45.74ms -step:5950/20000 train_loss:2.1111 train_time:272265ms step_avg:45.76ms -step:6000/20000 train_loss:2.2017 train_time:274533ms step_avg:45.76ms -step:6000/20000 val_loss:2.1214 val_bpb:1.2564 train_time:274549ms step_avg:45.76ms -step:6050/20000 train_loss:2.1040 train_time:276778ms step_avg:45.75ms -step:6100/20000 train_loss:2.0972 train_time:279041ms step_avg:45.74ms -step:6150/20000 train_loss:2.0780 train_time:281419ms step_avg:45.76ms -step:6200/20000 train_loss:2.0625 train_time:283675ms step_avg:45.75ms -step:6200/20000 val_loss:2.1192 val_bpb:1.2551 train_time:283696ms step_avg:45.76ms -step:6250/20000 train_loss:2.1312 train_time:285929ms step_avg:45.75ms -step:6300/20000 train_loss:2.0109 train_time:288345ms step_avg:45.77ms -step:6350/20000 train_loss:1.9996 train_time:290597ms step_avg:45.76ms -step:6400/20000 train_loss:2.1393 train_time:292855ms step_avg:45.76ms -step:6400/20000 val_loss:2.1165 val_bpb:1.2535 train_time:292870ms step_avg:45.76ms -step:6450/20000 train_loss:2.0561 train_time:295104ms step_avg:45.75ms -step:6500/20000 train_loss:2.0563 train_time:297478ms step_avg:45.77ms -step:6550/20000 train_loss:2.1868 train_time:299733ms step_avg:45.76ms -step:6600/20000 train_loss:2.1016 train_time:301990ms step_avg:45.76ms -step:6600/20000 val_loss:2.1126 val_bpb:1.2512 train_time:302015ms step_avg:45.76ms -step:6650/20000 train_loss:2.2676 train_time:304249ms step_avg:45.75ms -step:6700/20000 train_loss:2.1344 train_time:306638ms step_avg:45.77ms -step:6750/20000 train_loss:2.3069 train_time:308884ms step_avg:45.76ms -step:6800/20000 train_loss:2.1721 train_time:311133ms step_avg:45.75ms -step:6800/20000 val_loss:2.1120 val_bpb:1.2508 train_time:311158ms step_avg:45.76ms -step:6850/20000 train_loss:2.0027 train_time:313387ms step_avg:45.75ms -step:6900/20000 train_loss:2.0708 train_time:315782ms step_avg:45.77ms -step:6950/20000 train_loss:2.1526 train_time:318017ms step_avg:45.76ms -step:7000/20000 train_loss:2.2026 train_time:320274ms step_avg:45.75ms -step:7000/20000 val_loss:2.1100 val_bpb:1.2497 train_time:320294ms step_avg:45.76ms -step:7050/20000 train_loss:2.2315 train_time:322525ms step_avg:45.75ms -step:7100/20000 train_loss:2.0491 train_time:325043ms step_avg:45.78ms -step:7150/20000 train_loss:2.1278 train_time:327297ms step_avg:45.78ms -step:7200/20000 train_loss:2.1779 train_time:329554ms step_avg:45.77ms -step:7200/20000 val_loss:2.1091 val_bpb:1.2492 train_time:329578ms step_avg:45.77ms -step:7250/20000 train_loss:2.0836 train_time:331941ms step_avg:45.78ms -step:7300/20000 train_loss:2.0657 train_time:334194ms step_avg:45.78ms -step:7350/20000 train_loss:2.1627 train_time:336441ms step_avg:45.77ms -step:7400/20000 train_loss:2.0957 train_time:338691ms step_avg:45.77ms -step:7400/20000 val_loss:2.1063 val_bpb:1.2475 train_time:338718ms step_avg:45.77ms -step:7450/20000 train_loss:2.0930 train_time:341085ms step_avg:45.78ms -step:7500/20000 train_loss:2.0893 train_time:343339ms step_avg:45.78ms -step:7550/20000 train_loss:2.1512 train_time:345594ms step_avg:45.77ms -step:7600/20000 train_loss:1.9759 train_time:347849ms step_avg:45.77ms -step:7600/20000 val_loss:2.1051 val_bpb:1.2467 train_time:347873ms step_avg:45.77ms -step:7650/20000 train_loss:2.2603 train_time:350250ms step_avg:45.78ms -step:7700/20000 train_loss:2.0707 train_time:352506ms step_avg:45.78ms -step:7750/20000 train_loss:2.0895 train_time:354756ms step_avg:45.78ms -step:7800/20000 train_loss:2.1256 train_time:357009ms step_avg:45.77ms -step:7800/20000 val_loss:2.1026 val_bpb:1.2453 train_time:357034ms step_avg:45.77ms -step:7850/20000 train_loss:1.9783 train_time:359400ms step_avg:45.78ms -step:7900/20000 train_loss:2.1132 train_time:361649ms step_avg:45.78ms -step:7950/20000 train_loss:2.0734 train_time:363898ms step_avg:45.77ms -step:8000/20000 train_loss:2.0955 train_time:366149ms step_avg:45.77ms -step:8000/20000 val_loss:2.1000 val_bpb:1.2437 train_time:366174ms step_avg:45.77ms -step:8050/20000 train_loss:2.0608 train_time:368554ms step_avg:45.78ms -step:8100/20000 train_loss:2.1254 train_time:370808ms step_avg:45.78ms -step:8150/20000 train_loss:2.2333 train_time:373056ms step_avg:45.77ms -step:8200/20000 train_loss:2.1677 train_time:375301ms step_avg:45.77ms -step:8200/20000 val_loss:2.0991 val_bpb:1.2432 train_time:375326ms step_avg:45.77ms -step:8250/20000 train_loss:2.1263 train_time:377716ms step_avg:45.78ms -step:8300/20000 train_loss:2.0939 train_time:379966ms step_avg:45.78ms -step:8350/20000 train_loss:2.2056 train_time:382213ms step_avg:45.77ms -step:8400/20000 train_loss:2.1145 train_time:384602ms step_avg:45.79ms -step:8400/20000 val_loss:2.0986 val_bpb:1.2429 train_time:384627ms step_avg:45.79ms -step:8450/20000 train_loss:2.2060 train_time:386859ms step_avg:45.78ms -step:8500/20000 train_loss:2.1064 train_time:389125ms step_avg:45.78ms -step:8550/20000 train_loss:2.1724 train_time:391359ms step_avg:45.77ms -step:8600/20000 train_loss:2.1130 train_time:393767ms step_avg:45.79ms -step:8600/20000 val_loss:2.0957 val_bpb:1.2412 train_time:393779ms step_avg:45.79ms -step:8650/20000 train_loss:2.0786 train_time:396002ms step_avg:45.78ms -step:8700/20000 train_loss:2.0084 train_time:398259ms step_avg:45.78ms -step:8750/20000 train_loss:2.1722 train_time:400508ms step_avg:45.77ms -step:8800/20000 train_loss:2.0755 train_time:402921ms step_avg:45.79ms -step:8800/20000 val_loss:2.0949 val_bpb:1.2407 train_time:402947ms step_avg:45.79ms -step:8850/20000 train_loss:2.2856 train_time:405174ms step_avg:45.78ms -step:8900/20000 train_loss:2.1794 train_time:407424ms step_avg:45.78ms -step:8950/20000 train_loss:2.1362 train_time:409675ms step_avg:45.77ms -step:9000/20000 train_loss:2.0033 train_time:412068ms step_avg:45.79ms -step:9000/20000 val_loss:2.0952 val_bpb:1.2409 train_time:412090ms step_avg:45.79ms -step:9050/20000 train_loss:2.0378 train_time:414315ms step_avg:45.78ms -step:9100/20000 train_loss:2.2843 train_time:416568ms step_avg:45.78ms -step:9150/20000 train_loss:1.9771 train_time:418810ms step_avg:45.77ms -step:9200/20000 train_loss:2.0639 train_time:421205ms step_avg:45.78ms -step:9200/20000 val_loss:2.0933 val_bpb:1.2398 train_time:421229ms step_avg:45.79ms -step:9250/20000 train_loss:2.1723 train_time:423454ms step_avg:45.78ms -step:9300/20000 train_loss:2.1050 train_time:425706ms step_avg:45.77ms -step:9350/20000 train_loss:2.2069 train_time:428101ms step_avg:45.79ms -step:9400/20000 train_loss:2.1070 train_time:430349ms step_avg:45.78ms -step:9400/20000 val_loss:2.0910 val_bpb:1.2384 train_time:430374ms step_avg:45.78ms -step:9450/20000 train_loss:2.1418 train_time:432606ms step_avg:45.78ms -step:9500/20000 train_loss:2.2405 train_time:434854ms step_avg:45.77ms -step:9550/20000 train_loss:2.1759 train_time:437256ms step_avg:45.79ms -step:9600/20000 train_loss:2.1233 train_time:439511ms step_avg:45.78ms -step:9600/20000 val_loss:2.0904 val_bpb:1.2381 train_time:439538ms step_avg:45.79ms -step:9650/20000 train_loss:2.0687 train_time:441764ms step_avg:45.78ms -step:9700/20000 train_loss:2.0842 train_time:444016ms step_avg:45.77ms -step:9750/20000 train_loss:2.0422 train_time:446408ms step_avg:45.79ms -step:9800/20000 train_loss:2.0480 train_time:448658ms step_avg:45.78ms -step:9800/20000 val_loss:2.0920 val_bpb:1.2390 train_time:448685ms step_avg:45.78ms -step:9850/20000 train_loss:2.0115 train_time:450915ms step_avg:45.78ms -step:9900/20000 train_loss:2.1271 train_time:453166ms step_avg:45.77ms -step:9950/20000 train_loss:2.0026 train_time:455566ms step_avg:45.79ms -step:10000/20000 train_loss:2.0928 train_time:457811ms step_avg:45.78ms -step:10000/20000 val_loss:2.0901 val_bpb:1.2379 train_time:457836ms step_avg:45.78ms -step:10050/20000 train_loss:2.0821 train_time:460064ms step_avg:45.78ms -step:10100/20000 train_loss:2.0749 train_time:462314ms step_avg:45.77ms -step:10150/20000 train_loss:2.0414 train_time:464692ms step_avg:45.78ms -step:10200/20000 train_loss:2.0427 train_time:466956ms step_avg:45.78ms -step:10200/20000 val_loss:2.0864 val_bpb:1.2357 train_time:466975ms step_avg:45.78ms -step:10250/20000 train_loss:2.0468 train_time:469202ms step_avg:45.78ms -step:10300/20000 train_loss:2.1732 train_time:471628ms step_avg:45.79ms -step:10350/20000 train_loss:2.1053 train_time:473869ms step_avg:45.78ms -step:10400/20000 train_loss:2.0771 train_time:476113ms step_avg:45.78ms -step:10400/20000 val_loss:2.0866 val_bpb:1.2358 train_time:476129ms step_avg:45.78ms -step:10450/20000 train_loss:2.0673 train_time:478356ms step_avg:45.78ms -step:10500/20000 train_loss:1.9587 train_time:480740ms step_avg:45.78ms -step:10550/20000 train_loss:1.9904 train_time:482992ms step_avg:45.78ms -step:10600/20000 train_loss:1.9568 train_time:485256ms step_avg:45.78ms -step:10600/20000 val_loss:2.0864 val_bpb:1.2357 train_time:485269ms step_avg:45.78ms -step:10650/20000 train_loss:2.1711 train_time:487496ms step_avg:45.77ms -step:10700/20000 train_loss:2.0553 train_time:489875ms step_avg:45.78ms -step:10750/20000 train_loss:2.1130 train_time:492156ms step_avg:45.78ms -step:10800/20000 train_loss:2.1674 train_time:494388ms step_avg:45.78ms -step:10800/20000 val_loss:2.0843 val_bpb:1.2345 train_time:494408ms step_avg:45.78ms -step:10850/20000 train_loss:2.1129 train_time:496643ms step_avg:45.77ms -step:10900/20000 train_loss:2.1275 train_time:499051ms step_avg:45.78ms -step:10950/20000 train_loss:2.0914 train_time:501302ms step_avg:45.78ms -step:11000/20000 train_loss:2.0934 train_time:503554ms step_avg:45.78ms -step:11000/20000 val_loss:2.0832 val_bpb:1.2338 train_time:503578ms step_avg:45.78ms -step:11050/20000 train_loss:2.0554 train_time:505806ms step_avg:45.77ms -step:11100/20000 train_loss:2.0394 train_time:508195ms step_avg:45.78ms -step:11150/20000 train_loss:2.1367 train_time:510449ms step_avg:45.78ms -step:11200/20000 train_loss:2.0480 train_time:512705ms step_avg:45.78ms -step:11200/20000 val_loss:2.0825 val_bpb:1.2333 train_time:512727ms step_avg:45.78ms -step:11250/20000 train_loss:1.9284 train_time:514958ms step_avg:45.77ms -step:11300/20000 train_loss:1.9753 train_time:517364ms step_avg:45.78ms -step:11350/20000 train_loss:1.9595 train_time:519613ms step_avg:45.78ms -step:11400/20000 train_loss:2.0328 train_time:521860ms step_avg:45.78ms -step:11400/20000 val_loss:2.0824 val_bpb:1.2333 train_time:521885ms step_avg:45.78ms -step:11450/20000 train_loss:2.0225 train_time:524252ms step_avg:45.79ms -step:11500/20000 train_loss:2.0859 train_time:526498ms step_avg:45.78ms -step:11550/20000 train_loss:2.0876 train_time:528748ms step_avg:45.78ms -step:11600/20000 train_loss:2.0380 train_time:531003ms step_avg:45.78ms -step:11600/20000 val_loss:2.0805 val_bpb:1.2322 train_time:531027ms step_avg:45.78ms -step:11650/20000 train_loss:2.1590 train_time:533404ms step_avg:45.79ms -step:11700/20000 train_loss:2.1872 train_time:535658ms step_avg:45.78ms -step:11750/20000 train_loss:2.0971 train_time:537908ms step_avg:45.78ms -step:11800/20000 train_loss:2.0735 train_time:540163ms step_avg:45.78ms -step:11800/20000 val_loss:2.0792 val_bpb:1.2314 train_time:540177ms step_avg:45.78ms -step:11850/20000 train_loss:2.1114 train_time:542551ms step_avg:45.78ms -step:11900/20000 train_loss:2.0381 train_time:544803ms step_avg:45.78ms -step:11950/20000 train_loss:2.0598 train_time:547043ms step_avg:45.78ms -step:12000/20000 train_loss:2.0470 train_time:549297ms step_avg:45.77ms -step:12000/20000 val_loss:2.0757 val_bpb:1.2293 train_time:549321ms step_avg:45.78ms -step:12050/20000 train_loss:2.0646 train_time:551692ms step_avg:45.78ms -step:12100/20000 train_loss:2.0815 train_time:553944ms step_avg:45.78ms -step:12150/20000 train_loss:2.2381 train_time:556195ms step_avg:45.78ms -step:12200/20000 train_loss:2.1889 train_time:558450ms step_avg:45.77ms -step:12200/20000 val_loss:2.0693 val_bpb:1.2256 train_time:558473ms step_avg:45.78ms -step:12250/20000 train_loss:1.8816 train_time:560835ms step_avg:45.78ms -step:12300/20000 train_loss:2.0743 train_time:563087ms step_avg:45.78ms -step:12350/20000 train_loss:2.1383 train_time:565338ms step_avg:45.78ms -step:12400/20000 train_loss:1.8282 train_time:567731ms step_avg:45.78ms -step:12400/20000 val_loss:2.0625 val_bpb:1.2215 train_time:567756ms step_avg:45.79ms -step:12450/20000 train_loss:1.9936 train_time:569994ms step_avg:45.78ms -step:12500/20000 train_loss:2.3264 train_time:572244ms step_avg:45.78ms -step:12550/20000 train_loss:2.1008 train_time:574498ms step_avg:45.78ms -step:12600/20000 train_loss:2.0522 train_time:576894ms step_avg:45.79ms -step:12600/20000 val_loss:2.0554 val_bpb:1.2174 train_time:576920ms step_avg:45.79ms -step:12650/20000 train_loss:2.0168 train_time:579149ms step_avg:45.78ms -step:12700/20000 train_loss:2.0485 train_time:581399ms step_avg:45.78ms -step:12750/20000 train_loss:2.0600 train_time:583649ms step_avg:45.78ms -step:12800/20000 train_loss:2.0687 train_time:586059ms step_avg:45.79ms -step:12800/20000 val_loss:2.0480 val_bpb:1.2129 train_time:586089ms step_avg:45.79ms -step:12850/20000 train_loss:1.9659 train_time:588321ms step_avg:45.78ms -step:12900/20000 train_loss:2.0918 train_time:590577ms step_avg:45.78ms -step:12950/20000 train_loss:1.9700 train_time:592831ms step_avg:45.78ms -step:13000/20000 train_loss:2.1380 train_time:595226ms step_avg:45.79ms -step:13000/20000 val_loss:2.0416 val_bpb:1.2092 train_time:595253ms step_avg:45.79ms -step:13050/20000 train_loss:2.0707 train_time:597483ms step_avg:45.78ms -step:13100/20000 train_loss:1.9696 train_time:599748ms step_avg:45.78ms -step:13101/20000 val_loss:2.0396 val_bpb:1.2080 train_time:599812ms step_avg:45.78ms -stopping_early: wallclock_cap train_time:599812ms step:13101/20000 -peak memory allocated: 11389 MiB reserved: 11704 MiB -Serialized model: 74578915 bytes -Code size: 49058 bytes -Total submission size: 74627973 bytes -Serialized model int8+zlib: 15879916 bytes (payload:19030336 raw_torch:19080377 payload_ratio:3.92x) -Total submission size int8+zlib: 15928974 bytes -final_int8_zlib_roundtrip val_loss:2.0510 val_bpb:1.2147 eval_time:1432ms -final_int8_zlib_roundtrip_exact val_loss:2.05104604 val_bpb:1.21474500 diff --git a/records/track_10min_16mb/2026-03-19_10L_MixedPrecision/train_gpt.py b/records/track_10min_16mb/2026-03-19_10L_MixedPrecision/train_gpt.py deleted file mode 100644 index f06bd07f64..0000000000 --- a/records/track_10min_16mb/2026-03-19_10L_MixedPrecision/train_gpt.py +++ /dev/null @@ -1,1152 +0,0 @@ -""" -The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. - -Hard stop: `train_gpt.py` and `train_gpt_mlx.py` must never be longer than 1500 lines. -""" - -from __future__ import annotations - -import copy -import glob -import io -import math -import os -import random -import subprocess -import sys -import time -import uuid -import zlib -from pathlib import Path - -import numpy as np -import sentencepiece as spm -import torch -import torch.distributed as dist -import torch.nn.functional as F -from torch import Tensor, nn -from torch.nn.parallel import DistributedDataParallel as DDP - -# ----------------------------- -# HYPERPARAMETERS -# ----------------------------- -# Default Simple Baseline run: -# - 9 transformer blocks at width 512 -# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion -# - vocab size 1024, sequence length 1024, tied embeddings -# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap - -class Hyperparameters: - # Data paths are shard globs produced by the existing preprocessing pipeline. - data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") - train_files = os.path.join(data_path, "fineweb_train_*.bin") - val_files = os.path.join(data_path, "fineweb_val_*.bin") - tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") - run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) - seed = int(os.environ.get("SEED", 1337)) - - # Validation cadence and batch size. Validation always uses the full fineweb_val split. - val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) - val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) - train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) - - # Training length. - iterations = int(os.environ.get("ITERATIONS", 20000)) - warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) - warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) - train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) - train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) - max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) - qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) - prune_ratio = float(os.environ.get("PRUNE_RATIO", 0.0)) # fraction of int8 range to prune (e.g. 0.1 = zero out |val| <= 12) - int4_layers = os.environ.get("INT4_LAYERS", "") # comma-separated layer indices for reduced precision (e.g. "3,4,5,6") - int4_step = int(os.environ.get("INT4_STEP", 16)) # rounding step: 2=int7, 4=int6, 8=int5, 16=int4 - - # Model shape. - vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) - num_layers = int(os.environ.get("NUM_LAYERS", 9)) - num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) - model_dim = int(os.environ.get("MODEL_DIM", 512)) - num_heads = int(os.environ.get("NUM_HEADS", 8)) - mlp_mult = int(os.environ.get("MLP_MULT", 2)) - tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) - rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) - logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) - - # Optimizer hyperparameters. - embed_lr = float(os.environ.get("EMBED_LR", 0.6)) - head_lr = float(os.environ.get("HEAD_LR", 0.008)) - tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) - tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) - matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) - scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) - muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) - muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) - muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) - muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) - beta1 = float(os.environ.get("BETA1", 0.9)) - beta2 = float(os.environ.get("BETA2", 0.95)) - adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) - grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) - -# ----------------------------- -# MUON OPTIMIZER -# ----------------------------- -# -# As borrowed from modded-nanogpt -# Background on Muon: https://kellerjordan.github.io/posts/muon/ - -def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: - # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. - # Muon uses this to normalize matrix-shaped gradients before applying them. - a, b, c = (3.4445, -4.7750, 2.0315) - X = G.bfloat16() - X /= X.norm() + eps - transposed = G.size(0) > G.size(1) - if transposed: - X = X.T - for _ in range(steps): - A = X @ X.T - B = b * A + c * A @ A - X = a * X + B @ X - return X.T if transposed else X - - -class Muon(torch.optim.Optimizer): - def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): - super().__init__( - params, - dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), - ) - - @torch.no_grad() - def step(self, closure=None): - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - distributed = dist.is_available() and dist.is_initialized() - world_size = dist.get_world_size() if distributed else 1 - rank = dist.get_rank() if distributed else 0 - - for group in self.param_groups: - params = group["params"] - if not params: - continue - lr = group["lr"] - momentum = group["momentum"] - backend_steps = group["backend_steps"] - nesterov = group["nesterov"] - - total_params = sum(int(p.numel()) for p in params) - updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) - - curr = 0 - for i, p in enumerate(params): - if i % world_size == rank and p.grad is not None: - g = p.grad - state = self.state[p] - if "momentum_buffer" not in state: - state["momentum_buffer"] = torch.zeros_like(g) - buf = state["momentum_buffer"] - buf.mul_(momentum).add_(g) - if nesterov: - g = g.add(buf, alpha=momentum) - g = zeropower_via_newtonschulz5(g, steps=backend_steps) - # Scale correction from Muon reference implementations. - g *= max(1, g.size(0) / g.size(1)) ** 0.5 - updates_flat[curr : curr + p.numel()] = g.reshape(-1) - curr += p.numel() - - if distributed: - dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) - - curr = 0 - for p in params: - g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) - p.add_(g, alpha=-lr) - curr += p.numel() - - return loss - - -# ----------------------------- -# TOKENIZER-AGNOSTIC EVALUATION SETUP -# ----------------------------- -# -# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. -# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. -# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. -# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. - -def build_sentencepiece_luts( - sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device -) -> tuple[Tensor, Tensor, Tensor]: - sp_vocab_size = int(sp.vocab_size()) - table_size = max(sp_vocab_size, vocab_size) - base_bytes_np = np.zeros((table_size,), dtype=np.int16) - has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) - is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) - for token_id in range(sp_vocab_size): - if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): - continue - is_boundary_token_np[token_id] = False - if sp.is_byte(token_id): - base_bytes_np[token_id] = 1 - continue - piece = sp.id_to_piece(token_id) - if piece.startswith("▁"): - has_leading_space_np[token_id] = True - piece = piece[1:] - base_bytes_np[token_id] = len(piece.encode("utf-8")) - return ( - torch.tensor(base_bytes_np, dtype=torch.int16, device=device), - torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), - torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), - ) - - -def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: - files = [Path(p) for p in sorted(glob.glob(pattern))] - if not files: - raise FileNotFoundError(f"No files found for pattern: {pattern}") - # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. - tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() - usable = ((tokens.numel() - 1) // seq_len) * seq_len - if usable <= 0: - raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") - return tokens[: usable + 1] - - -def eval_val( - args: Hyperparameters, - model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - grad_accum_steps: int, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, -) -> tuple[float, float]: - # Validation computes two metrics: - # - val_loss: token cross-entropy (natural log) - # - val_bpb: tokenizer-agnostic compression metric used by the challenge - local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) - if local_batch_tokens < args.train_seq_len: - raise ValueError( - "VAL_BATCH_SIZE must provide at least one sequence per rank; " - f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " - f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" - ) - local_batch_seqs = local_batch_tokens // args.train_seq_len - total_seqs = (val_tokens.numel() - 1) // args.train_seq_len - seq_start = (total_seqs * rank) // world_size - seq_end = (total_seqs * (rank + 1)) // world_size - val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) - val_token_count = torch.zeros((), device=device, dtype=torch.float64) - val_byte_count = torch.zeros((), device=device, dtype=torch.float64) - - model.eval() - with torch.inference_mode(): - for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): - batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) - raw_start = batch_seq_start * args.train_seq_len - raw_end = batch_seq_end * args.train_seq_len + 1 - local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) - x = local[:-1].reshape(-1, args.train_seq_len) - y = local[1:].reshape(-1, args.train_seq_len) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - batch_loss = model(x, y).detach() - batch_token_count = float(y.numel()) - val_loss_sum += batch_loss.to(torch.float64) * batch_token_count - val_token_count += batch_token_count - prev_ids = x.reshape(-1) - tgt_ids = y.reshape(-1) - token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) - token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) - val_byte_count += token_bytes.to(torch.float64).sum() - - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) - - val_loss = val_loss_sum / val_token_count - bits_per_token = val_loss.item() / math.log(2.0) - tokens_per_byte = val_token_count.item() / val_byte_count.item() - model.train() - return float(val_loss.item()), float(bits_per_token * tokens_per_byte) - - - -# ----------------------------- -# POST-TRAINING QUANTIZATION -# ----------------------------- -# -# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. -# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. -# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. - -CONTROL_TENSOR_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "CONTROL_TENSOR_NAME_PATTERNS", - "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", - ",".join(CONTROL_TENSOR_NAME_PATTERNS), - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 -INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 -INT8_PER_ROW_SCALE_DTYPE = torch.float16 -INT8_CLIP_PERCENTILE = 99.99984 -INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 - -def tensor_nbytes(t: Tensor) -> int: - return int(t.numel()) * int(t.element_size()) - -def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: - if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): - return t.float().contiguous() - if t.dtype in {torch.float32, torch.bfloat16}: - passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") - return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() - return t - -def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: - t32 = t.float() - if t32.ndim == 2: - # Matrices get one scale per row, which usually tracks output-channel - # ranges much better than a single tensor-wide scale. - clip_abs = ( - torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) - if t32.numel() - else torch.empty((t32.shape[0],), dtype=torch.float32) - ) - clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) - scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) - q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() - return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() - - # Vectors / scalars use a simpler per-tensor scale. - clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 - scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) - q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() - return q, scale - -def quantize_state_dict_int8(state_dict: dict[str, Tensor]): - # Single supported clean-script export format: - # - per-row int8 for 2D float tensors - # - per-tensor int8 for other float tensors - # - exact passthrough for non-floats - # - passthrough for small float tensors, stored as fp16 to save bytes - quantized: dict[str, Tensor] = {} - scales: dict[str, Tensor] = {} - dtypes: dict[str, str] = {} - passthrough: dict[str, Tensor] = {} - passthrough_orig_dtypes: dict[str, str] = {} - qmeta: dict[str, dict[str, object]] = {} - stats = dict.fromkeys( - ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), - 0, - ) - - for name, tensor in state_dict.items(): - t = tensor.detach().to("cpu").contiguous() - stats["param_count"] += int(t.numel()) - stats["num_tensors"] += 1 - stats["baseline_tensor_bytes"] += tensor_nbytes(t) - - if not t.is_floating_point(): - stats["num_nonfloat_tensors"] += 1 - passthrough[name] = t - stats["int8_payload_bytes"] += tensor_nbytes(t) - continue - - # Small float tensors are cheap enough to keep directly. We still downcast - # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. - if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: - kept = keep_float_tensor(name, t, passthrough_orig_dtypes) - passthrough[name] = kept - stats["int8_payload_bytes"] += tensor_nbytes(kept) - continue - - stats["num_float_tensors"] += 1 - q, s = quantize_float_tensor(t) - if s.ndim > 0: - qmeta[name] = {"scheme": "per_row", "axis": 0} - quantized[name] = q - scales[name] = s - dtypes[name] = str(t.dtype).removeprefix("torch.") - stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) - - obj: dict[str, object] = { - "__quant_format__": "int8_clean_per_row_v1", - "quantized": quantized, - "scales": scales, - "dtypes": dtypes, - "passthrough": passthrough, - } - if qmeta: - obj["qmeta"] = qmeta - if passthrough_orig_dtypes: - obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes - return obj, stats - -def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: - out: dict[str, Tensor] = {} - qmeta = obj.get("qmeta", {}) - passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) - for name, q in obj["quantized"].items(): - dtype = getattr(torch, obj["dtypes"][name]) - s = obj["scales"][name] - if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: - s = s.to(dtype=torch.float32) - # Broadcast the saved row scale back across trailing dimensions. - out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() - else: - scale = float(s.item()) - out[name] = (q.float() * scale).to(dtype=dtype).contiguous() - for name, t in obj["passthrough"].items(): - # Restore small tensors, undoing the temporary fp16 storage cast if needed. - out_t = t.detach().to("cpu").contiguous() - orig_dtype = passthrough_orig_dtypes.get(name) - if isinstance(orig_dtype, str): - out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() - out[name] = out_t - return out - - -# ----------------------------- -# DATA LOADING -# ----------------------------- - -def load_data_shard(file: Path) -> Tensor: - header_bytes = 256 * np.dtype(" None: - self.file_idx = (self.file_idx + 1) % len(self.files) - self.tokens = load_data_shard(self.files[self.file_idx]) - self.pos = 0 - - def take(self, n: int) -> Tensor: - chunks: list[Tensor] = [] - remaining = n - while remaining > 0: - avail = self.tokens.numel() - self.pos - if avail <= 0: - self._advance_file() - continue - k = min(remaining, avail) - chunks.append(self.tokens[self.pos : self.pos + k]) - self.pos += k - remaining -= k - return chunks[0] if len(chunks) == 1 else torch.cat(chunks) - - -class DistributedTokenLoader: - # Each call consumes a contiguous chunk from the shared token stream, then slices out - # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. - def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): - self.rank = rank - self.world_size = world_size - self.device = device - self.stream = TokenStream(pattern) - - def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: - local_tokens = global_tokens // (self.world_size * grad_accum_steps) - per_rank_span = local_tokens + 1 - chunk = self.stream.take(per_rank_span * self.world_size) - start = self.rank * per_rank_span - local = chunk[start : start + per_rank_span].to(dtype=torch.int64) - x = local[:-1].reshape(-1, seq_len) - y = local[1:].reshape(-1, seq_len) - return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) - -# ----------------------------- -# TRANSFORMER MODULES -# ----------------------------- - -class RMSNorm(nn.Module): - def __init__(self, eps: float | None = None): - super().__init__() - self.eps = eps - - def forward(self, x: Tensor) -> Tensor: - return F.rms_norm(x, (x.size(-1),), eps=self.eps) - - -class CastedLinear(nn.Linear): - # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. - def forward(self, x: Tensor) -> Tensor: - bias = self.bias.to(x.dtype) if self.bias is not None else None - return F.linear(x, self.weight.to(x.dtype), bias) - - -def restore_low_dim_params_to_fp32(module: nn.Module) -> None: - # Keep small/control parameters in fp32 even when the model body runs in bf16. - with torch.no_grad(): - for name, param in module.named_parameters(): - if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: - param.data = param.data.float() - - -class Rotary(nn.Module): - # Caches cos/sin tables per sequence length on the current device. - def __init__(self, dim: int, base: float = 10000.0): - super().__init__() - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self._seq_len_cached = 0 - self._cos_cached: Tensor | None = None - self._sin_cached: Tensor | None = None - - def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: - if ( - self._cos_cached is None - or self._sin_cached is None - or self._seq_len_cached != seq_len - or self._cos_cached.device != device - ): - t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) - freqs = torch.outer(t, self.inv_freq.to(device)) - self._cos_cached = freqs.cos()[None, None, :, :] - self._sin_cached = freqs.sin()[None, None, :, :] - self._seq_len_cached = seq_len - return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) - - -def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: - half = x.size(-1) // 2 - x1, x2 = x[..., :half], x[..., half:] - return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) - - -class CausalSelfAttention(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - rope_base: float, - qk_gain_init: float, - ): - super().__init__() - if dim % num_heads != 0: - raise ValueError("model_dim must be divisible by num_heads") - if num_heads % num_kv_heads != 0: - raise ValueError("num_heads must be divisible by num_kv_heads") - self.num_heads = num_heads - self.num_kv_heads = num_kv_heads - self.head_dim = dim // num_heads - if self.head_dim % 2 != 0: - raise ValueError("head_dim must be even for RoPE") - kv_dim = self.num_kv_heads * self.head_dim - self.c_q = CastedLinear(dim, dim, bias=False) - self.c_k = CastedLinear(dim, kv_dim, bias=False) - self.c_v = CastedLinear(dim, kv_dim, bias=False) - self.proj = CastedLinear(dim, dim, bias=False) - self.proj._zero_init = True - self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) - self.rotary = Rotary(self.head_dim, base=rope_base) - - def forward(self, x: Tensor) -> Tensor: - bsz, seqlen, dim = x.shape - q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) - k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) - v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) - q = F.rms_norm(q, (q.size(-1),)) - k = F.rms_norm(k, (k.size(-1),)) - cos, sin = self.rotary(seqlen, x.device, q.dtype) - q = apply_rotary_emb(q, cos, sin) - k = apply_rotary_emb(k, cos, sin) - q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] - y = F.scaled_dot_product_attention( - q, - k, - v, - attn_mask=None, - is_causal=True, - enable_gqa=(self.num_kv_heads != self.num_heads), - ) - y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) - return self.proj(y) - - - -class MLP(nn.Module): - # relu^2 MLP from the original modded-nanogpt setup - def __init__(self, dim: int, mlp_mult: int): - super().__init__() - hidden = mlp_mult * dim - self.fc = CastedLinear(dim, hidden, bias=False) - self.proj = CastedLinear(hidden, dim, bias=False) - self.proj._zero_init = True - - def forward(self, x: Tensor) -> Tensor: - x = torch.relu(self.fc(x)) - return self.proj(x.square()) - - -class Block(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: int, - rope_base: float, - qk_gain_init: float, - ): - super().__init__() - self.attn_norm = RMSNorm() - self.mlp_norm = RMSNorm() - self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) - self.mlp = MLP(dim, mlp_mult) - self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) - - def forward(self, x: Tensor, x0: Tensor) -> Tensor: - mix = self.resid_mix.to(dtype=x.dtype) - x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 - attn_out = self.attn(self.attn_norm(x)) - x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out - x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) - return x - - -class GPT(nn.Module): - def __init__( - self, - vocab_size: int, - num_layers: int, - model_dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: int, - tie_embeddings: bool, - tied_embed_init_std: float, - logit_softcap: float, - rope_base: float, - qk_gain_init: float, - ): - super().__init__() - if logit_softcap <= 0.0: - raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") - self.tie_embeddings = tie_embeddings - self.tied_embed_init_std = tied_embed_init_std - self.logit_softcap = logit_softcap - self.tok_emb = nn.Embedding(vocab_size, model_dim) - self.num_encoder_layers = num_layers // 2 - self.num_decoder_layers = num_layers - self.num_encoder_layers - self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) - self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) - self.blocks = nn.ModuleList( - [ - Block( - model_dim, - num_heads, - num_kv_heads, - mlp_mult, - rope_base, - qk_gain_init, - ) - for i in range(num_layers) - ] - ) - self.final_norm = RMSNorm() - self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) - if self.lm_head is not None: - self.lm_head._zero_init = True - self._init_weights() - - def _init_weights(self) -> None: - if self.tie_embeddings: - nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) - for module in self.modules(): - if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): - nn.init.zeros_(module.weight) - - def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: - x = self.tok_emb(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - x0 = x - skips: list[Tensor] = [] - - # First half stores skips; second half reuses them in reverse order. - for i in range(self.num_encoder_layers): - x = self.blocks[i](x, x0) - skips.append(x) - for i in range(self.num_decoder_layers): - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - x = self.blocks[self.num_encoder_layers + i](x, x0) - - x = self.final_norm(x).reshape(-1, x.size(-1)) - targets = target_ids.reshape(-1) - if self.tie_embeddings: - logits_proj = F.linear(x, self.tok_emb.weight) - else: - if self.lm_head is None: - raise RuntimeError("lm_head is required when tie_embeddings=False") - logits_proj = self.lm_head(x) - logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - return F.cross_entropy(logits.float(), targets, reduction="mean") - - -# ----------------------------- -# TRAINING -# ----------------------------- - -def main() -> None: - global zeropower_via_newtonschulz5 - - code = Path(__file__).read_text(encoding="utf-8") - args = Hyperparameters() - zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) - - # ----------------------------- - # DISTRIBUTED + CUDA SETUP - # ----------------------------- - - distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ - rank = int(os.environ.get("RANK", "0")) - world_size = int(os.environ.get("WORLD_SIZE", "1")) - local_rank = int(os.environ.get("LOCAL_RANK", "0")) - if world_size <= 0: - raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") - if 8 % world_size != 0: - raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") - grad_accum_steps = 8 // world_size - grad_scale = 1.0 / grad_accum_steps - if not torch.cuda.is_available(): - raise RuntimeError("CUDA is required") - device = torch.device("cuda", local_rank) - torch.cuda.set_device(device) - if distributed: - dist.init_process_group(backend="nccl", device_id=device) - dist.barrier() - master_process = rank == 0 - - # Fast math knobs - torch.backends.cuda.matmul.allow_tf32 = True - torch.backends.cudnn.allow_tf32 = True - from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp - - enable_cudnn_sdp(False) - enable_flash_sdp(True) - enable_mem_efficient_sdp(False) - enable_math_sdp(False) - - logfile = None - if master_process: - os.makedirs("logs", exist_ok=True) - logfile = f"logs/{args.run_id}.txt" - print(logfile) - - def log0(msg: str, console: bool = True) -> None: - if not master_process: - return - if console: - print(msg) - if logfile is not None: - with open(logfile, "a", encoding="utf-8") as f: - print(msg, file=f) - - log0(code, console=False) - log0("=" * 100, console=False) - log0(f"Running Python {sys.version}", console=False) - log0(f"Running PyTorch {torch.__version__}", console=False) - log0( - subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, - console=False, - ) - log0("=" * 100, console=False) - - # ----------------------------- - # TOKENIZER + VALIDATION METRIC SETUP - # ----------------------------- - - random.seed(args.seed) - np.random.seed(args.seed) - torch.manual_seed(args.seed) - torch.cuda.manual_seed_all(args.seed) - - if not args.tokenizer_path.endswith(".model"): - raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") - sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) - if int(sp.vocab_size()) != args.vocab_size: - raise ValueError( - f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" - ) - dataset_dir = Path(args.data_path).resolve() - actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) - val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) - base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( - sp, args.vocab_size, device - ) - log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") - log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") - log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") - - # ----------------------------- - # MODEL + OPTIMIZER SETUP - # ----------------------------- - - base_model = GPT( - vocab_size=args.vocab_size, - num_layers=args.num_layers, - model_dim=args.model_dim, - num_heads=args.num_heads, - num_kv_heads=args.num_kv_heads, - mlp_mult=args.mlp_mult, - tie_embeddings=args.tie_embeddings, - tied_embed_init_std=args.tied_embed_init_std, - logit_softcap=args.logit_softcap, - rope_base=args.rope_base, - qk_gain_init=args.qk_gain_init, - ).to(device).bfloat16() - for module in base_model.modules(): - if isinstance(module, CastedLinear): - module.float() - restore_low_dim_params_to_fp32(base_model) - compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) - model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model - - # Optimizer split: - # - token embedding (Adam) uses EMBED_LR - # - untied lm_head (Adam) uses HEAD_LR - # - matrix params in transformer blocks use MATRIX_LR via Muon - # - vectors/scalars use SCALAR_LR via Adam - block_named_params = list(base_model.blocks.named_parameters()) - matrix_params = [ - p - for name, p in block_named_params - if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - scalar_params = [ - p - for name, p in block_named_params - if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - if base_model.skip_weights.numel() > 0: - scalar_params.append(base_model.skip_weights) - token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr - optimizer_tok = torch.optim.Adam( - [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizer_muon = Muon( - matrix_params, - lr=args.matrix_lr, - momentum=args.muon_momentum, - backend_steps=args.muon_backend_steps, - ) - for group in optimizer_muon.param_groups: - group["base_lr"] = args.matrix_lr - optimizer_scalar = torch.optim.Adam( - [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] - if base_model.lm_head is not None: - optimizer_head = torch.optim.Adam( - [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizers.insert(1, optimizer_head) - - n_params = sum(p.numel() for p in base_model.parameters()) - log0(f"model_params:{n_params}") - log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") - log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") - log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") - log0( - f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " - f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " - f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" - ) - log0( - f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " - f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " - f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" - ) - log0(f"seed:{args.seed}") - - # ----------------------------- - # DATA LOADER & MODEL WARMUP - # ----------------------------- - - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - def zero_grad_all() -> None: - for opt in optimizers: - opt.zero_grad(set_to_none=True) - - max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None - - def lr_mul(step: int, elapsed_ms: float) -> float: - if args.warmdown_iters <= 0: - return 1.0 - if max_wallclock_ms is None: - warmdown_start = max(args.iterations - args.warmdown_iters, 0) - return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 - step_ms = elapsed_ms / max(step, 1) - warmdown_ms = args.warmdown_iters * step_ms - remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) - return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 - - # Warmup primes the compiled forward/backward/optimizer paths, then we restore the - # initial weights/optimizer state so measured training starts from the true init. - if args.warmup_steps > 0: - initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} - initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] - model.train() - for warmup_step in range(args.warmup_steps): - zero_grad_all() - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - warmup_loss = model(x, y) - (warmup_loss * grad_scale).backward() - for opt in optimizers: - opt.step() - zero_grad_all() - if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: - log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") - base_model.load_state_dict(initial_model_state, strict=True) - for opt, state in zip(optimizers, initial_optimizer_states, strict=True): - opt.load_state_dict(state) - zero_grad_all() - if distributed: - model.require_backward_grad_sync = True - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - # ----------------------------- - # MAIN TRAINING LOOP - # ----------------------------- - - training_time_ms = 0.0 - stop_after_step: int | None = None - torch.cuda.synchronize() - t0 = time.perf_counter() - - step = 0 - while True: - last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) - - should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) - if should_validate: - torch.cuda.synchronize() - training_time_ms += 1000.0 * (time.perf_counter() - t0) - val_loss, val_bpb = eval_val( - args, - model, - rank, - world_size, - device, - grad_accum_steps, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, - ) - log0( - f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " - f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" - ) - torch.cuda.synchronize() - t0 = time.perf_counter() - - if last_step: - if stop_after_step is not None and step < args.iterations: - log0( - f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " - f"step:{step}/{args.iterations}" - ) - break - - elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - scale = lr_mul(step, elapsed_ms) - zero_grad_all() - train_loss = torch.zeros((), device=device) - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - loss = model(x, y) - train_loss += loss.detach() - (loss * grad_scale).backward() - train_loss /= grad_accum_steps - - frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 - muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum - for group in optimizer_muon.param_groups: - group["momentum"] = muon_momentum - - for opt in optimizers: - for group in opt.param_groups: - group["lr"] = group["base_lr"] * scale - - if args.grad_clip_norm > 0: - torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) - for opt in optimizers: - opt.step() - zero_grad_all() - - step += 1 - approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - should_log_train = ( - args.train_log_every > 0 - and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) - ) - if should_log_train: - log0( - f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " - f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" - ) - - # Needed to sync whether we've reached the wallclock cap. - reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms - if distributed and max_wallclock_ms is not None: - reached_cap_tensor = torch.tensor(int(reached_cap), device=device) - dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) - reached_cap = bool(reached_cap_tensor.item()) - if stop_after_step is None and reached_cap: - stop_after_step = step - - log0( - f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " - f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" - ) - - # ----------------------------- - # SERIALIZATION + ROUNDTRIP VALIDATION - # ----------------------------- - # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce - # the compressed int8+zlib artifact and validate the round-tripped weights. - - if master_process: - torch.save(base_model.state_dict(), "final_model.pt") - model_bytes = os.path.getsize("final_model.pt") - code_bytes = len(code.encode("utf-8")) - log0(f"Serialized model: {model_bytes} bytes") - log0(f"Code size: {code_bytes} bytes") - log0(f"Total submission size: {model_bytes + code_bytes} bytes") - - quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) - # Optional post-quantization pruning: zero out small int8 values for better compression - if args.prune_ratio > 0: - threshold = int(127 * args.prune_ratio) - for name in list(quant_obj.get("quantized", {}).keys()): - t = quant_obj["quantized"][name] - t[t.abs() <= threshold] = 0 - # Optional mixed-precision: round middle layers to int4 (16 levels) for better compression - if args.int4_layers: - int4_set = set(int(x) for x in args.int4_layers.split(",") if x.strip()) - for name in list(quant_obj.get("quantized", {}).keys()): - layer_num = -1 - if "blocks." in name: - try: - layer_num = int(name.split("blocks.")[1].split(".")[0]) - except (ValueError, IndexError): - pass - if layer_num in int4_set: - t = quant_obj["quantized"][name] - step = args.int4_step - quant_obj["quantized"][name] = ((t.float() / step).round() * step).clamp(-127, 127).to(torch.int8) - quant_buf = io.BytesIO() - torch.save(quant_obj, quant_buf) - quant_raw = quant_buf.getvalue() - quant_blob = zlib.compress(quant_raw, level=9) - quant_raw_bytes = len(quant_raw) - if master_process: - with open("final_model.int8.ptz", "wb") as f: - f.write(quant_blob) - quant_file_bytes = os.path.getsize("final_model.int8.ptz") - code_bytes = len(code.encode("utf-8")) - ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) - log0( - f"Serialized model int8+zlib: {quant_file_bytes} bytes " - f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" - ) - log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") - - if distributed: - dist.barrier() - with open("final_model.int8.ptz", "rb") as f: - quant_blob_disk = f.read() - quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") - base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) - torch.cuda.synchronize() - t_qeval = time.perf_counter() - q_val_loss, q_val_bpb = eval_val( - args, - model, - rank, - world_size, - device, - grad_accum_steps, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, - ) - torch.cuda.synchronize() - log0( - f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " - f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" - ) - log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") - - if distributed: - dist.destroy_process_group() - - -if __name__ == "__main__": - main() diff --git a/records/track_10min_16mb/2026-03-19_MLP3x_QAT_Int6_SlidingWindow/README.md b/records/track_10min_16mb/2026-03-19_MLP3x_QAT_Int6_SlidingWindow/README.md deleted file mode 100644 index ae4a43ec35..0000000000 --- a/records/track_10min_16mb/2026-03-19_MLP3x_QAT_Int6_SlidingWindow/README.md +++ /dev/null @@ -1,73 +0,0 @@ -# 11L MLP3x + WD=0.04 + Int6 QAT + zstd-22 + Sliding Window Eval - -## Summary - -11-layer transformer with 3x MLP expansion, int6 quantization-aware training, decoupled weight decay (0.04), zstd-22 compression, and sliding window evaluation. This achieves **val_bpb = 1.1502** (mean across 3 seeds). - -### Key Changes from Baseline - -1. **11 transformer layers** (vs 9 baseline) — more effective depth, funded by aggressive int6 compression -2. **Wider MLP (MLP_MULT=3)** — 3x expansion (hidden=1536), more capacity per layer -3. **Decoupled weight decay (0.04)** — on both Muon and AdamW, keeps weights small and quantization-friendly -4. **QAT int6** — STE fake-quantize simulates int6 noise during training -5. **Int6 quantization on all block weights** (layers 0-10) -6. **FP16 tied embedding export** — preserves embedding/output head quality -7. **zstd-22 compression** — saves ~1.5MB vs zlib, critical for fitting 11L MLP3x under 16MB -8. **Sliding window evaluation (stride=64)** — ~0.034 BPB free improvement -9. **Higher Muon momentum (0.99)** with warmup from 0.92 over 1500 steps -10. **Lower learning rates**: MATRIX_LR=0.025, SCALAR_LR=0.025, TIED_EMBED_LR=0.035 - -### Architecture - -- 11 transformer blocks, 512 model dim, 8 attention heads, 4 KV heads -- GQA attention with RoPE, ReLU² MLP (**3x** expansion) -- Tied embeddings with 1024 BPE vocabulary -- U-Net skip connections (5 encoder + 6 decoder layers) -- 26.5M parameters, ~15.4MB compressed artifact (zstd-22) - -## Multi-Seed Results (3 seeds, p << 0.001) - -| Seed | slide_loss (nats) | slide_bpb | rt_bpb | Artifact | -|---|---|---|---|---| -| 1337 | 1.94265607 | 1.15055135 | 1.18484075 | 15,360,260 | -| 42 | 1.94207795 | 1.15020896 | 1.18456681 | 15,556,813 | -| 123 | 1.94121940 | 1.14970047 | 1.18421993 | 15,365,293 | -| **Mean** | **1.94198447** | **1.15015359** | **1.18454250** | **15,427,455** | -| **Std** | **0.00072288** | | | | - -- **Mean improvement: 0.1307 nats** over baseline -- **t-statistic: 313.20** (df=2, p << 0.001) -- All 3 artifacts under 16MB -- Sliding window eval takes ~88s on 8xH100 (under 10-min eval budget) - -## Hardware - -All runs on 8×H100 SXM (RunPod). ~10,070 training steps at ~59.6ms/step in 600s. - -## How to Run - -Requires `zstandard` package (`pip install zstandard`). - -```bash -RUN_ID=submission \ -SEED=1337 \ -NUM_LAYERS=11 \ -MLP_MULT=3 \ -MATRIX_LR=0.025 \ -SCALAR_LR=0.025 \ -TIED_EMBED_LR=0.035 \ -FP16_EMBED_EXPORT=1 \ -INT6_LAYER_START=0 \ -INT6_LAYER_END=10 \ -QAT_ENABLED=1 \ -QAT_INT6=1 \ -MUON_WEIGHT_DECAY=0.04 \ -ADAM_WEIGHT_DECAY=0.04 \ -MUON_MOMENTUM=0.99 \ -MUON_MOMENTUM_WARMUP_START=0.92 \ -MUON_MOMENTUM_WARMUP_STEPS=1500 \ -WARMDOWN_ITERS=3000 \ -USE_ZSTD=1 \ -EVAL_STRIDE=64 \ -torchrun --standalone --nproc_per_node=8 train_gpt.py -``` diff --git a/records/track_10min_16mb/2026-03-19_MLP3x_QAT_Int6_SlidingWindow/submission.json b/records/track_10min_16mb/2026-03-19_MLP3x_QAT_Int6_SlidingWindow/submission.json deleted file mode 100644 index 40bfeccd2d..0000000000 --- a/records/track_10min_16mb/2026-03-19_MLP3x_QAT_Int6_SlidingWindow/submission.json +++ /dev/null @@ -1,45 +0,0 @@ -{ - "author": "aruniyer", - "github_id": "aruniyer", - "run_name": "11L_MLP3x_WD04_Int6QAT_Zstd22_SlidingWindow", - "val_bpb": 1.15015359, - "val_bpb_roundtrip": 1.18454250, - "val_loss_sliding": 1.94198447, - "val_loss_roundtrip": 2.00005039, - "val_bpb_std": 0.00043, - "val_loss_std": 0.00072, - "n_seeds": 3, - "seeds": [1337, 42, 123], - "t_statistic": 313.20, - "p_value": "<0.001", - "artifact_bytes_mean": 15427455, - "model_params": 26501720, - "num_layers": 11, - "model_dim": 512, - "num_heads": 8, - "num_kv_heads": 4, - "mlp_mult": 3, - "vocab_size": 1024, - "train_seq_len": 1024, - "eval_stride": 64, - "tie_embeddings": true, - "matrix_lr": 0.025, - "scalar_lr": 0.025, - "tied_embed_lr": 0.035, - "muon_weight_decay": 0.04, - "adam_weight_decay": 0.04, - "muon_momentum": 0.99, - "fp16_embed_export": true, - "int6_layer_start": 0, - "int6_layer_end": 10, - "qat_enabled": true, - "qat_int6": true, - "use_zstd": true, - "zstd_level": 22, - "train_steps_mean": 10070, - "train_time_seconds": 600, - "eval_time_seconds": 88, - "step_avg_ms": 59.6, - "hardware": "8xH100 SXM (RunPod)", - "date": "2026-03-20" -} diff --git a/records/track_10min_16mb/2026-03-19_MLP3x_QAT_Int6_SlidingWindow/train.log b/records/track_10min_16mb/2026-03-19_MLP3x_QAT_Int6_SlidingWindow/train.log deleted file mode 100644 index 5e7f463576..0000000000 --- a/records/track_10min_16mb/2026-03-19_MLP3x_QAT_Int6_SlidingWindow/train.log +++ /dev/null @@ -1,1431 +0,0 @@ -""" -The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. - -Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. -""" - -from __future__ import annotations - -import copy -import glob -import io -import math -import os -import random -import subprocess -import sys -import time -import uuid -import zlib -from pathlib import Path - -import numpy as np -import sentencepiece as spm -import torch -import torch.distributed as dist -import torch.nn.functional as F -from torch import Tensor, nn -from torch.nn.parallel import DistributedDataParallel as DDP - -# ----------------------------- -# HYPERPARAMETERS -# ----------------------------- -# Default Simple Baseline run: -# - 9 transformer blocks at width 512 -# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion -# - vocab size 1024, sequence length 1024, tied embeddings -# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap - -class Hyperparameters: - # Data paths are shard globs produced by the existing preprocessing pipeline. - data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") - train_files = os.path.join(data_path, "fineweb_train_*.bin") - val_files = os.path.join(data_path, "fineweb_val_*.bin") - tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") - run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) - seed = int(os.environ.get("SEED", 1337)) - - # Validation cadence and batch size. Validation always uses the full fineweb_val split. - val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) - val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) - train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) - - # Training length. - iterations = int(os.environ.get("ITERATIONS", 20000)) - warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) - warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) - train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) - train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) - max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) - qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) - - # Model shape. - vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) - num_layers = int(os.environ.get("NUM_LAYERS", 9)) - num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) - model_dim = int(os.environ.get("MODEL_DIM", 512)) - num_heads = int(os.environ.get("NUM_HEADS", 8)) - mlp_mult = int(os.environ.get("MLP_MULT", 2)) - tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) - rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) - fp16_embed_export = bool(int(os.environ.get("FP16_EMBED_EXPORT", "0"))) - int6_layer_start = int(os.environ.get("INT6_LAYER_START", "-1")) - int6_layer_end = int(os.environ.get("INT6_LAYER_END", "-1")) - qat_enabled = bool(int(os.environ.get("QAT_ENABLED", "0"))) - qat_int6 = bool(int(os.environ.get("QAT_INT6", "0"))) - eval_stride = int(os.environ.get("EVAL_STRIDE", "0")) - eval_seq_len = int(os.environ.get("EVAL_SEQ_LEN", "0")) - eval_batch_size = int(os.environ.get("EVAL_BATCH_SIZE", "32")) - logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) - - # Optimizer hyperparameters. - embed_lr = float(os.environ.get("EMBED_LR", 0.6)) - head_lr = float(os.environ.get("HEAD_LR", 0.008)) - tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) - tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) - matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) - scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) - muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) - muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) - muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) - muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) - beta1 = float(os.environ.get("BETA1", 0.9)) - beta2 = float(os.environ.get("BETA2", 0.95)) - adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) - grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) - -# ----------------------------- -# MUON OPTIMIZER -# ----------------------------- -# -# As borrowed from modded-nanogpt -# Background on Muon: https://kellerjordan.github.io/posts/muon/ - -def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: - # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. - # Muon uses this to normalize matrix-shaped gradients before applying them. - a, b, c = (3.4445, -4.7750, 2.0315) - X = G.bfloat16() - X /= X.norm() + eps - transposed = G.size(0) > G.size(1) - if transposed: - X = X.T - for _ in range(steps): - A = X @ X.T - B = b * A + c * A @ A - X = a * X + B @ X - return X.T if transposed else X - - -class Muon(torch.optim.Optimizer): - def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): - super().__init__( - params, - dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), - ) - - @torch.no_grad() - def step(self, closure=None): - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - distributed = dist.is_available() and dist.is_initialized() - world_size = dist.get_world_size() if distributed else 1 - rank = dist.get_rank() if distributed else 0 - - for group in self.param_groups: - params = group["params"] - if not params: - continue - lr = group["lr"] - momentum = group["momentum"] - backend_steps = group["backend_steps"] - nesterov = group["nesterov"] - - total_params = sum(int(p.numel()) for p in params) - updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) - - curr = 0 - for i, p in enumerate(params): - if i % world_size == rank and p.grad is not None: - g = p.grad - state = self.state[p] - if "momentum_buffer" not in state: - state["momentum_buffer"] = torch.zeros_like(g) - buf = state["momentum_buffer"] - buf.mul_(momentum).add_(g) - if nesterov: - g = g.add(buf, alpha=momentum) - g = zeropower_via_newtonschulz5(g, steps=backend_steps) - # Scale correction from Muon reference implementations. - g *= max(1, g.size(0) / g.size(1)) ** 0.5 - updates_flat[curr : curr + p.numel()] = g.reshape(-1) - curr += p.numel() - - if distributed: - dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) - - curr = 0 - for p in params: - g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) - p.add_(g, alpha=-lr) - curr += p.numel() - - return loss - - -# ----------------------------- -# TOKENIZER-AGNOSTIC EVALUATION SETUP -# ----------------------------- -# -# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. -# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. -# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. -# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. - -def build_sentencepiece_luts( - sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device -) -> tuple[Tensor, Tensor, Tensor]: - sp_vocab_size = int(sp.vocab_size()) - table_size = max(sp_vocab_size, vocab_size) - base_bytes_np = np.zeros((table_size,), dtype=np.int16) - has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) - is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) - for token_id in range(sp_vocab_size): - if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): - continue - is_boundary_token_np[token_id] = False - if sp.is_byte(token_id): - base_bytes_np[token_id] = 1 - continue - piece = sp.id_to_piece(token_id) - if piece.startswith("▁"): - has_leading_space_np[token_id] = True - piece = piece[1:] - base_bytes_np[token_id] = len(piece.encode("utf-8")) - return ( - torch.tensor(base_bytes_np, dtype=torch.int16, device=device), - torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), - torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), - ) - - -def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: - files = [Path(p) for p in sorted(glob.glob(pattern))] - if not files: - raise FileNotFoundError(f"No files found for pattern: {pattern}") - # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. - tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() - usable = ((tokens.numel() - 1) // seq_len) * seq_len - if usable <= 0: - raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") - return tokens[: usable + 1] - - -def eval_val( - args: Hyperparameters, - model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - grad_accum_steps: int, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, -) -> tuple[float, float]: - # Validation computes two metrics: - # - val_loss: token cross-entropy (natural log) - # - val_bpb: tokenizer-agnostic compression metric used by the challenge - local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) - if local_batch_tokens < args.train_seq_len: - raise ValueError( - "VAL_BATCH_SIZE must provide at least one sequence per rank; " - f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " - f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" - ) - local_batch_seqs = local_batch_tokens // args.train_seq_len - total_seqs = (val_tokens.numel() - 1) // args.train_seq_len - seq_start = (total_seqs * rank) // world_size - seq_end = (total_seqs * (rank + 1)) // world_size - val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) - val_token_count = torch.zeros((), device=device, dtype=torch.float64) - val_byte_count = torch.zeros((), device=device, dtype=torch.float64) - - model.eval() - with torch.inference_mode(): - for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): - batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) - raw_start = batch_seq_start * args.train_seq_len - raw_end = batch_seq_end * args.train_seq_len + 1 - local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) - x = local[:-1].reshape(-1, args.train_seq_len) - y = local[1:].reshape(-1, args.train_seq_len) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - batch_loss = model(x, y).detach() - batch_token_count = float(y.numel()) - val_loss_sum += batch_loss.to(torch.float64) * batch_token_count - val_token_count += batch_token_count - prev_ids = x.reshape(-1) - tgt_ids = y.reshape(-1) - token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) - token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) - val_byte_count += token_bytes.to(torch.float64).sum() - - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) - - val_loss = val_loss_sum / val_token_count - bits_per_token = val_loss.item() / math.log(2.0) - tokens_per_byte = val_token_count.item() / val_byte_count.item() - model.train() - return float(val_loss.item()), float(bits_per_token * tokens_per_byte) - -# ----------------------------- -# POST-TRAINING QUANTIZATION -# ----------------------------- -# -# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. -# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. -# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. - -CONTROL_TENSOR_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "CONTROL_TENSOR_NAME_PATTERNS", - "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", - ",".join(CONTROL_TENSOR_NAME_PATTERNS), - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 -INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 -INT8_PER_ROW_SCALE_DTYPE = torch.float16 -INT8_CLIP_PERCENTILE = 99.99984 -INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 - -def tensor_nbytes(t: Tensor) -> int: - return int(t.numel()) * int(t.element_size()) - -def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: - if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): - return t.float().contiguous() - if t.dtype in {torch.float32, torch.bfloat16}: - passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") - return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() - return t - -def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: - t32 = t.float() - if t32.ndim == 2: - # Matrices get one scale per row, which usually tracks output-channel - # ranges much better than a single tensor-wide scale. - clip_abs = ( - torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) - if t32.numel() - else torch.empty((t32.shape[0],), dtype=torch.float32) - ) - clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) - scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) - q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() - return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() - - # Vectors / scalars use a simpler per-tensor scale. - clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 - scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) - q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() - return q, scale - - -def quantize_float_tensor_int6(t: Tensor) -> tuple[Tensor, Tensor]: - """Int6-equivalent: quantize to int8 but round to multiples of 4 (64 levels).""" - t32 = t.float() - if t32.ndim == 2: - clip_abs = ( - torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) - if t32.numel() - else torch.empty((t32.shape[0],), dtype=torch.float32) - ) - clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) - scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) - q_raw = torch.round(clipped / scale[:, None]) - q = torch.clamp((torch.round(q_raw / 4) * 4), -128, 124).to(torch.int8).contiguous() - return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() - clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 - scale_val = clip_abs / 127.0 if clip_abs > 0 else 1.0 - scale = torch.tensor(scale_val, dtype=torch.float32) - q_raw = torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale) - q = torch.clamp((torch.round(q_raw / 4) * 4), -128, 124).to(torch.int8).contiguous() - return q, scale - - -def quantize_state_dict_int8(state_dict: dict[str, Tensor], fp16_embed: bool = False, int6_layer_start: int = -1, int6_layer_end: int = -1): - # Single supported clean-script export format: - # - per-row int8 for 2D float tensors - # - per-tensor int8 for other float tensors - # - exact passthrough for non-floats - # - passthrough for small float tensors, stored as fp16 to save bytes - quantized: dict[str, Tensor] = {} - scales: dict[str, Tensor] = {} - dtypes: dict[str, str] = {} - passthrough: dict[str, Tensor] = {} - passthrough_orig_dtypes: dict[str, str] = {} - qmeta: dict[str, dict[str, object]] = {} - stats = dict.fromkeys( - ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), - 0, - ) - - for name, tensor in state_dict.items(): - t = tensor.detach().to("cpu").contiguous() - stats["param_count"] += int(t.numel()) - stats["num_tensors"] += 1 - stats["baseline_tensor_bytes"] += tensor_nbytes(t) - - if not t.is_floating_point(): - stats["num_nonfloat_tensors"] += 1 - passthrough[name] = t - stats["int8_payload_bytes"] += tensor_nbytes(t) - continue - - # Small float tensors are cheap enough to keep directly. We still downcast - # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. - if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: - kept = keep_float_tensor(name, t, passthrough_orig_dtypes) - passthrough[name] = kept - stats["int8_payload_bytes"] += tensor_nbytes(kept) - continue - - # Keep tied embedding in fp16 for better roundtrip quality - if fp16_embed and "tok_emb.weight" in name: - kept = t.to(dtype=torch.float16).contiguous() - passthrough[name] = kept - passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") - stats["int8_payload_bytes"] += tensor_nbytes(kept) - continue - - stats["num_float_tensors"] += 1 - # Use int6 for middle layers to save space - use_int6 = False - if int6_layer_start >= 0 and int6_layer_end >= 0: - for layer_idx in range(int6_layer_start, int6_layer_end + 1): - if f"blocks.{layer_idx}." in name and t.ndim == 2: - use_int6 = True - break - q, s = quantize_float_tensor_int6(t) if use_int6 else quantize_float_tensor(t) - if s.ndim > 0: - qmeta[name] = {"scheme": "per_row", "axis": 0} - quantized[name] = q - scales[name] = s - dtypes[name] = str(t.dtype).removeprefix("torch.") - stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) - - obj: dict[str, object] = { - "__quant_format__": "int8_clean_per_row_v1", - "quantized": quantized, - "scales": scales, - "dtypes": dtypes, - "passthrough": passthrough, - } - if qmeta: - obj["qmeta"] = qmeta - if passthrough_orig_dtypes: - obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes - return obj, stats - -def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: - out: dict[str, Tensor] = {} - qmeta = obj.get("qmeta", {}) - passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) - for name, q in obj["quantized"].items(): - dtype = getattr(torch, obj["dtypes"][name]) - s = obj["scales"][name] - if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: - s = s.to(dtype=torch.float32) - # Broadcast the saved row scale back across trailing dimensions. - out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() - else: - scale = float(s.item()) - out[name] = (q.float() * scale).to(dtype=dtype).contiguous() - for name, t in obj["passthrough"].items(): - # Restore small tensors, undoing the temporary fp16 storage cast if needed. - out_t = t.detach().to("cpu").contiguous() - orig_dtype = passthrough_orig_dtypes.get(name) - if isinstance(orig_dtype, str): - out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() - out[name] = out_t - return out - - -# ----------------------------- -# DATA LOADING -# ----------------------------- - -def load_data_shard(file: Path) -> Tensor: - header_bytes = 256 * np.dtype(" None: - self.file_idx = (self.file_idx + 1) % len(self.files) - self.tokens = load_data_shard(self.files[self.file_idx]) - self.pos = 0 - - def take(self, n: int) -> Tensor: - chunks: list[Tensor] = [] - remaining = n - while remaining > 0: - avail = self.tokens.numel() - self.pos - if avail <= 0: - self._advance_file() - continue - k = min(remaining, avail) - chunks.append(self.tokens[self.pos : self.pos + k]) - self.pos += k - remaining -= k - return chunks[0] if len(chunks) == 1 else torch.cat(chunks) - - -class DistributedTokenLoader: - # Each call consumes a contiguous chunk from the shared token stream, then slices out - # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. - def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): - self.rank = rank - self.world_size = world_size - self.device = device - self.stream = TokenStream(pattern) - - def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: - local_tokens = global_tokens // (self.world_size * grad_accum_steps) - per_rank_span = local_tokens + 1 - chunk = self.stream.take(per_rank_span * self.world_size) - start = self.rank * per_rank_span - local = chunk[start : start + per_rank_span].to(dtype=torch.int64) - x = local[:-1].reshape(-1, seq_len) - y = local[1:].reshape(-1, seq_len) - return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) - -# ----------------------------- -# TRANSFORMER MODULES -# ----------------------------- - -class RMSNorm(nn.Module): - def __init__(self, eps: float | None = None): - super().__init__() - self.eps = eps - - def forward(self, x: Tensor) -> Tensor: - return F.rms_norm(x, (x.size(-1),), eps=self.eps) - - -class CastedLinear(nn.Linear): - # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. - _qat: bool = False - _qat_int6: bool = False - - def forward(self, x: Tensor) -> Tensor: - w = self.weight - if self._qat and self.training and w.ndim == 2: - w_f = w.float() - amax = w_f.abs().amax(dim=-1, keepdim=True).clamp_min(1e-12) - scale = amax / 127.0 - q_raw = (w_f / scale).round() - if self._qat_int6: - q = (q_raw / 4).round() * 4 - q = q.clamp(-128, 124) - else: - q = q_raw.clamp(-127, 127) - w_q = q * scale - w = w + (w_q - w_f).detach() # STE - bias = self.bias.to(x.dtype) if self.bias is not None else None - return F.linear(x, w.to(x.dtype), bias) - - -def restore_low_dim_params_to_fp32(module: nn.Module) -> None: - # Keep small/control parameters in fp32 even when the model body runs in bf16. - with torch.no_grad(): - for name, param in module.named_parameters(): - if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: - param.data = param.data.float() - - -class Rotary(nn.Module): - # Caches cos/sin tables per sequence length on the current device. - def __init__(self, dim: int, base: float = 10000.0): - super().__init__() - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self._seq_len_cached = 0 - self._cos_cached: Tensor | None = None - self._sin_cached: Tensor | None = None - - def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: - if ( - self._cos_cached is None - or self._sin_cached is None - or self._seq_len_cached != seq_len - or self._cos_cached.device != device - ): - t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) - freqs = torch.outer(t, self.inv_freq.to(device)) - self._cos_cached = freqs.cos()[None, None, :, :] - self._sin_cached = freqs.sin()[None, None, :, :] - self._seq_len_cached = seq_len - return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) - - -def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: - half = x.size(-1) // 2 - x1, x2 = x[..., :half], x[..., half:] - return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) - - -class CausalSelfAttention(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - rope_base: float, - qk_gain_init: float, - ): - super().__init__() - if dim % num_heads != 0: - raise ValueError("model_dim must be divisible by num_heads") - if num_heads % num_kv_heads != 0: - raise ValueError("num_heads must be divisible by num_kv_heads") - self.num_heads = num_heads - self.num_kv_heads = num_kv_heads - self.head_dim = dim // num_heads - if self.head_dim % 2 != 0: - raise ValueError("head_dim must be even for RoPE") - kv_dim = self.num_kv_heads * self.head_dim - self.c_q = CastedLinear(dim, dim, bias=False) - self.c_k = CastedLinear(dim, kv_dim, bias=False) - self.c_v = CastedLinear(dim, kv_dim, bias=False) - self.proj = CastedLinear(dim, dim, bias=False) - self.proj._zero_init = True - self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) - self.rotary = Rotary(self.head_dim, base=rope_base) - - def forward(self, x: Tensor) -> Tensor: - bsz, seqlen, dim = x.shape - q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) - k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) - v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) - q = F.rms_norm(q, (q.size(-1),)) - k = F.rms_norm(k, (k.size(-1),)) - cos, sin = self.rotary(seqlen, x.device, q.dtype) - q = apply_rotary_emb(q, cos, sin) - k = apply_rotary_emb(k, cos, sin) - q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] - y = F.scaled_dot_product_attention( - q, - k, - v, - attn_mask=None, - is_causal=True, - enable_gqa=(self.num_kv_heads != self.num_heads), - ) - y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) - return self.proj(y) - - -class MLP(nn.Module): - # relu^2 MLP from the original modded-nanogpt setup - def __init__(self, dim: int, mlp_mult: int): - super().__init__() - hidden = mlp_mult * dim - self.fc = CastedLinear(dim, hidden, bias=False) - self.proj = CastedLinear(hidden, dim, bias=False) - self.proj._zero_init = True - - def forward(self, x: Tensor) -> Tensor: - x = torch.relu(self.fc(x)) - return self.proj(x.square()) - - -class Block(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: int, - rope_base: float, - qk_gain_init: float, - ): - super().__init__() - self.attn_norm = RMSNorm() - self.mlp_norm = RMSNorm() - self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) - self.mlp = MLP(dim, mlp_mult) - self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) - - def forward(self, x: Tensor, x0: Tensor) -> Tensor: - mix = self.resid_mix.to(dtype=x.dtype) - x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 - attn_out = self.attn(self.attn_norm(x)) - x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out - x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) - return x - - -class GPT(nn.Module): - def __init__( - self, - vocab_size: int, - num_layers: int, - model_dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: int, - tie_embeddings: bool, - tied_embed_init_std: float, - logit_softcap: float, - rope_base: float, - qk_gain_init: float, - ): - super().__init__() - if logit_softcap <= 0.0: - raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") - self.tie_embeddings = tie_embeddings - self.tied_embed_init_std = tied_embed_init_std - self.logit_softcap = logit_softcap - self.tok_emb = nn.Embedding(vocab_size, model_dim) - self.num_encoder_layers = num_layers // 2 - self.num_decoder_layers = num_layers - self.num_encoder_layers - self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) - self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) - self.blocks = nn.ModuleList( - [ - Block( - model_dim, - num_heads, - num_kv_heads, - mlp_mult, - rope_base, - qk_gain_init, - ) - for i in range(num_layers) - ] - ) - self.final_norm = RMSNorm() - self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) - if self.lm_head is not None: - self.lm_head._zero_init = True - self._init_weights() - - def _init_weights(self) -> None: - if self.tie_embeddings: - nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) - for module in self.modules(): - if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): - nn.init.zeros_(module.weight) - - def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: - x = self.tok_emb(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - x0 = x - skips: list[Tensor] = [] - - # First half stores skips; second half reuses them in reverse order. - for i in range(self.num_encoder_layers): - x = self.blocks[i](x, x0) - skips.append(x) - for i in range(self.num_decoder_layers): - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - x = self.blocks[self.num_encoder_layers + i](x, x0) - - x = self.final_norm(x).reshape(-1, x.size(-1)) - targets = target_ids.reshape(-1) - if self.tie_embeddings: - logits_proj = F.linear(x, self.tok_emb.weight) - else: - if self.lm_head is None: - raise RuntimeError("lm_head is required when tie_embeddings=False") - logits_proj = self.lm_head(x) - logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - return F.cross_entropy(logits.float(), targets, reduction="mean") - - - - -def forward_logits(model: nn.Module, input_ids: Tensor) -> Tensor: - """Forward pass returning logits (for sliding window eval). Uses uncompiled model.""" - x = model.tok_emb(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - x0 = x - skips: list[Tensor] = [] - for i in range(model.num_encoder_layers): - x = model.blocks[i](x, x0) - skips.append(x) - for i in range(model.num_decoder_layers): - if skips: - x = x + model.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - x = model.blocks[model.num_encoder_layers + i](x, x0) - x = model.final_norm(x) - if model.tie_embeddings: - logits_proj = F.linear(x, model.tok_emb.weight) - else: - if model.lm_head is None: - raise RuntimeError("lm_head is required when tie_embeddings=False") - logits_proj = model.lm_head(x) - return model.logit_softcap * torch.tanh(logits_proj / model.logit_softcap) - - -def eval_val_sliding( - args: "Hyperparameters", - base_model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, -) -> tuple[float, float]: - """Sliding window evaluation for better BPB.""" - seq_len = args.eval_seq_len if args.eval_seq_len > 0 else args.train_seq_len - stride = args.eval_stride if args.eval_stride > 0 else seq_len - batch_size = args.eval_batch_size - total_len = val_tokens.numel() - - # Generate window start positions - positions = list(range(0, total_len - seq_len, stride)) - # Distribute across ranks (interleaved for balance) - rank_positions = positions[rank::world_size] - - loss_sum = torch.zeros((), device=device, dtype=torch.float64) - token_count = torch.zeros((), device=device, dtype=torch.float64) - byte_count = torch.zeros((), device=device, dtype=torch.float64) - - base_model.eval() - with torch.inference_mode(): - for batch_start in range(0, len(rank_positions), batch_size): - batch_pos = rank_positions[batch_start : batch_start + batch_size] - bs = len(batch_pos) - - x = torch.stack( - [val_tokens[p : p + seq_len] for p in batch_pos] - ).to(device=device, dtype=torch.int64) - y = torch.stack( - [val_tokens[p + 1 : p + seq_len + 1] for p in batch_pos] - ).to(device=device, dtype=torch.int64) - - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - logits = forward_logits(base_model, x) # [bs, seq_len, vocab] - - # Score only the last stride tokens of each window - score_start = seq_len - stride - logits_s = logits[:, score_start:, :].float() - targets_s = y[:, score_start:] - - per_tok = F.cross_entropy( - logits_s.reshape(-1, logits_s.size(-1)), - targets_s.reshape(-1), - reduction="none", - ) - loss_sum += per_tok.to(torch.float64).sum() - token_count += float(targets_s.numel()) - - prev_ids = x[:, score_start:].reshape(-1) - tgt_ids = targets_s.reshape(-1) - tb = base_bytes_lut[tgt_ids].to(dtype=torch.int16) - tb += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(torch.int16) - byte_count += tb.to(torch.float64).sum() - - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) - - val_loss = (loss_sum / token_count).item() - bits_per_token = val_loss / math.log(2.0) - tokens_per_byte = token_count.item() / byte_count.item() - return val_loss, float(bits_per_token * tokens_per_byte) - -# ----------------------------- -# TRAINING -# ----------------------------- - -def main() -> None: - global zeropower_via_newtonschulz5 - - code = Path(__file__).read_text(encoding="utf-8") - args = Hyperparameters() - zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) - - # ----------------------------- - # DISTRIBUTED + CUDA SETUP - # ----------------------------- - - distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ - rank = int(os.environ.get("RANK", "0")) - world_size = int(os.environ.get("WORLD_SIZE", "1")) - local_rank = int(os.environ.get("LOCAL_RANK", "0")) - if world_size <= 0: - raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") - if 8 % world_size != 0: - raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") - grad_accum_steps = 8 // world_size - grad_scale = 1.0 / grad_accum_steps - if not torch.cuda.is_available(): - raise RuntimeError("CUDA is required") - device = torch.device("cuda", local_rank) - torch.cuda.set_device(device) - if distributed: - dist.init_process_group(backend="nccl", device_id=device) - dist.barrier() - master_process = rank == 0 - - # Fast math knobs - torch.backends.cuda.matmul.allow_tf32 = True - torch.backends.cudnn.allow_tf32 = True - from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp - - enable_cudnn_sdp(False) - enable_flash_sdp(True) - enable_mem_efficient_sdp(False) - enable_math_sdp(False) - - logfile = None - if master_process: - os.makedirs("logs", exist_ok=True) - logfile = f"logs/{args.run_id}.txt" - print(logfile) - - def log0(msg: str, console: bool = True) -> None: - if not master_process: - return - if console: - print(msg) - if logfile is not None: - with open(logfile, "a", encoding="utf-8") as f: - print(msg, file=f) - - log0(code, console=False) - log0("=" * 100, console=False) - log0(f"Running Python {sys.version}", console=False) - log0(f"Running PyTorch {torch.__version__}", console=False) - log0( - subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, - console=False, - ) - log0("=" * 100, console=False) - - # ----------------------------- - # TOKENIZER + VALIDATION METRIC SETUP - # ----------------------------- - - random.seed(args.seed) - np.random.seed(args.seed) - torch.manual_seed(args.seed) - torch.cuda.manual_seed_all(args.seed) - - if not args.tokenizer_path.endswith(".model"): - raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") - sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) - if int(sp.vocab_size()) != args.vocab_size: - raise ValueError( - f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" - ) - dataset_dir = Path(args.data_path).resolve() - actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) - val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) - base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( - sp, args.vocab_size, device - ) - log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") - log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") - log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") - - # ----------------------------- - # MODEL + OPTIMIZER SETUP - # ----------------------------- - - base_model = GPT( - vocab_size=args.vocab_size, - num_layers=args.num_layers, - model_dim=args.model_dim, - num_heads=args.num_heads, - num_kv_heads=args.num_kv_heads, - mlp_mult=args.mlp_mult, - tie_embeddings=args.tie_embeddings, - tied_embed_init_std=args.tied_embed_init_std, - logit_softcap=args.logit_softcap, - rope_base=args.rope_base, - qk_gain_init=args.qk_gain_init, - ).to(device).bfloat16() - for module in base_model.modules(): - if isinstance(module, CastedLinear): - module.float() - restore_low_dim_params_to_fp32(base_model) - if args.qat_enabled: - for module in base_model.modules(): - if isinstance(module, CastedLinear): - module._qat = True - module._qat_int6 = args.qat_int6 - compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) - model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model - - # Optimizer split: - # - token embedding (Adam) uses EMBED_LR - # - untied lm_head (Adam) uses HEAD_LR - # - matrix params in transformer blocks use MATRIX_LR via Muon - # - vectors/scalars use SCALAR_LR via Adam - block_named_params = list(base_model.blocks.named_parameters()) - matrix_params = [ - p - for name, p in block_named_params - if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - scalar_params = [ - p - for name, p in block_named_params - if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - if base_model.skip_weights.numel() > 0: - scalar_params.append(base_model.skip_weights) - token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr - optimizer_tok = torch.optim.Adam( - [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizer_muon = Muon( - matrix_params, - lr=args.matrix_lr, - momentum=args.muon_momentum, - backend_steps=args.muon_backend_steps, - ) - for group in optimizer_muon.param_groups: - group["base_lr"] = args.matrix_lr - optimizer_scalar = torch.optim.Adam( - [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] - if base_model.lm_head is not None: - optimizer_head = torch.optim.Adam( - [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizers.insert(1, optimizer_head) - - n_params = sum(p.numel() for p in base_model.parameters()) - log0(f"model_params:{n_params}") - log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") - log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") - log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") - log0( - f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " - f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " - f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" - ) - log0( - f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " - f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " - f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" - ) - log0(f"seed:{args.seed}") - - # ----------------------------- - # DATA LOADER & MODEL WARMUP - # ----------------------------- - - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - def zero_grad_all() -> None: - for opt in optimizers: - opt.zero_grad(set_to_none=True) - - max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None - - def lr_mul(step: int, elapsed_ms: float) -> float: - if args.warmdown_iters <= 0: - return 1.0 - if max_wallclock_ms is None: - warmdown_start = max(args.iterations - args.warmdown_iters, 0) - return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 - step_ms = elapsed_ms / max(step, 1) - warmdown_ms = args.warmdown_iters * step_ms - remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) - return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 - - # Warmup primes the compiled forward/backward/optimizer paths, then we restore the - # initial weights/optimizer state so measured training starts from the true init. - if args.warmup_steps > 0: - initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} - initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] - model.train() - for warmup_step in range(args.warmup_steps): - zero_grad_all() - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - warmup_loss = model(x, y) - (warmup_loss * grad_scale).backward() - for opt in optimizers: - opt.step() - zero_grad_all() - if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: - log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") - base_model.load_state_dict(initial_model_state, strict=True) - for opt, state in zip(optimizers, initial_optimizer_states, strict=True): - opt.load_state_dict(state) - zero_grad_all() - if distributed: - model.require_backward_grad_sync = True - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - # ----------------------------- - # MAIN TRAINING LOOP - # ----------------------------- - - training_time_ms = 0.0 - stop_after_step: int | None = None - torch.cuda.synchronize() - t0 = time.perf_counter() - - step = 0 - while True: - last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) - - should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) - if should_validate: - torch.cuda.synchronize() - training_time_ms += 1000.0 * (time.perf_counter() - t0) - val_loss, val_bpb = eval_val( - args, - model, - rank, - world_size, - device, - grad_accum_steps, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, - ) - log0( - f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " - f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" - ) - torch.cuda.synchronize() - t0 = time.perf_counter() - - if last_step: - if stop_after_step is not None and step < args.iterations: - log0( - f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " - f"step:{step}/{args.iterations}" - ) - break - - elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - scale = lr_mul(step, elapsed_ms) - zero_grad_all() - train_loss = torch.zeros((), device=device) - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - loss = model(x, y) - train_loss += loss.detach() - (loss * grad_scale).backward() - train_loss /= grad_accum_steps - - frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 - muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum - for group in optimizer_muon.param_groups: - group["momentum"] = muon_momentum - - for opt in optimizers: - for group in opt.param_groups: - group["lr"] = group["base_lr"] * scale - - if args.grad_clip_norm > 0: - torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) - for opt in optimizers: - opt.step() - zero_grad_all() - - step += 1 - approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - should_log_train = ( - args.train_log_every > 0 - and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) - ) - if should_log_train: - log0( - f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " - f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" - ) - - # Needed to sync whether we've reached the wallclock cap. - reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms - if distributed and max_wallclock_ms is not None: - reached_cap_tensor = torch.tensor(int(reached_cap), device=device) - dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) - reached_cap = bool(reached_cap_tensor.item()) - if stop_after_step is None and reached_cap: - stop_after_step = step - - log0( - f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " - f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" - ) - - # ----------------------------- - # SERIALIZATION + ROUNDTRIP VALIDATION - # ----------------------------- - # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce - # the compressed int8+zlib artifact and validate the round-tripped weights. - - if master_process: - torch.save(base_model.state_dict(), "final_model.pt") - model_bytes = os.path.getsize("final_model.pt") - code_bytes = len(code.encode("utf-8")) - log0(f"Serialized model: {model_bytes} bytes") - log0(f"Code size: {code_bytes} bytes") - log0(f"Total submission size: {model_bytes + code_bytes} bytes") - - quant_obj, quant_stats = quantize_state_dict_int8( - base_model.state_dict(), - fp16_embed=args.fp16_embed_export if args.tie_embeddings else False, - int6_layer_start=args.int6_layer_start, - int6_layer_end=args.int6_layer_end, - ) - quant_buf = io.BytesIO() - torch.save(quant_obj, quant_buf) - quant_raw = quant_buf.getvalue() - quant_blob = zlib.compress(quant_raw, level=9) - quant_raw_bytes = len(quant_raw) - if master_process: - with open("final_model.int8.ptz", "wb") as f: - f.write(quant_blob) - quant_file_bytes = os.path.getsize("final_model.int8.ptz") - code_bytes = len(code.encode("utf-8")) - ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) - log0( - f"Serialized model int8+zlib: {quant_file_bytes} bytes " - f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" - ) - log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") - - if distributed: - dist.barrier() - with open("final_model.int8.ptz", "rb") as f: - quant_blob_disk = f.read() - quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") - base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) - torch.cuda.synchronize() - t_qeval = time.perf_counter() - q_val_loss, q_val_bpb = eval_val( - args, - model, - rank, - world_size, - device, - grad_accum_steps, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, - ) - torch.cuda.synchronize() - log0( - f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " - f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" - ) - log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") - - if args.eval_stride > 0: - torch.cuda.synchronize() - t_slide = time.perf_counter() - sw_val_loss, sw_val_bpb = eval_val_sliding( - args, base_model, rank, world_size, device, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - ) - torch.cuda.synchronize() - log0( - f"final_sliding_window val_loss:{sw_val_loss:.4f} val_bpb:{sw_val_bpb:.4f} " - f"stride:{args.eval_stride} seq_len:{args.eval_seq_len if args.eval_seq_len > 0 else args.train_seq_len} " - f"eval_time:{1000.0 * (time.perf_counter() - t_slide):.0f}ms" - ) - log0(f"final_sliding_window_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") - - if distributed: - dist.destroy_process_group() - - -if __name__ == "__main__": - main() - -==================================================================================================== -Running Python 3.12.3 (main, Nov 6 2025, 13:44:16) [GCC 13.3.0] -Running PyTorch 2.9.1+cu128 -Thu Mar 19 17:34:18 2026 -+-----------------------------------------------------------------------------------------+ -| NVIDIA-SMI 570.211.01 Driver Version: 570.211.01 CUDA Version: 12.8 | -|-----------------------------------------+------------------------+----------------------+ -| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | -| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | -| | | MIG M. | -|=========================================+========================+======================| -| 0 NVIDIA H100 80GB HBM3 On | 00000000:18:00.0 Off | 0 | -| N/A 26C P0 116W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 1 NVIDIA H100 80GB HBM3 On | 00000000:2A:00.0 Off | 0 | -| N/A 27C P0 114W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 2 NVIDIA H100 80GB HBM3 On | 00000000:3A:00.0 Off | 0 | -| N/A 29C P0 122W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 3 NVIDIA H100 80GB HBM3 On | 00000000:5D:00.0 Off | 0 | -| N/A 25C P0 116W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 4 NVIDIA H100 80GB HBM3 On | 00000000:9A:00.0 Off | 0 | -| N/A 25C P0 113W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 5 NVIDIA H100 80GB HBM3 On | 00000000:AB:00.0 Off | 0 | -| N/A 28C P0 119W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 6 NVIDIA H100 80GB HBM3 On | 00000000:BA:00.0 Off | 0 | -| N/A 26C P0 116W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 7 NVIDIA H100 80GB HBM3 On | 00000000:DB:00.0 Off | 0 | -| N/A 24C P0 117W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ - -+-----------------------------------------------------------------------------------------+ -| Processes: | -| GPU GI CI PID Type Process name GPU Memory | -| ID ID Usage | -|=========================================================================================| -| 0 N/A N/A 29912 C /usr/local/bin/python 1510MiB | -| 1 N/A N/A 29913 C /usr/local/bin/python 1510MiB | -| 2 N/A N/A 29914 C /usr/local/bin/python 1510MiB | -| 3 N/A N/A 29915 C /usr/local/bin/python 1510MiB | -| 4 N/A N/A 29916 C /usr/local/bin/python 1510MiB | -| 5 N/A N/A 29917 C /usr/local/bin/python 1510MiB | -| 6 N/A N/A 29918 C /usr/local/bin/python 1510MiB | -| 7 N/A N/A 29919 C /usr/local/bin/python 1510MiB | -+-----------------------------------------------------------------------------------------+ - -==================================================================================================== -val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model -train_loader:dataset:fineweb10B_sp1024 train_shards:80 -val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 -model_params:21778504 -world_size:8 grad_accum_steps:1 -sdp_backends:cudnn=False flash=True mem_efficient=False math=False -attention_mode:gqa num_heads:8 num_kv_heads:4 -tie_embeddings:True embed_lr:0.03 head_lr:0.0 matrix_lr:0.02 scalar_lr:0.02 -train_batch_tokens:524288 train_seq_len:1024 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 -seed:1337 -warmup_step:1/20 -warmup_step:2/20 -warmup_step:3/20 -warmup_step:4/20 -warmup_step:5/20 -warmup_step:6/20 -warmup_step:7/20 -warmup_step:8/20 -warmup_step:9/20 -warmup_step:10/20 -warmup_step:11/20 -warmup_step:12/20 -warmup_step:13/20 -warmup_step:14/20 -warmup_step:15/20 -warmup_step:16/20 -warmup_step:17/20 -warmup_step:18/20 -warmup_step:19/20 -warmup_step:20/20 -step:0/20000 val_loss:6.9378 val_bpb:4.1090 train_time:0ms step_avg:0.01ms -step:1/20000 train_loss:6.9364 train_time:33ms step_avg:33.19ms -step:2/20000 train_loss:12.3383 train_time:81ms step_avg:40.53ms -step:3/20000 train_loss:7.4163 train_time:130ms step_avg:43.28ms -step:4/20000 train_loss:6.3475 train_time:180ms step_avg:44.88ms -step:5/20000 train_loss:6.7813 train_time:228ms step_avg:45.56ms -step:6/20000 train_loss:7.4845 train_time:276ms step_avg:46.04ms -step:7/20000 train_loss:6.7421 train_time:326ms step_avg:46.51ms -step:8/20000 train_loss:6.4367 train_time:374ms step_avg:46.72ms -step:9/20000 train_loss:6.2890 train_time:422ms step_avg:46.92ms -step:10/20000 train_loss:6.1643 train_time:473ms step_avg:47.26ms -step:2000/20000 train_loss:2.2222 train_time:97250ms step_avg:48.62ms -step:4000/20000 train_loss:2.1321 train_time:194489ms step_avg:48.62ms -step:6000/20000 train_loss:2.1823 train_time:292006ms step_avg:48.67ms -step:6000/20000 val_loss:2.1045 val_bpb:1.2464 train_time:292027ms step_avg:48.67ms -step:8000/20000 train_loss:2.0748 train_time:389371ms step_avg:48.67ms -step:10000/20000 train_loss:2.0717 train_time:486716ms step_avg:48.67ms -step:12000/20000 train_loss:2.0082 train_time:583997ms step_avg:48.67ms -step:12000/20000 val_loss:2.0352 val_bpb:1.2053 train_time:584011ms step_avg:48.67ms -step:12329/20000 val_loss:2.0273 val_bpb:1.2007 train_time:600040ms step_avg:48.67ms -stopping_early: wallclock_cap train_time:600040ms step:12329/20000 -peak memory allocated: 11188 MiB reserved: 11490 MiB -Serialized model: 86099351 bytes -Code size: 55752 bytes -Total submission size: 86155103 bytes -Serialized model int8+zlib: 15637847 bytes (payload:22428960 raw_torch:22473755 payload_ratio:3.84x) -Total submission size int8+zlib: 15693599 bytes -final_int8_zlib_roundtrip val_loss:2.0258 val_bpb:1.1998 eval_time:1670ms -final_int8_zlib_roundtrip_exact val_loss:2.02579640 val_bpb:1.19979073 -final_sliding_window val_loss:1.9698 val_bpb:1.1666 stride:64 seq_len:1024 eval_time:72611ms -final_sliding_window_exact val_loss:1.96976839 val_bpb:1.16660881 diff --git a/records/track_10min_16mb/2026-03-19_MLP3x_QAT_Int6_SlidingWindow/train_gpt.py b/records/track_10min_16mb/2026-03-19_MLP3x_QAT_Int6_SlidingWindow/train_gpt.py deleted file mode 100644 index 56ab49f0fa..0000000000 --- a/records/track_10min_16mb/2026-03-19_MLP3x_QAT_Int6_SlidingWindow/train_gpt.py +++ /dev/null @@ -1,1419 +0,0 @@ -""" -The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. - -Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. -""" - -from __future__ import annotations - -import copy -import glob -import io -import math -import os -import random -import subprocess -import sys -import time -import uuid -import zlib -try: - import zstandard as zstd_mod -except ImportError: - zstd_mod = None -from pathlib import Path - -import numpy as np -import sentencepiece as spm -import torch -import torch.distributed as dist -import torch.nn.functional as F -from torch import Tensor, nn -from torch.nn.parallel import DistributedDataParallel as DDP - -# ----------------------------- -# HYPERPARAMETERS -# ----------------------------- -# Default Simple Baseline run: -# - 9 transformer blocks at width 512 -# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion -# - vocab size 1024, sequence length 1024, tied embeddings -# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap - -class Hyperparameters: - # Data paths are shard globs produced by the existing preprocessing pipeline. - data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") - train_files = os.path.join(data_path, "fineweb_train_*.bin") - val_files = os.path.join(data_path, "fineweb_val_*.bin") - tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") - run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) - seed = int(os.environ.get("SEED", 1337)) - - # Validation cadence and batch size. Validation always uses the full fineweb_val split. - val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) - val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) - train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) - - # Training length. - iterations = int(os.environ.get("ITERATIONS", 20000)) - warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) - warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) - train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) - train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) - max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) - qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) - - # Model shape. - vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) - num_layers = int(os.environ.get("NUM_LAYERS", 9)) - num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) - model_dim = int(os.environ.get("MODEL_DIM", 512)) - num_heads = int(os.environ.get("NUM_HEADS", 8)) - mlp_mult = int(os.environ.get("MLP_MULT", 2)) - tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) - rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) - fp16_embed_export = bool(int(os.environ.get("FP16_EMBED_EXPORT", "0"))) - int6_layer_start = int(os.environ.get("INT6_LAYER_START", "-1")) - int6_layer_end = int(os.environ.get("INT6_LAYER_END", "-1")) - qat_enabled = bool(int(os.environ.get("QAT_ENABLED", "0"))) - qat_int6 = bool(int(os.environ.get("QAT_INT6", "0"))) - eval_stride = int(os.environ.get("EVAL_STRIDE", "0")) - eval_seq_len = int(os.environ.get("EVAL_SEQ_LEN", "0")) - eval_batch_size = int(os.environ.get("EVAL_BATCH_SIZE", "32")) - eval_ntk_alpha = float(os.environ.get("EVAL_NTK_ALPHA", "0")) - block_lars_trust = float(os.environ.get("BLOCK_LARS_TRUST", "0")) - block_lars_min_scale = float(os.environ.get("BLOCK_LARS_MIN_SCALE", "0.01")) - logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) - - # Optimizer hyperparameters. - embed_lr = float(os.environ.get("EMBED_LR", 0.6)) - head_lr = float(os.environ.get("HEAD_LR", 0.008)) - tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) - tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) - matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) - scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) - muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) - muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) - muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) - muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) - muon_weight_decay = float(os.environ.get("MUON_WEIGHT_DECAY", 0.0)) - adam_weight_decay = float(os.environ.get("ADAM_WEIGHT_DECAY", 0.0)) - swa_start_frac = float(os.environ.get("SWA_START_FRAC", 0.0)) - use_zstd = bool(int(os.environ.get("USE_ZSTD", "0"))) - zstd_level = int(os.environ.get("ZSTD_LEVEL", 22)) - beta1 = float(os.environ.get("BETA1", 0.9)) - beta2 = float(os.environ.get("BETA2", 0.95)) - adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) - grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) - -# ----------------------------- -# MUON OPTIMIZER -# ----------------------------- -# -# As borrowed from modded-nanogpt -# Background on Muon: https://kellerjordan.github.io/posts/muon/ - -def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: - # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. - # Muon uses this to normalize matrix-shaped gradients before applying them. - a, b, c = (3.4445, -4.7750, 2.0315) - X = G.bfloat16() - X /= X.norm() + eps - transposed = G.size(0) > G.size(1) - if transposed: - X = X.T - for _ in range(steps): - A = X @ X.T - B = b * A + c * A @ A - X = a * X + B @ X - return X.T if transposed else X - - -class Muon(torch.optim.Optimizer): - def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): - super().__init__( - params, - dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), - ) - - @torch.no_grad() - def step(self, closure=None): - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - distributed = dist.is_available() and dist.is_initialized() - world_size = dist.get_world_size() if distributed else 1 - rank = dist.get_rank() if distributed else 0 - - for group in self.param_groups: - params = group["params"] - if not params: - continue - lr = group["lr"] - momentum = group["momentum"] - backend_steps = group["backend_steps"] - nesterov = group["nesterov"] - - total_params = sum(int(p.numel()) for p in params) - updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) - - curr = 0 - for i, p in enumerate(params): - if i % world_size == rank and p.grad is not None: - g = p.grad - state = self.state[p] - if "momentum_buffer" not in state: - state["momentum_buffer"] = torch.zeros_like(g) - buf = state["momentum_buffer"] - buf.mul_(momentum).add_(g) - if nesterov: - g = g.add(buf, alpha=momentum) - g = zeropower_via_newtonschulz5(g, steps=backend_steps) - # Scale correction from Muon reference implementations. - g *= max(1, g.size(0) / g.size(1)) ** 0.5 - updates_flat[curr : curr + p.numel()] = g.reshape(-1) - curr += p.numel() - - if distributed: - dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) - - curr = 0 - for p in params: - g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) - p.add_(g, alpha=-lr) - curr += p.numel() - - return loss - - -# ----------------------------- -# TOKENIZER-AGNOSTIC EVALUATION SETUP -# ----------------------------- -# -# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. -# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. -# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. -# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. - -def build_sentencepiece_luts( - sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device -) -> tuple[Tensor, Tensor, Tensor]: - sp_vocab_size = int(sp.vocab_size()) - table_size = max(sp_vocab_size, vocab_size) - base_bytes_np = np.zeros((table_size,), dtype=np.int16) - has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) - is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) - for token_id in range(sp_vocab_size): - if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): - continue - is_boundary_token_np[token_id] = False - if sp.is_byte(token_id): - base_bytes_np[token_id] = 1 - continue - piece = sp.id_to_piece(token_id) - if piece.startswith("▁"): - has_leading_space_np[token_id] = True - piece = piece[1:] - base_bytes_np[token_id] = len(piece.encode("utf-8")) - return ( - torch.tensor(base_bytes_np, dtype=torch.int16, device=device), - torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), - torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), - ) - - -def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: - files = [Path(p) for p in sorted(glob.glob(pattern))] - if not files: - raise FileNotFoundError(f"No files found for pattern: {pattern}") - # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. - tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() - usable = ((tokens.numel() - 1) // seq_len) * seq_len - if usable <= 0: - raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") - return tokens[: usable + 1] - - -def eval_val( - args: Hyperparameters, - model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - grad_accum_steps: int, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, -) -> tuple[float, float]: - # Validation computes two metrics: - # - val_loss: token cross-entropy (natural log) - # - val_bpb: tokenizer-agnostic compression metric used by the challenge - local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) - if local_batch_tokens < args.train_seq_len: - raise ValueError( - "VAL_BATCH_SIZE must provide at least one sequence per rank; " - f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " - f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" - ) - local_batch_seqs = local_batch_tokens // args.train_seq_len - total_seqs = (val_tokens.numel() - 1) // args.train_seq_len - seq_start = (total_seqs * rank) // world_size - seq_end = (total_seqs * (rank + 1)) // world_size - val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) - val_token_count = torch.zeros((), device=device, dtype=torch.float64) - val_byte_count = torch.zeros((), device=device, dtype=torch.float64) - - model.eval() - with torch.inference_mode(): - for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): - batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) - raw_start = batch_seq_start * args.train_seq_len - raw_end = batch_seq_end * args.train_seq_len + 1 - local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) - x = local[:-1].reshape(-1, args.train_seq_len) - y = local[1:].reshape(-1, args.train_seq_len) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - batch_loss = model(x, y).detach() - batch_token_count = float(y.numel()) - val_loss_sum += batch_loss.to(torch.float64) * batch_token_count - val_token_count += batch_token_count - prev_ids = x.reshape(-1) - tgt_ids = y.reshape(-1) - token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) - token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) - val_byte_count += token_bytes.to(torch.float64).sum() - - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) - - val_loss = val_loss_sum / val_token_count - bits_per_token = val_loss.item() / math.log(2.0) - tokens_per_byte = val_token_count.item() / val_byte_count.item() - model.train() - return float(val_loss.item()), float(bits_per_token * tokens_per_byte) - -# ----------------------------- -# POST-TRAINING QUANTIZATION -# ----------------------------- -# -# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. -# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. -# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. - -CONTROL_TENSOR_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "CONTROL_TENSOR_NAME_PATTERNS", - "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", - ",".join(CONTROL_TENSOR_NAME_PATTERNS), - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 -INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 -INT8_PER_ROW_SCALE_DTYPE = torch.float16 -INT8_CLIP_PERCENTILE = 99.99984 -INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 - -def tensor_nbytes(t: Tensor) -> int: - return int(t.numel()) * int(t.element_size()) - -def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: - if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): - return t.float().contiguous() - if t.dtype in {torch.float32, torch.bfloat16}: - passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") - return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() - return t - -def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: - t32 = t.float() - if t32.ndim == 2: - # Matrices get one scale per row, which usually tracks output-channel - # ranges much better than a single tensor-wide scale. - clip_abs = ( - torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) - if t32.numel() - else torch.empty((t32.shape[0],), dtype=torch.float32) - ) - clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) - scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) - q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() - return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() - - # Vectors / scalars use a simpler per-tensor scale. - clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 - scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) - q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() - return q, scale - - -def quantize_float_tensor_int6(t: Tensor) -> tuple[Tensor, Tensor]: - """Int6-equivalent: quantize to int8 but round to multiples of 4 (64 levels).""" - t32 = t.float() - if t32.ndim == 2: - clip_abs = ( - torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) - if t32.numel() - else torch.empty((t32.shape[0],), dtype=torch.float32) - ) - clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) - scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) - q_raw = torch.round(clipped / scale[:, None]) - q = torch.clamp((torch.round(q_raw / 4) * 4), -128, 124).to(torch.int8).contiguous() - return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() - clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 - scale_val = clip_abs / 127.0 if clip_abs > 0 else 1.0 - scale = torch.tensor(scale_val, dtype=torch.float32) - q_raw = torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale) - q = torch.clamp((torch.round(q_raw / 4) * 4), -128, 124).to(torch.int8).contiguous() - return q, scale - - -def quantize_state_dict_int8(state_dict: dict[str, Tensor], fp16_embed: bool = False, int6_layer_start: int = -1, int6_layer_end: int = -1): - # Single supported clean-script export format: - # - per-row int8 for 2D float tensors - # - per-tensor int8 for other float tensors - # - exact passthrough for non-floats - # - passthrough for small float tensors, stored as fp16 to save bytes - quantized: dict[str, Tensor] = {} - scales: dict[str, Tensor] = {} - dtypes: dict[str, str] = {} - passthrough: dict[str, Tensor] = {} - passthrough_orig_dtypes: dict[str, str] = {} - qmeta: dict[str, dict[str, object]] = {} - stats = dict.fromkeys( - ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), - 0, - ) - - for name, tensor in state_dict.items(): - t = tensor.detach().to("cpu").contiguous() - stats["param_count"] += int(t.numel()) - stats["num_tensors"] += 1 - stats["baseline_tensor_bytes"] += tensor_nbytes(t) - - if not t.is_floating_point(): - stats["num_nonfloat_tensors"] += 1 - passthrough[name] = t - stats["int8_payload_bytes"] += tensor_nbytes(t) - continue - - # Small float tensors are cheap enough to keep directly. We still downcast - # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. - if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: - kept = keep_float_tensor(name, t, passthrough_orig_dtypes) - passthrough[name] = kept - stats["int8_payload_bytes"] += tensor_nbytes(kept) - continue - - # Keep tied embedding in fp16 for better roundtrip quality - if fp16_embed and "tok_emb.weight" in name: - kept = t.to(dtype=torch.float16).contiguous() - passthrough[name] = kept - passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") - stats["int8_payload_bytes"] += tensor_nbytes(kept) - continue - - stats["num_float_tensors"] += 1 - # Use int6 for middle layers to save space - use_int6 = False - if int6_layer_start >= 0 and int6_layer_end >= 0: - for layer_idx in range(int6_layer_start, int6_layer_end + 1): - if f"blocks.{layer_idx}." in name and t.ndim == 2: - use_int6 = True - break - q, s = quantize_float_tensor_int6(t) if use_int6 else quantize_float_tensor(t) - if s.ndim > 0: - qmeta[name] = {"scheme": "per_row", "axis": 0} - quantized[name] = q - scales[name] = s - dtypes[name] = str(t.dtype).removeprefix("torch.") - stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) - - obj: dict[str, object] = { - "__quant_format__": "int8_clean_per_row_v1", - "quantized": quantized, - "scales": scales, - "dtypes": dtypes, - "passthrough": passthrough, - } - if qmeta: - obj["qmeta"] = qmeta - if passthrough_orig_dtypes: - obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes - return obj, stats - -def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: - out: dict[str, Tensor] = {} - qmeta = obj.get("qmeta", {}) - passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) - for name, q in obj["quantized"].items(): - dtype = getattr(torch, obj["dtypes"][name]) - s = obj["scales"][name] - if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: - s = s.to(dtype=torch.float32) - # Broadcast the saved row scale back across trailing dimensions. - out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() - else: - scale = float(s.item()) - out[name] = (q.float() * scale).to(dtype=dtype).contiguous() - for name, t in obj["passthrough"].items(): - # Restore small tensors, undoing the temporary fp16 storage cast if needed. - out_t = t.detach().to("cpu").contiguous() - orig_dtype = passthrough_orig_dtypes.get(name) - if isinstance(orig_dtype, str): - out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() - out[name] = out_t - return out - - -# ----------------------------- -# DATA LOADING -# ----------------------------- - -def load_data_shard(file: Path) -> Tensor: - header_bytes = 256 * np.dtype(" None: - self.file_idx = (self.file_idx + 1) % len(self.files) - self.tokens = load_data_shard(self.files[self.file_idx]) - self.pos = 0 - - def take(self, n: int) -> Tensor: - chunks: list[Tensor] = [] - remaining = n - while remaining > 0: - avail = self.tokens.numel() - self.pos - if avail <= 0: - self._advance_file() - continue - k = min(remaining, avail) - chunks.append(self.tokens[self.pos : self.pos + k]) - self.pos += k - remaining -= k - return chunks[0] if len(chunks) == 1 else torch.cat(chunks) - - -class DistributedTokenLoader: - # Each call consumes a contiguous chunk from the shared token stream, then slices out - # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. - def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): - self.rank = rank - self.world_size = world_size - self.device = device - self.stream = TokenStream(pattern) - - def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: - local_tokens = global_tokens // (self.world_size * grad_accum_steps) - per_rank_span = local_tokens + 1 - chunk = self.stream.take(per_rank_span * self.world_size) - start = self.rank * per_rank_span - local = chunk[start : start + per_rank_span].to(dtype=torch.int64) - x = local[:-1].reshape(-1, seq_len) - y = local[1:].reshape(-1, seq_len) - return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) - -# ----------------------------- -# TRANSFORMER MODULES -# ----------------------------- - -class RMSNorm(nn.Module): - def __init__(self, eps: float | None = None): - super().__init__() - self.eps = eps - - def forward(self, x: Tensor) -> Tensor: - return F.rms_norm(x, (x.size(-1),), eps=self.eps) - - -class CastedLinear(nn.Linear): - # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. - _qat: bool = False - _qat_int6: bool = False - - def forward(self, x: Tensor) -> Tensor: - w = self.weight - if self._qat and self.training and w.ndim == 2: - w_f = w.float() - amax = w_f.abs().amax(dim=-1, keepdim=True).clamp_min(1e-12) - scale = amax / 127.0 - q_raw = (w_f / scale).round() - if self._qat_int6: - q = (q_raw / 4).round() * 4 - q = q.clamp(-128, 124) - else: - q = q_raw.clamp(-127, 127) - w_q = q * scale - w = w + (w_q - w_f).detach() # STE - bias = self.bias.to(x.dtype) if self.bias is not None else None - return F.linear(x, w.to(x.dtype), bias) - - -def restore_low_dim_params_to_fp32(module: nn.Module) -> None: - # Keep small/control parameters in fp32 even when the model body runs in bf16. - with torch.no_grad(): - for name, param in module.named_parameters(): - if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: - param.data = param.data.float() - - -class Rotary(nn.Module): - # Caches cos/sin tables per sequence length on the current device. - def __init__(self, dim: int, base: float = 10000.0): - super().__init__() - self.dim = dim - self.base = base - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self._seq_len_cached = 0 - self._ntk_alpha_cached = 0.0 - self._cos_cached: Tensor | None = None - self._sin_cached: Tensor | None = None - - def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype, ntk_alpha: float = 0.0) -> tuple[Tensor, Tensor]: - if ( - self._cos_cached is None - or self._sin_cached is None - or self._seq_len_cached != seq_len - or self._ntk_alpha_cached != ntk_alpha - or self._cos_cached.device != device - ): - if ntk_alpha > 0: - base_scaled = self.base * ntk_alpha - inv_freq = 1.0 / (base_scaled ** (torch.arange(0, self.dim, 2, dtype=torch.float32, device=device) / self.dim)) - else: - inv_freq = self.inv_freq.to(device) - t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) - freqs = torch.outer(t, inv_freq) - self._cos_cached = freqs.cos()[None, None, :, :] - self._sin_cached = freqs.sin()[None, None, :, :] - self._seq_len_cached = seq_len - self._ntk_alpha_cached = ntk_alpha - return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) - - -def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: - half = x.size(-1) // 2 - x1, x2 = x[..., :half], x[..., half:] - return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) - - -class CausalSelfAttention(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - rope_base: float, - qk_gain_init: float, - ): - super().__init__() - if dim % num_heads != 0: - raise ValueError("model_dim must be divisible by num_heads") - if num_heads % num_kv_heads != 0: - raise ValueError("num_heads must be divisible by num_kv_heads") - self.num_heads = num_heads - self.num_kv_heads = num_kv_heads - self.head_dim = dim // num_heads - if self.head_dim % 2 != 0: - raise ValueError("head_dim must be even for RoPE") - kv_dim = self.num_kv_heads * self.head_dim - self.c_q = CastedLinear(dim, dim, bias=False) - self.c_k = CastedLinear(dim, kv_dim, bias=False) - self.c_v = CastedLinear(dim, kv_dim, bias=False) - self.proj = CastedLinear(dim, dim, bias=False) - self.proj._zero_init = True - self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) - self.rotary = Rotary(self.head_dim, base=rope_base) - - def forward(self, x: Tensor) -> Tensor: - bsz, seqlen, dim = x.shape - q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) - k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) - v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) - q = F.rms_norm(q, (q.size(-1),)) - k = F.rms_norm(k, (k.size(-1),)) - ntk_alpha = getattr(self, "_ntk_alpha", 0.0) - cos, sin = self.rotary(seqlen, x.device, q.dtype, ntk_alpha=ntk_alpha) - q = apply_rotary_emb(q, cos, sin) - k = apply_rotary_emb(k, cos, sin) - q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] - y = F.scaled_dot_product_attention( - q, - k, - v, - attn_mask=None, - is_causal=True, - enable_gqa=(self.num_kv_heads != self.num_heads), - ) - y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) - return self.proj(y) - - -class MLP(nn.Module): - # relu^2 MLP from the original modded-nanogpt setup - def __init__(self, dim: int, mlp_mult: int): - super().__init__() - hidden = mlp_mult * dim - self.fc = CastedLinear(dim, hidden, bias=False) - self.proj = CastedLinear(hidden, dim, bias=False) - self.proj._zero_init = True - - def forward(self, x: Tensor) -> Tensor: - x = torch.relu(self.fc(x)) - return self.proj(x.square()) - - -class Block(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: int, - rope_base: float, - qk_gain_init: float, - ): - super().__init__() - self.attn_norm = RMSNorm() - self.mlp_norm = RMSNorm() - self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) - self.mlp = MLP(dim, mlp_mult) - self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) - - def forward(self, x: Tensor, x0: Tensor) -> Tensor: - mix = self.resid_mix.to(dtype=x.dtype) - x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 - attn_out = self.attn(self.attn_norm(x)) - x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out - x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) - return x - - -class GPT(nn.Module): - def __init__( - self, - vocab_size: int, - num_layers: int, - model_dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: int, - tie_embeddings: bool, - tied_embed_init_std: float, - logit_softcap: float, - rope_base: float, - qk_gain_init: float, - ): - super().__init__() - if logit_softcap <= 0.0: - raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") - self.tie_embeddings = tie_embeddings - self.tied_embed_init_std = tied_embed_init_std - self.logit_softcap = logit_softcap - self.tok_emb = nn.Embedding(vocab_size, model_dim) - self.num_encoder_layers = num_layers // 2 - self.num_decoder_layers = num_layers - self.num_encoder_layers - self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) - self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) - self.blocks = nn.ModuleList( - [ - Block( - model_dim, - num_heads, - num_kv_heads, - mlp_mult, - rope_base, - qk_gain_init, - ) - for i in range(num_layers) - ] - ) - self.final_norm = RMSNorm() - self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) - if self.lm_head is not None: - self.lm_head._zero_init = True - self._init_weights() - - def _init_weights(self) -> None: - if self.tie_embeddings: - nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) - for module in self.modules(): - if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): - nn.init.zeros_(module.weight) - - def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: - x = self.tok_emb(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - x0 = x - skips: list[Tensor] = [] - - # First half stores skips; second half reuses them in reverse order. - for i in range(self.num_encoder_layers): - x = self.blocks[i](x, x0) - skips.append(x) - for i in range(self.num_decoder_layers): - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - x = self.blocks[self.num_encoder_layers + i](x, x0) - - x = self.final_norm(x).reshape(-1, x.size(-1)) - targets = target_ids.reshape(-1) - if self.tie_embeddings: - logits_proj = F.linear(x, self.tok_emb.weight) - else: - if self.lm_head is None: - raise RuntimeError("lm_head is required when tie_embeddings=False") - logits_proj = self.lm_head(x) - logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - return F.cross_entropy(logits.float(), targets, reduction="mean") - - - - -class BlockLARS: - """Block-level LARS: scale gradients per block (attn/mlp/other) by trust * ||W|| / ||grad||.""" - def __init__(self, model: nn.Module, trust: float = 0.02, min_scale: float = 0.01, max_scale: float = 20.0): - self.trust = trust - self.min_scale = min_scale - self.max_scale = max_scale - self.blocks: dict[str, list[nn.Parameter]] = {"attn": [], "mlp": [], "other": []} - for name, p in model.named_parameters(): - if not p.requires_grad: - continue - if "attn" in name: - self.blocks["attn"].append(p) - elif "mlp" in name: - self.blocks["mlp"].append(p) - else: - self.blocks["other"].append(p) - self.blocks = {k: v for k, v in self.blocks.items() if v} - - @torch.no_grad() - def step(self) -> None: - for params in self.blocks.values(): - w_sq = sum(p.data.norm().item() ** 2 for p in params) - g_sq = sum(p.grad.data.norm().item() ** 2 for p in params if p.grad is not None) - w_norm, g_norm = w_sq ** 0.5, g_sq ** 0.5 - if g_norm > 1e-8: - scale = max(self.min_scale, min(self.max_scale, self.trust * w_norm / g_norm)) - else: - scale = 1.0 - for p in params: - if p.grad is not None: - p.grad.data.mul_(scale) - - -def set_ntk_alpha(model: nn.Module, alpha: float) -> None: - """Set NTK-aware RoPE scaling on all attention modules.""" - for module in model.modules(): - if isinstance(module, CausalSelfAttention): - module._ntk_alpha = alpha - - -def forward_logits(model: nn.Module, input_ids: Tensor) -> Tensor: - """Forward pass returning logits (for sliding window eval). Uses uncompiled model.""" - x = model.tok_emb(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - x0 = x - skips: list[Tensor] = [] - for i in range(model.num_encoder_layers): - x = model.blocks[i](x, x0) - skips.append(x) - for i in range(model.num_decoder_layers): - if skips: - x = x + model.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - x = model.blocks[model.num_encoder_layers + i](x, x0) - x = model.final_norm(x) - if model.tie_embeddings: - logits_proj = F.linear(x, model.tok_emb.weight) - else: - if model.lm_head is None: - raise RuntimeError("lm_head is required when tie_embeddings=False") - logits_proj = model.lm_head(x) - return model.logit_softcap * torch.tanh(logits_proj / model.logit_softcap) - - -def eval_val_sliding( - args: "Hyperparameters", - base_model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, -) -> tuple[float, float]: - """Sliding window evaluation for better BPB.""" - seq_len = args.eval_seq_len if args.eval_seq_len > 0 else args.train_seq_len - stride = args.eval_stride if args.eval_stride > 0 else seq_len - batch_size = args.eval_batch_size - total_len = val_tokens.numel() - - # Generate window start positions - positions = list(range(0, total_len - seq_len, stride)) - # Distribute across ranks (interleaved for balance) - rank_positions = positions[rank::world_size] - - loss_sum = torch.zeros((), device=device, dtype=torch.float64) - token_count = torch.zeros((), device=device, dtype=torch.float64) - byte_count = torch.zeros((), device=device, dtype=torch.float64) - - base_model.eval() - with torch.inference_mode(): - for batch_start in range(0, len(rank_positions), batch_size): - batch_pos = rank_positions[batch_start : batch_start + batch_size] - bs = len(batch_pos) - - x = torch.stack( - [val_tokens[p : p + seq_len] for p in batch_pos] - ).to(device=device, dtype=torch.int64) - y = torch.stack( - [val_tokens[p + 1 : p + seq_len + 1] for p in batch_pos] - ).to(device=device, dtype=torch.int64) - - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - logits = forward_logits(base_model, x) # [bs, seq_len, vocab] - - # Score only the last stride tokens of each window - score_start = seq_len - stride - logits_s = logits[:, score_start:, :].float() - targets_s = y[:, score_start:] - - per_tok = F.cross_entropy( - logits_s.reshape(-1, logits_s.size(-1)), - targets_s.reshape(-1), - reduction="none", - ) - loss_sum += per_tok.to(torch.float64).sum() - token_count += float(targets_s.numel()) - - prev_ids = x[:, score_start:].reshape(-1) - tgt_ids = targets_s.reshape(-1) - tb = base_bytes_lut[tgt_ids].to(dtype=torch.int16) - tb += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(torch.int16) - byte_count += tb.to(torch.float64).sum() - - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) - - val_loss = (loss_sum / token_count).item() - bits_per_token = val_loss / math.log(2.0) - tokens_per_byte = token_count.item() / byte_count.item() - return val_loss, float(bits_per_token * tokens_per_byte) - -# ----------------------------- -# TRAINING -# ----------------------------- - -def main() -> None: - global zeropower_via_newtonschulz5 - - code = Path(__file__).read_text(encoding="utf-8") - args = Hyperparameters() - zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) - - # ----------------------------- - # DISTRIBUTED + CUDA SETUP - # ----------------------------- - - distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ - rank = int(os.environ.get("RANK", "0")) - world_size = int(os.environ.get("WORLD_SIZE", "1")) - local_rank = int(os.environ.get("LOCAL_RANK", "0")) - if world_size <= 0: - raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") - if 8 % world_size != 0: - raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") - grad_accum_steps = 8 // world_size - grad_scale = 1.0 / grad_accum_steps - if not torch.cuda.is_available(): - raise RuntimeError("CUDA is required") - device = torch.device("cuda", local_rank) - torch.cuda.set_device(device) - if distributed: - dist.init_process_group(backend="nccl", device_id=device) - dist.barrier() - master_process = rank == 0 - - # Fast math knobs - torch.backends.cuda.matmul.allow_tf32 = True - torch.backends.cudnn.allow_tf32 = True - from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp - - enable_cudnn_sdp(False) - enable_flash_sdp(True) - enable_mem_efficient_sdp(False) - enable_math_sdp(False) - - logfile = None - if master_process: - os.makedirs("logs", exist_ok=True) - logfile = f"logs/{args.run_id}.txt" - print(logfile) - - def log0(msg: str, console: bool = True) -> None: - if not master_process: - return - if console: - print(msg) - if logfile is not None: - with open(logfile, "a", encoding="utf-8") as f: - print(msg, file=f) - - log0(code, console=False) - log0("=" * 100, console=False) - log0(f"Running Python {sys.version}", console=False) - log0(f"Running PyTorch {torch.__version__}", console=False) - log0( - subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, - console=False, - ) - log0("=" * 100, console=False) - - # ----------------------------- - # TOKENIZER + VALIDATION METRIC SETUP - # ----------------------------- - - random.seed(args.seed) - np.random.seed(args.seed) - torch.manual_seed(args.seed) - torch.cuda.manual_seed_all(args.seed) - - if not args.tokenizer_path.endswith(".model"): - raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") - sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) - if int(sp.vocab_size()) != args.vocab_size: - raise ValueError( - f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" - ) - dataset_dir = Path(args.data_path).resolve() - actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) - val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) - base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( - sp, args.vocab_size, device - ) - log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") - log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") - log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") - - # ----------------------------- - # MODEL + OPTIMIZER SETUP - # ----------------------------- - - base_model = GPT( - vocab_size=args.vocab_size, - num_layers=args.num_layers, - model_dim=args.model_dim, - num_heads=args.num_heads, - num_kv_heads=args.num_kv_heads, - mlp_mult=args.mlp_mult, - tie_embeddings=args.tie_embeddings, - tied_embed_init_std=args.tied_embed_init_std, - logit_softcap=args.logit_softcap, - rope_base=args.rope_base, - qk_gain_init=args.qk_gain_init, - ).to(device).bfloat16() - for module in base_model.modules(): - if isinstance(module, CastedLinear): - module.float() - restore_low_dim_params_to_fp32(base_model) - if args.qat_enabled: - for module in base_model.modules(): - if isinstance(module, CastedLinear): - module._qat = True - module._qat_int6 = args.qat_int6 - compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) - model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model - - # Optimizer split: - # - token embedding (Adam) uses EMBED_LR - # - untied lm_head (Adam) uses HEAD_LR - # - matrix params in transformer blocks use MATRIX_LR via Muon - # - vectors/scalars use SCALAR_LR via Adam - block_named_params = list(base_model.blocks.named_parameters()) - matrix_params = [ - p - for name, p in block_named_params - if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - scalar_params = [ - p - for name, p in block_named_params - if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - if base_model.skip_weights.numel() > 0: - scalar_params.append(base_model.skip_weights) - token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr - AdamCls = torch.optim.AdamW if args.adam_weight_decay > 0 else torch.optim.Adam - adam_wd = args.adam_weight_decay - optimizer_tok = AdamCls( - [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - weight_decay=adam_wd, - fused=True, - ) - optimizer_muon = Muon( - matrix_params, - lr=args.matrix_lr, - momentum=args.muon_momentum, - backend_steps=args.muon_backend_steps, - ) - for group in optimizer_muon.param_groups: - group["base_lr"] = args.matrix_lr - optimizer_scalar = AdamCls( - [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - weight_decay=adam_wd, - fused=True, - ) - optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] - if base_model.lm_head is not None: - optimizer_head = AdamCls( - [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - weight_decay=adam_wd, - fused=True, - ) - optimizers.insert(1, optimizer_head) - - n_params = sum(p.numel() for p in base_model.parameters()) - log0(f"model_params:{n_params}") - log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") - log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") - log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") - log0( - f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " - f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " - f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" - ) - log0( - f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " - f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " - f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" - ) - log0(f"seed:{args.seed}") - - # ----------------------------- - # DATA LOADER & MODEL WARMUP - # ----------------------------- - - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - block_lars = BlockLARS(base_model, trust=args.block_lars_trust, min_scale=args.block_lars_min_scale) if args.block_lars_trust > 0 else None - - def zero_grad_all() -> None: - for opt in optimizers: - opt.zero_grad(set_to_none=True) - - max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None - - def lr_mul(step: int, elapsed_ms: float) -> float: - if args.warmdown_iters <= 0: - return 1.0 - if max_wallclock_ms is None: - warmdown_start = max(args.iterations - args.warmdown_iters, 0) - return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 - step_ms = elapsed_ms / max(step, 1) - warmdown_ms = args.warmdown_iters * step_ms - remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) - return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 - - # Warmup primes the compiled forward/backward/optimizer paths, then we restore the - # initial weights/optimizer state so measured training starts from the true init. - if args.warmup_steps > 0: - initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} - initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] - model.train() - for warmup_step in range(args.warmup_steps): - zero_grad_all() - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - warmup_loss = model(x, y) - (warmup_loss * grad_scale).backward() - for opt in optimizers: - opt.step() - zero_grad_all() - if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: - log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") - base_model.load_state_dict(initial_model_state, strict=True) - for opt, state in zip(optimizers, initial_optimizer_states, strict=True): - opt.load_state_dict(state) - zero_grad_all() - if distributed: - model.require_backward_grad_sync = True - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - # ----------------------------- - # MAIN TRAINING LOOP - # ----------------------------- - - training_time_ms = 0.0 - stop_after_step: int | None = None - # SWA state - swa_state: dict[str, Tensor] | None = None - swa_count = 0 - torch.cuda.synchronize() - t0 = time.perf_counter() - - step = 0 - while True: - last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) - - should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) - if should_validate: - torch.cuda.synchronize() - training_time_ms += 1000.0 * (time.perf_counter() - t0) - val_loss, val_bpb = eval_val( - args, - model, - rank, - world_size, - device, - grad_accum_steps, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, - ) - log0( - f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " - f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" - ) - torch.cuda.synchronize() - t0 = time.perf_counter() - - if last_step: - if stop_after_step is not None and step < args.iterations: - log0( - f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " - f"step:{step}/{args.iterations}" - ) - break - - elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - scale = lr_mul(step, elapsed_ms) - zero_grad_all() - train_loss = torch.zeros((), device=device) - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - loss = model(x, y) - train_loss += loss.detach() - (loss * grad_scale).backward() - train_loss /= grad_accum_steps - - frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 - muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum - for group in optimizer_muon.param_groups: - group["momentum"] = muon_momentum - - for opt in optimizers: - for group in opt.param_groups: - group["lr"] = group["base_lr"] * scale - - if args.grad_clip_norm > 0: - torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) - if args.block_lars_trust > 0: - block_lars.step() - for opt in optimizers: - opt.step() - # Decoupled weight decay on Muon (matrix) params - if args.muon_weight_decay > 0: - with torch.no_grad(): - muon_lr = optimizer_muon.param_groups[0]["lr"] - for p in matrix_params: - p.mul_(1.0 - muon_lr * args.muon_weight_decay) - zero_grad_all() - - # SWA: accumulate weight average during warmdown - if args.swa_start_frac > 0 and scale < args.swa_start_frac: - if swa_state is None: - swa_state = {k: v.detach().clone() for k, v in base_model.state_dict().items()} - swa_count = 1 - else: - for k, v in base_model.state_dict().items(): - swa_state[k].add_(v.detach()) - swa_count += 1 - - step += 1 - approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - should_log_train = ( - args.train_log_every > 0 - and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) - ) - if should_log_train: - log0( - f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " - f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" - ) - - # Needed to sync whether we've reached the wallclock cap. - reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms - if distributed and max_wallclock_ms is not None: - reached_cap_tensor = torch.tensor(int(reached_cap), device=device) - dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) - reached_cap = bool(reached_cap_tensor.item()) - if stop_after_step is None and reached_cap: - stop_after_step = step - - log0( - f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " - f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" - ) - - # ----------------------------- - # SERIALIZATION + ROUNDTRIP VALIDATION - # ----------------------------- - # Apply SWA averaged weights if available - if swa_state is not None and swa_count > 1: - log0(f"SWA: applying averaged weights from {swa_count} checkpoints") - for k in swa_state: - swa_state[k].div_(swa_count) - base_model.load_state_dict(swa_state, strict=True) - restore_low_dim_params_to_fp32(base_model) - - if master_process: - torch.save(base_model.state_dict(), "final_model.pt") - model_bytes = os.path.getsize("final_model.pt") - code_bytes = len(code.encode("utf-8")) - log0(f"Serialized model: {model_bytes} bytes") - log0(f"Code size: {code_bytes} bytes") - log0(f"Total submission size: {model_bytes + code_bytes} bytes") - - quant_obj, quant_stats = quantize_state_dict_int8( - base_model.state_dict(), - fp16_embed=args.fp16_embed_export if args.tie_embeddings else False, - int6_layer_start=args.int6_layer_start, - int6_layer_end=args.int6_layer_end, - ) - quant_buf = io.BytesIO() - torch.save(quant_obj, quant_buf) - quant_raw = quant_buf.getvalue() - if args.use_zstd and zstd_mod is not None: - cctx = zstd_mod.ZstdCompressor(level=args.zstd_level) - quant_blob = cctx.compress(quant_raw) - compress_label = f"int8+zstd{args.zstd_level}" - else: - quant_blob = zlib.compress(quant_raw, level=9) - compress_label = "int8+zlib" - quant_raw_bytes = len(quant_raw) - if master_process: - with open("final_model.int8.ptz", "wb") as f: - f.write(quant_blob) - quant_file_bytes = os.path.getsize("final_model.int8.ptz") - code_bytes = len(code.encode("utf-8")) - ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) - log0( - f"Serialized model {compress_label}: {quant_file_bytes} bytes " - f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" - ) - log0(f"Total submission size {compress_label}: {quant_file_bytes + code_bytes} bytes") - - if distributed: - dist.barrier() - with open("final_model.int8.ptz", "rb") as f: - quant_blob_disk = f.read() - if args.use_zstd and zstd_mod is not None: - dctx = zstd_mod.ZstdDecompressor() - quant_state = torch.load(io.BytesIO(dctx.decompress(quant_blob_disk)), map_location="cpu") - else: - quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") - base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) - torch.cuda.synchronize() - t_qeval = time.perf_counter() - q_val_loss, q_val_bpb = eval_val( - args, - model, - rank, - world_size, - device, - grad_accum_steps, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, - ) - torch.cuda.synchronize() - log0( - f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " - f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" - ) - log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") - - if args.eval_stride > 0: - if args.eval_ntk_alpha > 0: - set_ntk_alpha(base_model, args.eval_ntk_alpha) - torch.cuda.synchronize() - t_slide = time.perf_counter() - sw_val_loss, sw_val_bpb = eval_val_sliding( - args, base_model, rank, world_size, device, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - ) - torch.cuda.synchronize() - log0( - f"final_sliding_window val_loss:{sw_val_loss:.4f} val_bpb:{sw_val_bpb:.4f} " - f"stride:{args.eval_stride} seq_len:{args.eval_seq_len if args.eval_seq_len > 0 else args.train_seq_len} " - f"ntk_alpha:{args.eval_ntk_alpha} " - f"eval_time:{1000.0 * (time.perf_counter() - t_slide):.0f}ms" - ) - log0(f"final_sliding_window_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") - if args.eval_ntk_alpha > 0: - set_ntk_alpha(base_model, 0.0) - - if distributed: - dist.destroy_process_group() - - -if __name__ == "__main__": - main() diff --git a/records/track_10min_16mb/2026-03-19_MixedQuant_Int6Int8_SlidingWindow/README.md b/records/track_10min_16mb/2026-03-19_MixedQuant_Int6Int8_SlidingWindow/README.md deleted file mode 100644 index ea0b1d8d34..0000000000 --- a/records/track_10min_16mb/2026-03-19_MixedQuant_Int6Int8_SlidingWindow/README.md +++ /dev/null @@ -1,104 +0,0 @@ -Four orthogonal improvements over the naive baseline, each contributing independently to the final score. - -### Changes from Baseline - -**1. Wider MLP (MLP\_MULT=3)** - -The baseline uses a 2x MLP expansion (hidden=1024). We widen to 3x (hidden=1536), increasing total parameters from 17.1M to 21.8M. The wider MLP is enabled by the int6 quantization scheme below which keeps the artifact under 16MB. - -**2. Mixed-Precision Post-Training Quantization** - -Key insight: during training, all `CastedLinear` weights get **fake int6 quantization via Straight-Through Estimator (STE)** — the forward pass uses quantized weights while gradients flow through the originals. This teaches the model weight distributions that survive int6 (31-level) quantization. However, the token embedding (`tok\_emb.weight`) is a plain `nn.Embedding` that **never sees fake quantization during training**. - -Previous approaches applied uniform int6 to all 2D tensors, causing a +0.048 BPB quantization penalty dominated by embedding degradation. Our mixed scheme: - -* **int6 per-row** (31 levels) on all 2D block weights (attention projections, MLP layers) — these have STE protection -* **int8 per-row** (127 levels) on the token embedding — no STE protection, needs gentler quantization -* Small/control tensors pass through as fp16/fp32 - -This reduces the quantization penalty from +0.048 to +0.0015 BPB — a 32x improvement. The int6 values are stored in int8 containers; zlib-9 compresses the zero high bits efficiently. - -**3. Optimized Training Configuration** - -* `TRAIN\_SEQ\_LEN=1024` (down from 4096): Attention is O(N²) in sequence length. Shorter sequences = faster steps (48.4ms vs 55.5ms) = more total training in the 10-minute window. The 512-dim model cannot meaningfully exploit 4K context. -* `TRAIN\_BATCH\_TOKENS=524,288` (up from 393,216): Better GPU saturation at seq\_len=1024, \~33% more tokens per step. -* Result: 12,395 steps × 524K tokens = \~6.50B total tokens (vs \~4.25B with the old config). - -**4. Sliding Window Evaluation (stride=64)** - -Instead of non-overlapping evaluation where early tokens in each chunk get minimal context, we use overlapping windows advanced by 64 tokens. Each window runs the full 1024-token forward pass, but only the last 64 tokens are scored. Every scored token gets 960 tokens of preceding context. - -Sliding window eval improves val\_bpb by \~0.034 with zero artifact cost. stride=64 gives more context per token than stride=256 (960 vs 768), at the cost of longer eval time (\~73s vs \~18s). - -### Configuration - -``` -MLP\_MULT=3 -NUM\_LAYERS=9 -MODEL\_DIM=512 -NUM\_HEADS=8 -NUM\_KV\_HEADS=4 -VOCAB\_SIZE=1024 -TRAIN\_SEQ\_LEN=1024 -TRAIN\_BATCH\_TOKENS=524288 -TIE\_EMBEDDINGS=1 -EVAL\_STRIDE=64 -``` - -Optimizer settings (tuned via env vars, no code changes from baseline optimizer structure): - -``` -MATRIX\_LR=0.020 -SCALAR\_LR=0.020 -TIED\_EMBED\_LR=0.030 -MUON\_MOMENTUM=0.99 -MUON\_MOMENTUM\_WARMUP\_STEPS=1500 -MUON\_MOMENTUM\_WARMUP\_START=0.92 -WARMDOWN\_ITERS=3000 -``` - -### Run Command - -```bash -RUN\_ID=v2\_int6\_qat\_mlp3 \\ -MAX\_WALLCLOCK\_SECONDS=600 \\ -VAL\_LOSS\_EVERY=2000 \\ -TRAIN\_LOG\_EVERY=200 \\ -torchrun --standalone --nproc\_per\_node=8 train\_gpt.py -``` - -### Key Metrics - -* Training stopped at **12,395/20,000** steps due to 10-minute wallclock cap -* Step time: **48.41ms** average on 8xH100 SXM -* Total train tokens: \~6,499,880,000 (12,395 steps × 524,288 tokens/step) -* Peak memory: **11,251 MiB** allocated per GPU - -|Metric|Value| -|-|-| -|Pre-quant val\_bpb (last step)|1.1950| -|int6/int8 mixed roundtrip val\_bpb (standard)|1.1965| -|**int6/int8 mixed roundtrip val\_bpb (sliding, stride=64)**|**1.1630**| -|Quantization penalty (standard eval)|+0.0015 BPB| -|Sliding window eval time|72.6s| -|Compressed artifact (int6+zlib-9)|15,296,720 bytes| -|Code size|56,770 bytes| -|**Total submission size**|**15,353,490 bytes**| - -### Improvement Breakdown - -|Component|val\_bpb|Improvement vs baseline| -|-|-|-| -|Naive baseline (int8, standard eval)|1.2244|—| -|+ Wider MLP 3x + seq1024 + 524K batch|1.1950|-0.0294| -|+ Mixed quant (int6 blocks, int8 embed)|1.1965|+0.0015 penalty| -|+ Sliding window stride=64|**1.1630**|-0.0335 additional| -|**Total improvement**||**-0.0614**| - -### Included Files - -* `train\_gpt.py` — full training + mixed quantization + evaluation script -* `train.log` — complete training log from the 8xH100 run -* `submission.json` — leaderboard metadata -* `README.md` — this file - diff --git a/records/track_10min_16mb/2026-03-19_MixedQuant_Int6Int8_SlidingWindow/submission.json b/records/track_10min_16mb/2026-03-19_MixedQuant_Int6Int8_SlidingWindow/submission.json deleted file mode 100644 index 25922290cf..0000000000 --- a/records/track_10min_16mb/2026-03-19_MixedQuant_Int6Int8_SlidingWindow/submission.json +++ /dev/null @@ -1,11 +0,0 @@ -{ - "author": "aquariouseworkman", - "github_id": "aquariouseworkman", - "name": "Mixed Quant (int6 blocks + int8 embeddings) + Sliding Window Eval, val_bpb=1.1630", - "blurb": "3x MLP expansion with mixed-precision quantization: int6 per-row (31 levels) on STE-protected block weights, int8 per-row (127 levels) on embedding, zlib-9 compression, sliding window evaluation at stride=64.", - "date": "2026-03-19T10:30:00Z", - "val_loss": 1.96369923, - "val_bpb": 1.16301431, - "bytes_total": 15353490, - "bytes_code": 56770 -} diff --git a/records/track_10min_16mb/2026-03-19_MixedQuant_Int6Int8_SlidingWindow/train.log b/records/track_10min_16mb/2026-03-19_MixedQuant_Int6Int8_SlidingWindow/train.log deleted file mode 100644 index cc4f1477f4..0000000000 --- a/records/track_10min_16mb/2026-03-19_MixedQuant_Int6Int8_SlidingWindow/train.log +++ /dev/null @@ -1,4464 +0,0 @@ -""" -The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. - -Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. -""" - -from __future__ import annotations - -import copy -import glob -import io -import math -import os -import random -import subprocess -import sys -import time -import uuid -import zlib -from pathlib import Path - -try: - import zstandard as zstd - HAS_ZSTD = True -except ImportError: - HAS_ZSTD = False - -import numpy as np -import sentencepiece as spm -import torch -import torch.distributed as dist -import torch.nn.functional as F -from torch import Tensor, nn -from torch.nn.parallel import DistributedDataParallel as DDP - -# ----------------------------- -# HYPERPARAMETERS -# ----------------------------- -# Default Simple Baseline run: -# - 9 transformer blocks at width 512 -# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion -# - vocab size 1024, sequence length 1024, tied embeddings -# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap - -class Hyperparameters: - # Data paths are shard globs produced by the existing preprocessing pipeline. - data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") - train_files = os.path.join(data_path, "fineweb_train_*.bin") - val_files = os.path.join(data_path, "fineweb_val_*.bin") - tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") - run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) - seed = int(os.environ.get("SEED", 1337)) - - # Validation cadence and batch size. Validation always uses the full fineweb_val split. - val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) - val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 0)) - train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) - # Sliding window eval: stride controls how many tokens to slide between windows. - # Only the last `eval_stride` tokens per window are scored, giving each scored token - # (seq_len - eval_stride) tokens of context. Set to 0 to disable (use standard eval). - eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) - - # Training length. - iterations = int(os.environ.get("ITERATIONS", 20000)) - warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3000)) - warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) - train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 393_216)) - train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 4096)) - max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) - qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) - - # Model shape. - vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) - num_layers = int(os.environ.get("NUM_LAYERS", 9)) - num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) - model_dim = int(os.environ.get("MODEL_DIM", 512)) - num_heads = int(os.environ.get("NUM_HEADS", 8)) - mlp_mult = int(os.environ.get("MLP_MULT", 3)) - tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) - rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) - logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) - - # Optimizer hyperparameters. - embed_lr = float(os.environ.get("EMBED_LR", 0.6)) - head_lr = float(os.environ.get("HEAD_LR", 0.008)) - tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.030)) - tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) - matrix_lr = float(os.environ.get("MATRIX_LR", 0.020)) - scalar_lr = float(os.environ.get("SCALAR_LR", 0.020)) - muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) - muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) - muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) - muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) - beta1 = float(os.environ.get("BETA1", 0.9)) - beta2 = float(os.environ.get("BETA2", 0.95)) - adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) - grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) - -# ----------------------------- -# MUON OPTIMIZER -# ----------------------------- -# -# As borrowed from modded-nanogpt -# Background on Muon: https://kellerjordan.github.io/posts/muon/ - -def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: - # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. - # Muon uses this to normalize matrix-shaped gradients before applying them. - a, b, c = (3.4445, -4.7750, 2.0315) - X = G.bfloat16() - X /= X.norm() + eps - transposed = G.size(0) > G.size(1) - if transposed: - X = X.T - for _ in range(steps): - A = X @ X.T - B = b * A + c * A @ A - X = a * X + B @ X - return X.T if transposed else X - - -class Muon(torch.optim.Optimizer): - def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): - super().__init__( - params, - dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), - ) - - @torch.no_grad() - def step(self, closure=None): - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - distributed = dist.is_available() and dist.is_initialized() - world_size = dist.get_world_size() if distributed else 1 - rank = dist.get_rank() if distributed else 0 - - for group in self.param_groups: - params = group["params"] - if not params: - continue - lr = group["lr"] - momentum = group["momentum"] - backend_steps = group["backend_steps"] - nesterov = group["nesterov"] - - total_params = sum(int(p.numel()) for p in params) - updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) - - curr = 0 - for i, p in enumerate(params): - if i % world_size == rank and p.grad is not None: - g = p.grad - state = self.state[p] - if "momentum_buffer" not in state: - state["momentum_buffer"] = torch.zeros_like(g) - buf = state["momentum_buffer"] - buf.mul_(momentum).add_(g) - if nesterov: - g = g.add(buf, alpha=momentum) - g = zeropower_via_newtonschulz5(g, steps=backend_steps) - # Scale correction from Muon reference implementations. - g *= max(1, g.size(0) / g.size(1)) ** 0.5 - updates_flat[curr : curr + p.numel()] = g.reshape(-1) - curr += p.numel() - - if distributed: - dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) - - curr = 0 - for p in params: - g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) - p.add_(g, alpha=-lr) - curr += p.numel() - - return loss - - -# ----------------------------- -# TOKENIZER-AGNOSTIC EVALUATION SETUP -# ----------------------------- -# -# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. -# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. -# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. -# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. - -def build_sentencepiece_luts( - sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device -) -> tuple[Tensor, Tensor, Tensor]: - sp_vocab_size = int(sp.vocab_size()) - table_size = max(sp_vocab_size, vocab_size) - base_bytes_np = np.zeros((table_size,), dtype=np.int16) - has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) - is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) - for token_id in range(sp_vocab_size): - if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): - continue - is_boundary_token_np[token_id] = False - if sp.is_byte(token_id): - base_bytes_np[token_id] = 1 - continue - piece = sp.id_to_piece(token_id) - if piece.startswith("▁"): - has_leading_space_np[token_id] = True - piece = piece[1:] - base_bytes_np[token_id] = len(piece.encode("utf-8")) - return ( - torch.tensor(base_bytes_np, dtype=torch.int16, device=device), - torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), - torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), - ) - - -def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: - files = [Path(p) for p in sorted(glob.glob(pattern))] - if not files: - raise FileNotFoundError(f"No files found for pattern: {pattern}") - # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. - tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() - usable = ((tokens.numel() - 1) // seq_len) * seq_len - if usable <= 0: - raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") - return tokens[: usable + 1] - - -def eval_val( - args: Hyperparameters, - model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - grad_accum_steps: int, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, -) -> tuple[float, float]: - # Validation computes two metrics: - # - val_loss: token cross-entropy (natural log) - # - val_bpb: tokenizer-agnostic compression metric used by the challenge - local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) - if local_batch_tokens < args.train_seq_len: - raise ValueError( - "VAL_BATCH_SIZE must provide at least one sequence per rank; " - f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " - f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" - ) - local_batch_seqs = local_batch_tokens // args.train_seq_len - total_seqs = (val_tokens.numel() - 1) // args.train_seq_len - seq_start = (total_seqs * rank) // world_size - seq_end = (total_seqs * (rank + 1)) // world_size - val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) - val_token_count = torch.zeros((), device=device, dtype=torch.float64) - val_byte_count = torch.zeros((), device=device, dtype=torch.float64) - - model.eval() - with torch.inference_mode(): - for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): - batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) - raw_start = batch_seq_start * args.train_seq_len - raw_end = batch_seq_end * args.train_seq_len + 1 - local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) - x = local[:-1].reshape(-1, args.train_seq_len) - y = local[1:].reshape(-1, args.train_seq_len) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - batch_loss = model(x, y).detach() - batch_token_count = float(y.numel()) - val_loss_sum += batch_loss.to(torch.float64) * batch_token_count - val_token_count += batch_token_count - prev_ids = x.reshape(-1) - tgt_ids = y.reshape(-1) - token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) - token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) - val_byte_count += token_bytes.to(torch.float64).sum() - - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) - - val_loss = val_loss_sum / val_token_count - bits_per_token = val_loss.item() / math.log(2.0) - tokens_per_byte = val_token_count.item() / val_byte_count.item() - model.train() - return float(val_loss.item()), float(bits_per_token * tokens_per_byte) - - -def eval_val_sliding_window( - args: Hyperparameters, - base_model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, - stride: int = 64, -) -> tuple[float, float]: - """Sliding window evaluation: each scored token gets (seq_len - stride) context. - - Instead of chopping validation into non-overlapping 1024-token blocks (where - the first token in each block gets zero context), we slide a 1024-token window - by `stride` tokens at a time and only score the last `stride` tokens per window. - Every scored token sees 960+ tokens of context, dramatically improving BPB. - """ - seq_len = args.train_seq_len - total_tokens = val_tokens.numel() - 1 # need 1 extra for final target - - # Number of windows: first window at offset 0, then slide by stride - num_windows = max((total_tokens - seq_len) // stride + 1, 1) - - # Distribute windows across ranks - win_start = (num_windows * rank) // world_size - win_end = (num_windows * (rank + 1)) // world_size - - val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) - val_token_count = torch.zeros((), device=device, dtype=torch.float64) - val_byte_count = torch.zeros((), device=device, dtype=torch.float64) - - # Eval batch: how many windows per forward pass. Tune for memory. - eval_batch = int(os.environ.get("SW_EVAL_BATCH", 32)) - - base_model.eval() - with torch.inference_mode(): - window_list = list(range(win_start, win_end)) - num_batches = (len(window_list) + eval_batch - 1) // eval_batch - for batch_idx in range(num_batches): - batch_wins = window_list[batch_idx * eval_batch : (batch_idx + 1) * eval_batch] - bsz = len(batch_wins) - - # Build input: each window is val_tokens[w*stride : w*stride + seq_len] - inputs = torch.stack([ - val_tokens[w * stride : w * stride + seq_len] - for w in batch_wins - ]).to(device=device, dtype=torch.int64) # [B, seq_len] - - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - logits = base_model.forward_logits(inputs) # [B, seq_len, vocab] - - # Score only the last `stride` positions per window. - # logits[:, j, :] predicts the token AFTER position j in the input. - # So logits[:, -stride:, :] predicts tokens at input positions - # [seq_len - stride + 1, seq_len + 1) — which are the targets. - scored_logits = logits[:, -stride:, :].reshape(-1, logits.size(-1)) # [B*stride, vocab] - - # Build targets: for window w, targets are val_tokens[w*stride + seq_len - stride + 1 : w*stride + seq_len + 1] - targets = torch.stack([ - val_tokens[w * stride + seq_len - stride + 1 : w * stride + seq_len + 1] - for w in batch_wins - ]).to(device=device, dtype=torch.int64).reshape(-1) # [B*stride] - - loss = F.cross_entropy(scored_logits.float(), targets, reduction="sum") - val_loss_sum += loss.to(torch.float64) - val_token_count += float(targets.numel()) - - # BPB: prev_ids are the input tokens at each scored position - prev_ids = torch.stack([ - val_tokens[w * stride + seq_len - stride : w * stride + seq_len] - for w in batch_wins - ]).to(device=device, dtype=torch.int64).reshape(-1) # [B*stride] - - tgt_ids = targets - token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) - token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) - val_byte_count += token_bytes.to(torch.float64).sum() - - if rank == 0 and (batch_idx + 1) % 100 == 0: - print(f" sw_eval batch {batch_idx + 1}/{num_batches}", flush=True) - - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) - - val_loss = val_loss_sum / val_token_count - bits_per_token = val_loss.item() / math.log(2.0) - tokens_per_byte = val_token_count.item() / val_byte_count.item() - base_model.train() - return float(val_loss.item()), float(bits_per_token * tokens_per_byte) - -# ----------------------------- -# POST-TRAINING QUANTIZATION -# ----------------------------- -# -# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. -# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. -# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. - -CONTROL_TENSOR_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "CONTROL_TENSOR_NAME_PATTERNS", - "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", - ",".join(CONTROL_TENSOR_NAME_PATTERNS), - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 -INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 -INT8_PER_ROW_SCALE_DTYPE = torch.float16 - -def tensor_nbytes(t: Tensor) -> int: - return int(t.numel()) * int(t.element_size()) - -def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: - if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): - return t.float().contiguous() - if t.dtype in {torch.float32, torch.bfloat16}: - passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") - return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() - return t - -def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: - t32 = t.float() - if t32.ndim == 2: - # Int6 per-row quantization: [-31, 31] range stored in int8 containers. - # The unused high bits compress extremely well with zstd. - clip_abs = ( - torch.quantile(t32.abs(), INT6_CLIP_Q, dim=1) - if t32.numel() - else torch.empty((t32.shape[0],), dtype=torch.float32) - ) - clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) - scale = (clip_abs / float(INT6_QUANT_RANGE)).clamp_min(1.0 / float(INT6_QUANT_RANGE)) - q = torch.clamp(torch.round(clipped / scale[:, None]), -INT6_QUANT_RANGE, INT6_QUANT_RANGE).to(torch.int8).contiguous() - return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() - - # Vectors / scalars still use int8 per-tensor scale (they're tiny). - clip_abs = float(torch.quantile(t32.abs().flatten(), INT6_CLIP_Q).item()) if t32.numel() else 0.0 - scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) - q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() - return q, scale - -def quantize_state_dict_int8(state_dict: dict[str, Tensor]): - # Single supported clean-script export format: - # - per-row int8 for 2D float tensors - # - per-tensor int8 for other float tensors - # - exact passthrough for non-floats - # - passthrough for small float tensors, stored as fp16 to save bytes - quantized: dict[str, Tensor] = {} - scales: dict[str, Tensor] = {} - dtypes: dict[str, str] = {} - passthrough: dict[str, Tensor] = {} - passthrough_orig_dtypes: dict[str, str] = {} - qmeta: dict[str, dict[str, object]] = {} - stats = dict.fromkeys( - ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), - 0, - ) - - for name, tensor in state_dict.items(): - t = tensor.detach().to("cpu").contiguous() - stats["param_count"] += int(t.numel()) - stats["num_tensors"] += 1 - stats["baseline_tensor_bytes"] += tensor_nbytes(t) - - if not t.is_floating_point(): - stats["num_nonfloat_tensors"] += 1 - passthrough[name] = t - stats["int8_payload_bytes"] += tensor_nbytes(t) - continue - - # Small float tensors are cheap enough to keep directly. We still downcast - # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. - if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: - kept = keep_float_tensor(name, t, passthrough_orig_dtypes) - passthrough[name] = kept - stats["int8_payload_bytes"] += tensor_nbytes(kept) - continue - - stats["num_float_tensors"] += 1 - q, s = quantize_float_tensor(t) - if s.ndim > 0: - qmeta[name] = {"scheme": "per_row", "axis": 0} - quantized[name] = q - scales[name] = s - dtypes[name] = str(t.dtype).removeprefix("torch.") - stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) - - obj: dict[str, object] = { - "__quant_format__": "int8_clean_per_row_v1", - "quantized": quantized, - "scales": scales, - "dtypes": dtypes, - "passthrough": passthrough, - } - if qmeta: - obj["qmeta"] = qmeta - if passthrough_orig_dtypes: - obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes - return obj, stats - -def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: - out: dict[str, Tensor] = {} - qmeta = obj.get("qmeta", {}) - passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) - for name, q in obj["quantized"].items(): - dtype = getattr(torch, obj["dtypes"][name]) - s = obj["scales"][name] - if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: - s = s.to(dtype=torch.float32) - # Broadcast the saved row scale back across trailing dimensions. - out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() - else: - scale = float(s.item()) - out[name] = (q.float() * scale).to(dtype=dtype).contiguous() - for name, t in obj["passthrough"].items(): - # Restore small tensors, undoing the temporary fp16 storage cast if needed. - out_t = t.detach().to("cpu").contiguous() - orig_dtype = passthrough_orig_dtypes.get(name) - if isinstance(orig_dtype, str): - out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() - out[name] = out_t - return out - - -# ----------------------------- -# DATA LOADING -# ----------------------------- - -def load_data_shard(file: Path) -> Tensor: - header_bytes = 256 * np.dtype(" None: - self.file_idx = (self.file_idx + 1) % len(self.files) - self.tokens = load_data_shard(self.files[self.file_idx]) - self.pos = 0 - - def take(self, n: int) -> Tensor: - chunks: list[Tensor] = [] - remaining = n - while remaining > 0: - avail = self.tokens.numel() - self.pos - if avail <= 0: - self._advance_file() - continue - k = min(remaining, avail) - chunks.append(self.tokens[self.pos : self.pos + k]) - self.pos += k - remaining -= k - return chunks[0] if len(chunks) == 1 else torch.cat(chunks) - - -class DistributedTokenLoader: - # Each call consumes a contiguous chunk from the shared token stream, then slices out - # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. - def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): - self.rank = rank - self.world_size = world_size - self.device = device - self.stream = TokenStream(pattern) - - def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: - local_tokens = global_tokens // (self.world_size * grad_accum_steps) - per_rank_span = local_tokens + 1 - chunk = self.stream.take(per_rank_span * self.world_size) - start = self.rank * per_rank_span - local = chunk[start : start + per_rank_span].to(dtype=torch.int64) - x = local[:-1].reshape(-1, seq_len) - y = local[1:].reshape(-1, seq_len) - return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) - -# ----------------------------- -# TRANSFORMER MODULES -# ----------------------------- - -class RMSNorm(nn.Module): - def __init__(self, eps: float | None = None): - super().__init__() - self.eps = eps - - def forward(self, x: Tensor) -> Tensor: - return F.rms_norm(x, (x.size(-1),), eps=self.eps) - - -INT6_QUANT_RANGE = 31 # int6: [-31, 31] -INT6_CLIP_Q = 0.9999984 - -class CastedLinear(nn.Linear): - # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. - # During training, apply fake int6 quantization (STE) so the model learns to be robust - # to the post-training quantization that will be applied for the 16MB artifact. - def forward(self, x: Tensor) -> Tensor: - w = self.weight.to(x.dtype) - if self.training and w.ndim == 2: - # Fake int6 per-row quantization with Straight-Through Estimator: - # Forward uses quantized weights, backward passes gradients through as-is. - with torch.no_grad(): - w32 = w.float() - clip_abs = torch.quantile(w32.abs(), INT6_CLIP_Q, dim=1).clamp_min(1e-8) - scale = clip_abs / INT6_QUANT_RANGE - w_clipped = torch.clamp(w32, -clip_abs[:, None], clip_abs[:, None]) - w_q = (torch.round(w_clipped / scale[:, None]) * scale[:, None]).to(x.dtype) - w = w + (w_q - w).detach() # STE: value of w_q, gradient of w - bias = self.bias.to(x.dtype) if self.bias is not None else None - return F.linear(x, w, bias) - - -def restore_low_dim_params_to_fp32(module: nn.Module) -> None: - # Keep small/control parameters in fp32 even when the model body runs in bf16. - with torch.no_grad(): - for name, param in module.named_parameters(): - if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: - param.data = param.data.float() - - -class Rotary(nn.Module): - # Caches cos/sin tables per sequence length on the current device. - def __init__(self, dim: int, base: float = 10000.0): - super().__init__() - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self._seq_len_cached = 0 - self._cos_cached: Tensor | None = None - self._sin_cached: Tensor | None = None - - def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: - if ( - self._cos_cached is None - or self._sin_cached is None - or self._seq_len_cached != seq_len - or self._cos_cached.device != device - ): - t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) - freqs = torch.outer(t, self.inv_freq.to(device)) - self._cos_cached = freqs.cos()[None, None, :, :] - self._sin_cached = freqs.sin()[None, None, :, :] - self._seq_len_cached = seq_len - return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) - - -def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: - half = x.size(-1) // 2 - x1, x2 = x[..., :half], x[..., half:] - return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) - - -class CausalSelfAttention(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - rope_base: float, - qk_gain_init: float, - ): - super().__init__() - if dim % num_heads != 0: - raise ValueError("model_dim must be divisible by num_heads") - if num_heads % num_kv_heads != 0: - raise ValueError("num_heads must be divisible by num_kv_heads") - self.num_heads = num_heads - self.num_kv_heads = num_kv_heads - self.head_dim = dim // num_heads - if self.head_dim % 2 != 0: - raise ValueError("head_dim must be even for RoPE") - kv_dim = self.num_kv_heads * self.head_dim - self.c_q = CastedLinear(dim, dim, bias=False) - self.c_k = CastedLinear(dim, kv_dim, bias=False) - self.c_v = CastedLinear(dim, kv_dim, bias=False) - self.proj = CastedLinear(dim, dim, bias=False) - self.proj._zero_init = True - self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) - self.rotary = Rotary(self.head_dim, base=rope_base) - - def forward(self, x: Tensor) -> Tensor: - bsz, seqlen, dim = x.shape - q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) - k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) - v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) - q = F.rms_norm(q, (q.size(-1),)) - k = F.rms_norm(k, (k.size(-1),)) - cos, sin = self.rotary(seqlen, x.device, q.dtype) - q = apply_rotary_emb(q, cos, sin) - k = apply_rotary_emb(k, cos, sin) - q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] - y = F.scaled_dot_product_attention( - q, - k, - v, - attn_mask=None, - is_causal=True, - enable_gqa=(self.num_kv_heads != self.num_heads), - ) - y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) - return self.proj(y) - - -class MLP(nn.Module): - # relu^2 MLP from the original modded-nanogpt setup - def __init__(self, dim: int, mlp_mult: int): - super().__init__() - hidden = mlp_mult * dim - self.fc = CastedLinear(dim, hidden, bias=False) - self.proj = CastedLinear(hidden, dim, bias=False) - self.proj._zero_init = True - - def forward(self, x: Tensor) -> Tensor: - x = torch.relu(self.fc(x)) - return self.proj(x.square()) - - -class Block(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: int, - rope_base: float, - qk_gain_init: float, - ): - super().__init__() - self.attn_norm = RMSNorm() - self.mlp_norm = RMSNorm() - self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) - self.mlp = MLP(dim, mlp_mult) - self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) - - def forward(self, x: Tensor, x0: Tensor) -> Tensor: - mix = self.resid_mix.to(dtype=x.dtype) - x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 - attn_out = self.attn(self.attn_norm(x)) - x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out - x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) - return x - - -class GPT(nn.Module): - def __init__( - self, - vocab_size: int, - num_layers: int, - model_dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: int, - tie_embeddings: bool, - tied_embed_init_std: float, - logit_softcap: float, - rope_base: float, - qk_gain_init: float, - ): - super().__init__() - if logit_softcap <= 0.0: - raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") - self.tie_embeddings = tie_embeddings - self.tied_embed_init_std = tied_embed_init_std - self.logit_softcap = logit_softcap - self.tok_emb = nn.Embedding(vocab_size, model_dim) - self.num_encoder_layers = num_layers // 2 - self.num_decoder_layers = num_layers - self.num_encoder_layers - self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) - self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) - self.blocks = nn.ModuleList( - [ - Block( - model_dim, - num_heads, - num_kv_heads, - mlp_mult, - rope_base, - qk_gain_init, - ) - for i in range(num_layers) - ] - ) - self.final_norm = RMSNorm() - self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) - if self.lm_head is not None: - self.lm_head._zero_init = True - self._init_weights() - - def _init_weights(self) -> None: - if self.tie_embeddings: - nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) - for module in self.modules(): - if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): - nn.init.zeros_(module.weight) - - def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: - x = self.tok_emb(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - x0 = x - skips: list[Tensor] = [] - - # First half stores skips; second half reuses them in reverse order. - for i in range(self.num_encoder_layers): - x = self.blocks[i](x, x0) - skips.append(x) - for i in range(self.num_decoder_layers): - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - x = self.blocks[self.num_encoder_layers + i](x, x0) - - x = self.final_norm(x).reshape(-1, x.size(-1)) - targets = target_ids.reshape(-1) - if self.tie_embeddings: - logits_proj = F.linear(x, self.tok_emb.weight) - else: - if self.lm_head is None: - raise RuntimeError("lm_head is required when tie_embeddings=False") - logits_proj = self.lm_head(x) - logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - return F.cross_entropy(logits.float(), targets, reduction="mean") - - def forward_logits(self, input_ids: Tensor) -> Tensor: - """Return logits without computing loss. Used for sliding window eval.""" - x = self.tok_emb(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - x0 = x - skips: list[Tensor] = [] - for i in range(self.num_encoder_layers): - x = self.blocks[i](x, x0) - skips.append(x) - for i in range(self.num_decoder_layers): - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - x = self.blocks[self.num_encoder_layers + i](x, x0) - x = self.final_norm(x) - if self.tie_embeddings: - logits_proj = F.linear(x, self.tok_emb.weight.to(x.dtype)) - else: - if self.lm_head is None: - raise RuntimeError("lm_head is required when tie_embeddings=False") - logits_proj = self.lm_head(x) - return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - - -# ----------------------------- -# TRAINING -# ----------------------------- - -def main() -> None: - global zeropower_via_newtonschulz5 - - code = Path(__file__).read_text(encoding="utf-8") - args = Hyperparameters() - zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) - - # ----------------------------- - # DISTRIBUTED + CUDA SETUP - # ----------------------------- - - distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ - rank = int(os.environ.get("RANK", "0")) - world_size = int(os.environ.get("WORLD_SIZE", "1")) - local_rank = int(os.environ.get("LOCAL_RANK", "0")) - if world_size <= 0: - raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") - if 8 % world_size != 0: - raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") - grad_accum_steps = 8 // world_size - grad_scale = 1.0 / grad_accum_steps - if not torch.cuda.is_available(): - raise RuntimeError("CUDA is required") - device = torch.device("cuda", local_rank) - torch.cuda.set_device(device) - if distributed: - dist.init_process_group(backend="nccl", device_id=device) - dist.barrier() - master_process = rank == 0 - - # Fast math knobs - torch.backends.cuda.matmul.allow_tf32 = True - torch.backends.cudnn.allow_tf32 = True - from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp - - enable_cudnn_sdp(False) - enable_flash_sdp(True) - enable_mem_efficient_sdp(False) - enable_math_sdp(False) - - logfile = None - if master_process: - os.makedirs("logs", exist_ok=True) - logfile = f"logs/{args.run_id}.txt" - print(logfile) - - def log0(msg: str, console: bool = True) -> None: - if not master_process: - return - if console: - print(msg) - if logfile is not None: - with open(logfile, "a", encoding="utf-8") as f: - print(msg, file=f) - - log0(code, console=False) - log0("=" * 100, console=False) - log0(f"Running Python {sys.version}", console=False) - log0(f"Running PyTorch {torch.__version__}", console=False) - log0( - subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, - console=False, - ) - log0("=" * 100, console=False) - - # ----------------------------- - # TOKENIZER + VALIDATION METRIC SETUP - # ----------------------------- - - random.seed(args.seed) - np.random.seed(args.seed) - torch.manual_seed(args.seed) - torch.cuda.manual_seed_all(args.seed) - - if not args.tokenizer_path.endswith(".model"): - raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") - sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) - if int(sp.vocab_size()) != args.vocab_size: - raise ValueError( - f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" - ) - dataset_dir = Path(args.data_path).resolve() - actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) - val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) - base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( - sp, args.vocab_size, device - ) - log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") - log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") - log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") - - # ----------------------------- - # MODEL + OPTIMIZER SETUP - # ----------------------------- - - base_model = GPT( - vocab_size=args.vocab_size, - num_layers=args.num_layers, - model_dim=args.model_dim, - num_heads=args.num_heads, - num_kv_heads=args.num_kv_heads, - mlp_mult=args.mlp_mult, - tie_embeddings=args.tie_embeddings, - tied_embed_init_std=args.tied_embed_init_std, - logit_softcap=args.logit_softcap, - rope_base=args.rope_base, - qk_gain_init=args.qk_gain_init, - ).to(device).bfloat16() - for module in base_model.modules(): - if isinstance(module, CastedLinear): - module.float() - restore_low_dim_params_to_fp32(base_model) - compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) - model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model - - # Optimizer split: - # - token embedding (Adam) uses EMBED_LR - # - untied lm_head (Adam) uses HEAD_LR - # - matrix params in transformer blocks use MATRIX_LR via Muon - # - vectors/scalars use SCALAR_LR via Adam - block_named_params = list(base_model.blocks.named_parameters()) - matrix_params = [ - p - for name, p in block_named_params - if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - scalar_params = [ - p - for name, p in block_named_params - if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - if base_model.skip_weights.numel() > 0: - scalar_params.append(base_model.skip_weights) - token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr - optimizer_tok = torch.optim.Adam( - [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizer_muon = Muon( - matrix_params, - lr=args.matrix_lr, - momentum=args.muon_momentum, - backend_steps=args.muon_backend_steps, - ) - for group in optimizer_muon.param_groups: - group["base_lr"] = args.matrix_lr - optimizer_scalar = torch.optim.Adam( - [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] - if base_model.lm_head is not None: - optimizer_head = torch.optim.Adam( - [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizers.insert(1, optimizer_head) - - n_params = sum(p.numel() for p in base_model.parameters()) - log0(f"model_params:{n_params}") - log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") - log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") - log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") - log0( - f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " - f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " - f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" - ) - log0( - f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " - f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " - f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" - ) - log0(f"seed:{args.seed}") - - # ----------------------------- - # DATA LOADER & MODEL WARMUP - # ----------------------------- - - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - def zero_grad_all() -> None: - for opt in optimizers: - opt.zero_grad(set_to_none=True) - - max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None - - def lr_mul(step: int, elapsed_ms: float) -> float: - if args.warmdown_iters <= 0: - return 1.0 - if max_wallclock_ms is None: - warmdown_start = max(args.iterations - args.warmdown_iters, 0) - return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 - step_ms = elapsed_ms / max(step, 1) - warmdown_ms = args.warmdown_iters * step_ms - remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) - return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 - - # Warmup primes the compiled forward/backward/optimizer paths, then we restore the - # initial weights/optimizer state so measured training starts from the true init. - if args.warmup_steps > 0: - initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} - initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] - model.train() - for warmup_step in range(args.warmup_steps): - zero_grad_all() - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - warmup_loss = model(x, y) - (warmup_loss * grad_scale).backward() - for opt in optimizers: - opt.step() - zero_grad_all() - if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: - log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") - base_model.load_state_dict(initial_model_state, strict=True) - for opt, state in zip(optimizers, initial_optimizer_states, strict=True): - opt.load_state_dict(state) - zero_grad_all() - if distributed: - model.require_backward_grad_sync = True - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - # ----------------------------- - # MAIN TRAINING LOOP - # ----------------------------- - - training_time_ms = 0.0 - stop_after_step: int | None = None - torch.cuda.synchronize() - t0 = time.perf_counter() - - step = 0 - while True: - last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) - - should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) - if should_validate: - torch.cuda.synchronize() - training_time_ms += 1000.0 * (time.perf_counter() - t0) - val_loss, val_bpb = eval_val( - args, - model, - rank, - world_size, - device, - grad_accum_steps, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, - ) - log0( - f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " - f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" - ) - torch.cuda.synchronize() - t0 = time.perf_counter() - - if last_step: - if stop_after_step is not None and step < args.iterations: - log0( - f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " - f"step:{step}/{args.iterations}" - ) - break - - elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - scale = lr_mul(step, elapsed_ms) - zero_grad_all() - train_loss = torch.zeros((), device=device) - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - loss = model(x, y) - train_loss += loss.detach() - (loss * grad_scale).backward() - train_loss /= grad_accum_steps - - frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 - muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum - for group in optimizer_muon.param_groups: - group["momentum"] = muon_momentum - - for opt in optimizers: - for group in opt.param_groups: - group["lr"] = group["base_lr"] * scale - - if args.grad_clip_norm > 0: - torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) - for opt in optimizers: - opt.step() - zero_grad_all() - - step += 1 - approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - should_log_train = ( - args.train_log_every > 0 - and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) - ) - if should_log_train: - log0( - f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " - f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" - ) - - # Needed to sync whether we've reached the wallclock cap. - reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms - if distributed and max_wallclock_ms is not None: - reached_cap_tensor = torch.tensor(int(reached_cap), device=device) - dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) - reached_cap = bool(reached_cap_tensor.item()) - if stop_after_step is None and reached_cap: - stop_after_step = step - - log0( - f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " - f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" - ) - - # ----------------------------- - # SERIALIZATION + ROUNDTRIP VALIDATION - # ----------------------------- - # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce - # the compressed int8+zlib artifact and validate the round-tripped weights. - - if master_process: - torch.save(base_model.state_dict(), "final_model.pt") - model_bytes = os.path.getsize("final_model.pt") - code_bytes = len(code.encode("utf-8")) - log0(f"Serialized model: {model_bytes} bytes") - log0(f"Code size: {code_bytes} bytes") - log0(f"Total submission size: {model_bytes + code_bytes} bytes") - - quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) - quant_buf = io.BytesIO() - torch.save(quant_obj, quant_buf) - quant_raw = quant_buf.getvalue() - if HAS_ZSTD: - cctx = zstd.ZstdCompressor(level=22) - quant_blob = cctx.compress(quant_raw) - compress_name = "zstd-22" - else: - quant_blob = zlib.compress(quant_raw, level=9) - compress_name = "zlib-9" - quant_raw_bytes = len(quant_raw) - if master_process: - quant_ext = "int6.ptz" if HAS_ZSTD else "int6.ptz" - with open(f"final_model.{quant_ext}", "wb") as f: - f.write(quant_blob) - quant_file_bytes = os.path.getsize(f"final_model.{quant_ext}") - code_bytes = len(code.encode("utf-8")) - ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) - log0( - f"Serialized model int6+{compress_name}: {quant_file_bytes} bytes " - f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" - ) - log0(f"Total submission size int6+{compress_name}: {quant_file_bytes + code_bytes} bytes") - - if distributed: - dist.barrier() - quant_ext = "int6.ptz" - with open(f"final_model.{quant_ext}", "rb") as f: - quant_blob_disk = f.read() - if HAS_ZSTD: - dctx = zstd.ZstdDecompressor() - quant_decompressed = dctx.decompress(quant_blob_disk) - else: - quant_decompressed = zlib.decompress(quant_blob_disk) - quant_state = torch.load(io.BytesIO(quant_decompressed), map_location="cpu") - base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) - torch.cuda.synchronize() - t_qeval = time.perf_counter() - q_val_loss, q_val_bpb = eval_val( - args, - model, - rank, - world_size, - device, - grad_accum_steps, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, - ) - torch.cuda.synchronize() - log0( - f"final_int6_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " - f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" - ) - log0(f"final_int6_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") - - # Sliding window eval: gives each scored token (seq_len - stride) context. - if args.eval_stride > 0: - torch.cuda.synchronize() - t_sw = time.perf_counter() - sw_val_loss, sw_val_bpb = eval_val_sliding_window( - args, - base_model, - rank, - world_size, - device, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, - stride=args.eval_stride, - ) - torch.cuda.synchronize() - log0( - f"final_sliding_window_eval stride:{args.eval_stride} " - f"val_loss:{sw_val_loss:.4f} val_bpb:{sw_val_bpb:.4f} " - f"eval_time:{1000.0 * (time.perf_counter() - t_sw):.0f}ms" - ) - log0( - f"final_sliding_window_eval_exact stride:{args.eval_stride} " - f"val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}" - ) - - if distributed: - dist.destroy_process_group() - - -if __name__ == "__main__": - main() -==================================================================================================== -Running Python 3.12.3 (main, Nov 6 2025, 13:44:16) [GCC 13.3.0] -Running PyTorch 2.9.1+cu128 -Thu Mar 19 09:44:13 2026 -+-----------------------------------------------------------------------------------------+ -| NVIDIA-SMI 580.126.09 Driver Version: 580.126.09 CUDA Version: 13.0 | -+-----------------------------------------+------------------------+----------------------+ -| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | -| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | -| | | MIG M. | -|=========================================+========================+======================| -| 0 NVIDIA H100 80GB HBM3 On | 00000000:0A:00.0 Off | 0 | -| N/A 35C P0 120W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 1 NVIDIA H100 80GB HBM3 On | 00000000:18:00.0 Off | 0 | -| N/A 30C P0 116W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 2 NVIDIA H100 80GB HBM3 On | 00000000:2F:00.0 Off | 0 | -| N/A 29C P0 113W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 3 NVIDIA H100 80GB HBM3 On | 00000000:38:00.0 Off | 0 | -| N/A 32C P0 119W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 4 NVIDIA H100 80GB HBM3 On | 00000000:87:00.0 Off | 0 | -| N/A 32C P0 116W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 5 NVIDIA H100 80GB HBM3 On | 00000000:90:00.0 Off | 0 | -| N/A 30C P0 118W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 6 NVIDIA H100 80GB HBM3 On | 00000000:BE:00.0 Off | 0 | -| N/A 31C P0 121W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 7 NVIDIA H100 80GB HBM3 On | 00000000:C7:00.0 Off | 0 | -| N/A 32C P0 118W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ - -+-----------------------------------------------------------------------------------------+ -| Processes: | -| GPU GI CI PID Type Process name GPU Memory | -| ID ID Usage | -|=========================================================================================| -| 0 N/A N/A 2072 C /usr/local/bin/python 1512MiB | -| 1 N/A N/A 2073 C /usr/local/bin/python 1512MiB | -| 2 N/A N/A 2074 C /usr/local/bin/python 1512MiB | -| 3 N/A N/A 2075 C /usr/local/bin/python 1512MiB | -| 4 N/A N/A 2076 C /usr/local/bin/python 1512MiB | -| 5 N/A N/A 2077 C /usr/local/bin/python 1512MiB | -| 6 N/A N/A 2078 C /usr/local/bin/python 1512MiB | -| 7 N/A N/A 2079 C /usr/local/bin/python 1512MiB | -+-----------------------------------------------------------------------------------------+ - -==================================================================================================== -val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model -train_loader:dataset:fineweb10B_sp1024 train_shards:20 -val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 -model_params:21778504 -world_size:8 grad_accum_steps:1 -sdp_backends:cudnn=False flash=True mem_efficient=False math=False -attention_mode:gqa num_heads:8 num_kv_heads:4 -tie_embeddings:True embed_lr:0.03 head_lr:0.0 matrix_lr:0.02 scalar_lr:0.02 -train_batch_tokens:393216 train_seq_len:4096 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 -seed:1337 -warmup_step:1/20 -warmup_step:2/20 -warmup_step:3/20 -warmup_step:4/20 -warmup_step:5/20 -warmup_step:6/20 -warmup_step:7/20 -warmup_step:8/20 -warmup_step:9/20 -warmup_step:10/20 -warmup_step:11/20 -warmup_step:12/20 -warmup_step:13/20 -warmup_step:14/20 -warmup_step:15/20 -warmup_step:16/20 -warmup_step:17/20 -warmup_step:18/20 -warmup_step:19/20 -warmup_step:20/20 -step:1/20000 train_loss:6.9373 train_time:59ms step_avg:58.72ms -step:2/20000 train_loss:12.2289 train_time:80ms step_avg:39.82ms -step:3/20000 train_loss:7.5371 train_time:134ms step_avg:44.70ms -step:4/20000 train_loss:6.2662 train_time:189ms step_avg:47.18ms -step:5/20000 train_loss:6.7749 train_time:244ms step_avg:48.71ms -step:6/20000 train_loss:6.8398 train_time:298ms step_avg:49.73ms -step:7/20000 train_loss:6.7133 train_time:353ms step_avg:50.43ms -step:8/20000 train_loss:6.6259 train_time:408ms step_avg:50.98ms -step:9/20000 train_loss:6.2381 train_time:463ms step_avg:51.41ms -step:10/20000 train_loss:6.1987 train_time:517ms step_avg:51.73ms -step:200/20000 train_loss:2.7737 train_time:10932ms step_avg:54.66ms -step:400/20000 train_loss:2.4091 train_time:21963ms step_avg:54.91ms -step:600/20000 train_loss:2.2791 train_time:33022ms step_avg:55.04ms -step:800/20000 train_loss:2.3532 train_time:44111ms step_avg:55.14ms -step:1000/20000 train_loss:2.3124 train_time:55178ms step_avg:55.18ms -step:1200/20000 train_loss:2.3323 train_time:66301ms step_avg:55.25ms -step:1400/20000 train_loss:2.2776 train_time:77425ms step_avg:55.30ms -step:1600/20000 train_loss:2.2437 train_time:88549ms step_avg:55.34ms -step:1800/20000 train_loss:2.0226 train_time:99669ms step_avg:55.37ms -step:2000/20000 train_loss:2.2808 train_time:110737ms step_avg:55.37ms -step:2200/20000 train_loss:2.0338 train_time:121855ms step_avg:55.39ms -step:2400/20000 train_loss:2.1836 train_time:132969ms step_avg:55.40ms -step:2600/20000 train_loss:2.3013 train_time:144085ms step_avg:55.42ms -step:2800/20000 train_loss:2.3646 train_time:155210ms step_avg:55.43ms -step:3000/20000 train_loss:2.1464 train_time:166276ms step_avg:55.43ms -step:3200/20000 train_loss:2.1688 train_time:177394ms step_avg:55.44ms -step:3400/20000 train_loss:2.1724 train_time:188515ms step_avg:55.45ms -step:3600/20000 train_loss:2.0546 train_time:199637ms step_avg:55.45ms -step:3800/20000 train_loss:2.0704 train_time:210709ms step_avg:55.45ms -step:4000/20000 train_loss:1.8991 train_time:221829ms step_avg:55.46ms -step:4200/20000 train_loss:1.9806 train_time:232956ms step_avg:55.47ms -step:4400/20000 train_loss:2.0861 train_time:244084ms step_avg:55.47ms -step:4600/20000 train_loss:2.0158 train_time:255204ms step_avg:55.48ms -step:4800/20000 train_loss:2.0401 train_time:266278ms step_avg:55.47ms -step:5000/20000 train_loss:2.1076 train_time:277494ms step_avg:55.50ms -step:5200/20000 train_loss:2.0759 train_time:288628ms step_avg:55.51ms -step:5400/20000 train_loss:2.1771 train_time:299751ms step_avg:55.51ms -step:5600/20000 train_loss:2.0141 train_time:310878ms step_avg:55.51ms -step:5800/20000 train_loss:2.1045 train_time:321955ms step_avg:55.51ms -step:6000/20000 train_loss:2.0173 train_time:333077ms step_avg:55.51ms -step:6200/20000 train_loss:1.9834 train_time:344195ms step_avg:55.52ms -step:6400/20000 train_loss:2.0574 train_time:355323ms step_avg:55.52ms -step:6600/20000 train_loss:2.0057 train_time:366399ms step_avg:55.51ms -step:6800/20000 train_loss:1.8728 train_time:377522ms step_avg:55.52ms -step:7000/20000 train_loss:2.0247 train_time:388645ms step_avg:55.52ms -step:7200/20000 train_loss:2.0745 train_time:399763ms step_avg:55.52ms -step:7400/20000 train_loss:2.0938 train_time:410884ms step_avg:55.52ms -step:7600/20000 train_loss:1.9440 train_time:421961ms step_avg:55.52ms -step:7800/20000 train_loss:1.8961 train_time:433081ms step_avg:55.52ms -step:8000/20000 train_loss:1.9993 train_time:444199ms step_avg:55.52ms -step:8200/20000 train_loss:2.0314 train_time:455324ms step_avg:55.53ms -step:8400/20000 train_loss:2.0255 train_time:466445ms step_avg:55.53ms -step:8600/20000 train_loss:2.0221 train_time:477511ms step_avg:55.52ms -step:8800/20000 train_loss:1.9800 train_time:488632ms step_avg:55.53ms -step:9000/20000 train_loss:1.9318 train_time:499754ms step_avg:55.53ms -step:9200/20000 train_loss:2.0151 train_time:510878ms step_avg:55.53ms -step:9400/20000 train_loss:1.9677 train_time:521951ms step_avg:55.53ms -step:9600/20000 train_loss:1.8810 train_time:533075ms step_avg:55.53ms -step:9800/20000 train_loss:1.8309 train_time:544198ms step_avg:55.53ms -step:10000/20000 train_loss:2.2471 train_time:555333ms step_avg:55.53ms -step:10200/20000 train_loss:1.9500 train_time:566455ms step_avg:55.53ms -step:10400/20000 train_loss:2.0805 train_time:577533ms step_avg:55.53ms -step:10600/20000 train_loss:1.9196 train_time:588662ms step_avg:55.53ms -step:10800/20000 train_loss:1.9962 train_time:599781ms step_avg:55.54ms -step:10804/20000 val_loss:1.9787 val_bpb:1.1719 train_time:600037ms step_avg:55.54ms -stopping_early: wallclock_cap train_time:600037ms step:10804/20000 -peak memory allocated: 8517 MiB reserved: 9032 MiB -Serialized model: 86099351 bytes -Code size: 55931 bytes -Total submission size: 86155282 bytes -Serialized model int6+zlib-9: 15160388 bytes (payload:21906720 raw_torch:21951833 payload_ratio:3.93x) -Total submission size int6+zlib-9: 15216319 bytes -final_int6_roundtrip val_loss:2.0436 val_bpb:1.2103 eval_time:2262ms -final_int6_roundtrip_exact val_loss:2.04356694 val_bpb:1.21031545 -final_sliding_window_eval stride:64 val_loss:2.0222 val_bpb:1.1976 eval_time:304360ms -final_sliding_window_eval_exact stride:64 val_loss:2.02215627 val_bpb:1.19763674 -""" -The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. - -Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. -""" - -from __future__ import annotations - -import copy -import glob -import io -import math -import os -import random -import subprocess -import sys -import time -import uuid -import zlib -from pathlib import Path - -try: - import zstandard as zstd - HAS_ZSTD = True -except ImportError: - HAS_ZSTD = False - -import numpy as np -import sentencepiece as spm -import torch -import torch.distributed as dist -import torch.nn.functional as F -from torch import Tensor, nn -from torch.nn.parallel import DistributedDataParallel as DDP - -# ----------------------------- -# HYPERPARAMETERS -# ----------------------------- -# Default Simple Baseline run: -# - 9 transformer blocks at width 512 -# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion -# - vocab size 1024, sequence length 1024, tied embeddings -# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap - -class Hyperparameters: - # Data paths are shard globs produced by the existing preprocessing pipeline. - data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") - train_files = os.path.join(data_path, "fineweb_train_*.bin") - val_files = os.path.join(data_path, "fineweb_val_*.bin") - tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") - run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) - seed = int(os.environ.get("SEED", 1337)) - - # Validation cadence and batch size. Validation always uses the full fineweb_val split. - val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) - val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 0)) - train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) - # Sliding window eval: stride controls how many tokens to slide between windows. - # Only the last `eval_stride` tokens per window are scored, giving each scored token - # (seq_len - eval_stride) tokens of context. Set to 0 to disable (use standard eval). - eval_stride = int(os.environ.get("EVAL_STRIDE", 256)) - - # Training length. - iterations = int(os.environ.get("ITERATIONS", 20000)) - warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3000)) - warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) - train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) - train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) - max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) - qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) - - # Model shape. - vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) - num_layers = int(os.environ.get("NUM_LAYERS", 9)) - num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) - model_dim = int(os.environ.get("MODEL_DIM", 512)) - num_heads = int(os.environ.get("NUM_HEADS", 8)) - mlp_mult = int(os.environ.get("MLP_MULT", 3)) - tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) - rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) - logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) - - # Optimizer hyperparameters. - embed_lr = float(os.environ.get("EMBED_LR", 0.6)) - head_lr = float(os.environ.get("HEAD_LR", 0.008)) - tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.030)) - tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) - matrix_lr = float(os.environ.get("MATRIX_LR", 0.020)) - scalar_lr = float(os.environ.get("SCALAR_LR", 0.020)) - muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) - muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) - muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) - muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) - beta1 = float(os.environ.get("BETA1", 0.9)) - beta2 = float(os.environ.get("BETA2", 0.95)) - adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) - grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) - -# ----------------------------- -# MUON OPTIMIZER -# ----------------------------- -# -# As borrowed from modded-nanogpt -# Background on Muon: https://kellerjordan.github.io/posts/muon/ - -def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: - # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. - # Muon uses this to normalize matrix-shaped gradients before applying them. - a, b, c = (3.4445, -4.7750, 2.0315) - X = G.bfloat16() - X /= X.norm() + eps - transposed = G.size(0) > G.size(1) - if transposed: - X = X.T - for _ in range(steps): - A = X @ X.T - B = b * A + c * A @ A - X = a * X + B @ X - return X.T if transposed else X - - -class Muon(torch.optim.Optimizer): - def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): - super().__init__( - params, - dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), - ) - - @torch.no_grad() - def step(self, closure=None): - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - distributed = dist.is_available() and dist.is_initialized() - world_size = dist.get_world_size() if distributed else 1 - rank = dist.get_rank() if distributed else 0 - - for group in self.param_groups: - params = group["params"] - if not params: - continue - lr = group["lr"] - momentum = group["momentum"] - backend_steps = group["backend_steps"] - nesterov = group["nesterov"] - - total_params = sum(int(p.numel()) for p in params) - updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) - - curr = 0 - for i, p in enumerate(params): - if i % world_size == rank and p.grad is not None: - g = p.grad - state = self.state[p] - if "momentum_buffer" not in state: - state["momentum_buffer"] = torch.zeros_like(g) - buf = state["momentum_buffer"] - buf.mul_(momentum).add_(g) - if nesterov: - g = g.add(buf, alpha=momentum) - g = zeropower_via_newtonschulz5(g, steps=backend_steps) - # Scale correction from Muon reference implementations. - g *= max(1, g.size(0) / g.size(1)) ** 0.5 - updates_flat[curr : curr + p.numel()] = g.reshape(-1) - curr += p.numel() - - if distributed: - dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) - - curr = 0 - for p in params: - g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) - p.add_(g, alpha=-lr) - curr += p.numel() - - return loss - - -# ----------------------------- -# TOKENIZER-AGNOSTIC EVALUATION SETUP -# ----------------------------- -# -# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. -# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. -# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. -# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. - -def build_sentencepiece_luts( - sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device -) -> tuple[Tensor, Tensor, Tensor]: - sp_vocab_size = int(sp.vocab_size()) - table_size = max(sp_vocab_size, vocab_size) - base_bytes_np = np.zeros((table_size,), dtype=np.int16) - has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) - is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) - for token_id in range(sp_vocab_size): - if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): - continue - is_boundary_token_np[token_id] = False - if sp.is_byte(token_id): - base_bytes_np[token_id] = 1 - continue - piece = sp.id_to_piece(token_id) - if piece.startswith("▁"): - has_leading_space_np[token_id] = True - piece = piece[1:] - base_bytes_np[token_id] = len(piece.encode("utf-8")) - return ( - torch.tensor(base_bytes_np, dtype=torch.int16, device=device), - torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), - torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), - ) - - -def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: - files = [Path(p) for p in sorted(glob.glob(pattern))] - if not files: - raise FileNotFoundError(f"No files found for pattern: {pattern}") - # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. - tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() - usable = ((tokens.numel() - 1) // seq_len) * seq_len - if usable <= 0: - raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") - return tokens[: usable + 1] - - -def eval_val( - args: Hyperparameters, - model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - grad_accum_steps: int, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, -) -> tuple[float, float]: - # Validation computes two metrics: - # - val_loss: token cross-entropy (natural log) - # - val_bpb: tokenizer-agnostic compression metric used by the challenge - local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) - if local_batch_tokens < args.train_seq_len: - raise ValueError( - "VAL_BATCH_SIZE must provide at least one sequence per rank; " - f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " - f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" - ) - local_batch_seqs = local_batch_tokens // args.train_seq_len - total_seqs = (val_tokens.numel() - 1) // args.train_seq_len - seq_start = (total_seqs * rank) // world_size - seq_end = (total_seqs * (rank + 1)) // world_size - val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) - val_token_count = torch.zeros((), device=device, dtype=torch.float64) - val_byte_count = torch.zeros((), device=device, dtype=torch.float64) - - model.eval() - with torch.inference_mode(): - for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): - batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) - raw_start = batch_seq_start * args.train_seq_len - raw_end = batch_seq_end * args.train_seq_len + 1 - local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) - x = local[:-1].reshape(-1, args.train_seq_len) - y = local[1:].reshape(-1, args.train_seq_len) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - batch_loss = model(x, y).detach() - batch_token_count = float(y.numel()) - val_loss_sum += batch_loss.to(torch.float64) * batch_token_count - val_token_count += batch_token_count - prev_ids = x.reshape(-1) - tgt_ids = y.reshape(-1) - token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) - token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) - val_byte_count += token_bytes.to(torch.float64).sum() - - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) - - val_loss = val_loss_sum / val_token_count - bits_per_token = val_loss.item() / math.log(2.0) - tokens_per_byte = val_token_count.item() / val_byte_count.item() - model.train() - return float(val_loss.item()), float(bits_per_token * tokens_per_byte) - - -def eval_val_sliding_window( - args: Hyperparameters, - base_model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, - stride: int = 64, -) -> tuple[float, float]: - """Sliding window evaluation: each scored token gets (seq_len - stride) context. - - Instead of chopping validation into non-overlapping 1024-token blocks (where - the first token in each block gets zero context), we slide a 1024-token window - by `stride` tokens at a time and only score the last `stride` tokens per window. - Every scored token sees 960+ tokens of context, dramatically improving BPB. - """ - seq_len = args.train_seq_len - total_tokens = val_tokens.numel() - 1 # need 1 extra for final target - - # Number of windows: first window at offset 0, then slide by stride - num_windows = max((total_tokens - seq_len) // stride + 1, 1) - - # Distribute windows across ranks - win_start = (num_windows * rank) // world_size - win_end = (num_windows * (rank + 1)) // world_size - - val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) - val_token_count = torch.zeros((), device=device, dtype=torch.float64) - val_byte_count = torch.zeros((), device=device, dtype=torch.float64) - - # Eval batch: how many windows per forward pass. Tune for memory. - eval_batch = int(os.environ.get("SW_EVAL_BATCH", 32)) - - base_model.eval() - with torch.inference_mode(): - window_list = list(range(win_start, win_end)) - num_batches = (len(window_list) + eval_batch - 1) // eval_batch - for batch_idx in range(num_batches): - batch_wins = window_list[batch_idx * eval_batch : (batch_idx + 1) * eval_batch] - bsz = len(batch_wins) - - # Build input: each window is val_tokens[w*stride : w*stride + seq_len] - inputs = torch.stack([ - val_tokens[w * stride : w * stride + seq_len] - for w in batch_wins - ]).to(device=device, dtype=torch.int64) # [B, seq_len] - - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - logits = base_model.forward_logits(inputs) # [B, seq_len, vocab] - - # Score only the last `stride` positions per window. - # logits[:, j, :] predicts the token AFTER position j in the input. - # So logits[:, -stride:, :] predicts tokens at input positions - # [seq_len - stride + 1, seq_len + 1) — which are the targets. - scored_logits = logits[:, -stride:, :].reshape(-1, logits.size(-1)) # [B*stride, vocab] - - # Build targets: for window w, targets are val_tokens[w*stride + seq_len - stride + 1 : w*stride + seq_len + 1] - targets = torch.stack([ - val_tokens[w * stride + seq_len - stride + 1 : w * stride + seq_len + 1] - for w in batch_wins - ]).to(device=device, dtype=torch.int64).reshape(-1) # [B*stride] - - loss = F.cross_entropy(scored_logits.float(), targets, reduction="sum") - val_loss_sum += loss.to(torch.float64) - val_token_count += float(targets.numel()) - - # BPB: prev_ids are the input tokens at each scored position - prev_ids = torch.stack([ - val_tokens[w * stride + seq_len - stride : w * stride + seq_len] - for w in batch_wins - ]).to(device=device, dtype=torch.int64).reshape(-1) # [B*stride] - - tgt_ids = targets - token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) - token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) - val_byte_count += token_bytes.to(torch.float64).sum() - - if rank == 0 and (batch_idx + 1) % 100 == 0: - print(f" sw_eval batch {batch_idx + 1}/{num_batches}", flush=True) - - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) - - val_loss = val_loss_sum / val_token_count - bits_per_token = val_loss.item() / math.log(2.0) - tokens_per_byte = val_token_count.item() / val_byte_count.item() - base_model.train() - return float(val_loss.item()), float(bits_per_token * tokens_per_byte) - -# ----------------------------- -# POST-TRAINING QUANTIZATION -# ----------------------------- -# -# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. -# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. -# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. - -CONTROL_TENSOR_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "CONTROL_TENSOR_NAME_PATTERNS", - "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", - ",".join(CONTROL_TENSOR_NAME_PATTERNS), - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 -INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 -INT8_PER_ROW_SCALE_DTYPE = torch.float16 - -def tensor_nbytes(t: Tensor) -> int: - return int(t.numel()) * int(t.element_size()) - -def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: - if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): - return t.float().contiguous() - if t.dtype in {torch.float32, torch.bfloat16}: - passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") - return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() - return t - -def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: - t32 = t.float() - if t32.ndim == 2: - # Int6 per-row quantization: [-31, 31] range stored in int8 containers. - # The unused high bits compress extremely well with zstd. - clip_abs = ( - torch.quantile(t32.abs(), INT6_CLIP_Q, dim=1) - if t32.numel() - else torch.empty((t32.shape[0],), dtype=torch.float32) - ) - clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) - scale = (clip_abs / float(INT6_QUANT_RANGE)).clamp_min(1.0 / float(INT6_QUANT_RANGE)) - q = torch.clamp(torch.round(clipped / scale[:, None]), -INT6_QUANT_RANGE, INT6_QUANT_RANGE).to(torch.int8).contiguous() - return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() - - # Vectors / scalars still use int8 per-tensor scale (they're tiny). - clip_abs = float(torch.quantile(t32.abs().flatten(), INT6_CLIP_Q).item()) if t32.numel() else 0.0 - scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) - q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() - return q, scale - -def quantize_state_dict_int8(state_dict: dict[str, Tensor]): - # Single supported clean-script export format: - # - per-row int8 for 2D float tensors - # - per-tensor int8 for other float tensors - # - exact passthrough for non-floats - # - passthrough for small float tensors, stored as fp16 to save bytes - quantized: dict[str, Tensor] = {} - scales: dict[str, Tensor] = {} - dtypes: dict[str, str] = {} - passthrough: dict[str, Tensor] = {} - passthrough_orig_dtypes: dict[str, str] = {} - qmeta: dict[str, dict[str, object]] = {} - stats = dict.fromkeys( - ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), - 0, - ) - - for name, tensor in state_dict.items(): - t = tensor.detach().to("cpu").contiguous() - stats["param_count"] += int(t.numel()) - stats["num_tensors"] += 1 - stats["baseline_tensor_bytes"] += tensor_nbytes(t) - - if not t.is_floating_point(): - stats["num_nonfloat_tensors"] += 1 - passthrough[name] = t - stats["int8_payload_bytes"] += tensor_nbytes(t) - continue - - # Small float tensors are cheap enough to keep directly. We still downcast - # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. - if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: - kept = keep_float_tensor(name, t, passthrough_orig_dtypes) - passthrough[name] = kept - stats["int8_payload_bytes"] += tensor_nbytes(kept) - continue - - stats["num_float_tensors"] += 1 - q, s = quantize_float_tensor(t) - if s.ndim > 0: - qmeta[name] = {"scheme": "per_row", "axis": 0} - quantized[name] = q - scales[name] = s - dtypes[name] = str(t.dtype).removeprefix("torch.") - stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) - - obj: dict[str, object] = { - "__quant_format__": "int8_clean_per_row_v1", - "quantized": quantized, - "scales": scales, - "dtypes": dtypes, - "passthrough": passthrough, - } - if qmeta: - obj["qmeta"] = qmeta - if passthrough_orig_dtypes: - obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes - return obj, stats - -def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: - out: dict[str, Tensor] = {} - qmeta = obj.get("qmeta", {}) - passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) - for name, q in obj["quantized"].items(): - dtype = getattr(torch, obj["dtypes"][name]) - s = obj["scales"][name] - if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: - s = s.to(dtype=torch.float32) - # Broadcast the saved row scale back across trailing dimensions. - out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() - else: - scale = float(s.item()) - out[name] = (q.float() * scale).to(dtype=dtype).contiguous() - for name, t in obj["passthrough"].items(): - # Restore small tensors, undoing the temporary fp16 storage cast if needed. - out_t = t.detach().to("cpu").contiguous() - orig_dtype = passthrough_orig_dtypes.get(name) - if isinstance(orig_dtype, str): - out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() - out[name] = out_t - return out - - -# ----------------------------- -# DATA LOADING -# ----------------------------- - -def load_data_shard(file: Path) -> Tensor: - header_bytes = 256 * np.dtype(" None: - self.file_idx = (self.file_idx + 1) % len(self.files) - self.tokens = load_data_shard(self.files[self.file_idx]) - self.pos = 0 - - def take(self, n: int) -> Tensor: - chunks: list[Tensor] = [] - remaining = n - while remaining > 0: - avail = self.tokens.numel() - self.pos - if avail <= 0: - self._advance_file() - continue - k = min(remaining, avail) - chunks.append(self.tokens[self.pos : self.pos + k]) - self.pos += k - remaining -= k - return chunks[0] if len(chunks) == 1 else torch.cat(chunks) - - -class DistributedTokenLoader: - # Each call consumes a contiguous chunk from the shared token stream, then slices out - # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. - def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): - self.rank = rank - self.world_size = world_size - self.device = device - self.stream = TokenStream(pattern) - - def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: - local_tokens = global_tokens // (self.world_size * grad_accum_steps) - per_rank_span = local_tokens + 1 - chunk = self.stream.take(per_rank_span * self.world_size) - start = self.rank * per_rank_span - local = chunk[start : start + per_rank_span].to(dtype=torch.int64) - x = local[:-1].reshape(-1, seq_len) - y = local[1:].reshape(-1, seq_len) - return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) - -# ----------------------------- -# TRANSFORMER MODULES -# ----------------------------- - -class RMSNorm(nn.Module): - def __init__(self, eps: float | None = None): - super().__init__() - self.eps = eps - - def forward(self, x: Tensor) -> Tensor: - return F.rms_norm(x, (x.size(-1),), eps=self.eps) - - -INT6_QUANT_RANGE = 31 # int6: [-31, 31] -INT6_CLIP_Q = 0.9999984 - -class CastedLinear(nn.Linear): - # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. - # During training, apply fake int6 quantization (STE) so the model learns to be robust - # to the post-training quantization that will be applied for the 16MB artifact. - def forward(self, x: Tensor) -> Tensor: - w = self.weight.to(x.dtype) - if self.training and w.ndim == 2: - # Fake int6 per-row quantization with Straight-Through Estimator: - # Forward uses quantized weights, backward passes gradients through as-is. - with torch.no_grad(): - w32 = w.float() - clip_abs = torch.quantile(w32.abs(), INT6_CLIP_Q, dim=1).clamp_min(1e-8) - scale = clip_abs / INT6_QUANT_RANGE - w_clipped = torch.clamp(w32, -clip_abs[:, None], clip_abs[:, None]) - w_q = (torch.round(w_clipped / scale[:, None]) * scale[:, None]).to(x.dtype) - w = w + (w_q - w).detach() # STE: value of w_q, gradient of w - bias = self.bias.to(x.dtype) if self.bias is not None else None - return F.linear(x, w, bias) - - -def restore_low_dim_params_to_fp32(module: nn.Module) -> None: - # Keep small/control parameters in fp32 even when the model body runs in bf16. - with torch.no_grad(): - for name, param in module.named_parameters(): - if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: - param.data = param.data.float() - - -class Rotary(nn.Module): - # Caches cos/sin tables per sequence length on the current device. - def __init__(self, dim: int, base: float = 10000.0): - super().__init__() - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self._seq_len_cached = 0 - self._cos_cached: Tensor | None = None - self._sin_cached: Tensor | None = None - - def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: - if ( - self._cos_cached is None - or self._sin_cached is None - or self._seq_len_cached != seq_len - or self._cos_cached.device != device - ): - t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) - freqs = torch.outer(t, self.inv_freq.to(device)) - self._cos_cached = freqs.cos()[None, None, :, :] - self._sin_cached = freqs.sin()[None, None, :, :] - self._seq_len_cached = seq_len - return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) - - -def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: - half = x.size(-1) // 2 - x1, x2 = x[..., :half], x[..., half:] - return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) - - -class CausalSelfAttention(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - rope_base: float, - qk_gain_init: float, - ): - super().__init__() - if dim % num_heads != 0: - raise ValueError("model_dim must be divisible by num_heads") - if num_heads % num_kv_heads != 0: - raise ValueError("num_heads must be divisible by num_kv_heads") - self.num_heads = num_heads - self.num_kv_heads = num_kv_heads - self.head_dim = dim // num_heads - if self.head_dim % 2 != 0: - raise ValueError("head_dim must be even for RoPE") - kv_dim = self.num_kv_heads * self.head_dim - self.c_q = CastedLinear(dim, dim, bias=False) - self.c_k = CastedLinear(dim, kv_dim, bias=False) - self.c_v = CastedLinear(dim, kv_dim, bias=False) - self.proj = CastedLinear(dim, dim, bias=False) - self.proj._zero_init = True - self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) - self.rotary = Rotary(self.head_dim, base=rope_base) - - def forward(self, x: Tensor) -> Tensor: - bsz, seqlen, dim = x.shape - q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) - k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) - v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) - q = F.rms_norm(q, (q.size(-1),)) - k = F.rms_norm(k, (k.size(-1),)) - cos, sin = self.rotary(seqlen, x.device, q.dtype) - q = apply_rotary_emb(q, cos, sin) - k = apply_rotary_emb(k, cos, sin) - q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] - y = F.scaled_dot_product_attention( - q, - k, - v, - attn_mask=None, - is_causal=True, - enable_gqa=(self.num_kv_heads != self.num_heads), - ) - y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) - return self.proj(y) - - -class MLP(nn.Module): - # relu^2 MLP from the original modded-nanogpt setup - def __init__(self, dim: int, mlp_mult: int): - super().__init__() - hidden = mlp_mult * dim - self.fc = CastedLinear(dim, hidden, bias=False) - self.proj = CastedLinear(hidden, dim, bias=False) - self.proj._zero_init = True - - def forward(self, x: Tensor) -> Tensor: - x = torch.relu(self.fc(x)) - return self.proj(x.square()) - - -class Block(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: int, - rope_base: float, - qk_gain_init: float, - ): - super().__init__() - self.attn_norm = RMSNorm() - self.mlp_norm = RMSNorm() - self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) - self.mlp = MLP(dim, mlp_mult) - self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) - - def forward(self, x: Tensor, x0: Tensor) -> Tensor: - mix = self.resid_mix.to(dtype=x.dtype) - x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 - attn_out = self.attn(self.attn_norm(x)) - x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out - x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) - return x - - -class GPT(nn.Module): - def __init__( - self, - vocab_size: int, - num_layers: int, - model_dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: int, - tie_embeddings: bool, - tied_embed_init_std: float, - logit_softcap: float, - rope_base: float, - qk_gain_init: float, - ): - super().__init__() - if logit_softcap <= 0.0: - raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") - self.tie_embeddings = tie_embeddings - self.tied_embed_init_std = tied_embed_init_std - self.logit_softcap = logit_softcap - self.tok_emb = nn.Embedding(vocab_size, model_dim) - self.num_encoder_layers = num_layers // 2 - self.num_decoder_layers = num_layers - self.num_encoder_layers - self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) - self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) - self.blocks = nn.ModuleList( - [ - Block( - model_dim, - num_heads, - num_kv_heads, - mlp_mult, - rope_base, - qk_gain_init, - ) - for i in range(num_layers) - ] - ) - self.final_norm = RMSNorm() - self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) - if self.lm_head is not None: - self.lm_head._zero_init = True - self._init_weights() - - def _init_weights(self) -> None: - if self.tie_embeddings: - nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) - for module in self.modules(): - if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): - nn.init.zeros_(module.weight) - - def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: - x = self.tok_emb(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - x0 = x - skips: list[Tensor] = [] - - # First half stores skips; second half reuses them in reverse order. - for i in range(self.num_encoder_layers): - x = self.blocks[i](x, x0) - skips.append(x) - for i in range(self.num_decoder_layers): - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - x = self.blocks[self.num_encoder_layers + i](x, x0) - - x = self.final_norm(x).reshape(-1, x.size(-1)) - targets = target_ids.reshape(-1) - if self.tie_embeddings: - logits_proj = F.linear(x, self.tok_emb.weight) - else: - if self.lm_head is None: - raise RuntimeError("lm_head is required when tie_embeddings=False") - logits_proj = self.lm_head(x) - logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - return F.cross_entropy(logits.float(), targets, reduction="mean") - - def forward_logits(self, input_ids: Tensor) -> Tensor: - """Return logits without computing loss. Used for sliding window eval.""" - x = self.tok_emb(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - x0 = x - skips: list[Tensor] = [] - for i in range(self.num_encoder_layers): - x = self.blocks[i](x, x0) - skips.append(x) - for i in range(self.num_decoder_layers): - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - x = self.blocks[self.num_encoder_layers + i](x, x0) - x = self.final_norm(x) - if self.tie_embeddings: - logits_proj = F.linear(x, self.tok_emb.weight.to(x.dtype)) - else: - if self.lm_head is None: - raise RuntimeError("lm_head is required when tie_embeddings=False") - logits_proj = self.lm_head(x) - return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - - -# ----------------------------- -# TRAINING -# ----------------------------- - -def main() -> None: - global zeropower_via_newtonschulz5 - - code = Path(__file__).read_text(encoding="utf-8") - args = Hyperparameters() - zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) - - # ----------------------------- - # DISTRIBUTED + CUDA SETUP - # ----------------------------- - - distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ - rank = int(os.environ.get("RANK", "0")) - world_size = int(os.environ.get("WORLD_SIZE", "1")) - local_rank = int(os.environ.get("LOCAL_RANK", "0")) - if world_size <= 0: - raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") - if 8 % world_size != 0: - raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") - grad_accum_steps = 8 // world_size - grad_scale = 1.0 / grad_accum_steps - if not torch.cuda.is_available(): - raise RuntimeError("CUDA is required") - device = torch.device("cuda", local_rank) - torch.cuda.set_device(device) - if distributed: - dist.init_process_group(backend="nccl", device_id=device) - dist.barrier() - master_process = rank == 0 - - # Fast math knobs - torch.backends.cuda.matmul.allow_tf32 = True - torch.backends.cudnn.allow_tf32 = True - from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp - - enable_cudnn_sdp(False) - enable_flash_sdp(True) - enable_mem_efficient_sdp(False) - enable_math_sdp(False) - - logfile = None - if master_process: - os.makedirs("logs", exist_ok=True) - logfile = f"logs/{args.run_id}.txt" - print(logfile) - - def log0(msg: str, console: bool = True) -> None: - if not master_process: - return - if console: - print(msg) - if logfile is not None: - with open(logfile, "a", encoding="utf-8") as f: - print(msg, file=f) - - log0(code, console=False) - log0("=" * 100, console=False) - log0(f"Running Python {sys.version}", console=False) - log0(f"Running PyTorch {torch.__version__}", console=False) - log0( - subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, - console=False, - ) - log0("=" * 100, console=False) - - # ----------------------------- - # TOKENIZER + VALIDATION METRIC SETUP - # ----------------------------- - - random.seed(args.seed) - np.random.seed(args.seed) - torch.manual_seed(args.seed) - torch.cuda.manual_seed_all(args.seed) - - if not args.tokenizer_path.endswith(".model"): - raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") - sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) - if int(sp.vocab_size()) != args.vocab_size: - raise ValueError( - f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" - ) - dataset_dir = Path(args.data_path).resolve() - actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) - val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) - base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( - sp, args.vocab_size, device - ) - log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") - log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") - log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") - - # ----------------------------- - # MODEL + OPTIMIZER SETUP - # ----------------------------- - - base_model = GPT( - vocab_size=args.vocab_size, - num_layers=args.num_layers, - model_dim=args.model_dim, - num_heads=args.num_heads, - num_kv_heads=args.num_kv_heads, - mlp_mult=args.mlp_mult, - tie_embeddings=args.tie_embeddings, - tied_embed_init_std=args.tied_embed_init_std, - logit_softcap=args.logit_softcap, - rope_base=args.rope_base, - qk_gain_init=args.qk_gain_init, - ).to(device).bfloat16() - for module in base_model.modules(): - if isinstance(module, CastedLinear): - module.float() - restore_low_dim_params_to_fp32(base_model) - compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) - model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model - - # Optimizer split: - # - token embedding (Adam) uses EMBED_LR - # - untied lm_head (Adam) uses HEAD_LR - # - matrix params in transformer blocks use MATRIX_LR via Muon - # - vectors/scalars use SCALAR_LR via Adam - block_named_params = list(base_model.blocks.named_parameters()) - matrix_params = [ - p - for name, p in block_named_params - if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - scalar_params = [ - p - for name, p in block_named_params - if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - if base_model.skip_weights.numel() > 0: - scalar_params.append(base_model.skip_weights) - token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr - optimizer_tok = torch.optim.Adam( - [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizer_muon = Muon( - matrix_params, - lr=args.matrix_lr, - momentum=args.muon_momentum, - backend_steps=args.muon_backend_steps, - ) - for group in optimizer_muon.param_groups: - group["base_lr"] = args.matrix_lr - optimizer_scalar = torch.optim.Adam( - [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] - if base_model.lm_head is not None: - optimizer_head = torch.optim.Adam( - [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizers.insert(1, optimizer_head) - - n_params = sum(p.numel() for p in base_model.parameters()) - log0(f"model_params:{n_params}") - log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") - log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") - log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") - log0( - f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " - f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " - f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" - ) - log0( - f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " - f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " - f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" - ) - log0(f"seed:{args.seed}") - - # ----------------------------- - # DATA LOADER & MODEL WARMUP - # ----------------------------- - - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - def zero_grad_all() -> None: - for opt in optimizers: - opt.zero_grad(set_to_none=True) - - max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None - - def lr_mul(step: int, elapsed_ms: float) -> float: - if args.warmdown_iters <= 0: - return 1.0 - if max_wallclock_ms is None: - warmdown_start = max(args.iterations - args.warmdown_iters, 0) - return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 - step_ms = elapsed_ms / max(step, 1) - warmdown_ms = args.warmdown_iters * step_ms - remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) - return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 - - # Warmup primes the compiled forward/backward/optimizer paths, then we restore the - # initial weights/optimizer state so measured training starts from the true init. - if args.warmup_steps > 0: - initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} - initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] - model.train() - for warmup_step in range(args.warmup_steps): - zero_grad_all() - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - warmup_loss = model(x, y) - (warmup_loss * grad_scale).backward() - for opt in optimizers: - opt.step() - zero_grad_all() - if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: - log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") - base_model.load_state_dict(initial_model_state, strict=True) - for opt, state in zip(optimizers, initial_optimizer_states, strict=True): - opt.load_state_dict(state) - zero_grad_all() - if distributed: - model.require_backward_grad_sync = True - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - # ----------------------------- - # MAIN TRAINING LOOP - # ----------------------------- - - training_time_ms = 0.0 - stop_after_step: int | None = None - torch.cuda.synchronize() - t0 = time.perf_counter() - - step = 0 - while True: - last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) - - should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) - if should_validate: - torch.cuda.synchronize() - training_time_ms += 1000.0 * (time.perf_counter() - t0) - val_loss, val_bpb = eval_val( - args, - model, - rank, - world_size, - device, - grad_accum_steps, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, - ) - log0( - f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " - f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" - ) - torch.cuda.synchronize() - t0 = time.perf_counter() - - if last_step: - if stop_after_step is not None and step < args.iterations: - log0( - f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " - f"step:{step}/{args.iterations}" - ) - break - - elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - scale = lr_mul(step, elapsed_ms) - zero_grad_all() - train_loss = torch.zeros((), device=device) - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - loss = model(x, y) - train_loss += loss.detach() - (loss * grad_scale).backward() - train_loss /= grad_accum_steps - - frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 - muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum - for group in optimizer_muon.param_groups: - group["momentum"] = muon_momentum - - for opt in optimizers: - for group in opt.param_groups: - group["lr"] = group["base_lr"] * scale - - if args.grad_clip_norm > 0: - torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) - for opt in optimizers: - opt.step() - zero_grad_all() - - step += 1 - approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - should_log_train = ( - args.train_log_every > 0 - and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) - ) - if should_log_train: - log0( - f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " - f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" - ) - - # Needed to sync whether we've reached the wallclock cap. - reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms - if distributed and max_wallclock_ms is not None: - reached_cap_tensor = torch.tensor(int(reached_cap), device=device) - dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) - reached_cap = bool(reached_cap_tensor.item()) - if stop_after_step is None and reached_cap: - stop_after_step = step - - log0( - f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " - f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" - ) - - # ----------------------------- - # SERIALIZATION + ROUNDTRIP VALIDATION - # ----------------------------- - # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce - # the compressed int8+zlib artifact and validate the round-tripped weights. - - if master_process: - torch.save(base_model.state_dict(), "final_model.pt") - model_bytes = os.path.getsize("final_model.pt") - code_bytes = len(code.encode("utf-8")) - log0(f"Serialized model: {model_bytes} bytes") - log0(f"Code size: {code_bytes} bytes") - log0(f"Total submission size: {model_bytes + code_bytes} bytes") - - quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) - quant_buf = io.BytesIO() - torch.save(quant_obj, quant_buf) - quant_raw = quant_buf.getvalue() - if HAS_ZSTD: - cctx = zstd.ZstdCompressor(level=22) - quant_blob = cctx.compress(quant_raw) - compress_name = "zstd-22" - else: - quant_blob = zlib.compress(quant_raw, level=9) - compress_name = "zlib-9" - quant_raw_bytes = len(quant_raw) - if master_process: - quant_ext = "int6.ptz" if HAS_ZSTD else "int6.ptz" - with open(f"final_model.{quant_ext}", "wb") as f: - f.write(quant_blob) - quant_file_bytes = os.path.getsize(f"final_model.{quant_ext}") - code_bytes = len(code.encode("utf-8")) - ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) - log0( - f"Serialized model int6+{compress_name}: {quant_file_bytes} bytes " - f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" - ) - log0(f"Total submission size int6+{compress_name}: {quant_file_bytes + code_bytes} bytes") - - if distributed: - dist.barrier() - quant_ext = "int6.ptz" - with open(f"final_model.{quant_ext}", "rb") as f: - quant_blob_disk = f.read() - if HAS_ZSTD: - dctx = zstd.ZstdDecompressor() - quant_decompressed = dctx.decompress(quant_blob_disk) - else: - quant_decompressed = zlib.decompress(quant_blob_disk) - quant_state = torch.load(io.BytesIO(quant_decompressed), map_location="cpu") - base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) - torch.cuda.synchronize() - t_qeval = time.perf_counter() - q_val_loss, q_val_bpb = eval_val( - args, - model, - rank, - world_size, - device, - grad_accum_steps, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, - ) - torch.cuda.synchronize() - log0( - f"final_int6_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " - f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" - ) - log0(f"final_int6_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") - - # Sliding window eval: gives each scored token (seq_len - stride) context. - if args.eval_stride > 0: - torch.cuda.synchronize() - t_sw = time.perf_counter() - sw_val_loss, sw_val_bpb = eval_val_sliding_window( - args, - base_model, - rank, - world_size, - device, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, - stride=args.eval_stride, - ) - torch.cuda.synchronize() - log0( - f"final_sliding_window_eval stride:{args.eval_stride} " - f"val_loss:{sw_val_loss:.4f} val_bpb:{sw_val_bpb:.4f} " - f"eval_time:{1000.0 * (time.perf_counter() - t_sw):.0f}ms" - ) - log0( - f"final_sliding_window_eval_exact stride:{args.eval_stride} " - f"val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}" - ) - - if distributed: - dist.destroy_process_group() - - -if __name__ == "__main__": - main() -==================================================================================================== -Running Python 3.12.3 (main, Nov 6 2025, 13:44:16) [GCC 13.3.0] -Running PyTorch 2.9.1+cu128 -Thu Mar 19 10:08:37 2026 -+-----------------------------------------------------------------------------------------+ -| NVIDIA-SMI 580.126.09 Driver Version: 580.126.09 CUDA Version: 13.0 | -+-----------------------------------------+------------------------+----------------------+ -| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | -| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | -| | | MIG M. | -|=========================================+========================+======================| -| 0 NVIDIA H100 80GB HBM3 On | 00000000:0A:00.0 Off | 0 | -| N/A 35C P0 121W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 1 NVIDIA H100 80GB HBM3 On | 00000000:18:00.0 Off | 0 | -| N/A 30C P0 116W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 2 NVIDIA H100 80GB HBM3 On | 00000000:2F:00.0 Off | 0 | -| N/A 29C P0 114W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 3 NVIDIA H100 80GB HBM3 On | 00000000:38:00.0 Off | 0 | -| N/A 32C P0 118W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 4 NVIDIA H100 80GB HBM3 On | 00000000:87:00.0 Off | 0 | -| N/A 32C P0 116W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 5 NVIDIA H100 80GB HBM3 On | 00000000:90:00.0 Off | 0 | -| N/A 30C P0 118W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 6 NVIDIA H100 80GB HBM3 On | 00000000:BE:00.0 Off | 0 | -| N/A 31C P0 121W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 7 NVIDIA H100 80GB HBM3 On | 00000000:C7:00.0 Off | 0 | -| N/A 32C P0 118W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ - -+-----------------------------------------------------------------------------------------+ -| Processes: | -| GPU GI CI PID Type Process name GPU Memory | -| ID ID Usage | -|=========================================================================================| -| 0 N/A N/A 36375 C /usr/local/bin/python 1512MiB | -| 1 N/A N/A 36376 C /usr/local/bin/python 1512MiB | -| 2 N/A N/A 36377 C /usr/local/bin/python 1512MiB | -| 3 N/A N/A 36378 C /usr/local/bin/python 1512MiB | -| 4 N/A N/A 36379 C /usr/local/bin/python 1512MiB | -| 5 N/A N/A 36380 C /usr/local/bin/python 1512MiB | -| 6 N/A N/A 36381 C /usr/local/bin/python 1512MiB | -| 7 N/A N/A 36382 C /usr/local/bin/python 1512MiB | -+-----------------------------------------------------------------------------------------+ - -==================================================================================================== -val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model -train_loader:dataset:fineweb10B_sp1024 train_shards:20 -val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 -model_params:21778504 -world_size:8 grad_accum_steps:1 -sdp_backends:cudnn=False flash=True mem_efficient=False math=False -attention_mode:gqa num_heads:8 num_kv_heads:4 -tie_embeddings:True embed_lr:0.03 head_lr:0.0 matrix_lr:0.02 scalar_lr:0.02 -train_batch_tokens:524288 train_seq_len:1024 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 -seed:1337 -warmup_step:1/20 -warmup_step:2/20 -warmup_step:3/20 -warmup_step:4/20 -warmup_step:5/20 -warmup_step:6/20 -warmup_step:7/20 -warmup_step:8/20 -warmup_step:9/20 -warmup_step:10/20 -warmup_step:11/20 -warmup_step:12/20 -warmup_step:13/20 -warmup_step:14/20 -warmup_step:15/20 -warmup_step:16/20 -warmup_step:17/20 -warmup_step:18/20 -warmup_step:19/20 -warmup_step:20/20 -step:0/20000 val_loss:6.9378 val_bpb:4.1090 train_time:0ms step_avg:0.01ms -step:1/20000 train_loss:6.9364 train_time:25ms step_avg:25.10ms -step:2/20000 train_loss:12.3366 train_time:70ms step_avg:35.22ms -step:3/20000 train_loss:7.4143 train_time:118ms step_avg:39.45ms -step:4/20000 train_loss:6.3474 train_time:165ms step_avg:41.32ms -step:5/20000 train_loss:6.7814 train_time:213ms step_avg:42.66ms -step:6/20000 train_loss:7.4895 train_time:261ms step_avg:43.54ms -step:7/20000 train_loss:6.7486 train_time:309ms step_avg:44.11ms -step:8/20000 train_loss:6.4461 train_time:356ms step_avg:44.54ms -step:9/20000 train_loss:6.3008 train_time:404ms step_avg:44.87ms -step:10/20000 train_loss:6.1806 train_time:451ms step_avg:45.14ms -step:200/20000 train_loss:2.7609 train_time:9576ms step_avg:47.88ms -step:400/20000 train_loss:2.2985 train_time:19184ms step_avg:47.96ms -step:600/20000 train_loss:2.5044 train_time:28808ms step_avg:48.01ms -step:800/20000 train_loss:2.2650 train_time:38451ms step_avg:48.06ms -step:1000/20000 train_loss:2.3522 train_time:48118ms step_avg:48.12ms -step:1200/20000 train_loss:2.3722 train_time:57803ms step_avg:48.17ms -step:1400/20000 train_loss:2.4249 train_time:67502ms step_avg:48.22ms -step:1600/20000 train_loss:2.0953 train_time:77188ms step_avg:48.24ms -step:1800/20000 train_loss:2.1866 train_time:86887ms step_avg:48.27ms -step:2000/20000 train_loss:2.2375 train_time:96583ms step_avg:48.29ms -step:2000/20000 val_loss:2.2181 val_bpb:1.3137 train_time:96613ms step_avg:48.31ms -step:2200/20000 train_loss:2.0498 train_time:106273ms step_avg:48.31ms -step:2400/20000 train_loss:2.1729 train_time:115967ms step_avg:48.32ms -step:2600/20000 train_loss:2.3764 train_time:125661ms step_avg:48.33ms -step:2800/20000 train_loss:2.1998 train_time:135356ms step_avg:48.34ms -step:3000/20000 train_loss:2.1901 train_time:145046ms step_avg:48.35ms -step:3200/20000 train_loss:2.1510 train_time:154732ms step_avg:48.35ms -step:3400/20000 train_loss:2.1192 train_time:164420ms step_avg:48.36ms -step:3600/20000 train_loss:2.0793 train_time:174112ms step_avg:48.36ms -step:3800/20000 train_loss:2.1825 train_time:183806ms step_avg:48.37ms -step:4000/20000 train_loss:2.2339 train_time:193495ms step_avg:48.37ms -step:4000/20000 val_loss:2.1286 val_bpb:1.2607 train_time:193525ms step_avg:48.38ms -step:4200/20000 train_loss:2.1753 train_time:203239ms step_avg:48.39ms -step:4400/20000 train_loss:2.1257 train_time:212931ms step_avg:48.39ms -step:4600/20000 train_loss:2.1593 train_time:222632ms step_avg:48.40ms -step:4800/20000 train_loss:2.1010 train_time:232323ms step_avg:48.40ms -step:5000/20000 train_loss:2.1841 train_time:242009ms step_avg:48.40ms -step:5200/20000 train_loss:2.2419 train_time:251704ms step_avg:48.40ms -step:5400/20000 train_loss:2.1947 train_time:261390ms step_avg:48.41ms -step:5600/20000 train_loss:2.1010 train_time:271090ms step_avg:48.41ms -step:5800/20000 train_loss:2.2822 train_time:280782ms step_avg:48.41ms -step:6000/20000 train_loss:2.0261 train_time:290480ms step_avg:48.41ms -step:6000/20000 val_loss:2.0964 val_bpb:1.2416 train_time:290510ms step_avg:48.42ms -step:6200/20000 train_loss:2.1061 train_time:300162ms step_avg:48.41ms -step:6400/20000 train_loss:1.9307 train_time:309857ms step_avg:48.42ms -step:6600/20000 train_loss:2.0882 train_time:319546ms step_avg:48.42ms -step:6800/20000 train_loss:2.1693 train_time:329233ms step_avg:48.42ms -step:7000/20000 train_loss:2.0145 train_time:338923ms step_avg:48.42ms -step:7200/20000 train_loss:2.0029 train_time:348614ms step_avg:48.42ms -step:7400/20000 train_loss:1.9361 train_time:358310ms step_avg:48.42ms -step:7600/20000 train_loss:2.0350 train_time:368000ms step_avg:48.42ms -step:7800/20000 train_loss:2.0813 train_time:377688ms step_avg:48.42ms -step:8000/20000 train_loss:2.0144 train_time:387378ms step_avg:48.42ms -step:8000/20000 val_loss:2.0773 val_bpb:1.2303 train_time:387408ms step_avg:48.43ms -step:8200/20000 train_loss:2.1489 train_time:397067ms step_avg:48.42ms -step:8400/20000 train_loss:2.1424 train_time:406801ms step_avg:48.43ms -step:8600/20000 train_loss:2.1657 train_time:416481ms step_avg:48.43ms -step:8800/20000 train_loss:2.0297 train_time:426178ms step_avg:48.43ms -step:9000/20000 train_loss:2.0599 train_time:435859ms step_avg:48.43ms -step:9200/20000 train_loss:2.1330 train_time:445542ms step_avg:48.43ms -step:9400/20000 train_loss:1.9739 train_time:455237ms step_avg:48.43ms -step:9600/20000 train_loss:2.0265 train_time:464918ms step_avg:48.43ms -step:9800/20000 train_loss:2.0734 train_time:474609ms step_avg:48.43ms -step:10000/20000 train_loss:2.1375 train_time:484302ms step_avg:48.43ms -step:10000/20000 val_loss:2.0568 val_bpb:1.2182 train_time:484331ms step_avg:48.43ms -step:10200/20000 train_loss:1.9916 train_time:493993ms step_avg:48.43ms -step:10400/20000 train_loss:2.2391 train_time:503676ms step_avg:48.43ms -step:10600/20000 train_loss:1.8816 train_time:513372ms step_avg:48.43ms -step:10800/20000 train_loss:2.0403 train_time:523069ms step_avg:48.43ms -step:11000/20000 train_loss:2.1571 train_time:532770ms step_avg:48.43ms -step:11200/20000 train_loss:2.1207 train_time:542462ms step_avg:48.43ms -step:11400/20000 train_loss:2.1568 train_time:552158ms step_avg:48.43ms -step:11600/20000 train_loss:2.0102 train_time:561851ms step_avg:48.44ms -step:11800/20000 train_loss:2.0893 train_time:571552ms step_avg:48.44ms -step:12000/20000 train_loss:1.9977 train_time:581254ms step_avg:48.44ms -step:12000/20000 val_loss:2.0216 val_bpb:1.1973 train_time:581283ms step_avg:48.44ms -step:12200/20000 train_loss:2.1668 train_time:590989ms step_avg:48.44ms -step:12386/20000 val_loss:2.0178 val_bpb:1.1951 train_time:600039ms step_avg:48.44ms -stopping_early: wallclock_cap train_time:600039ms step:12386/20000 -peak memory allocated: 11254 MiB reserved: 11520 MiB -Serialized model: 86099351 bytes -Code size: 55932 bytes -Total submission size: 86155283 bytes -Serialized model int6+zlib-9: 15164668 bytes (payload:21906720 raw_torch:21951833 payload_ratio:3.93x) -Total submission size int6+zlib-9: 15220600 bytes -final_int6_roundtrip val_loss:2.0991 val_bpb:1.2432 eval_time:1548ms -final_int6_roundtrip_exact val_loss:2.09906746 val_bpb:1.24318598 -final_sliding_window_eval stride:256 val_loss:2.0403 val_bpb:1.2084 eval_time:18256ms -final_sliding_window_eval_exact stride:256 val_loss:2.04033873 val_bpb:1.20840424 -""" -The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. - -Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. -""" - -from __future__ import annotations - -import copy -import glob -import io -import math -import os -import random -import subprocess -import sys -import time -import uuid -import zlib -from pathlib import Path - -try: - import zstandard as zstd - HAS_ZSTD = True -except ImportError: - HAS_ZSTD = False - -import numpy as np -import sentencepiece as spm -import torch -import torch.distributed as dist -import torch.nn.functional as F -from torch import Tensor, nn -from torch.nn.parallel import DistributedDataParallel as DDP - -# ----------------------------- -# HYPERPARAMETERS -# ----------------------------- -# Default Simple Baseline run: -# - 9 transformer blocks at width 512 -# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion -# - vocab size 1024, sequence length 1024, tied embeddings -# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap - -class Hyperparameters: - # Data paths are shard globs produced by the existing preprocessing pipeline. - data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") - train_files = os.path.join(data_path, "fineweb_train_*.bin") - val_files = os.path.join(data_path, "fineweb_val_*.bin") - tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") - run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) - seed = int(os.environ.get("SEED", 1337)) - - # Validation cadence and batch size. Validation always uses the full fineweb_val split. - val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) - val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 0)) - train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) - # Sliding window eval: stride controls how many tokens to slide between windows. - # Only the last `eval_stride` tokens per window are scored, giving each scored token - # (seq_len - eval_stride) tokens of context. Set to 0 to disable (use standard eval). - eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) - - # Training length. - iterations = int(os.environ.get("ITERATIONS", 20000)) - warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3000)) - warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) - train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) - train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) - max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) - qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) - - # Model shape. - vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) - num_layers = int(os.environ.get("NUM_LAYERS", 9)) - num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) - model_dim = int(os.environ.get("MODEL_DIM", 512)) - num_heads = int(os.environ.get("NUM_HEADS", 8)) - mlp_mult = int(os.environ.get("MLP_MULT", 3)) - tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) - rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) - logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) - - # Optimizer hyperparameters. - embed_lr = float(os.environ.get("EMBED_LR", 0.6)) - head_lr = float(os.environ.get("HEAD_LR", 0.008)) - tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.030)) - tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) - matrix_lr = float(os.environ.get("MATRIX_LR", 0.020)) - scalar_lr = float(os.environ.get("SCALAR_LR", 0.020)) - muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) - muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) - muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) - muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) - beta1 = float(os.environ.get("BETA1", 0.9)) - beta2 = float(os.environ.get("BETA2", 0.95)) - adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) - grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) - -# ----------------------------- -# MUON OPTIMIZER -# ----------------------------- -# -# As borrowed from modded-nanogpt -# Background on Muon: https://kellerjordan.github.io/posts/muon/ - -def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: - # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. - # Muon uses this to normalize matrix-shaped gradients before applying them. - a, b, c = (3.4445, -4.7750, 2.0315) - X = G.bfloat16() - X /= X.norm() + eps - transposed = G.size(0) > G.size(1) - if transposed: - X = X.T - for _ in range(steps): - A = X @ X.T - B = b * A + c * A @ A - X = a * X + B @ X - return X.T if transposed else X - - -class Muon(torch.optim.Optimizer): - def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): - super().__init__( - params, - dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), - ) - - @torch.no_grad() - def step(self, closure=None): - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - distributed = dist.is_available() and dist.is_initialized() - world_size = dist.get_world_size() if distributed else 1 - rank = dist.get_rank() if distributed else 0 - - for group in self.param_groups: - params = group["params"] - if not params: - continue - lr = group["lr"] - momentum = group["momentum"] - backend_steps = group["backend_steps"] - nesterov = group["nesterov"] - - total_params = sum(int(p.numel()) for p in params) - updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) - - curr = 0 - for i, p in enumerate(params): - if i % world_size == rank and p.grad is not None: - g = p.grad - state = self.state[p] - if "momentum_buffer" not in state: - state["momentum_buffer"] = torch.zeros_like(g) - buf = state["momentum_buffer"] - buf.mul_(momentum).add_(g) - if nesterov: - g = g.add(buf, alpha=momentum) - g = zeropower_via_newtonschulz5(g, steps=backend_steps) - # Scale correction from Muon reference implementations. - g *= max(1, g.size(0) / g.size(1)) ** 0.5 - updates_flat[curr : curr + p.numel()] = g.reshape(-1) - curr += p.numel() - - if distributed: - dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) - - curr = 0 - for p in params: - g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) - p.add_(g, alpha=-lr) - curr += p.numel() - - return loss - - -# ----------------------------- -# TOKENIZER-AGNOSTIC EVALUATION SETUP -# ----------------------------- -# -# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. -# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. -# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. -# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. - -def build_sentencepiece_luts( - sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device -) -> tuple[Tensor, Tensor, Tensor]: - sp_vocab_size = int(sp.vocab_size()) - table_size = max(sp_vocab_size, vocab_size) - base_bytes_np = np.zeros((table_size,), dtype=np.int16) - has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) - is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) - for token_id in range(sp_vocab_size): - if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): - continue - is_boundary_token_np[token_id] = False - if sp.is_byte(token_id): - base_bytes_np[token_id] = 1 - continue - piece = sp.id_to_piece(token_id) - if piece.startswith("▁"): - has_leading_space_np[token_id] = True - piece = piece[1:] - base_bytes_np[token_id] = len(piece.encode("utf-8")) - return ( - torch.tensor(base_bytes_np, dtype=torch.int16, device=device), - torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), - torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), - ) - - -def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: - files = [Path(p) for p in sorted(glob.glob(pattern))] - if not files: - raise FileNotFoundError(f"No files found for pattern: {pattern}") - # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. - tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() - usable = ((tokens.numel() - 1) // seq_len) * seq_len - if usable <= 0: - raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") - return tokens[: usable + 1] - - -def eval_val( - args: Hyperparameters, - model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - grad_accum_steps: int, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, -) -> tuple[float, float]: - # Validation computes two metrics: - # - val_loss: token cross-entropy (natural log) - # - val_bpb: tokenizer-agnostic compression metric used by the challenge - local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) - if local_batch_tokens < args.train_seq_len: - raise ValueError( - "VAL_BATCH_SIZE must provide at least one sequence per rank; " - f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " - f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" - ) - local_batch_seqs = local_batch_tokens // args.train_seq_len - total_seqs = (val_tokens.numel() - 1) // args.train_seq_len - seq_start = (total_seqs * rank) // world_size - seq_end = (total_seqs * (rank + 1)) // world_size - val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) - val_token_count = torch.zeros((), device=device, dtype=torch.float64) - val_byte_count = torch.zeros((), device=device, dtype=torch.float64) - - model.eval() - with torch.inference_mode(): - for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): - batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) - raw_start = batch_seq_start * args.train_seq_len - raw_end = batch_seq_end * args.train_seq_len + 1 - local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) - x = local[:-1].reshape(-1, args.train_seq_len) - y = local[1:].reshape(-1, args.train_seq_len) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - batch_loss = model(x, y).detach() - batch_token_count = float(y.numel()) - val_loss_sum += batch_loss.to(torch.float64) * batch_token_count - val_token_count += batch_token_count - prev_ids = x.reshape(-1) - tgt_ids = y.reshape(-1) - token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) - token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) - val_byte_count += token_bytes.to(torch.float64).sum() - - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) - - val_loss = val_loss_sum / val_token_count - bits_per_token = val_loss.item() / math.log(2.0) - tokens_per_byte = val_token_count.item() / val_byte_count.item() - model.train() - return float(val_loss.item()), float(bits_per_token * tokens_per_byte) - - -def eval_val_sliding_window( - args: Hyperparameters, - base_model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, - stride: int = 64, -) -> tuple[float, float]: - """Sliding window evaluation: each scored token gets (seq_len - stride) context. - - Instead of chopping validation into non-overlapping 1024-token blocks (where - the first token in each block gets zero context), we slide a 1024-token window - by `stride` tokens at a time and only score the last `stride` tokens per window. - Every scored token sees 960+ tokens of context, dramatically improving BPB. - """ - seq_len = args.train_seq_len - total_tokens = val_tokens.numel() - 1 # need 1 extra for final target - - # Number of windows: first window at offset 0, then slide by stride - num_windows = max((total_tokens - seq_len) // stride + 1, 1) - - # Distribute windows across ranks - win_start = (num_windows * rank) // world_size - win_end = (num_windows * (rank + 1)) // world_size - - val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) - val_token_count = torch.zeros((), device=device, dtype=torch.float64) - val_byte_count = torch.zeros((), device=device, dtype=torch.float64) - - # Eval batch: how many windows per forward pass. Tune for memory. - eval_batch = int(os.environ.get("SW_EVAL_BATCH", 32)) - - base_model.eval() - with torch.inference_mode(): - window_list = list(range(win_start, win_end)) - num_batches = (len(window_list) + eval_batch - 1) // eval_batch - for batch_idx in range(num_batches): - batch_wins = window_list[batch_idx * eval_batch : (batch_idx + 1) * eval_batch] - bsz = len(batch_wins) - - # Build input: each window is val_tokens[w*stride : w*stride + seq_len] - inputs = torch.stack([ - val_tokens[w * stride : w * stride + seq_len] - for w in batch_wins - ]).to(device=device, dtype=torch.int64) # [B, seq_len] - - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - logits = base_model.forward_logits(inputs) # [B, seq_len, vocab] - - # Score only the last `stride` positions per window. - # logits[:, j, :] predicts the token AFTER position j in the input. - # So logits[:, -stride:, :] predicts tokens at input positions - # [seq_len - stride + 1, seq_len + 1) — which are the targets. - scored_logits = logits[:, -stride:, :].reshape(-1, logits.size(-1)) # [B*stride, vocab] - - # Build targets: for window w, targets are val_tokens[w*stride + seq_len - stride + 1 : w*stride + seq_len + 1] - targets = torch.stack([ - val_tokens[w * stride + seq_len - stride + 1 : w * stride + seq_len + 1] - for w in batch_wins - ]).to(device=device, dtype=torch.int64).reshape(-1) # [B*stride] - - loss = F.cross_entropy(scored_logits.float(), targets, reduction="sum") - val_loss_sum += loss.to(torch.float64) - val_token_count += float(targets.numel()) - - # BPB: prev_ids are the input tokens at each scored position - prev_ids = torch.stack([ - val_tokens[w * stride + seq_len - stride : w * stride + seq_len] - for w in batch_wins - ]).to(device=device, dtype=torch.int64).reshape(-1) # [B*stride] - - tgt_ids = targets - token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) - token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) - val_byte_count += token_bytes.to(torch.float64).sum() - - if rank == 0 and (batch_idx + 1) % 100 == 0: - print(f" sw_eval batch {batch_idx + 1}/{num_batches}", flush=True) - - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) - - val_loss = val_loss_sum / val_token_count - bits_per_token = val_loss.item() / math.log(2.0) - tokens_per_byte = val_token_count.item() / val_byte_count.item() - base_model.train() - return float(val_loss.item()), float(bits_per_token * tokens_per_byte) - -# ----------------------------- -# POST-TRAINING QUANTIZATION -# ----------------------------- -# -# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. -# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. -# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. - -CONTROL_TENSOR_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "CONTROL_TENSOR_NAME_PATTERNS", - "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", - ",".join(CONTROL_TENSOR_NAME_PATTERNS), - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 -INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 -INT8_PER_ROW_SCALE_DTYPE = torch.float16 - -INT6_QUANT_RANGE = 31 # int6: [-31, 31] -INT8_QUANT_RANGE = 127 # int8: [-127, 127] -INT6_CLIP_Q = 0.9999984 - -# Tensors matching these patterns get int8 (127 levels) instead of int6 (31 levels) -# during post-training quantization. This is for weights that DON'T have fake-quant -# STE protection during training (e.g. nn.Embedding), so they need gentler quantization. -INT8_FULL_RANGE_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "INT8_FULL_RANGE_PATTERNS", - "tok_emb", - ).split(",") - if pattern -) - -def tensor_nbytes(t: Tensor) -> int: - return int(t.numel()) * int(t.element_size()) - -def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: - if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): - return t.float().contiguous() - if t.dtype in {torch.float32, torch.bfloat16}: - passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") - return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() - return t - -def quantize_float_tensor(t: Tensor, quant_range: int = INT6_QUANT_RANGE) -> tuple[Tensor, Tensor]: - t32 = t.float() - if t32.ndim == 2: - # Per-row quantization: [-quant_range, quant_range] stored in int8 containers. - # For int6 (range=31), the unused high bits compress extremely well with zstd. - clip_abs = ( - torch.quantile(t32.abs(), INT6_CLIP_Q, dim=1) - if t32.numel() - else torch.empty((t32.shape[0],), dtype=torch.float32) - ) - clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) - scale = (clip_abs / float(quant_range)).clamp_min(1.0 / float(quant_range)) - q = torch.clamp(torch.round(clipped / scale[:, None]), -quant_range, quant_range).to(torch.int8).contiguous() - return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() - - # Vectors / scalars still use int8 per-tensor scale (they're tiny). - clip_abs = float(torch.quantile(t32.abs().flatten(), INT6_CLIP_Q).item()) if t32.numel() else 0.0 - scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) - q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() - return q, scale - -def quantize_state_dict_int8(state_dict: dict[str, Tensor]): - # Single supported clean-script export format: - # - per-row int8 for 2D float tensors - # - per-tensor int8 for other float tensors - # - exact passthrough for non-floats - # - passthrough for small float tensors, stored as fp16 to save bytes - quantized: dict[str, Tensor] = {} - scales: dict[str, Tensor] = {} - dtypes: dict[str, str] = {} - passthrough: dict[str, Tensor] = {} - passthrough_orig_dtypes: dict[str, str] = {} - qmeta: dict[str, dict[str, object]] = {} - stats = dict.fromkeys( - ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), - 0, - ) - - for name, tensor in state_dict.items(): - t = tensor.detach().to("cpu").contiguous() - stats["param_count"] += int(t.numel()) - stats["num_tensors"] += 1 - stats["baseline_tensor_bytes"] += tensor_nbytes(t) - - if not t.is_floating_point(): - stats["num_nonfloat_tensors"] += 1 - passthrough[name] = t - stats["int8_payload_bytes"] += tensor_nbytes(t) - continue - - # Small float tensors are cheap enough to keep directly. We still downcast - # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. - if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: - kept = keep_float_tensor(name, t, passthrough_orig_dtypes) - passthrough[name] = kept - stats["int8_payload_bytes"] += tensor_nbytes(kept) - continue - - stats["num_float_tensors"] += 1 - # Use int8 (127 levels) for tensors without STE fake-quant protection - # (e.g. tok_emb), int6 (31 levels) for block weights that trained with STE. - use_int8 = any(pattern in name for pattern in INT8_FULL_RANGE_PATTERNS) - qr = INT8_QUANT_RANGE if use_int8 else INT6_QUANT_RANGE - q, s = quantize_float_tensor(t, quant_range=qr) - if s.ndim > 0: - qmeta[name] = {"scheme": "per_row", "axis": 0} - quantized[name] = q - scales[name] = s - dtypes[name] = str(t.dtype).removeprefix("torch.") - stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) - - obj: dict[str, object] = { - "__quant_format__": "int8_clean_per_row_v1", - "quantized": quantized, - "scales": scales, - "dtypes": dtypes, - "passthrough": passthrough, - } - if qmeta: - obj["qmeta"] = qmeta - if passthrough_orig_dtypes: - obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes - return obj, stats - -def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: - out: dict[str, Tensor] = {} - qmeta = obj.get("qmeta", {}) - passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) - for name, q in obj["quantized"].items(): - dtype = getattr(torch, obj["dtypes"][name]) - s = obj["scales"][name] - if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: - s = s.to(dtype=torch.float32) - # Broadcast the saved row scale back across trailing dimensions. - out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() - else: - scale = float(s.item()) - out[name] = (q.float() * scale).to(dtype=dtype).contiguous() - for name, t in obj["passthrough"].items(): - # Restore small tensors, undoing the temporary fp16 storage cast if needed. - out_t = t.detach().to("cpu").contiguous() - orig_dtype = passthrough_orig_dtypes.get(name) - if isinstance(orig_dtype, str): - out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() - out[name] = out_t - return out - - -# ----------------------------- -# DATA LOADING -# ----------------------------- - -def load_data_shard(file: Path) -> Tensor: - header_bytes = 256 * np.dtype(" None: - self.file_idx = (self.file_idx + 1) % len(self.files) - self.tokens = load_data_shard(self.files[self.file_idx]) - self.pos = 0 - - def take(self, n: int) -> Tensor: - chunks: list[Tensor] = [] - remaining = n - while remaining > 0: - avail = self.tokens.numel() - self.pos - if avail <= 0: - self._advance_file() - continue - k = min(remaining, avail) - chunks.append(self.tokens[self.pos : self.pos + k]) - self.pos += k - remaining -= k - return chunks[0] if len(chunks) == 1 else torch.cat(chunks) - - -class DistributedTokenLoader: - # Each call consumes a contiguous chunk from the shared token stream, then slices out - # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. - def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): - self.rank = rank - self.world_size = world_size - self.device = device - self.stream = TokenStream(pattern) - - def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: - local_tokens = global_tokens // (self.world_size * grad_accum_steps) - per_rank_span = local_tokens + 1 - chunk = self.stream.take(per_rank_span * self.world_size) - start = self.rank * per_rank_span - local = chunk[start : start + per_rank_span].to(dtype=torch.int64) - x = local[:-1].reshape(-1, seq_len) - y = local[1:].reshape(-1, seq_len) - return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) - -# ----------------------------- -# TRANSFORMER MODULES -# ----------------------------- - -class RMSNorm(nn.Module): - def __init__(self, eps: float | None = None): - super().__init__() - self.eps = eps - - def forward(self, x: Tensor) -> Tensor: - return F.rms_norm(x, (x.size(-1),), eps=self.eps) - - -class CastedLinear(nn.Linear): - # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. - # During training, apply fake int6 quantization (STE) so the model learns to be robust - # to the post-training quantization that will be applied for the 16MB artifact. - def forward(self, x: Tensor) -> Tensor: - w = self.weight.to(x.dtype) - if self.training and w.ndim == 2: - # Fake int6 per-row quantization with Straight-Through Estimator: - # Forward uses quantized weights, backward passes gradients through as-is. - with torch.no_grad(): - w32 = w.float() - clip_abs = torch.quantile(w32.abs(), INT6_CLIP_Q, dim=1).clamp_min(1e-8) - scale = clip_abs / INT6_QUANT_RANGE - w_clipped = torch.clamp(w32, -clip_abs[:, None], clip_abs[:, None]) - w_q = (torch.round(w_clipped / scale[:, None]) * scale[:, None]).to(x.dtype) - w = w + (w_q - w).detach() # STE: value of w_q, gradient of w - bias = self.bias.to(x.dtype) if self.bias is not None else None - return F.linear(x, w, bias) - - -def restore_low_dim_params_to_fp32(module: nn.Module) -> None: - # Keep small/control parameters in fp32 even when the model body runs in bf16. - with torch.no_grad(): - for name, param in module.named_parameters(): - if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: - param.data = param.data.float() - - -class Rotary(nn.Module): - # Caches cos/sin tables per sequence length on the current device. - def __init__(self, dim: int, base: float = 10000.0): - super().__init__() - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self._seq_len_cached = 0 - self._cos_cached: Tensor | None = None - self._sin_cached: Tensor | None = None - - def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: - if ( - self._cos_cached is None - or self._sin_cached is None - or self._seq_len_cached != seq_len - or self._cos_cached.device != device - ): - t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) - freqs = torch.outer(t, self.inv_freq.to(device)) - self._cos_cached = freqs.cos()[None, None, :, :] - self._sin_cached = freqs.sin()[None, None, :, :] - self._seq_len_cached = seq_len - return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) - - -def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: - half = x.size(-1) // 2 - x1, x2 = x[..., :half], x[..., half:] - return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) - - -class CausalSelfAttention(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - rope_base: float, - qk_gain_init: float, - ): - super().__init__() - if dim % num_heads != 0: - raise ValueError("model_dim must be divisible by num_heads") - if num_heads % num_kv_heads != 0: - raise ValueError("num_heads must be divisible by num_kv_heads") - self.num_heads = num_heads - self.num_kv_heads = num_kv_heads - self.head_dim = dim // num_heads - if self.head_dim % 2 != 0: - raise ValueError("head_dim must be even for RoPE") - kv_dim = self.num_kv_heads * self.head_dim - self.c_q = CastedLinear(dim, dim, bias=False) - self.c_k = CastedLinear(dim, kv_dim, bias=False) - self.c_v = CastedLinear(dim, kv_dim, bias=False) - self.proj = CastedLinear(dim, dim, bias=False) - self.proj._zero_init = True - self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) - self.rotary = Rotary(self.head_dim, base=rope_base) - - def forward(self, x: Tensor) -> Tensor: - bsz, seqlen, dim = x.shape - q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) - k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) - v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) - q = F.rms_norm(q, (q.size(-1),)) - k = F.rms_norm(k, (k.size(-1),)) - cos, sin = self.rotary(seqlen, x.device, q.dtype) - q = apply_rotary_emb(q, cos, sin) - k = apply_rotary_emb(k, cos, sin) - q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] - y = F.scaled_dot_product_attention( - q, - k, - v, - attn_mask=None, - is_causal=True, - enable_gqa=(self.num_kv_heads != self.num_heads), - ) - y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) - return self.proj(y) - - -class MLP(nn.Module): - # relu^2 MLP from the original modded-nanogpt setup - def __init__(self, dim: int, mlp_mult: int): - super().__init__() - hidden = mlp_mult * dim - self.fc = CastedLinear(dim, hidden, bias=False) - self.proj = CastedLinear(hidden, dim, bias=False) - self.proj._zero_init = True - - def forward(self, x: Tensor) -> Tensor: - x = torch.relu(self.fc(x)) - return self.proj(x.square()) - - -class Block(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: int, - rope_base: float, - qk_gain_init: float, - ): - super().__init__() - self.attn_norm = RMSNorm() - self.mlp_norm = RMSNorm() - self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) - self.mlp = MLP(dim, mlp_mult) - self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) - - def forward(self, x: Tensor, x0: Tensor) -> Tensor: - mix = self.resid_mix.to(dtype=x.dtype) - x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 - attn_out = self.attn(self.attn_norm(x)) - x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out - x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) - return x - - -class GPT(nn.Module): - def __init__( - self, - vocab_size: int, - num_layers: int, - model_dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: int, - tie_embeddings: bool, - tied_embed_init_std: float, - logit_softcap: float, - rope_base: float, - qk_gain_init: float, - ): - super().__init__() - if logit_softcap <= 0.0: - raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") - self.tie_embeddings = tie_embeddings - self.tied_embed_init_std = tied_embed_init_std - self.logit_softcap = logit_softcap - self.tok_emb = nn.Embedding(vocab_size, model_dim) - self.num_encoder_layers = num_layers // 2 - self.num_decoder_layers = num_layers - self.num_encoder_layers - self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) - self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) - self.blocks = nn.ModuleList( - [ - Block( - model_dim, - num_heads, - num_kv_heads, - mlp_mult, - rope_base, - qk_gain_init, - ) - for i in range(num_layers) - ] - ) - self.final_norm = RMSNorm() - self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) - if self.lm_head is not None: - self.lm_head._zero_init = True - self._init_weights() - - def _init_weights(self) -> None: - if self.tie_embeddings: - nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) - for module in self.modules(): - if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): - nn.init.zeros_(module.weight) - - def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: - x = self.tok_emb(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - x0 = x - skips: list[Tensor] = [] - - # First half stores skips; second half reuses them in reverse order. - for i in range(self.num_encoder_layers): - x = self.blocks[i](x, x0) - skips.append(x) - for i in range(self.num_decoder_layers): - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - x = self.blocks[self.num_encoder_layers + i](x, x0) - - x = self.final_norm(x).reshape(-1, x.size(-1)) - targets = target_ids.reshape(-1) - if self.tie_embeddings: - logits_proj = F.linear(x, self.tok_emb.weight) - else: - if self.lm_head is None: - raise RuntimeError("lm_head is required when tie_embeddings=False") - logits_proj = self.lm_head(x) - logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - return F.cross_entropy(logits.float(), targets, reduction="mean") - - def forward_logits(self, input_ids: Tensor) -> Tensor: - """Return logits without computing loss. Used for sliding window eval.""" - x = self.tok_emb(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - x0 = x - skips: list[Tensor] = [] - for i in range(self.num_encoder_layers): - x = self.blocks[i](x, x0) - skips.append(x) - for i in range(self.num_decoder_layers): - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - x = self.blocks[self.num_encoder_layers + i](x, x0) - x = self.final_norm(x) - if self.tie_embeddings: - logits_proj = F.linear(x, self.tok_emb.weight.to(x.dtype)) - else: - if self.lm_head is None: - raise RuntimeError("lm_head is required when tie_embeddings=False") - logits_proj = self.lm_head(x) - return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - - -# ----------------------------- -# TRAINING -# ----------------------------- - -def main() -> None: - global zeropower_via_newtonschulz5 - - code = Path(__file__).read_text(encoding="utf-8") - args = Hyperparameters() - zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) - - # ----------------------------- - # DISTRIBUTED + CUDA SETUP - # ----------------------------- - - distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ - rank = int(os.environ.get("RANK", "0")) - world_size = int(os.environ.get("WORLD_SIZE", "1")) - local_rank = int(os.environ.get("LOCAL_RANK", "0")) - if world_size <= 0: - raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") - if 8 % world_size != 0: - raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") - grad_accum_steps = 8 // world_size - grad_scale = 1.0 / grad_accum_steps - if not torch.cuda.is_available(): - raise RuntimeError("CUDA is required") - device = torch.device("cuda", local_rank) - torch.cuda.set_device(device) - if distributed: - dist.init_process_group(backend="nccl", device_id=device) - dist.barrier() - master_process = rank == 0 - - # Fast math knobs - torch.backends.cuda.matmul.allow_tf32 = True - torch.backends.cudnn.allow_tf32 = True - from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp - - enable_cudnn_sdp(False) - enable_flash_sdp(True) - enable_mem_efficient_sdp(False) - enable_math_sdp(False) - - logfile = None - if master_process: - os.makedirs("logs", exist_ok=True) - logfile = f"logs/{args.run_id}.txt" - print(logfile) - - def log0(msg: str, console: bool = True) -> None: - if not master_process: - return - if console: - print(msg) - if logfile is not None: - with open(logfile, "a", encoding="utf-8") as f: - print(msg, file=f) - - log0(code, console=False) - log0("=" * 100, console=False) - log0(f"Running Python {sys.version}", console=False) - log0(f"Running PyTorch {torch.__version__}", console=False) - log0( - subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, - console=False, - ) - log0("=" * 100, console=False) - - # ----------------------------- - # TOKENIZER + VALIDATION METRIC SETUP - # ----------------------------- - - random.seed(args.seed) - np.random.seed(args.seed) - torch.manual_seed(args.seed) - torch.cuda.manual_seed_all(args.seed) - - if not args.tokenizer_path.endswith(".model"): - raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") - sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) - if int(sp.vocab_size()) != args.vocab_size: - raise ValueError( - f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" - ) - dataset_dir = Path(args.data_path).resolve() - actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) - val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) - base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( - sp, args.vocab_size, device - ) - log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") - log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") - log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") - - # ----------------------------- - # MODEL + OPTIMIZER SETUP - # ----------------------------- - - base_model = GPT( - vocab_size=args.vocab_size, - num_layers=args.num_layers, - model_dim=args.model_dim, - num_heads=args.num_heads, - num_kv_heads=args.num_kv_heads, - mlp_mult=args.mlp_mult, - tie_embeddings=args.tie_embeddings, - tied_embed_init_std=args.tied_embed_init_std, - logit_softcap=args.logit_softcap, - rope_base=args.rope_base, - qk_gain_init=args.qk_gain_init, - ).to(device).bfloat16() - for module in base_model.modules(): - if isinstance(module, CastedLinear): - module.float() - restore_low_dim_params_to_fp32(base_model) - compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) - model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model - - # Optimizer split: - # - token embedding (Adam) uses EMBED_LR - # - untied lm_head (Adam) uses HEAD_LR - # - matrix params in transformer blocks use MATRIX_LR via Muon - # - vectors/scalars use SCALAR_LR via Adam - block_named_params = list(base_model.blocks.named_parameters()) - matrix_params = [ - p - for name, p in block_named_params - if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - scalar_params = [ - p - for name, p in block_named_params - if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - if base_model.skip_weights.numel() > 0: - scalar_params.append(base_model.skip_weights) - token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr - optimizer_tok = torch.optim.Adam( - [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizer_muon = Muon( - matrix_params, - lr=args.matrix_lr, - momentum=args.muon_momentum, - backend_steps=args.muon_backend_steps, - ) - for group in optimizer_muon.param_groups: - group["base_lr"] = args.matrix_lr - optimizer_scalar = torch.optim.Adam( - [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] - if base_model.lm_head is not None: - optimizer_head = torch.optim.Adam( - [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizers.insert(1, optimizer_head) - - n_params = sum(p.numel() for p in base_model.parameters()) - log0(f"model_params:{n_params}") - log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") - log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") - log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") - log0( - f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " - f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " - f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" - ) - log0( - f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " - f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " - f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" - ) - log0(f"seed:{args.seed}") - - # ----------------------------- - # DATA LOADER & MODEL WARMUP - # ----------------------------- - - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - def zero_grad_all() -> None: - for opt in optimizers: - opt.zero_grad(set_to_none=True) - - max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None - - def lr_mul(step: int, elapsed_ms: float) -> float: - if args.warmdown_iters <= 0: - return 1.0 - if max_wallclock_ms is None: - warmdown_start = max(args.iterations - args.warmdown_iters, 0) - return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 - step_ms = elapsed_ms / max(step, 1) - warmdown_ms = args.warmdown_iters * step_ms - remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) - return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 - - # Warmup primes the compiled forward/backward/optimizer paths, then we restore the - # initial weights/optimizer state so measured training starts from the true init. - if args.warmup_steps > 0: - initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} - initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] - model.train() - for warmup_step in range(args.warmup_steps): - zero_grad_all() - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - warmup_loss = model(x, y) - (warmup_loss * grad_scale).backward() - for opt in optimizers: - opt.step() - zero_grad_all() - if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: - log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") - base_model.load_state_dict(initial_model_state, strict=True) - for opt, state in zip(optimizers, initial_optimizer_states, strict=True): - opt.load_state_dict(state) - zero_grad_all() - if distributed: - model.require_backward_grad_sync = True - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - # ----------------------------- - # MAIN TRAINING LOOP - # ----------------------------- - - training_time_ms = 0.0 - stop_after_step: int | None = None - torch.cuda.synchronize() - t0 = time.perf_counter() - - step = 0 - while True: - last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) - - should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) - if should_validate: - torch.cuda.synchronize() - training_time_ms += 1000.0 * (time.perf_counter() - t0) - val_loss, val_bpb = eval_val( - args, - model, - rank, - world_size, - device, - grad_accum_steps, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, - ) - log0( - f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " - f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" - ) - torch.cuda.synchronize() - t0 = time.perf_counter() - - if last_step: - if stop_after_step is not None and step < args.iterations: - log0( - f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " - f"step:{step}/{args.iterations}" - ) - break - - elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - scale = lr_mul(step, elapsed_ms) - zero_grad_all() - train_loss = torch.zeros((), device=device) - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - loss = model(x, y) - train_loss += loss.detach() - (loss * grad_scale).backward() - train_loss /= grad_accum_steps - - frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 - muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum - for group in optimizer_muon.param_groups: - group["momentum"] = muon_momentum - - for opt in optimizers: - for group in opt.param_groups: - group["lr"] = group["base_lr"] * scale - - if args.grad_clip_norm > 0: - torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) - for opt in optimizers: - opt.step() - zero_grad_all() - - step += 1 - approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - should_log_train = ( - args.train_log_every > 0 - and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) - ) - if should_log_train: - log0( - f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " - f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" - ) - - # Needed to sync whether we've reached the wallclock cap. - reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms - if distributed and max_wallclock_ms is not None: - reached_cap_tensor = torch.tensor(int(reached_cap), device=device) - dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) - reached_cap = bool(reached_cap_tensor.item()) - if stop_after_step is None and reached_cap: - stop_after_step = step - - log0( - f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " - f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" - ) - - # ----------------------------- - # SERIALIZATION + ROUNDTRIP VALIDATION - # ----------------------------- - # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce - # the compressed int8+zlib artifact and validate the round-tripped weights. - - if master_process: - torch.save(base_model.state_dict(), "final_model.pt") - model_bytes = os.path.getsize("final_model.pt") - code_bytes = len(code.encode("utf-8")) - log0(f"Serialized model: {model_bytes} bytes") - log0(f"Code size: {code_bytes} bytes") - log0(f"Total submission size: {model_bytes + code_bytes} bytes") - - quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) - quant_buf = io.BytesIO() - torch.save(quant_obj, quant_buf) - quant_raw = quant_buf.getvalue() - if HAS_ZSTD: - cctx = zstd.ZstdCompressor(level=22) - quant_blob = cctx.compress(quant_raw) - compress_name = "zstd-22" - else: - quant_blob = zlib.compress(quant_raw, level=9) - compress_name = "zlib-9" - quant_raw_bytes = len(quant_raw) - if master_process: - quant_ext = "int6.ptz" if HAS_ZSTD else "int6.ptz" - with open(f"final_model.{quant_ext}", "wb") as f: - f.write(quant_blob) - quant_file_bytes = os.path.getsize(f"final_model.{quant_ext}") - code_bytes = len(code.encode("utf-8")) - ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) - log0( - f"Serialized model int6+{compress_name}: {quant_file_bytes} bytes " - f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" - ) - log0(f"Total submission size int6+{compress_name}: {quant_file_bytes + code_bytes} bytes") - - if distributed: - dist.barrier() - quant_ext = "int6.ptz" - with open(f"final_model.{quant_ext}", "rb") as f: - quant_blob_disk = f.read() - if HAS_ZSTD: - dctx = zstd.ZstdDecompressor() - quant_decompressed = dctx.decompress(quant_blob_disk) - else: - quant_decompressed = zlib.decompress(quant_blob_disk) - quant_state = torch.load(io.BytesIO(quant_decompressed), map_location="cpu") - base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) - torch.cuda.synchronize() - t_qeval = time.perf_counter() - q_val_loss, q_val_bpb = eval_val( - args, - model, - rank, - world_size, - device, - grad_accum_steps, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, - ) - torch.cuda.synchronize() - log0( - f"final_int6_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " - f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" - ) - log0(f"final_int6_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") - - # Sliding window eval: gives each scored token (seq_len - stride) context. - if args.eval_stride > 0: - torch.cuda.synchronize() - t_sw = time.perf_counter() - sw_val_loss, sw_val_bpb = eval_val_sliding_window( - args, - base_model, - rank, - world_size, - device, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, - stride=args.eval_stride, - ) - torch.cuda.synchronize() - log0( - f"final_sliding_window_eval stride:{args.eval_stride} " - f"val_loss:{sw_val_loss:.4f} val_bpb:{sw_val_bpb:.4f} " - f"eval_time:{1000.0 * (time.perf_counter() - t_sw):.0f}ms" - ) - log0( - f"final_sliding_window_eval_exact stride:{args.eval_stride} " - f"val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}" - ) - - if distributed: - dist.destroy_process_group() - - -if __name__ == "__main__": - main() -==================================================================================================== -Running Python 3.12.3 (main, Nov 6 2025, 13:44:16) [GCC 13.3.0] -Running PyTorch 2.9.1+cu128 -Thu Mar 19 10:32:00 2026 -+-----------------------------------------------------------------------------------------+ -| NVIDIA-SMI 580.126.09 Driver Version: 580.126.09 CUDA Version: 13.0 | -+-----------------------------------------+------------------------+----------------------+ -| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | -| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | -| | | MIG M. | -|=========================================+========================+======================| -| 0 NVIDIA H100 80GB HBM3 On | 00000000:0A:00.0 Off | 0 | -| N/A 34C P0 120W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 1 NVIDIA H100 80GB HBM3 On | 00000000:18:00.0 Off | 0 | -| N/A 29C P0 116W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 2 NVIDIA H100 80GB HBM3 On | 00000000:2F:00.0 Off | 0 | -| N/A 29C P0 113W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 3 NVIDIA H100 80GB HBM3 On | 00000000:38:00.0 Off | 0 | -| N/A 32C P0 120W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 4 NVIDIA H100 80GB HBM3 On | 00000000:87:00.0 Off | 0 | -| N/A 32C P0 116W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 5 NVIDIA H100 80GB HBM3 On | 00000000:90:00.0 Off | 0 | -| N/A 30C P0 118W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 6 NVIDIA H100 80GB HBM3 On | 00000000:BE:00.0 Off | 0 | -| N/A 31C P0 121W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 7 NVIDIA H100 80GB HBM3 On | 00000000:C7:00.0 Off | 0 | -| N/A 32C P0 119W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ - -+-----------------------------------------------------------------------------------------+ -| Processes: | -| GPU GI CI PID Type Process name GPU Memory | -| ID ID Usage | -|=========================================================================================| -| 0 N/A N/A 57118 C /usr/local/bin/python 1512MiB | -| 1 N/A N/A 57119 C /usr/local/bin/python 1512MiB | -| 2 N/A N/A 57120 C /usr/local/bin/python 1512MiB | -| 3 N/A N/A 57121 C /usr/local/bin/python 1512MiB | -| 4 N/A N/A 57122 C /usr/local/bin/python 1512MiB | -| 5 N/A N/A 57123 C /usr/local/bin/python 1512MiB | -| 6 N/A N/A 57124 C /usr/local/bin/python 1512MiB | -| 7 N/A N/A 57125 C /usr/local/bin/python 1512MiB | -+-----------------------------------------------------------------------------------------+ - -==================================================================================================== -val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model -train_loader:dataset:fineweb10B_sp1024 train_shards:20 -val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 -model_params:21778504 -world_size:8 grad_accum_steps:1 -sdp_backends:cudnn=False flash=True mem_efficient=False math=False -attention_mode:gqa num_heads:8 num_kv_heads:4 -tie_embeddings:True embed_lr:0.03 head_lr:0.0 matrix_lr:0.02 scalar_lr:0.02 -train_batch_tokens:524288 train_seq_len:1024 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 -seed:1337 -warmup_step:1/20 -warmup_step:2/20 -warmup_step:3/20 -warmup_step:4/20 -warmup_step:5/20 -warmup_step:6/20 -warmup_step:7/20 -warmup_step:8/20 -warmup_step:9/20 -warmup_step:10/20 -warmup_step:11/20 -warmup_step:12/20 -warmup_step:13/20 -warmup_step:14/20 -warmup_step:15/20 -warmup_step:16/20 -warmup_step:17/20 -warmup_step:18/20 -warmup_step:19/20 -warmup_step:20/20 -step:0/20000 val_loss:6.9378 val_bpb:4.1090 train_time:0ms step_avg:0.01ms -step:1/20000 train_loss:6.9364 train_time:25ms step_avg:25.39ms -step:2/20000 train_loss:12.3366 train_time:69ms step_avg:34.49ms -step:3/20000 train_loss:7.4142 train_time:116ms step_avg:38.83ms -step:4/20000 train_loss:6.3475 train_time:164ms step_avg:40.94ms -step:5/20000 train_loss:6.7813 train_time:211ms step_avg:42.24ms -step:6/20000 train_loss:7.4896 train_time:259ms step_avg:43.11ms -step:7/20000 train_loss:6.7491 train_time:307ms step_avg:43.79ms -step:8/20000 train_loss:6.4461 train_time:354ms step_avg:44.27ms -step:9/20000 train_loss:6.3003 train_time:402ms step_avg:44.63ms -step:10/20000 train_loss:6.1808 train_time:449ms step_avg:44.92ms -step:200/20000 train_loss:2.7644 train_time:9564ms step_avg:47.82ms -step:400/20000 train_loss:2.3035 train_time:19162ms step_avg:47.91ms -step:600/20000 train_loss:2.5030 train_time:28769ms step_avg:47.95ms -step:800/20000 train_loss:2.2665 train_time:38403ms step_avg:48.00ms -step:1000/20000 train_loss:2.3514 train_time:48062ms step_avg:48.06ms -step:1200/20000 train_loss:2.3714 train_time:57723ms step_avg:48.10ms -step:1400/20000 train_loss:2.4243 train_time:67404ms step_avg:48.15ms -step:1600/20000 train_loss:2.0992 train_time:77085ms step_avg:48.18ms -step:1800/20000 train_loss:2.1921 train_time:86767ms step_avg:48.20ms -step:2000/20000 train_loss:2.2339 train_time:96459ms step_avg:48.23ms -step:2000/20000 val_loss:2.2180 val_bpb:1.3136 train_time:96488ms step_avg:48.24ms -step:2200/20000 train_loss:2.0467 train_time:106139ms step_avg:48.24ms -step:2400/20000 train_loss:2.1760 train_time:115822ms step_avg:48.26ms -step:2600/20000 train_loss:2.3846 train_time:125507ms step_avg:48.27ms -step:2800/20000 train_loss:2.2012 train_time:135190ms step_avg:48.28ms -step:3000/20000 train_loss:2.1922 train_time:144871ms step_avg:48.29ms -step:3200/20000 train_loss:2.1540 train_time:154556ms step_avg:48.30ms -step:3400/20000 train_loss:2.1266 train_time:164240ms step_avg:48.31ms -step:3600/20000 train_loss:2.0762 train_time:173935ms step_avg:48.32ms -step:3800/20000 train_loss:2.1830 train_time:183612ms step_avg:48.32ms -step:4000/20000 train_loss:2.2349 train_time:193304ms step_avg:48.33ms -step:4000/20000 val_loss:2.1282 val_bpb:1.2604 train_time:193334ms step_avg:48.33ms -step:4200/20000 train_loss:2.1801 train_time:203036ms step_avg:48.34ms -step:4400/20000 train_loss:2.1307 train_time:212724ms step_avg:48.35ms -step:4600/20000 train_loss:2.1653 train_time:222417ms step_avg:48.35ms -step:4800/20000 train_loss:2.0977 train_time:232099ms step_avg:48.35ms -step:5000/20000 train_loss:2.1822 train_time:241786ms step_avg:48.36ms -step:5200/20000 train_loss:2.2444 train_time:251474ms step_avg:48.36ms -step:5400/20000 train_loss:2.1939 train_time:261163ms step_avg:48.36ms -step:5600/20000 train_loss:2.0976 train_time:270860ms step_avg:48.37ms -step:5800/20000 train_loss:2.2809 train_time:280540ms step_avg:48.37ms -step:6000/20000 train_loss:2.0262 train_time:290222ms step_avg:48.37ms -step:6000/20000 val_loss:2.0966 val_bpb:1.2417 train_time:290251ms step_avg:48.38ms -step:6200/20000 train_loss:2.1005 train_time:299909ms step_avg:48.37ms -step:6400/20000 train_loss:1.9273 train_time:309598ms step_avg:48.37ms -step:6600/20000 train_loss:2.0822 train_time:319283ms step_avg:48.38ms -step:6800/20000 train_loss:2.1677 train_time:328963ms step_avg:48.38ms -step:7000/20000 train_loss:2.0135 train_time:338650ms step_avg:48.38ms -step:7200/20000 train_loss:2.0018 train_time:348331ms step_avg:48.38ms -step:7400/20000 train_loss:1.9328 train_time:358015ms step_avg:48.38ms -step:7600/20000 train_loss:2.0359 train_time:367703ms step_avg:48.38ms -step:7800/20000 train_loss:2.0807 train_time:377384ms step_avg:48.38ms -step:8000/20000 train_loss:2.0150 train_time:387064ms step_avg:48.38ms -step:8000/20000 val_loss:2.0768 val_bpb:1.2300 train_time:387094ms step_avg:48.39ms -step:8200/20000 train_loss:2.1487 train_time:396747ms step_avg:48.38ms -step:8400/20000 train_loss:2.1440 train_time:406469ms step_avg:48.39ms -step:8600/20000 train_loss:2.1641 train_time:416138ms step_avg:48.39ms -step:8800/20000 train_loss:2.0296 train_time:425826ms step_avg:48.39ms -step:9000/20000 train_loss:2.0574 train_time:435503ms step_avg:48.39ms -step:9200/20000 train_loss:2.1356 train_time:445268ms step_avg:48.40ms -step:9400/20000 train_loss:1.9742 train_time:454945ms step_avg:48.40ms -step:9600/20000 train_loss:2.0251 train_time:464630ms step_avg:48.40ms -step:9800/20000 train_loss:2.0776 train_time:474313ms step_avg:48.40ms -step:10000/20000 train_loss:2.1372 train_time:483997ms step_avg:48.40ms -step:10000/20000 val_loss:2.0570 val_bpb:1.2183 train_time:484027ms step_avg:48.40ms -step:10200/20000 train_loss:1.9914 train_time:493676ms step_avg:48.40ms -step:10400/20000 train_loss:2.2388 train_time:503361ms step_avg:48.40ms -step:10600/20000 train_loss:1.8790 train_time:513046ms step_avg:48.40ms -step:10800/20000 train_loss:2.0378 train_time:522738ms step_avg:48.40ms -step:11000/20000 train_loss:2.1572 train_time:532426ms step_avg:48.40ms -step:11200/20000 train_loss:2.1221 train_time:542116ms step_avg:48.40ms -step:11400/20000 train_loss:2.1492 train_time:551800ms step_avg:48.40ms -step:11600/20000 train_loss:2.0106 train_time:561488ms step_avg:48.40ms -step:11800/20000 train_loss:2.0874 train_time:571185ms step_avg:48.41ms -step:12000/20000 train_loss:1.9972 train_time:580870ms step_avg:48.41ms -step:12000/20000 val_loss:2.0215 val_bpb:1.1973 train_time:580900ms step_avg:48.41ms -step:12200/20000 train_loss:2.1699 train_time:590570ms step_avg:48.41ms -step:12395/20000 val_loss:2.0177 val_bpb:1.1950 train_time:600044ms step_avg:48.41ms -stopping_early: wallclock_cap train_time:600044ms step:12395/20000 -peak memory allocated: 11251 MiB reserved: 11396 MiB -Serialized model: 86099351 bytes -Code size: 56770 bytes -Total submission size: 86156121 bytes -Serialized model int6+zlib-9: 15296720 bytes (payload:21906720 raw_torch:21951833 payload_ratio:3.93x) -Total submission size int6+zlib-9: 15353490 bytes -final_int6_roundtrip val_loss:2.0203 val_bpb:1.1965 eval_time:1547ms -final_int6_roundtrip_exact val_loss:2.02025601 val_bpb:1.19650941 -final_sliding_window_eval stride:64 val_loss:1.9637 val_bpb:1.1630 eval_time:72650ms -final_sliding_window_eval_exact stride:64 val_loss:1.96369923 val_bpb:1.16301431 diff --git a/records/track_10min_16mb/2026-03-19_MixedQuant_Int6Int8_SlidingWindow/train_gpt.py b/records/track_10min_16mb/2026-03-19_MixedQuant_Int6Int8_SlidingWindow/train_gpt.py deleted file mode 100644 index 284a42bd72..0000000000 --- a/records/track_10min_16mb/2026-03-19_MixedQuant_Int6Int8_SlidingWindow/train_gpt.py +++ /dev/null @@ -1,1324 +0,0 @@ -""" -The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. - -Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. -""" - -from __future__ import annotations - -import copy -import glob -import io -import math -import os -import random -import subprocess -import sys -import time -import uuid -import zlib -from pathlib import Path - -try: - import zstandard as zstd - HAS_ZSTD = True -except ImportError: - HAS_ZSTD = False - -import numpy as np -import sentencepiece as spm -import torch -import torch.distributed as dist -import torch.nn.functional as F -from torch import Tensor, nn -from torch.nn.parallel import DistributedDataParallel as DDP - -# ----------------------------- -# HYPERPARAMETERS -# ----------------------------- -# Default Simple Baseline run: -# - 9 transformer blocks at width 512 -# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion -# - vocab size 1024, sequence length 1024, tied embeddings -# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap - -class Hyperparameters: - # Data paths are shard globs produced by the existing preprocessing pipeline. - data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") - train_files = os.path.join(data_path, "fineweb_train_*.bin") - val_files = os.path.join(data_path, "fineweb_val_*.bin") - tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") - run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) - seed = int(os.environ.get("SEED", 1337)) - - # Validation cadence and batch size. Validation always uses the full fineweb_val split. - val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) - val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 0)) - train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) - # Sliding window eval: stride controls how many tokens to slide between windows. - # Only the last `eval_stride` tokens per window are scored, giving each scored token - # (seq_len - eval_stride) tokens of context. Set to 0 to disable (use standard eval). - eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) - - # Training length. - iterations = int(os.environ.get("ITERATIONS", 20000)) - warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3000)) - warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) - train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) - train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) - max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) - qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) - - # Model shape. - vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) - num_layers = int(os.environ.get("NUM_LAYERS", 9)) - num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) - model_dim = int(os.environ.get("MODEL_DIM", 512)) - num_heads = int(os.environ.get("NUM_HEADS", 8)) - mlp_mult = int(os.environ.get("MLP_MULT", 3)) - tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) - rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) - logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) - - # Optimizer hyperparameters. - embed_lr = float(os.environ.get("EMBED_LR", 0.6)) - head_lr = float(os.environ.get("HEAD_LR", 0.008)) - tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.030)) - tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) - matrix_lr = float(os.environ.get("MATRIX_LR", 0.020)) - scalar_lr = float(os.environ.get("SCALAR_LR", 0.020)) - muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) - muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) - muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) - muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) - beta1 = float(os.environ.get("BETA1", 0.9)) - beta2 = float(os.environ.get("BETA2", 0.95)) - adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) - grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) - -# ----------------------------- -# MUON OPTIMIZER -# ----------------------------- -# -# As borrowed from modded-nanogpt -# Background on Muon: https://kellerjordan.github.io/posts/muon/ - -def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: - # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. - # Muon uses this to normalize matrix-shaped gradients before applying them. - a, b, c = (3.4445, -4.7750, 2.0315) - X = G.bfloat16() - X /= X.norm() + eps - transposed = G.size(0) > G.size(1) - if transposed: - X = X.T - for _ in range(steps): - A = X @ X.T - B = b * A + c * A @ A - X = a * X + B @ X - return X.T if transposed else X - - -class Muon(torch.optim.Optimizer): - def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): - super().__init__( - params, - dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), - ) - - @torch.no_grad() - def step(self, closure=None): - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - distributed = dist.is_available() and dist.is_initialized() - world_size = dist.get_world_size() if distributed else 1 - rank = dist.get_rank() if distributed else 0 - - for group in self.param_groups: - params = group["params"] - if not params: - continue - lr = group["lr"] - momentum = group["momentum"] - backend_steps = group["backend_steps"] - nesterov = group["nesterov"] - - total_params = sum(int(p.numel()) for p in params) - updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) - - curr = 0 - for i, p in enumerate(params): - if i % world_size == rank and p.grad is not None: - g = p.grad - state = self.state[p] - if "momentum_buffer" not in state: - state["momentum_buffer"] = torch.zeros_like(g) - buf = state["momentum_buffer"] - buf.mul_(momentum).add_(g) - if nesterov: - g = g.add(buf, alpha=momentum) - g = zeropower_via_newtonschulz5(g, steps=backend_steps) - # Scale correction from Muon reference implementations. - g *= max(1, g.size(0) / g.size(1)) ** 0.5 - updates_flat[curr : curr + p.numel()] = g.reshape(-1) - curr += p.numel() - - if distributed: - dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) - - curr = 0 - for p in params: - g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) - p.add_(g, alpha=-lr) - curr += p.numel() - - return loss - - -# ----------------------------- -# TOKENIZER-AGNOSTIC EVALUATION SETUP -# ----------------------------- -# -# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. -# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. -# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. -# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. - -def build_sentencepiece_luts( - sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device -) -> tuple[Tensor, Tensor, Tensor]: - sp_vocab_size = int(sp.vocab_size()) - table_size = max(sp_vocab_size, vocab_size) - base_bytes_np = np.zeros((table_size,), dtype=np.int16) - has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) - is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) - for token_id in range(sp_vocab_size): - if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): - continue - is_boundary_token_np[token_id] = False - if sp.is_byte(token_id): - base_bytes_np[token_id] = 1 - continue - piece = sp.id_to_piece(token_id) - if piece.startswith("▁"): - has_leading_space_np[token_id] = True - piece = piece[1:] - base_bytes_np[token_id] = len(piece.encode("utf-8")) - return ( - torch.tensor(base_bytes_np, dtype=torch.int16, device=device), - torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), - torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), - ) - - -def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: - files = [Path(p) for p in sorted(glob.glob(pattern))] - if not files: - raise FileNotFoundError(f"No files found for pattern: {pattern}") - # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. - tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() - usable = ((tokens.numel() - 1) // seq_len) * seq_len - if usable <= 0: - raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") - return tokens[: usable + 1] - - -def eval_val( - args: Hyperparameters, - model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - grad_accum_steps: int, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, -) -> tuple[float, float]: - # Validation computes two metrics: - # - val_loss: token cross-entropy (natural log) - # - val_bpb: tokenizer-agnostic compression metric used by the challenge - local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) - if local_batch_tokens < args.train_seq_len: - raise ValueError( - "VAL_BATCH_SIZE must provide at least one sequence per rank; " - f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " - f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" - ) - local_batch_seqs = local_batch_tokens // args.train_seq_len - total_seqs = (val_tokens.numel() - 1) // args.train_seq_len - seq_start = (total_seqs * rank) // world_size - seq_end = (total_seqs * (rank + 1)) // world_size - val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) - val_token_count = torch.zeros((), device=device, dtype=torch.float64) - val_byte_count = torch.zeros((), device=device, dtype=torch.float64) - - model.eval() - with torch.inference_mode(): - for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): - batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) - raw_start = batch_seq_start * args.train_seq_len - raw_end = batch_seq_end * args.train_seq_len + 1 - local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) - x = local[:-1].reshape(-1, args.train_seq_len) - y = local[1:].reshape(-1, args.train_seq_len) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - batch_loss = model(x, y).detach() - batch_token_count = float(y.numel()) - val_loss_sum += batch_loss.to(torch.float64) * batch_token_count - val_token_count += batch_token_count - prev_ids = x.reshape(-1) - tgt_ids = y.reshape(-1) - token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) - token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) - val_byte_count += token_bytes.to(torch.float64).sum() - - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) - - val_loss = val_loss_sum / val_token_count - bits_per_token = val_loss.item() / math.log(2.0) - tokens_per_byte = val_token_count.item() / val_byte_count.item() - model.train() - return float(val_loss.item()), float(bits_per_token * tokens_per_byte) - - -def eval_val_sliding_window( - args: Hyperparameters, - base_model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, - stride: int = 64, -) -> tuple[float, float]: - """Sliding window evaluation: each scored token gets (seq_len - stride) context. - - Instead of chopping validation into non-overlapping 1024-token blocks (where - the first token in each block gets zero context), we slide a 1024-token window - by `stride` tokens at a time and only score the last `stride` tokens per window. - Every scored token sees 960+ tokens of context, dramatically improving BPB. - """ - seq_len = args.train_seq_len - total_tokens = val_tokens.numel() - 1 # need 1 extra for final target - - # Number of windows: first window at offset 0, then slide by stride - num_windows = max((total_tokens - seq_len) // stride + 1, 1) - - # Distribute windows across ranks - win_start = (num_windows * rank) // world_size - win_end = (num_windows * (rank + 1)) // world_size - - val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) - val_token_count = torch.zeros((), device=device, dtype=torch.float64) - val_byte_count = torch.zeros((), device=device, dtype=torch.float64) - - # Eval batch: how many windows per forward pass. Tune for memory. - eval_batch = int(os.environ.get("SW_EVAL_BATCH", 32)) - - base_model.eval() - with torch.inference_mode(): - window_list = list(range(win_start, win_end)) - num_batches = (len(window_list) + eval_batch - 1) // eval_batch - for batch_idx in range(num_batches): - batch_wins = window_list[batch_idx * eval_batch : (batch_idx + 1) * eval_batch] - bsz = len(batch_wins) - - # Build input: each window is val_tokens[w*stride : w*stride + seq_len] - inputs = torch.stack([ - val_tokens[w * stride : w * stride + seq_len] - for w in batch_wins - ]).to(device=device, dtype=torch.int64) # [B, seq_len] - - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - logits = base_model.forward_logits(inputs) # [B, seq_len, vocab] - - # Score only the last `stride` positions per window. - # logits[:, j, :] predicts the token AFTER position j in the input. - # So logits[:, -stride:, :] predicts tokens at input positions - # [seq_len - stride + 1, seq_len + 1) — which are the targets. - scored_logits = logits[:, -stride:, :].reshape(-1, logits.size(-1)) # [B*stride, vocab] - - # Build targets: for window w, targets are val_tokens[w*stride + seq_len - stride + 1 : w*stride + seq_len + 1] - targets = torch.stack([ - val_tokens[w * stride + seq_len - stride + 1 : w * stride + seq_len + 1] - for w in batch_wins - ]).to(device=device, dtype=torch.int64).reshape(-1) # [B*stride] - - loss = F.cross_entropy(scored_logits.float(), targets, reduction="sum") - val_loss_sum += loss.to(torch.float64) - val_token_count += float(targets.numel()) - - # BPB: prev_ids are the input tokens at each scored position - prev_ids = torch.stack([ - val_tokens[w * stride + seq_len - stride : w * stride + seq_len] - for w in batch_wins - ]).to(device=device, dtype=torch.int64).reshape(-1) # [B*stride] - - tgt_ids = targets - token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) - token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) - val_byte_count += token_bytes.to(torch.float64).sum() - - if rank == 0 and (batch_idx + 1) % 100 == 0: - print(f" sw_eval batch {batch_idx + 1}/{num_batches}", flush=True) - - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) - - val_loss = val_loss_sum / val_token_count - bits_per_token = val_loss.item() / math.log(2.0) - tokens_per_byte = val_token_count.item() / val_byte_count.item() - base_model.train() - return float(val_loss.item()), float(bits_per_token * tokens_per_byte) - -# ----------------------------- -# POST-TRAINING QUANTIZATION -# ----------------------------- -# -# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. -# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. -# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. - -CONTROL_TENSOR_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "CONTROL_TENSOR_NAME_PATTERNS", - "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", - ",".join(CONTROL_TENSOR_NAME_PATTERNS), - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 -INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 -INT8_PER_ROW_SCALE_DTYPE = torch.float16 - -INT6_QUANT_RANGE = 31 # int6: [-31, 31] -INT8_QUANT_RANGE = 127 # int8: [-127, 127] -INT6_CLIP_Q = 0.9999984 - -# Tensors matching these patterns get int8 (127 levels) instead of int6 (31 levels) -# during post-training quantization. This is for weights that DON'T have fake-quant -# STE protection during training (e.g. nn.Embedding), so they need gentler quantization. -INT8_FULL_RANGE_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "INT8_FULL_RANGE_PATTERNS", - "tok_emb", - ).split(",") - if pattern -) - -def tensor_nbytes(t: Tensor) -> int: - return int(t.numel()) * int(t.element_size()) - -def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: - if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): - return t.float().contiguous() - if t.dtype in {torch.float32, torch.bfloat16}: - passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") - return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() - return t - -def quantize_float_tensor(t: Tensor, quant_range: int = INT6_QUANT_RANGE) -> tuple[Tensor, Tensor]: - t32 = t.float() - if t32.ndim == 2: - # Per-row quantization: [-quant_range, quant_range] stored in int8 containers. - # For int6 (range=31), the unused high bits compress extremely well with zstd. - clip_abs = ( - torch.quantile(t32.abs(), INT6_CLIP_Q, dim=1) - if t32.numel() - else torch.empty((t32.shape[0],), dtype=torch.float32) - ) - clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) - scale = (clip_abs / float(quant_range)).clamp_min(1.0 / float(quant_range)) - q = torch.clamp(torch.round(clipped / scale[:, None]), -quant_range, quant_range).to(torch.int8).contiguous() - return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() - - # Vectors / scalars still use int8 per-tensor scale (they're tiny). - clip_abs = float(torch.quantile(t32.abs().flatten(), INT6_CLIP_Q).item()) if t32.numel() else 0.0 - scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) - q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() - return q, scale - -def quantize_state_dict_int8(state_dict: dict[str, Tensor]): - # Single supported clean-script export format: - # - per-row int8 for 2D float tensors - # - per-tensor int8 for other float tensors - # - exact passthrough for non-floats - # - passthrough for small float tensors, stored as fp16 to save bytes - quantized: dict[str, Tensor] = {} - scales: dict[str, Tensor] = {} - dtypes: dict[str, str] = {} - passthrough: dict[str, Tensor] = {} - passthrough_orig_dtypes: dict[str, str] = {} - qmeta: dict[str, dict[str, object]] = {} - stats = dict.fromkeys( - ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), - 0, - ) - - for name, tensor in state_dict.items(): - t = tensor.detach().to("cpu").contiguous() - stats["param_count"] += int(t.numel()) - stats["num_tensors"] += 1 - stats["baseline_tensor_bytes"] += tensor_nbytes(t) - - if not t.is_floating_point(): - stats["num_nonfloat_tensors"] += 1 - passthrough[name] = t - stats["int8_payload_bytes"] += tensor_nbytes(t) - continue - - # Small float tensors are cheap enough to keep directly. We still downcast - # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. - if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: - kept = keep_float_tensor(name, t, passthrough_orig_dtypes) - passthrough[name] = kept - stats["int8_payload_bytes"] += tensor_nbytes(kept) - continue - - stats["num_float_tensors"] += 1 - # Use int8 (127 levels) for tensors without STE fake-quant protection - # (e.g. tok_emb), int6 (31 levels) for block weights that trained with STE. - use_int8 = any(pattern in name for pattern in INT8_FULL_RANGE_PATTERNS) - qr = INT8_QUANT_RANGE if use_int8 else INT6_QUANT_RANGE - q, s = quantize_float_tensor(t, quant_range=qr) - if s.ndim > 0: - qmeta[name] = {"scheme": "per_row", "axis": 0} - quantized[name] = q - scales[name] = s - dtypes[name] = str(t.dtype).removeprefix("torch.") - stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) - - obj: dict[str, object] = { - "__quant_format__": "int8_clean_per_row_v1", - "quantized": quantized, - "scales": scales, - "dtypes": dtypes, - "passthrough": passthrough, - } - if qmeta: - obj["qmeta"] = qmeta - if passthrough_orig_dtypes: - obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes - return obj, stats - -def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: - out: dict[str, Tensor] = {} - qmeta = obj.get("qmeta", {}) - passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) - for name, q in obj["quantized"].items(): - dtype = getattr(torch, obj["dtypes"][name]) - s = obj["scales"][name] - if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: - s = s.to(dtype=torch.float32) - # Broadcast the saved row scale back across trailing dimensions. - out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() - else: - scale = float(s.item()) - out[name] = (q.float() * scale).to(dtype=dtype).contiguous() - for name, t in obj["passthrough"].items(): - # Restore small tensors, undoing the temporary fp16 storage cast if needed. - out_t = t.detach().to("cpu").contiguous() - orig_dtype = passthrough_orig_dtypes.get(name) - if isinstance(orig_dtype, str): - out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() - out[name] = out_t - return out - - -# ----------------------------- -# DATA LOADING -# ----------------------------- - -def load_data_shard(file: Path) -> Tensor: - header_bytes = 256 * np.dtype(" None: - self.file_idx = (self.file_idx + 1) % len(self.files) - self.tokens = load_data_shard(self.files[self.file_idx]) - self.pos = 0 - - def take(self, n: int) -> Tensor: - chunks: list[Tensor] = [] - remaining = n - while remaining > 0: - avail = self.tokens.numel() - self.pos - if avail <= 0: - self._advance_file() - continue - k = min(remaining, avail) - chunks.append(self.tokens[self.pos : self.pos + k]) - self.pos += k - remaining -= k - return chunks[0] if len(chunks) == 1 else torch.cat(chunks) - - -class DistributedTokenLoader: - # Each call consumes a contiguous chunk from the shared token stream, then slices out - # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. - def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): - self.rank = rank - self.world_size = world_size - self.device = device - self.stream = TokenStream(pattern) - - def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: - local_tokens = global_tokens // (self.world_size * grad_accum_steps) - per_rank_span = local_tokens + 1 - chunk = self.stream.take(per_rank_span * self.world_size) - start = self.rank * per_rank_span - local = chunk[start : start + per_rank_span].to(dtype=torch.int64) - x = local[:-1].reshape(-1, seq_len) - y = local[1:].reshape(-1, seq_len) - return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) - -# ----------------------------- -# TRANSFORMER MODULES -# ----------------------------- - -class RMSNorm(nn.Module): - def __init__(self, eps: float | None = None): - super().__init__() - self.eps = eps - - def forward(self, x: Tensor) -> Tensor: - return F.rms_norm(x, (x.size(-1),), eps=self.eps) - - -class CastedLinear(nn.Linear): - # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. - # During training, apply fake int6 quantization (STE) so the model learns to be robust - # to the post-training quantization that will be applied for the 16MB artifact. - def forward(self, x: Tensor) -> Tensor: - w = self.weight.to(x.dtype) - if self.training and w.ndim == 2: - # Fake int6 per-row quantization with Straight-Through Estimator: - # Forward uses quantized weights, backward passes gradients through as-is. - with torch.no_grad(): - w32 = w.float() - clip_abs = torch.quantile(w32.abs(), INT6_CLIP_Q, dim=1).clamp_min(1e-8) - scale = clip_abs / INT6_QUANT_RANGE - w_clipped = torch.clamp(w32, -clip_abs[:, None], clip_abs[:, None]) - w_q = (torch.round(w_clipped / scale[:, None]) * scale[:, None]).to(x.dtype) - w = w + (w_q - w).detach() # STE: value of w_q, gradient of w - bias = self.bias.to(x.dtype) if self.bias is not None else None - return F.linear(x, w, bias) - - -def restore_low_dim_params_to_fp32(module: nn.Module) -> None: - # Keep small/control parameters in fp32 even when the model body runs in bf16. - with torch.no_grad(): - for name, param in module.named_parameters(): - if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: - param.data = param.data.float() - - -class Rotary(nn.Module): - # Caches cos/sin tables per sequence length on the current device. - def __init__(self, dim: int, base: float = 10000.0): - super().__init__() - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self._seq_len_cached = 0 - self._cos_cached: Tensor | None = None - self._sin_cached: Tensor | None = None - - def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: - if ( - self._cos_cached is None - or self._sin_cached is None - or self._seq_len_cached != seq_len - or self._cos_cached.device != device - ): - t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) - freqs = torch.outer(t, self.inv_freq.to(device)) - self._cos_cached = freqs.cos()[None, None, :, :] - self._sin_cached = freqs.sin()[None, None, :, :] - self._seq_len_cached = seq_len - return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) - - -def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: - half = x.size(-1) // 2 - x1, x2 = x[..., :half], x[..., half:] - return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) - - -class CausalSelfAttention(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - rope_base: float, - qk_gain_init: float, - ): - super().__init__() - if dim % num_heads != 0: - raise ValueError("model_dim must be divisible by num_heads") - if num_heads % num_kv_heads != 0: - raise ValueError("num_heads must be divisible by num_kv_heads") - self.num_heads = num_heads - self.num_kv_heads = num_kv_heads - self.head_dim = dim // num_heads - if self.head_dim % 2 != 0: - raise ValueError("head_dim must be even for RoPE") - kv_dim = self.num_kv_heads * self.head_dim - self.c_q = CastedLinear(dim, dim, bias=False) - self.c_k = CastedLinear(dim, kv_dim, bias=False) - self.c_v = CastedLinear(dim, kv_dim, bias=False) - self.proj = CastedLinear(dim, dim, bias=False) - self.proj._zero_init = True - self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) - self.rotary = Rotary(self.head_dim, base=rope_base) - - def forward(self, x: Tensor) -> Tensor: - bsz, seqlen, dim = x.shape - q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) - k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) - v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) - q = F.rms_norm(q, (q.size(-1),)) - k = F.rms_norm(k, (k.size(-1),)) - cos, sin = self.rotary(seqlen, x.device, q.dtype) - q = apply_rotary_emb(q, cos, sin) - k = apply_rotary_emb(k, cos, sin) - q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] - y = F.scaled_dot_product_attention( - q, - k, - v, - attn_mask=None, - is_causal=True, - enable_gqa=(self.num_kv_heads != self.num_heads), - ) - y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) - return self.proj(y) - - -class MLP(nn.Module): - # relu^2 MLP from the original modded-nanogpt setup - def __init__(self, dim: int, mlp_mult: int): - super().__init__() - hidden = mlp_mult * dim - self.fc = CastedLinear(dim, hidden, bias=False) - self.proj = CastedLinear(hidden, dim, bias=False) - self.proj._zero_init = True - - def forward(self, x: Tensor) -> Tensor: - x = torch.relu(self.fc(x)) - return self.proj(x.square()) - - -class Block(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: int, - rope_base: float, - qk_gain_init: float, - ): - super().__init__() - self.attn_norm = RMSNorm() - self.mlp_norm = RMSNorm() - self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) - self.mlp = MLP(dim, mlp_mult) - self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) - - def forward(self, x: Tensor, x0: Tensor) -> Tensor: - mix = self.resid_mix.to(dtype=x.dtype) - x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 - attn_out = self.attn(self.attn_norm(x)) - x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out - x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) - return x - - -class GPT(nn.Module): - def __init__( - self, - vocab_size: int, - num_layers: int, - model_dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: int, - tie_embeddings: bool, - tied_embed_init_std: float, - logit_softcap: float, - rope_base: float, - qk_gain_init: float, - ): - super().__init__() - if logit_softcap <= 0.0: - raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") - self.tie_embeddings = tie_embeddings - self.tied_embed_init_std = tied_embed_init_std - self.logit_softcap = logit_softcap - self.tok_emb = nn.Embedding(vocab_size, model_dim) - self.num_encoder_layers = num_layers // 2 - self.num_decoder_layers = num_layers - self.num_encoder_layers - self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) - self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) - self.blocks = nn.ModuleList( - [ - Block( - model_dim, - num_heads, - num_kv_heads, - mlp_mult, - rope_base, - qk_gain_init, - ) - for i in range(num_layers) - ] - ) - self.final_norm = RMSNorm() - self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) - if self.lm_head is not None: - self.lm_head._zero_init = True - self._init_weights() - - def _init_weights(self) -> None: - if self.tie_embeddings: - nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) - for module in self.modules(): - if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): - nn.init.zeros_(module.weight) - - def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: - x = self.tok_emb(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - x0 = x - skips: list[Tensor] = [] - - # First half stores skips; second half reuses them in reverse order. - for i in range(self.num_encoder_layers): - x = self.blocks[i](x, x0) - skips.append(x) - for i in range(self.num_decoder_layers): - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - x = self.blocks[self.num_encoder_layers + i](x, x0) - - x = self.final_norm(x).reshape(-1, x.size(-1)) - targets = target_ids.reshape(-1) - if self.tie_embeddings: - logits_proj = F.linear(x, self.tok_emb.weight) - else: - if self.lm_head is None: - raise RuntimeError("lm_head is required when tie_embeddings=False") - logits_proj = self.lm_head(x) - logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - return F.cross_entropy(logits.float(), targets, reduction="mean") - - def forward_logits(self, input_ids: Tensor) -> Tensor: - """Return logits without computing loss. Used for sliding window eval.""" - x = self.tok_emb(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - x0 = x - skips: list[Tensor] = [] - for i in range(self.num_encoder_layers): - x = self.blocks[i](x, x0) - skips.append(x) - for i in range(self.num_decoder_layers): - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - x = self.blocks[self.num_encoder_layers + i](x, x0) - x = self.final_norm(x) - if self.tie_embeddings: - logits_proj = F.linear(x, self.tok_emb.weight.to(x.dtype)) - else: - if self.lm_head is None: - raise RuntimeError("lm_head is required when tie_embeddings=False") - logits_proj = self.lm_head(x) - return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - - -# ----------------------------- -# TRAINING -# ----------------------------- - -def main() -> None: - global zeropower_via_newtonschulz5 - - code = Path(__file__).read_text(encoding="utf-8") - args = Hyperparameters() - zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) - - # ----------------------------- - # DISTRIBUTED + CUDA SETUP - # ----------------------------- - - distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ - rank = int(os.environ.get("RANK", "0")) - world_size = int(os.environ.get("WORLD_SIZE", "1")) - local_rank = int(os.environ.get("LOCAL_RANK", "0")) - if world_size <= 0: - raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") - if 8 % world_size != 0: - raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") - grad_accum_steps = 8 // world_size - grad_scale = 1.0 / grad_accum_steps - if not torch.cuda.is_available(): - raise RuntimeError("CUDA is required") - device = torch.device("cuda", local_rank) - torch.cuda.set_device(device) - if distributed: - dist.init_process_group(backend="nccl", device_id=device) - dist.barrier() - master_process = rank == 0 - - # Fast math knobs - torch.backends.cuda.matmul.allow_tf32 = True - torch.backends.cudnn.allow_tf32 = True - from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp - - enable_cudnn_sdp(False) - enable_flash_sdp(True) - enable_mem_efficient_sdp(False) - enable_math_sdp(False) - - logfile = None - if master_process: - os.makedirs("logs", exist_ok=True) - logfile = f"logs/{args.run_id}.txt" - print(logfile) - - def log0(msg: str, console: bool = True) -> None: - if not master_process: - return - if console: - print(msg) - if logfile is not None: - with open(logfile, "a", encoding="utf-8") as f: - print(msg, file=f) - - log0(code, console=False) - log0("=" * 100, console=False) - log0(f"Running Python {sys.version}", console=False) - log0(f"Running PyTorch {torch.__version__}", console=False) - log0( - subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, - console=False, - ) - log0("=" * 100, console=False) - - # ----------------------------- - # TOKENIZER + VALIDATION METRIC SETUP - # ----------------------------- - - random.seed(args.seed) - np.random.seed(args.seed) - torch.manual_seed(args.seed) - torch.cuda.manual_seed_all(args.seed) - - if not args.tokenizer_path.endswith(".model"): - raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") - sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) - if int(sp.vocab_size()) != args.vocab_size: - raise ValueError( - f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" - ) - dataset_dir = Path(args.data_path).resolve() - actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) - val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) - base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( - sp, args.vocab_size, device - ) - log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") - log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") - log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") - - # ----------------------------- - # MODEL + OPTIMIZER SETUP - # ----------------------------- - - base_model = GPT( - vocab_size=args.vocab_size, - num_layers=args.num_layers, - model_dim=args.model_dim, - num_heads=args.num_heads, - num_kv_heads=args.num_kv_heads, - mlp_mult=args.mlp_mult, - tie_embeddings=args.tie_embeddings, - tied_embed_init_std=args.tied_embed_init_std, - logit_softcap=args.logit_softcap, - rope_base=args.rope_base, - qk_gain_init=args.qk_gain_init, - ).to(device).bfloat16() - for module in base_model.modules(): - if isinstance(module, CastedLinear): - module.float() - restore_low_dim_params_to_fp32(base_model) - compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) - model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model - - # Optimizer split: - # - token embedding (Adam) uses EMBED_LR - # - untied lm_head (Adam) uses HEAD_LR - # - matrix params in transformer blocks use MATRIX_LR via Muon - # - vectors/scalars use SCALAR_LR via Adam - block_named_params = list(base_model.blocks.named_parameters()) - matrix_params = [ - p - for name, p in block_named_params - if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - scalar_params = [ - p - for name, p in block_named_params - if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - if base_model.skip_weights.numel() > 0: - scalar_params.append(base_model.skip_weights) - token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr - optimizer_tok = torch.optim.Adam( - [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizer_muon = Muon( - matrix_params, - lr=args.matrix_lr, - momentum=args.muon_momentum, - backend_steps=args.muon_backend_steps, - ) - for group in optimizer_muon.param_groups: - group["base_lr"] = args.matrix_lr - optimizer_scalar = torch.optim.Adam( - [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] - if base_model.lm_head is not None: - optimizer_head = torch.optim.Adam( - [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizers.insert(1, optimizer_head) - - n_params = sum(p.numel() for p in base_model.parameters()) - log0(f"model_params:{n_params}") - log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") - log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") - log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") - log0( - f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " - f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " - f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" - ) - log0( - f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " - f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " - f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" - ) - log0(f"seed:{args.seed}") - - # ----------------------------- - # DATA LOADER & MODEL WARMUP - # ----------------------------- - - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - def zero_grad_all() -> None: - for opt in optimizers: - opt.zero_grad(set_to_none=True) - - max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None - - def lr_mul(step: int, elapsed_ms: float) -> float: - if args.warmdown_iters <= 0: - return 1.0 - if max_wallclock_ms is None: - warmdown_start = max(args.iterations - args.warmdown_iters, 0) - return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 - step_ms = elapsed_ms / max(step, 1) - warmdown_ms = args.warmdown_iters * step_ms - remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) - return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 - - # Warmup primes the compiled forward/backward/optimizer paths, then we restore the - # initial weights/optimizer state so measured training starts from the true init. - if args.warmup_steps > 0: - initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} - initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] - model.train() - for warmup_step in range(args.warmup_steps): - zero_grad_all() - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - warmup_loss = model(x, y) - (warmup_loss * grad_scale).backward() - for opt in optimizers: - opt.step() - zero_grad_all() - if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: - log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") - base_model.load_state_dict(initial_model_state, strict=True) - for opt, state in zip(optimizers, initial_optimizer_states, strict=True): - opt.load_state_dict(state) - zero_grad_all() - if distributed: - model.require_backward_grad_sync = True - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - # ----------------------------- - # MAIN TRAINING LOOP - # ----------------------------- - - training_time_ms = 0.0 - stop_after_step: int | None = None - torch.cuda.synchronize() - t0 = time.perf_counter() - - step = 0 - while True: - last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) - - should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) - if should_validate: - torch.cuda.synchronize() - training_time_ms += 1000.0 * (time.perf_counter() - t0) - val_loss, val_bpb = eval_val( - args, - model, - rank, - world_size, - device, - grad_accum_steps, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, - ) - log0( - f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " - f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" - ) - torch.cuda.synchronize() - t0 = time.perf_counter() - - if last_step: - if stop_after_step is not None and step < args.iterations: - log0( - f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " - f"step:{step}/{args.iterations}" - ) - break - - elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - scale = lr_mul(step, elapsed_ms) - zero_grad_all() - train_loss = torch.zeros((), device=device) - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - loss = model(x, y) - train_loss += loss.detach() - (loss * grad_scale).backward() - train_loss /= grad_accum_steps - - frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 - muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum - for group in optimizer_muon.param_groups: - group["momentum"] = muon_momentum - - for opt in optimizers: - for group in opt.param_groups: - group["lr"] = group["base_lr"] * scale - - if args.grad_clip_norm > 0: - torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) - for opt in optimizers: - opt.step() - zero_grad_all() - - step += 1 - approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - should_log_train = ( - args.train_log_every > 0 - and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) - ) - if should_log_train: - log0( - f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " - f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" - ) - - # Needed to sync whether we've reached the wallclock cap. - reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms - if distributed and max_wallclock_ms is not None: - reached_cap_tensor = torch.tensor(int(reached_cap), device=device) - dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) - reached_cap = bool(reached_cap_tensor.item()) - if stop_after_step is None and reached_cap: - stop_after_step = step - - log0( - f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " - f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" - ) - - # ----------------------------- - # SERIALIZATION + ROUNDTRIP VALIDATION - # ----------------------------- - # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce - # the compressed int8+zlib artifact and validate the round-tripped weights. - - if master_process: - torch.save(base_model.state_dict(), "final_model.pt") - model_bytes = os.path.getsize("final_model.pt") - code_bytes = len(code.encode("utf-8")) - log0(f"Serialized model: {model_bytes} bytes") - log0(f"Code size: {code_bytes} bytes") - log0(f"Total submission size: {model_bytes + code_bytes} bytes") - - quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) - quant_buf = io.BytesIO() - torch.save(quant_obj, quant_buf) - quant_raw = quant_buf.getvalue() - if HAS_ZSTD: - cctx = zstd.ZstdCompressor(level=22) - quant_blob = cctx.compress(quant_raw) - compress_name = "zstd-22" - else: - quant_blob = zlib.compress(quant_raw, level=9) - compress_name = "zlib-9" - quant_raw_bytes = len(quant_raw) - if master_process: - quant_ext = "int6.ptz" if HAS_ZSTD else "int6.ptz" - with open(f"final_model.{quant_ext}", "wb") as f: - f.write(quant_blob) - quant_file_bytes = os.path.getsize(f"final_model.{quant_ext}") - code_bytes = len(code.encode("utf-8")) - ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) - log0( - f"Serialized model int6+{compress_name}: {quant_file_bytes} bytes " - f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" - ) - log0(f"Total submission size int6+{compress_name}: {quant_file_bytes + code_bytes} bytes") - - if distributed: - dist.barrier() - quant_ext = "int6.ptz" - with open(f"final_model.{quant_ext}", "rb") as f: - quant_blob_disk = f.read() - if HAS_ZSTD: - dctx = zstd.ZstdDecompressor() - quant_decompressed = dctx.decompress(quant_blob_disk) - else: - quant_decompressed = zlib.decompress(quant_blob_disk) - quant_state = torch.load(io.BytesIO(quant_decompressed), map_location="cpu") - base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) - torch.cuda.synchronize() - t_qeval = time.perf_counter() - q_val_loss, q_val_bpb = eval_val( - args, - model, - rank, - world_size, - device, - grad_accum_steps, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, - ) - torch.cuda.synchronize() - log0( - f"final_int6_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " - f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" - ) - log0(f"final_int6_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") - - # Sliding window eval: gives each scored token (seq_len - stride) context. - if args.eval_stride > 0: - torch.cuda.synchronize() - t_sw = time.perf_counter() - sw_val_loss, sw_val_bpb = eval_val_sliding_window( - args, - base_model, - rank, - world_size, - device, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, - stride=args.eval_stride, - ) - torch.cuda.synchronize() - log0( - f"final_sliding_window_eval stride:{args.eval_stride} " - f"val_loss:{sw_val_loss:.4f} val_bpb:{sw_val_bpb:.4f} " - f"eval_time:{1000.0 * (time.perf_counter() - t_sw):.0f}ms" - ) - log0( - f"final_sliding_window_eval_exact stride:{args.eval_stride} " - f"val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}" - ) - - if distributed: - dist.destroy_process_group() - - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/records/track_10min_16mb/2026-03-19_Seq2048_FP16Emb_TunedLR/README.md b/records/track_10min_16mb/2026-03-19_Seq2048_FP16Emb_TunedLR/README.md deleted file mode 100644 index 797059c373..0000000000 --- a/records/track_10min_16mb/2026-03-19_Seq2048_FP16Emb_TunedLR/README.md +++ /dev/null @@ -1,68 +0,0 @@ -# 10L Int6 QAT + Zstd MLP2.6x Muon0.99 Sliding Window - -## Summary - -Stacked improvements on the Naive Baseline: - -1. **10 transformer layers** (from 9). - -2. **STE int6 QAT**: Straight-through estimator fake quantization during training. Each CastedLinear forward pass applies `fake_quantize_int6(w)` — quantize to [-31,31], dequantize, with gradients flowing through via STE. This teaches the model to be robust to int6 quantization, **completely eliminating the quant gap** (pre-quant = post-quant loss). - -3. **Full int6 quantization**: All 2D block weights quantized to [-31,31] (63 levels) in int8 container. - -4. **zstd-22 compression**: Better than zlib for int6 data. - -5. **MLP hidden 1344** (2.625x model_dim): Wider MLP enabled by int6+zstd savings. - -6. **FP16 tied embedding passthrough**. - -7. **Sequence length 2048**. - -8. **Muon momentum 0.99**, warmup from 0.92 over 1500 steps. - -9. **MATRIX_LR=0.02, SCALAR_LR=0.02, TIED_EMBED_LR=0.04**. - -10. **Gradient clipping** GRAD_CLIP_NORM=0.3. - -11. **Sliding window evaluation** stride=64. - -## Configuration - -```bash -MLP_HIDDEN=1344 \ -DATA_PATH=./data/datasets/fineweb10B_sp1024/ \ -TOKENIZER_PATH=./data/tokenizers/fineweb_1024_bpe.model \ -VOCAB_SIZE=1024 \ -MAX_WALLCLOCK_SECONDS=600 \ -EVAL_STRIDE=64 \ -torchrun --standalone --nproc_per_node=8 train_gpt.py -``` - -Requires: `pip install zstandard` - -## Results - -| Seed | Steps | val_bpb (standard) | val_bpb (sliding) | Artifact size | -|------|-------|--------------------|--------------------|---------------| -| 1337 | 8,319 | 1.1821 | 1.1610 | 15,558,319 | -| 42 | ~8,300 | ~1.1815 | 1.1598 | ~15,558,000 | -| 3 | ~8,300 | ~1.1810 | 1.1586 | ~15,558,000 | - -**Mean val_bpb (sliding): 1.1598** (std: 0.00120) -**Mean val_loss (sliding): 1.9583** (std: 0.00203) - -Quant gap: **0.0000** — STE QAT completely eliminated quantization loss. - -Statistical significance vs SOTA (1.2244 BPB / 2.0727 val_loss): -- Improvement: 0.1144 nats (threshold: 0.005) -- t-statistic: -93.6, df=2, p << 0.01 - -Hardware: 8xH100 80GB HBM3, PyTorch 2.8.0+cu128, ~72ms/step avg. -QAT overhead: ~28% (72ms vs 69ms without QAT). -Sliding window eval time: ~370s. - -## Included Files - -- `train_gpt.py` (modified training script) -- `train_seed1337.log`, `train_seed42.log`, `train_seed3.log` -- `submission.json` diff --git a/records/track_10min_16mb/2026-03-19_Seq2048_FP16Emb_TunedLR/submission.json b/records/track_10min_16mb/2026-03-19_Seq2048_FP16Emb_TunedLR/submission.json deleted file mode 100644 index e086e6465a..0000000000 --- a/records/track_10min_16mb/2026-03-19_Seq2048_FP16Emb_TunedLR/submission.json +++ /dev/null @@ -1,11 +0,0 @@ -{ - "author": "yahya010", - "github_id": "yahya010", - "name": "10L Int6 QAT + Zstd MLP2.6x Muon0.99 Sliding Window", - "blurb": "10-layer 512dim SP-1024, STE int6 QAT (zero quant gap), full int6 [-31,31] + zstd-22, MLP hidden=1344, fp16 tied embedding, Muon 0.99, LR 0.02, grad clip 0.3, sliding window stride=64.", - "date": "2026-03-19", - "val_loss": 1.95627871, - "val_bpb": 1.15861696, - "bytes_total": 15558319, - "bytes_code": 56005 -} diff --git a/records/track_10min_16mb/2026-03-19_Seq2048_FP16Emb_TunedLR/train_gpt.py b/records/track_10min_16mb/2026-03-19_Seq2048_FP16Emb_TunedLR/train_gpt.py deleted file mode 100644 index 555cccfdad..0000000000 --- a/records/track_10min_16mb/2026-03-19_Seq2048_FP16Emb_TunedLR/train_gpt.py +++ /dev/null @@ -1,1313 +0,0 @@ -""" -The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. - -Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. -""" - -from __future__ import annotations - -import copy -import glob -import io -import math -import os -import random -import subprocess -import sys -import time -import uuid -import zlib -import zstandard as zstd -from pathlib import Path - -import numpy as np -import sentencepiece as spm -import torch -import torch.distributed as dist -import torch.nn.functional as F -from torch import Tensor, nn -from torch.nn.parallel import DistributedDataParallel as DDP - -# ----------------------------- -# HYPERPARAMETERS -# ----------------------------- -# Default Simple Baseline run: -# - 9 transformer blocks at width 512 -# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion -# - vocab size 1024, sequence length 1024, tied embeddings -# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap - -class Hyperparameters: - # Data paths are shard globs produced by the existing preprocessing pipeline. - data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") - train_files = os.path.join(data_path, "fineweb_train_*.bin") - val_files = os.path.join(data_path, "fineweb_val_*.bin") - tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") - run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) - seed = int(os.environ.get("SEED", 1337)) - - # Validation cadence and batch size. Validation always uses the full fineweb_val split. - val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) - val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) - train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) - # Sliding window eval: stride=0 means standard (non-overlapping) eval. - eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) - eval_seq_len = int(os.environ.get("EVAL_SEQ_LEN", 0)) # 0 = use train_seq_len - - # Training length. - iterations = int(os.environ.get("ITERATIONS", 20000)) - warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3600)) - warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) - train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) - train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) - max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) - qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) - - # Model shape. - vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) - num_layers = int(os.environ.get("NUM_LAYERS", 10)) - num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) - model_dim = int(os.environ.get("MODEL_DIM", 512)) - num_heads = int(os.environ.get("NUM_HEADS", 8)) - mlp_mult = int(os.environ.get("MLP_MULT", 2)) - mlp_hidden = int(os.environ.get("MLP_HIDDEN", 1536)) # 3x model_dim, enabled by int6+zstd compression - tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) - rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) - logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) - - # Optimizer hyperparameters. - embed_lr = float(os.environ.get("EMBED_LR", 0.6)) - head_lr = float(os.environ.get("HEAD_LR", 0.008)) - tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.04)) - tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) - matrix_lr = float(os.environ.get("MATRIX_LR", 0.02)) - scalar_lr = float(os.environ.get("SCALAR_LR", 0.02)) - muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) - muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) - muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) - muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) - beta1 = float(os.environ.get("BETA1", 0.9)) - beta2 = float(os.environ.get("BETA2", 0.95)) - adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) - grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) - -# ----------------------------- -# MUON OPTIMIZER -# ----------------------------- -# -# As borrowed from modded-nanogpt -# Background on Muon: https://kellerjordan.github.io/posts/muon/ - -def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: - # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. - # Muon uses this to normalize matrix-shaped gradients before applying them. - a, b, c = (3.4445, -4.7750, 2.0315) - X = G.bfloat16() - X /= X.norm() + eps - transposed = G.size(0) > G.size(1) - if transposed: - X = X.T - for _ in range(steps): - A = X @ X.T - B = b * A + c * A @ A - X = a * X + B @ X - return X.T if transposed else X - - -class Muon(torch.optim.Optimizer): - def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): - super().__init__( - params, - dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), - ) - - @torch.no_grad() - def step(self, closure=None): - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - distributed = dist.is_available() and dist.is_initialized() - world_size = dist.get_world_size() if distributed else 1 - rank = dist.get_rank() if distributed else 0 - - for group in self.param_groups: - params = group["params"] - if not params: - continue - lr = group["lr"] - momentum = group["momentum"] - backend_steps = group["backend_steps"] - nesterov = group["nesterov"] - - total_params = sum(int(p.numel()) for p in params) - updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) - - curr = 0 - for i, p in enumerate(params): - if i % world_size == rank and p.grad is not None: - g = p.grad - state = self.state[p] - if "momentum_buffer" not in state: - state["momentum_buffer"] = torch.zeros_like(g) - buf = state["momentum_buffer"] - buf.mul_(momentum).add_(g) - if nesterov: - g = g.add(buf, alpha=momentum) - g = zeropower_via_newtonschulz5(g, steps=backend_steps) - # Scale correction from Muon reference implementations. - g *= max(1, g.size(0) / g.size(1)) ** 0.5 - updates_flat[curr : curr + p.numel()] = g.reshape(-1) - curr += p.numel() - - if distributed: - dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) - - curr = 0 - for p in params: - g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) - p.add_(g, alpha=-lr) - curr += p.numel() - - return loss - - -# ----------------------------- -# TOKENIZER-AGNOSTIC EVALUATION SETUP -# ----------------------------- -# -# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. -# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. -# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. -# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. - -def build_sentencepiece_luts( - sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device -) -> tuple[Tensor, Tensor, Tensor]: - sp_vocab_size = int(sp.vocab_size()) - table_size = max(sp_vocab_size, vocab_size) - base_bytes_np = np.zeros((table_size,), dtype=np.int16) - has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) - is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) - for token_id in range(sp_vocab_size): - if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): - continue - is_boundary_token_np[token_id] = False - if sp.is_byte(token_id): - base_bytes_np[token_id] = 1 - continue - piece = sp.id_to_piece(token_id) - if piece.startswith("▁"): - has_leading_space_np[token_id] = True - piece = piece[1:] - base_bytes_np[token_id] = len(piece.encode("utf-8")) - return ( - torch.tensor(base_bytes_np, dtype=torch.int16, device=device), - torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), - torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), - ) - - -def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: - files = [Path(p) for p in sorted(glob.glob(pattern))] - if not files: - raise FileNotFoundError(f"No files found for pattern: {pattern}") - # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. - tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() - usable = ((tokens.numel() - 1) // seq_len) * seq_len - if usable <= 0: - raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") - return tokens[: usable + 1] - - -def eval_val( - args: Hyperparameters, - model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - grad_accum_steps: int, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, -) -> tuple[float, float]: - # Validation computes two metrics: - # - val_loss: token cross-entropy (natural log) - # - val_bpb: tokenizer-agnostic compression metric used by the challenge - local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) - if local_batch_tokens < args.train_seq_len: - raise ValueError( - "VAL_BATCH_SIZE must provide at least one sequence per rank; " - f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " - f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" - ) - local_batch_seqs = local_batch_tokens // args.train_seq_len - total_seqs = (val_tokens.numel() - 1) // args.train_seq_len - seq_start = (total_seqs * rank) // world_size - seq_end = (total_seqs * (rank + 1)) // world_size - val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) - val_token_count = torch.zeros((), device=device, dtype=torch.float64) - val_byte_count = torch.zeros((), device=device, dtype=torch.float64) - - model.eval() - with torch.inference_mode(): - for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): - batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) - raw_start = batch_seq_start * args.train_seq_len - raw_end = batch_seq_end * args.train_seq_len + 1 - local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) - x = local[:-1].reshape(-1, args.train_seq_len) - y = local[1:].reshape(-1, args.train_seq_len) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - batch_loss = model(x, y).detach() - batch_token_count = float(y.numel()) - val_loss_sum += batch_loss.to(torch.float64) * batch_token_count - val_token_count += batch_token_count - prev_ids = x.reshape(-1) - tgt_ids = y.reshape(-1) - token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) - token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) - val_byte_count += token_bytes.to(torch.float64).sum() - - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) - - val_loss = val_loss_sum / val_token_count - bits_per_token = val_loss.item() / math.log(2.0) - tokens_per_byte = val_token_count.item() / val_byte_count.item() - model.train() - return float(val_loss.item()), float(bits_per_token * tokens_per_byte) - - -def eval_val_sliding( - base_model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, - eval_seq_len: int = 2048, - stride: int = 64, - batch_size: int = 4, -) -> tuple[float, float]: - """Sliding window eval: each token scored with max context.""" - N = val_tokens.numel() - # Distribute windows across ranks - window_starts: list[int] = [] - pos = 0 - while pos + eval_seq_len < N: - window_starts.append(pos) - if pos == 0: - pos = eval_seq_len - stride # after first full window - else: - pos += stride - - # Partition windows across ranks - per_rank = len(window_starts) // max(world_size, 1) - rank_start = rank * per_rank - rank_end = (rank + 1) * per_rank if rank < world_size - 1 else len(window_starts) - my_windows = window_starts[rank_start:rank_end] - - val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) - val_token_count = torch.zeros((), device=device, dtype=torch.float64) - val_byte_count = torch.zeros((), device=device, dtype=torch.float64) - - base_model.eval() - with torch.inference_mode(): - for batch_start in range(0, len(my_windows), batch_size): - batch_windows = my_windows[batch_start:batch_start + batch_size] - xs, ys = [], [] - for w in batch_windows: - xs.append(val_tokens[w:w + eval_seq_len]) - ys.append(val_tokens[w + 1:w + eval_seq_len + 1]) - x = torch.stack(xs).to(device=device, dtype=torch.int64) - y = torch.stack(ys).to(device=device, dtype=torch.int64) - - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - per_token_loss = base_model.forward_per_token_loss(x, y) - - for i, w in enumerate(batch_windows): - if w == 0: - # First window: score all positions - score_start = 0 - else: - # Subsequent windows: score only last `stride` positions - score_start = eval_seq_len - stride - - losses = per_token_loss[i, score_start:] - tgt_ids = y[i, score_start:] - prev_ids = x[i, score_start:] - n_scored = losses.numel() - - val_loss_sum += losses.to(torch.float64).sum() - val_token_count += n_scored - token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) - token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) - val_byte_count += token_bytes.to(torch.float64).sum() - - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) - - val_loss = val_loss_sum / val_token_count - bits_per_token = val_loss.item() / math.log(2.0) - tokens_per_byte = val_token_count.item() / val_byte_count.item() - base_model.train() - return float(val_loss.item()), float(bits_per_token * tokens_per_byte) - - -# ----------------------------- -# POST-TRAINING QUANTIZATION -# ----------------------------- -# -# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. -# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. -# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. - -CONTROL_TENSOR_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "CONTROL_TENSOR_NAME_PATTERNS", - "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", - ",".join(CONTROL_TENSOR_NAME_PATTERNS), - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 -# Large tensors to keep as fp16 passthrough instead of int8 quantization. -# Tied embeddings serve dual duty (input + output head) and are disproportionately -# sensitive to quantization noise (~0.007 BPB gap). -FP16_PASSTHROUGH_PATTERNS = tuple( - p for p in os.environ.get("FP16_PASSTHROUGH_PATTERNS", "tok_emb").split(",") if p -) -INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 -INT8_PER_ROW_SCALE_DTYPE = torch.float16 -INT8_CLIP_PERCENTILE = 99.99984 -INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 -# Full int6 quantization: all 2D block weights use [-31,31] range (63 levels) -# stored in int8 container. zstd-22 compresses the low-entropy data much better. -USE_INT6 = bool(int(os.environ.get("USE_INT6", "1"))) -USE_ZSTD = bool(int(os.environ.get("USE_ZSTD", "1"))) -ZSTD_LEVEL = int(os.environ.get("ZSTD_LEVEL", 22)) - -def tensor_nbytes(t: Tensor) -> int: - return int(t.numel()) * int(t.element_size()) - -def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: - if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): - return t.float().contiguous() - if t.dtype in {torch.float32, torch.bfloat16}: - passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") - return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() - return t - -def quantize_float_tensor(t: Tensor, use_int6: bool = False) -> tuple[Tensor, Tensor]: - t32 = t.float() - qmax = 31 if use_int6 else 127 - if t32.ndim == 2: - # Matrices get one scale per row, which usually tracks output-channel - # ranges much better than a single tensor-wide scale. - clip_abs = ( - torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) - if t32.numel() - else torch.empty((t32.shape[0],), dtype=torch.float32) - ) - clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) - scale = (clip_abs / float(qmax)).clamp_min(1.0 / float(qmax)) - q = torch.clamp(torch.round(clipped / scale[:, None]), -qmax, qmax).to(torch.int8).contiguous() - return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() - - # Vectors / scalars use a simpler per-tensor scale. - clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 - scale = torch.tensor(clip_abs / float(qmax) if clip_abs > 0 else 1.0, dtype=torch.float32) - q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -qmax, qmax).to(torch.int8).contiguous() - return q, scale - -def quantize_state_dict_int8(state_dict: dict[str, Tensor]): - # Export format: - # - per-row int6/int8 for 2D float tensors (int6 if USE_INT6) - # - per-tensor int8 for other float tensors - # - exact passthrough for non-floats - # - passthrough for small float tensors, stored as fp16 to save bytes - quantized: dict[str, Tensor] = {} - scales: dict[str, Tensor] = {} - dtypes: dict[str, str] = {} - passthrough: dict[str, Tensor] = {} - passthrough_orig_dtypes: dict[str, str] = {} - qmeta: dict[str, dict[str, object]] = {} - stats = dict.fromkeys( - ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), - 0, - ) - - for name, tensor in state_dict.items(): - t = tensor.detach().to("cpu").contiguous() - stats["param_count"] += int(t.numel()) - stats["num_tensors"] += 1 - stats["baseline_tensor_bytes"] += tensor_nbytes(t) - - if not t.is_floating_point(): - stats["num_nonfloat_tensors"] += 1 - passthrough[name] = t - stats["int8_payload_bytes"] += tensor_nbytes(t) - continue - - # Small float tensors are cheap enough to keep directly. We still downcast - # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. - if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: - kept = keep_float_tensor(name, t, passthrough_orig_dtypes) - passthrough[name] = kept - stats["int8_payload_bytes"] += tensor_nbytes(kept) - continue - - # Large tensors matching FP16_PASSTHROUGH_PATTERNS are kept as fp16 - # instead of int8 (e.g. tied embeddings that are quant-sensitive). - if any(pattern in name for pattern in FP16_PASSTHROUGH_PATTERNS): - kept = keep_float_tensor(name, t, passthrough_orig_dtypes) - passthrough[name] = kept - stats["int8_payload_bytes"] += tensor_nbytes(kept) - continue - - stats["num_float_tensors"] += 1 - q, s = quantize_float_tensor(t, use_int6=USE_INT6) - if s.ndim > 0: - qmeta[name] = {"scheme": "per_row", "axis": 0} - quantized[name] = q - scales[name] = s - dtypes[name] = str(t.dtype).removeprefix("torch.") - stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) - - obj: dict[str, object] = { - "__quant_format__": "int8_clean_per_row_v1", - "quantized": quantized, - "scales": scales, - "dtypes": dtypes, - "passthrough": passthrough, - } - if qmeta: - obj["qmeta"] = qmeta - if passthrough_orig_dtypes: - obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes - return obj, stats - -def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: - out: dict[str, Tensor] = {} - qmeta = obj.get("qmeta", {}) - passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) - for name, q in obj["quantized"].items(): - dtype = getattr(torch, obj["dtypes"][name]) - s = obj["scales"][name] - if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: - s = s.to(dtype=torch.float32) - # Broadcast the saved row scale back across trailing dimensions. - out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() - else: - scale = float(s.item()) - out[name] = (q.float() * scale).to(dtype=dtype).contiguous() - for name, t in obj["passthrough"].items(): - # Restore small tensors, undoing the temporary fp16 storage cast if needed. - out_t = t.detach().to("cpu").contiguous() - orig_dtype = passthrough_orig_dtypes.get(name) - if isinstance(orig_dtype, str): - out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() - out[name] = out_t - return out - - -# ----------------------------- -# DATA LOADING -# ----------------------------- - -def load_data_shard(file: Path) -> Tensor: - header_bytes = 256 * np.dtype(" None: - self.file_idx = (self.file_idx + 1) % len(self.files) - self.tokens = load_data_shard(self.files[self.file_idx]) - self.pos = 0 - - def take(self, n: int) -> Tensor: - chunks: list[Tensor] = [] - remaining = n - while remaining > 0: - avail = self.tokens.numel() - self.pos - if avail <= 0: - self._advance_file() - continue - k = min(remaining, avail) - chunks.append(self.tokens[self.pos : self.pos + k]) - self.pos += k - remaining -= k - return chunks[0] if len(chunks) == 1 else torch.cat(chunks) - - -class DistributedTokenLoader: - # Each call consumes a contiguous chunk from the shared token stream, then slices out - # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. - def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): - self.rank = rank - self.world_size = world_size - self.device = device - self.stream = TokenStream(pattern) - - def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: - local_tokens = global_tokens // (self.world_size * grad_accum_steps) - per_rank_span = local_tokens + 1 - chunk = self.stream.take(per_rank_span * self.world_size) - start = self.rank * per_rank_span - local = chunk[start : start + per_rank_span].to(dtype=torch.int64) - x = local[:-1].reshape(-1, seq_len) - y = local[1:].reshape(-1, seq_len) - return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) - -# ----------------------------- -# TRANSFORMER MODULES -# ----------------------------- - -class RMSNorm(nn.Module): - def __init__(self, eps: float | None = None): - super().__init__() - self.eps = eps - - def forward(self, x: Tensor) -> Tensor: - return F.rms_norm(x, (x.size(-1),), eps=self.eps) - - -def _fake_quantize_int6(w: Tensor) -> Tensor: - """STE fake int6 quantization: forward uses quantized values, backward passes through.""" - # Use amax instead of quantile — compile-friendly and O(n) per row - clip_abs = w.abs().amax(dim=1).clamp_min(1.0 / 31.0) - scale = clip_abs / 31.0 - q = (w / scale[:, None]).round().clamp(-31, 31) - dequant = q * scale[:, None] - return w + (dequant - w).detach() # STE: forward=quantized, backward=identity - -# Global flag to enable/disable QAT (disabled during eval) -_QAT_ENABLED = False - -class CastedLinear(nn.Linear): - # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. - def forward(self, x: Tensor) -> Tensor: - w = self.weight.to(x.dtype) - if _QAT_ENABLED and self.weight.ndim == 2: - w = _fake_quantize_int6(w) - bias = self.bias.to(x.dtype) if self.bias is not None else None - return F.linear(x, w, bias) - - -def restore_low_dim_params_to_fp32(module: nn.Module) -> None: - # Keep small/control parameters in fp32 even when the model body runs in bf16. - with torch.no_grad(): - for name, param in module.named_parameters(): - if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: - param.data = param.data.float() - - -class Rotary(nn.Module): - # Caches cos/sin tables per sequence length on the current device. - def __init__(self, dim: int, base: float = 10000.0): - super().__init__() - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self._seq_len_cached = 0 - self._cos_cached: Tensor | None = None - self._sin_cached: Tensor | None = None - - def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: - if ( - self._cos_cached is None - or self._sin_cached is None - or self._seq_len_cached != seq_len - or self._cos_cached.device != device - ): - t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) - freqs = torch.outer(t, self.inv_freq.to(device)) - self._cos_cached = freqs.cos()[None, None, :, :] - self._sin_cached = freqs.sin()[None, None, :, :] - self._seq_len_cached = seq_len - return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) - - -def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: - half = x.size(-1) // 2 - x1, x2 = x[..., :half], x[..., half:] - return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) - - -class CausalSelfAttention(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - rope_base: float, - qk_gain_init: float, - ): - super().__init__() - if dim % num_heads != 0: - raise ValueError("model_dim must be divisible by num_heads") - if num_heads % num_kv_heads != 0: - raise ValueError("num_heads must be divisible by num_kv_heads") - self.num_heads = num_heads - self.num_kv_heads = num_kv_heads - self.head_dim = dim // num_heads - if self.head_dim % 2 != 0: - raise ValueError("head_dim must be even for RoPE") - kv_dim = self.num_kv_heads * self.head_dim - self.c_q = CastedLinear(dim, dim, bias=False) - self.c_k = CastedLinear(dim, kv_dim, bias=False) - self.c_v = CastedLinear(dim, kv_dim, bias=False) - self.proj = CastedLinear(dim, dim, bias=False) - self.proj._zero_init = True - self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) - self.rotary = Rotary(self.head_dim, base=rope_base) - - def forward(self, x: Tensor) -> Tensor: - bsz, seqlen, dim = x.shape - q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) - k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) - v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) - q = F.rms_norm(q, (q.size(-1),)) - k = F.rms_norm(k, (k.size(-1),)) - cos, sin = self.rotary(seqlen, x.device, q.dtype) - q = apply_rotary_emb(q, cos, sin) - k = apply_rotary_emb(k, cos, sin) - q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] - y = F.scaled_dot_product_attention( - q, - k, - v, - attn_mask=None, - is_causal=True, - enable_gqa=(self.num_kv_heads != self.num_heads), - ) - y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) - return self.proj(y) - - -class MLP(nn.Module): - # relu^2 MLP from the original modded-nanogpt setup - def __init__(self, dim: int, hidden: int): - super().__init__() - self.fc = CastedLinear(dim, hidden, bias=False) - self.proj = CastedLinear(hidden, dim, bias=False) - self.proj._zero_init = True - - def forward(self, x: Tensor) -> Tensor: - x = torch.relu(self.fc(x)) - return self.proj(x.square()) - - -class Block(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - mlp_hidden: int, - rope_base: float, - qk_gain_init: float, - ): - super().__init__() - self.attn_norm = RMSNorm() - self.mlp_norm = RMSNorm() - self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) - self.mlp = MLP(dim, mlp_hidden) - self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) - - def forward(self, x: Tensor, x0: Tensor) -> Tensor: - mix = self.resid_mix.to(dtype=x.dtype) - x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 - attn_out = self.attn(self.attn_norm(x)) - x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out - x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) - return x - - -class GPT(nn.Module): - def __init__( - self, - vocab_size: int, - num_layers: int, - model_dim: int, - num_heads: int, - num_kv_heads: int, - mlp_hidden: int, - tie_embeddings: bool, - tied_embed_init_std: float, - logit_softcap: float, - rope_base: float, - qk_gain_init: float, - ): - super().__init__() - if logit_softcap <= 0.0: - raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") - self.tie_embeddings = tie_embeddings - self.tied_embed_init_std = tied_embed_init_std - self.logit_softcap = logit_softcap - self.tok_emb = nn.Embedding(vocab_size, model_dim) - self.num_encoder_layers = num_layers // 2 - self.num_decoder_layers = num_layers - self.num_encoder_layers - self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) - self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) - self.blocks = nn.ModuleList( - [ - Block( - model_dim, - num_heads, - num_kv_heads, - mlp_hidden, - rope_base, - qk_gain_init, - ) - for i in range(num_layers) - ] - ) - self.final_norm = RMSNorm() - self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) - if self.lm_head is not None: - self.lm_head._zero_init = True - self._init_weights() - - def _init_weights(self) -> None: - if self.tie_embeddings: - nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) - for module in self.modules(): - if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): - nn.init.zeros_(module.weight) - - def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: - x = self.tok_emb(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - x0 = x - skips: list[Tensor] = [] - - # First half stores skips; second half reuses them in reverse order. - for i in range(self.num_encoder_layers): - x = self.blocks[i](x, x0) - skips.append(x) - for i in range(self.num_decoder_layers): - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - x = self.blocks[self.num_encoder_layers + i](x, x0) - - x = self.final_norm(x).reshape(-1, x.size(-1)) - targets = target_ids.reshape(-1) - if self.tie_embeddings: - logits_proj = F.linear(x, self.tok_emb.weight) - else: - if self.lm_head is None: - raise RuntimeError("lm_head is required when tie_embeddings=False") - logits_proj = self.lm_head(x) - logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - return F.cross_entropy(logits.float(), targets, reduction="mean") - - def forward_per_token_loss(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: - """Return per-token losses shaped (batch, seq_len) for sliding window eval.""" - x = self.tok_emb(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - x0 = x - skips: list[Tensor] = [] - for i in range(self.num_encoder_layers): - x = self.blocks[i](x, x0) - skips.append(x) - for i in range(self.num_decoder_layers): - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - x = self.blocks[self.num_encoder_layers + i](x, x0) - x = self.final_norm(x).reshape(-1, x.size(-1)) - targets = target_ids.reshape(-1) - if self.tie_embeddings: - logits_proj = F.linear(x, self.tok_emb.weight) - else: - if self.lm_head is None: - raise RuntimeError("lm_head is required when tie_embeddings=False") - logits_proj = self.lm_head(x) - logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - per_token = F.cross_entropy(logits.float(), targets, reduction="none") - return per_token.view(input_ids.shape) - - -# ----------------------------- -# TRAINING -# ----------------------------- - -def main() -> None: - global zeropower_via_newtonschulz5 - - code = Path(__file__).read_text(encoding="utf-8") - args = Hyperparameters() - zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) - - # ----------------------------- - # DISTRIBUTED + CUDA SETUP - # ----------------------------- - - distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ - rank = int(os.environ.get("RANK", "0")) - world_size = int(os.environ.get("WORLD_SIZE", "1")) - local_rank = int(os.environ.get("LOCAL_RANK", "0")) - if world_size <= 0: - raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") - if 8 % world_size != 0: - raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") - grad_accum_steps = 8 // world_size - grad_scale = 1.0 / grad_accum_steps - if not torch.cuda.is_available(): - raise RuntimeError("CUDA is required") - device = torch.device("cuda", local_rank) - torch.cuda.set_device(device) - if distributed: - dist.init_process_group(backend="nccl", device_id=device) - dist.barrier() - master_process = rank == 0 - - # Fast math knobs - torch.backends.cuda.matmul.allow_tf32 = True - torch.backends.cudnn.allow_tf32 = True - from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp - - enable_cudnn_sdp(False) - enable_flash_sdp(True) - enable_mem_efficient_sdp(False) - enable_math_sdp(False) - - logfile = None - if master_process: - os.makedirs("logs", exist_ok=True) - logfile = f"logs/{args.run_id}.txt" - print(logfile) - - def log0(msg: str, console: bool = True) -> None: - if not master_process: - return - if console: - print(msg) - if logfile is not None: - with open(logfile, "a", encoding="utf-8") as f: - print(msg, file=f) - - log0(code, console=False) - log0("=" * 100, console=False) - log0(f"Running Python {sys.version}", console=False) - log0(f"Running PyTorch {torch.__version__}", console=False) - log0( - subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, - console=False, - ) - log0("=" * 100, console=False) - - # ----------------------------- - # TOKENIZER + VALIDATION METRIC SETUP - # ----------------------------- - - random.seed(args.seed) - np.random.seed(args.seed) - torch.manual_seed(args.seed) - torch.cuda.manual_seed_all(args.seed) - - if not args.tokenizer_path.endswith(".model"): - raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") - sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) - if int(sp.vocab_size()) != args.vocab_size: - raise ValueError( - f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" - ) - dataset_dir = Path(args.data_path).resolve() - actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) - val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) - base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( - sp, args.vocab_size, device - ) - log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") - log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") - log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") - - # ----------------------------- - # MODEL + OPTIMIZER SETUP - # ----------------------------- - - mlp_hidden = args.mlp_hidden if args.mlp_hidden > 0 else args.mlp_mult * args.model_dim - log0(f"mlp_hidden:{mlp_hidden}") - - base_model = GPT( - vocab_size=args.vocab_size, - num_layers=args.num_layers, - model_dim=args.model_dim, - num_heads=args.num_heads, - num_kv_heads=args.num_kv_heads, - mlp_hidden=mlp_hidden, - tie_embeddings=args.tie_embeddings, - tied_embed_init_std=args.tied_embed_init_std, - logit_softcap=args.logit_softcap, - rope_base=args.rope_base, - qk_gain_init=args.qk_gain_init, - ).to(device).bfloat16() - for module in base_model.modules(): - if isinstance(module, CastedLinear): - module.float() - restore_low_dim_params_to_fp32(base_model) - compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) - model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model - - # Optimizer split: - # - token embedding (Adam) uses EMBED_LR - # - untied lm_head (Adam) uses HEAD_LR - # - matrix params in transformer blocks use MATRIX_LR via Muon - # - vectors/scalars use SCALAR_LR via Adam - block_named_params = list(base_model.blocks.named_parameters()) - matrix_params = [ - p - for name, p in block_named_params - if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - scalar_params = [ - p - for name, p in block_named_params - if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - if base_model.skip_weights.numel() > 0: - scalar_params.append(base_model.skip_weights) - token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr - optimizer_tok = torch.optim.Adam( - [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizer_muon = Muon( - matrix_params, - lr=args.matrix_lr, - momentum=args.muon_momentum, - backend_steps=args.muon_backend_steps, - ) - for group in optimizer_muon.param_groups: - group["base_lr"] = args.matrix_lr - optimizer_scalar = torch.optim.Adam( - [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] - if base_model.lm_head is not None: - optimizer_head = torch.optim.Adam( - [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizers.insert(1, optimizer_head) - - n_params = sum(p.numel() for p in base_model.parameters()) - log0(f"model_params:{n_params}") - log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") - log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") - log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") - log0( - f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " - f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " - f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" - ) - log0( - f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " - f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " - f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" - ) - log0(f"seed:{args.seed}") - - # ----------------------------- - # DATA LOADER & MODEL WARMUP - # ----------------------------- - - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - def zero_grad_all() -> None: - for opt in optimizers: - opt.zero_grad(set_to_none=True) - - max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None - - def lr_mul(step: int, elapsed_ms: float) -> float: - if args.warmdown_iters <= 0: - return 1.0 - if max_wallclock_ms is None: - warmdown_start = max(args.iterations - args.warmdown_iters, 0) - return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 - step_ms = elapsed_ms / max(step, 1) - warmdown_ms = args.warmdown_iters * step_ms - remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) - return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 - - # Warmup primes the compiled forward/backward/optimizer paths, then we restore the - # initial weights/optimizer state so measured training starts from the true init. - if args.warmup_steps > 0: - initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} - initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] - model.train() - for warmup_step in range(args.warmup_steps): - zero_grad_all() - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - warmup_loss = model(x, y) - (warmup_loss * grad_scale).backward() - for opt in optimizers: - opt.step() - zero_grad_all() - if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: - log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") - base_model.load_state_dict(initial_model_state, strict=True) - for opt, state in zip(optimizers, initial_optimizer_states, strict=True): - opt.load_state_dict(state) - zero_grad_all() - if distributed: - model.require_backward_grad_sync = True - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - # ----------------------------- - # MAIN TRAINING LOOP - # ----------------------------- - - global _QAT_ENABLED - _QAT_ENABLED = USE_INT6 # Enable QAT during training if using int6 - log0(f"qat_enabled:{_QAT_ENABLED}") - - training_time_ms = 0.0 - stop_after_step: int | None = None - torch.cuda.synchronize() - t0 = time.perf_counter() - - step = 0 - while True: - last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) - - should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) - if should_validate: - torch.cuda.synchronize() - training_time_ms += 1000.0 * (time.perf_counter() - t0) - val_loss, val_bpb = eval_val( - args, - model, - rank, - world_size, - device, - grad_accum_steps, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, - ) - log0( - f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " - f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" - ) - torch.cuda.synchronize() - t0 = time.perf_counter() - - if last_step: - if stop_after_step is not None and step < args.iterations: - log0( - f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " - f"step:{step}/{args.iterations}" - ) - break - - elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - scale = lr_mul(step, elapsed_ms) - zero_grad_all() - train_loss = torch.zeros((), device=device) - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - loss = model(x, y) - train_loss += loss.detach() - (loss * grad_scale).backward() - train_loss /= grad_accum_steps - - frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 - muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum - for group in optimizer_muon.param_groups: - group["momentum"] = muon_momentum - - for opt in optimizers: - for group in opt.param_groups: - group["lr"] = group["base_lr"] * scale - - if args.grad_clip_norm > 0: - torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) - for opt in optimizers: - opt.step() - zero_grad_all() - - step += 1 - approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - should_log_train = ( - args.train_log_every > 0 - and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) - ) - if should_log_train: - log0( - f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " - f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" - ) - - # Needed to sync whether we've reached the wallclock cap. - reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms - if distributed and max_wallclock_ms is not None: - reached_cap_tensor = torch.tensor(int(reached_cap), device=device) - dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) - reached_cap = bool(reached_cap_tensor.item()) - if stop_after_step is None and reached_cap: - stop_after_step = step - - _QAT_ENABLED = False # Disable QAT for eval/serialization - log0( - f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " - f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" - ) - - # ----------------------------- - # SERIALIZATION + ROUNDTRIP VALIDATION - # ----------------------------- - # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce - # the compressed int8+zlib artifact and validate the round-tripped weights. - - if master_process: - torch.save(base_model.state_dict(), "final_model.pt") - model_bytes = os.path.getsize("final_model.pt") - code_bytes = len(code.encode("utf-8")) - log0(f"Serialized model: {model_bytes} bytes") - log0(f"Code size: {code_bytes} bytes") - log0(f"Total submission size: {model_bytes + code_bytes} bytes") - - quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) - quant_buf = io.BytesIO() - torch.save(quant_obj, quant_buf) - quant_raw = quant_buf.getvalue() - if USE_ZSTD: - cctx = zstd.ZstdCompressor(level=ZSTD_LEVEL) - quant_blob = cctx.compress(quant_raw) - else: - quant_blob = zlib.compress(quant_raw, level=9) - quant_raw_bytes = len(quant_raw) - compress_label = f"zstd-{ZSTD_LEVEL}" if USE_ZSTD else "zlib" - if master_process: - with open("final_model.int8.ptz", "wb") as f: - f.write(quant_blob) - quant_file_bytes = os.path.getsize("final_model.int8.ptz") - code_bytes = len(code.encode("utf-8")) - ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) - log0( - f"Serialized model int6+{compress_label}: {quant_file_bytes} bytes " - f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" - ) - log0(f"Total submission size int6+{compress_label}: {quant_file_bytes + code_bytes} bytes") - - if distributed: - dist.barrier() - with open("final_model.int8.ptz", "rb") as f: - quant_blob_disk = f.read() - if USE_ZSTD: - dctx = zstd.ZstdDecompressor() - quant_state = torch.load(io.BytesIO(dctx.decompress(quant_blob_disk)), map_location="cpu") - else: - quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") - base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) - torch.cuda.synchronize() - t_qeval = time.perf_counter() - q_val_loss, q_val_bpb = eval_val( - args, - model, - rank, - world_size, - device, - grad_accum_steps, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, - ) - torch.cuda.synchronize() - log0( - f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " - f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" - ) - log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") - - # Sliding window eval for final score (if stride > 0) - if args.eval_stride > 0: - eval_sl = args.eval_seq_len if args.eval_seq_len > 0 else args.train_seq_len - torch.cuda.synchronize() - t_slide = time.perf_counter() - sw_val_loss, sw_val_bpb = eval_val_sliding( - base_model, - rank, - world_size, - device, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, - eval_seq_len=eval_sl, - stride=args.eval_stride, - ) - torch.cuda.synchronize() - log0( - f"final_sliding_window val_loss:{sw_val_loss:.4f} val_bpb:{sw_val_bpb:.4f} " - f"stride:{args.eval_stride} eval_seq_len:{eval_sl} " - f"eval_time:{1000.0 * (time.perf_counter() - t_slide):.0f}ms" - ) - log0(f"final_sliding_window_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") - - if distributed: - dist.destroy_process_group() - - -if __name__ == "__main__": - main() diff --git a/records/track_10min_16mb/2026-03-19_Seq2048_FP16Emb_TunedLR/train_seed1337.log b/records/track_10min_16mb/2026-03-19_Seq2048_FP16Emb_TunedLR/train_seed1337.log deleted file mode 100644 index a35f771669..0000000000 --- a/records/track_10min_16mb/2026-03-19_Seq2048_FP16Emb_TunedLR/train_seed1337.log +++ /dev/null @@ -1,1515 +0,0 @@ -""" -The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. - -Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. -""" - -from __future__ import annotations - -import copy -import glob -import io -import math -import os -import random -import subprocess -import sys -import time -import uuid -import zlib -import zstandard as zstd -from pathlib import Path - -import numpy as np -import sentencepiece as spm -import torch -import torch.distributed as dist -import torch.nn.functional as F -from torch import Tensor, nn -from torch.nn.parallel import DistributedDataParallel as DDP - -# ----------------------------- -# HYPERPARAMETERS -# ----------------------------- -# Default Simple Baseline run: -# - 9 transformer blocks at width 512 -# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion -# - vocab size 1024, sequence length 1024, tied embeddings -# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap - -class Hyperparameters: - # Data paths are shard globs produced by the existing preprocessing pipeline. - data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") - train_files = os.path.join(data_path, "fineweb_train_*.bin") - val_files = os.path.join(data_path, "fineweb_val_*.bin") - tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") - run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) - seed = int(os.environ.get("SEED", 1337)) - - # Validation cadence and batch size. Validation always uses the full fineweb_val split. - val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) - val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) - train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) - # Sliding window eval: stride=0 means standard (non-overlapping) eval. - eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) - eval_seq_len = int(os.environ.get("EVAL_SEQ_LEN", 0)) # 0 = use train_seq_len - - # Training length. - iterations = int(os.environ.get("ITERATIONS", 20000)) - warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3600)) - warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) - train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) - train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) - max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) - qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) - - # Model shape. - vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) - num_layers = int(os.environ.get("NUM_LAYERS", 10)) - num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) - model_dim = int(os.environ.get("MODEL_DIM", 512)) - num_heads = int(os.environ.get("NUM_HEADS", 8)) - mlp_mult = int(os.environ.get("MLP_MULT", 2)) - mlp_hidden = int(os.environ.get("MLP_HIDDEN", 1536)) # 3x model_dim, enabled by int6+zstd compression - tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) - rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) - logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) - - # Optimizer hyperparameters. - embed_lr = float(os.environ.get("EMBED_LR", 0.6)) - head_lr = float(os.environ.get("HEAD_LR", 0.008)) - tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.04)) - tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) - matrix_lr = float(os.environ.get("MATRIX_LR", 0.02)) - scalar_lr = float(os.environ.get("SCALAR_LR", 0.02)) - muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) - muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) - muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) - muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) - beta1 = float(os.environ.get("BETA1", 0.9)) - beta2 = float(os.environ.get("BETA2", 0.95)) - adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) - grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) - -# ----------------------------- -# MUON OPTIMIZER -# ----------------------------- -# -# As borrowed from modded-nanogpt -# Background on Muon: https://kellerjordan.github.io/posts/muon/ - -def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: - # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. - # Muon uses this to normalize matrix-shaped gradients before applying them. - a, b, c = (3.4445, -4.7750, 2.0315) - X = G.bfloat16() - X /= X.norm() + eps - transposed = G.size(0) > G.size(1) - if transposed: - X = X.T - for _ in range(steps): - A = X @ X.T - B = b * A + c * A @ A - X = a * X + B @ X - return X.T if transposed else X - - -class Muon(torch.optim.Optimizer): - def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): - super().__init__( - params, - dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), - ) - - @torch.no_grad() - def step(self, closure=None): - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - distributed = dist.is_available() and dist.is_initialized() - world_size = dist.get_world_size() if distributed else 1 - rank = dist.get_rank() if distributed else 0 - - for group in self.param_groups: - params = group["params"] - if not params: - continue - lr = group["lr"] - momentum = group["momentum"] - backend_steps = group["backend_steps"] - nesterov = group["nesterov"] - - total_params = sum(int(p.numel()) for p in params) - updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) - - curr = 0 - for i, p in enumerate(params): - if i % world_size == rank and p.grad is not None: - g = p.grad - state = self.state[p] - if "momentum_buffer" not in state: - state["momentum_buffer"] = torch.zeros_like(g) - buf = state["momentum_buffer"] - buf.mul_(momentum).add_(g) - if nesterov: - g = g.add(buf, alpha=momentum) - g = zeropower_via_newtonschulz5(g, steps=backend_steps) - # Scale correction from Muon reference implementations. - g *= max(1, g.size(0) / g.size(1)) ** 0.5 - updates_flat[curr : curr + p.numel()] = g.reshape(-1) - curr += p.numel() - - if distributed: - dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) - - curr = 0 - for p in params: - g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) - p.add_(g, alpha=-lr) - curr += p.numel() - - return loss - - -# ----------------------------- -# TOKENIZER-AGNOSTIC EVALUATION SETUP -# ----------------------------- -# -# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. -# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. -# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. -# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. - -def build_sentencepiece_luts( - sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device -) -> tuple[Tensor, Tensor, Tensor]: - sp_vocab_size = int(sp.vocab_size()) - table_size = max(sp_vocab_size, vocab_size) - base_bytes_np = np.zeros((table_size,), dtype=np.int16) - has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) - is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) - for token_id in range(sp_vocab_size): - if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): - continue - is_boundary_token_np[token_id] = False - if sp.is_byte(token_id): - base_bytes_np[token_id] = 1 - continue - piece = sp.id_to_piece(token_id) - if piece.startswith("▁"): - has_leading_space_np[token_id] = True - piece = piece[1:] - base_bytes_np[token_id] = len(piece.encode("utf-8")) - return ( - torch.tensor(base_bytes_np, dtype=torch.int16, device=device), - torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), - torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), - ) - - -def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: - files = [Path(p) for p in sorted(glob.glob(pattern))] - if not files: - raise FileNotFoundError(f"No files found for pattern: {pattern}") - # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. - tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() - usable = ((tokens.numel() - 1) // seq_len) * seq_len - if usable <= 0: - raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") - return tokens[: usable + 1] - - -def eval_val( - args: Hyperparameters, - model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - grad_accum_steps: int, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, -) -> tuple[float, float]: - # Validation computes two metrics: - # - val_loss: token cross-entropy (natural log) - # - val_bpb: tokenizer-agnostic compression metric used by the challenge - local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) - if local_batch_tokens < args.train_seq_len: - raise ValueError( - "VAL_BATCH_SIZE must provide at least one sequence per rank; " - f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " - f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" - ) - local_batch_seqs = local_batch_tokens // args.train_seq_len - total_seqs = (val_tokens.numel() - 1) // args.train_seq_len - seq_start = (total_seqs * rank) // world_size - seq_end = (total_seqs * (rank + 1)) // world_size - val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) - val_token_count = torch.zeros((), device=device, dtype=torch.float64) - val_byte_count = torch.zeros((), device=device, dtype=torch.float64) - - model.eval() - with torch.inference_mode(): - for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): - batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) - raw_start = batch_seq_start * args.train_seq_len - raw_end = batch_seq_end * args.train_seq_len + 1 - local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) - x = local[:-1].reshape(-1, args.train_seq_len) - y = local[1:].reshape(-1, args.train_seq_len) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - batch_loss = model(x, y).detach() - batch_token_count = float(y.numel()) - val_loss_sum += batch_loss.to(torch.float64) * batch_token_count - val_token_count += batch_token_count - prev_ids = x.reshape(-1) - tgt_ids = y.reshape(-1) - token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) - token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) - val_byte_count += token_bytes.to(torch.float64).sum() - - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) - - val_loss = val_loss_sum / val_token_count - bits_per_token = val_loss.item() / math.log(2.0) - tokens_per_byte = val_token_count.item() / val_byte_count.item() - model.train() - return float(val_loss.item()), float(bits_per_token * tokens_per_byte) - - -def eval_val_sliding( - base_model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, - eval_seq_len: int = 2048, - stride: int = 64, - batch_size: int = 4, -) -> tuple[float, float]: - """Sliding window eval: each token scored with max context.""" - N = val_tokens.numel() - # Distribute windows across ranks - window_starts: list[int] = [] - pos = 0 - while pos + eval_seq_len < N: - window_starts.append(pos) - if pos == 0: - pos = eval_seq_len - stride # after first full window - else: - pos += stride - - # Partition windows across ranks - per_rank = len(window_starts) // max(world_size, 1) - rank_start = rank * per_rank - rank_end = (rank + 1) * per_rank if rank < world_size - 1 else len(window_starts) - my_windows = window_starts[rank_start:rank_end] - - val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) - val_token_count = torch.zeros((), device=device, dtype=torch.float64) - val_byte_count = torch.zeros((), device=device, dtype=torch.float64) - - base_model.eval() - with torch.inference_mode(): - for batch_start in range(0, len(my_windows), batch_size): - batch_windows = my_windows[batch_start:batch_start + batch_size] - xs, ys = [], [] - for w in batch_windows: - xs.append(val_tokens[w:w + eval_seq_len]) - ys.append(val_tokens[w + 1:w + eval_seq_len + 1]) - x = torch.stack(xs).to(device=device, dtype=torch.int64) - y = torch.stack(ys).to(device=device, dtype=torch.int64) - - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - per_token_loss = base_model.forward_per_token_loss(x, y) - - for i, w in enumerate(batch_windows): - if w == 0: - # First window: score all positions - score_start = 0 - else: - # Subsequent windows: score only last `stride` positions - score_start = eval_seq_len - stride - - losses = per_token_loss[i, score_start:] - tgt_ids = y[i, score_start:] - prev_ids = x[i, score_start:] - n_scored = losses.numel() - - val_loss_sum += losses.to(torch.float64).sum() - val_token_count += n_scored - token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) - token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) - val_byte_count += token_bytes.to(torch.float64).sum() - - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) - - val_loss = val_loss_sum / val_token_count - bits_per_token = val_loss.item() / math.log(2.0) - tokens_per_byte = val_token_count.item() / val_byte_count.item() - base_model.train() - return float(val_loss.item()), float(bits_per_token * tokens_per_byte) - - -# ----------------------------- -# POST-TRAINING QUANTIZATION -# ----------------------------- -# -# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. -# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. -# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. - -CONTROL_TENSOR_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "CONTROL_TENSOR_NAME_PATTERNS", - "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", - ",".join(CONTROL_TENSOR_NAME_PATTERNS), - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 -# Large tensors to keep as fp16 passthrough instead of int8 quantization. -# Tied embeddings serve dual duty (input + output head) and are disproportionately -# sensitive to quantization noise (~0.007 BPB gap). -FP16_PASSTHROUGH_PATTERNS = tuple( - p for p in os.environ.get("FP16_PASSTHROUGH_PATTERNS", "tok_emb").split(",") if p -) -INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 -INT8_PER_ROW_SCALE_DTYPE = torch.float16 -INT8_CLIP_PERCENTILE = 99.99984 -INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 -# Full int6 quantization: all 2D block weights use [-31,31] range (63 levels) -# stored in int8 container. zstd-22 compresses the low-entropy data much better. -USE_INT6 = bool(int(os.environ.get("USE_INT6", "1"))) -USE_ZSTD = bool(int(os.environ.get("USE_ZSTD", "1"))) -ZSTD_LEVEL = int(os.environ.get("ZSTD_LEVEL", 22)) - -def tensor_nbytes(t: Tensor) -> int: - return int(t.numel()) * int(t.element_size()) - -def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: - if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): - return t.float().contiguous() - if t.dtype in {torch.float32, torch.bfloat16}: - passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") - return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() - return t - -def quantize_float_tensor(t: Tensor, use_int6: bool = False) -> tuple[Tensor, Tensor]: - t32 = t.float() - qmax = 31 if use_int6 else 127 - if t32.ndim == 2: - # Matrices get one scale per row, which usually tracks output-channel - # ranges much better than a single tensor-wide scale. - clip_abs = ( - torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) - if t32.numel() - else torch.empty((t32.shape[0],), dtype=torch.float32) - ) - clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) - scale = (clip_abs / float(qmax)).clamp_min(1.0 / float(qmax)) - q = torch.clamp(torch.round(clipped / scale[:, None]), -qmax, qmax).to(torch.int8).contiguous() - return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() - - # Vectors / scalars use a simpler per-tensor scale. - clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 - scale = torch.tensor(clip_abs / float(qmax) if clip_abs > 0 else 1.0, dtype=torch.float32) - q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -qmax, qmax).to(torch.int8).contiguous() - return q, scale - -def quantize_state_dict_int8(state_dict: dict[str, Tensor]): - # Export format: - # - per-row int6/int8 for 2D float tensors (int6 if USE_INT6) - # - per-tensor int8 for other float tensors - # - exact passthrough for non-floats - # - passthrough for small float tensors, stored as fp16 to save bytes - quantized: dict[str, Tensor] = {} - scales: dict[str, Tensor] = {} - dtypes: dict[str, str] = {} - passthrough: dict[str, Tensor] = {} - passthrough_orig_dtypes: dict[str, str] = {} - qmeta: dict[str, dict[str, object]] = {} - stats = dict.fromkeys( - ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), - 0, - ) - - for name, tensor in state_dict.items(): - t = tensor.detach().to("cpu").contiguous() - stats["param_count"] += int(t.numel()) - stats["num_tensors"] += 1 - stats["baseline_tensor_bytes"] += tensor_nbytes(t) - - if not t.is_floating_point(): - stats["num_nonfloat_tensors"] += 1 - passthrough[name] = t - stats["int8_payload_bytes"] += tensor_nbytes(t) - continue - - # Small float tensors are cheap enough to keep directly. We still downcast - # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. - if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: - kept = keep_float_tensor(name, t, passthrough_orig_dtypes) - passthrough[name] = kept - stats["int8_payload_bytes"] += tensor_nbytes(kept) - continue - - # Large tensors matching FP16_PASSTHROUGH_PATTERNS are kept as fp16 - # instead of int8 (e.g. tied embeddings that are quant-sensitive). - if any(pattern in name for pattern in FP16_PASSTHROUGH_PATTERNS): - kept = keep_float_tensor(name, t, passthrough_orig_dtypes) - passthrough[name] = kept - stats["int8_payload_bytes"] += tensor_nbytes(kept) - continue - - stats["num_float_tensors"] += 1 - q, s = quantize_float_tensor(t, use_int6=USE_INT6) - if s.ndim > 0: - qmeta[name] = {"scheme": "per_row", "axis": 0} - quantized[name] = q - scales[name] = s - dtypes[name] = str(t.dtype).removeprefix("torch.") - stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) - - obj: dict[str, object] = { - "__quant_format__": "int8_clean_per_row_v1", - "quantized": quantized, - "scales": scales, - "dtypes": dtypes, - "passthrough": passthrough, - } - if qmeta: - obj["qmeta"] = qmeta - if passthrough_orig_dtypes: - obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes - return obj, stats - -def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: - out: dict[str, Tensor] = {} - qmeta = obj.get("qmeta", {}) - passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) - for name, q in obj["quantized"].items(): - dtype = getattr(torch, obj["dtypes"][name]) - s = obj["scales"][name] - if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: - s = s.to(dtype=torch.float32) - # Broadcast the saved row scale back across trailing dimensions. - out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() - else: - scale = float(s.item()) - out[name] = (q.float() * scale).to(dtype=dtype).contiguous() - for name, t in obj["passthrough"].items(): - # Restore small tensors, undoing the temporary fp16 storage cast if needed. - out_t = t.detach().to("cpu").contiguous() - orig_dtype = passthrough_orig_dtypes.get(name) - if isinstance(orig_dtype, str): - out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() - out[name] = out_t - return out - - -# ----------------------------- -# DATA LOADING -# ----------------------------- - -def load_data_shard(file: Path) -> Tensor: - header_bytes = 256 * np.dtype(" None: - self.file_idx = (self.file_idx + 1) % len(self.files) - self.tokens = load_data_shard(self.files[self.file_idx]) - self.pos = 0 - - def take(self, n: int) -> Tensor: - chunks: list[Tensor] = [] - remaining = n - while remaining > 0: - avail = self.tokens.numel() - self.pos - if avail <= 0: - self._advance_file() - continue - k = min(remaining, avail) - chunks.append(self.tokens[self.pos : self.pos + k]) - self.pos += k - remaining -= k - return chunks[0] if len(chunks) == 1 else torch.cat(chunks) - - -class DistributedTokenLoader: - # Each call consumes a contiguous chunk from the shared token stream, then slices out - # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. - def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): - self.rank = rank - self.world_size = world_size - self.device = device - self.stream = TokenStream(pattern) - - def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: - local_tokens = global_tokens // (self.world_size * grad_accum_steps) - per_rank_span = local_tokens + 1 - chunk = self.stream.take(per_rank_span * self.world_size) - start = self.rank * per_rank_span - local = chunk[start : start + per_rank_span].to(dtype=torch.int64) - x = local[:-1].reshape(-1, seq_len) - y = local[1:].reshape(-1, seq_len) - return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) - -# ----------------------------- -# TRANSFORMER MODULES -# ----------------------------- - -class RMSNorm(nn.Module): - def __init__(self, eps: float | None = None): - super().__init__() - self.eps = eps - - def forward(self, x: Tensor) -> Tensor: - return F.rms_norm(x, (x.size(-1),), eps=self.eps) - - -def _fake_quantize_int6(w: Tensor) -> Tensor: - """STE fake int6 quantization: forward uses quantized values, backward passes through.""" - # Use amax instead of quantile — compile-friendly and O(n) per row - clip_abs = w.abs().amax(dim=1).clamp_min(1.0 / 31.0) - scale = clip_abs / 31.0 - q = (w / scale[:, None]).round().clamp(-31, 31) - dequant = q * scale[:, None] - return w + (dequant - w).detach() # STE: forward=quantized, backward=identity - -# Global flag to enable/disable QAT (disabled during eval) -_QAT_ENABLED = False - -class CastedLinear(nn.Linear): - # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. - def forward(self, x: Tensor) -> Tensor: - w = self.weight.to(x.dtype) - if _QAT_ENABLED and self.weight.ndim == 2: - w = _fake_quantize_int6(w) - bias = self.bias.to(x.dtype) if self.bias is not None else None - return F.linear(x, w, bias) - - -def restore_low_dim_params_to_fp32(module: nn.Module) -> None: - # Keep small/control parameters in fp32 even when the model body runs in bf16. - with torch.no_grad(): - for name, param in module.named_parameters(): - if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: - param.data = param.data.float() - - -class Rotary(nn.Module): - # Caches cos/sin tables per sequence length on the current device. - def __init__(self, dim: int, base: float = 10000.0): - super().__init__() - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self._seq_len_cached = 0 - self._cos_cached: Tensor | None = None - self._sin_cached: Tensor | None = None - - def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: - if ( - self._cos_cached is None - or self._sin_cached is None - or self._seq_len_cached != seq_len - or self._cos_cached.device != device - ): - t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) - freqs = torch.outer(t, self.inv_freq.to(device)) - self._cos_cached = freqs.cos()[None, None, :, :] - self._sin_cached = freqs.sin()[None, None, :, :] - self._seq_len_cached = seq_len - return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) - - -def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: - half = x.size(-1) // 2 - x1, x2 = x[..., :half], x[..., half:] - return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) - - -class CausalSelfAttention(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - rope_base: float, - qk_gain_init: float, - ): - super().__init__() - if dim % num_heads != 0: - raise ValueError("model_dim must be divisible by num_heads") - if num_heads % num_kv_heads != 0: - raise ValueError("num_heads must be divisible by num_kv_heads") - self.num_heads = num_heads - self.num_kv_heads = num_kv_heads - self.head_dim = dim // num_heads - if self.head_dim % 2 != 0: - raise ValueError("head_dim must be even for RoPE") - kv_dim = self.num_kv_heads * self.head_dim - self.c_q = CastedLinear(dim, dim, bias=False) - self.c_k = CastedLinear(dim, kv_dim, bias=False) - self.c_v = CastedLinear(dim, kv_dim, bias=False) - self.proj = CastedLinear(dim, dim, bias=False) - self.proj._zero_init = True - self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) - self.rotary = Rotary(self.head_dim, base=rope_base) - - def forward(self, x: Tensor) -> Tensor: - bsz, seqlen, dim = x.shape - q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) - k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) - v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) - q = F.rms_norm(q, (q.size(-1),)) - k = F.rms_norm(k, (k.size(-1),)) - cos, sin = self.rotary(seqlen, x.device, q.dtype) - q = apply_rotary_emb(q, cos, sin) - k = apply_rotary_emb(k, cos, sin) - q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] - y = F.scaled_dot_product_attention( - q, - k, - v, - attn_mask=None, - is_causal=True, - enable_gqa=(self.num_kv_heads != self.num_heads), - ) - y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) - return self.proj(y) - - -class MLP(nn.Module): - # relu^2 MLP from the original modded-nanogpt setup - def __init__(self, dim: int, hidden: int): - super().__init__() - self.fc = CastedLinear(dim, hidden, bias=False) - self.proj = CastedLinear(hidden, dim, bias=False) - self.proj._zero_init = True - - def forward(self, x: Tensor) -> Tensor: - x = torch.relu(self.fc(x)) - return self.proj(x.square()) - - -class Block(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - mlp_hidden: int, - rope_base: float, - qk_gain_init: float, - ): - super().__init__() - self.attn_norm = RMSNorm() - self.mlp_norm = RMSNorm() - self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) - self.mlp = MLP(dim, mlp_hidden) - self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) - - def forward(self, x: Tensor, x0: Tensor) -> Tensor: - mix = self.resid_mix.to(dtype=x.dtype) - x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 - attn_out = self.attn(self.attn_norm(x)) - x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out - x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) - return x - - -class GPT(nn.Module): - def __init__( - self, - vocab_size: int, - num_layers: int, - model_dim: int, - num_heads: int, - num_kv_heads: int, - mlp_hidden: int, - tie_embeddings: bool, - tied_embed_init_std: float, - logit_softcap: float, - rope_base: float, - qk_gain_init: float, - ): - super().__init__() - if logit_softcap <= 0.0: - raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") - self.tie_embeddings = tie_embeddings - self.tied_embed_init_std = tied_embed_init_std - self.logit_softcap = logit_softcap - self.tok_emb = nn.Embedding(vocab_size, model_dim) - self.num_encoder_layers = num_layers // 2 - self.num_decoder_layers = num_layers - self.num_encoder_layers - self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) - self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) - self.blocks = nn.ModuleList( - [ - Block( - model_dim, - num_heads, - num_kv_heads, - mlp_hidden, - rope_base, - qk_gain_init, - ) - for i in range(num_layers) - ] - ) - self.final_norm = RMSNorm() - self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) - if self.lm_head is not None: - self.lm_head._zero_init = True - self._init_weights() - - def _init_weights(self) -> None: - if self.tie_embeddings: - nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) - for module in self.modules(): - if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): - nn.init.zeros_(module.weight) - - def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: - x = self.tok_emb(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - x0 = x - skips: list[Tensor] = [] - - # First half stores skips; second half reuses them in reverse order. - for i in range(self.num_encoder_layers): - x = self.blocks[i](x, x0) - skips.append(x) - for i in range(self.num_decoder_layers): - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - x = self.blocks[self.num_encoder_layers + i](x, x0) - - x = self.final_norm(x).reshape(-1, x.size(-1)) - targets = target_ids.reshape(-1) - if self.tie_embeddings: - logits_proj = F.linear(x, self.tok_emb.weight) - else: - if self.lm_head is None: - raise RuntimeError("lm_head is required when tie_embeddings=False") - logits_proj = self.lm_head(x) - logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - return F.cross_entropy(logits.float(), targets, reduction="mean") - - def forward_per_token_loss(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: - """Return per-token losses shaped (batch, seq_len) for sliding window eval.""" - x = self.tok_emb(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - x0 = x - skips: list[Tensor] = [] - for i in range(self.num_encoder_layers): - x = self.blocks[i](x, x0) - skips.append(x) - for i in range(self.num_decoder_layers): - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - x = self.blocks[self.num_encoder_layers + i](x, x0) - x = self.final_norm(x).reshape(-1, x.size(-1)) - targets = target_ids.reshape(-1) - if self.tie_embeddings: - logits_proj = F.linear(x, self.tok_emb.weight) - else: - if self.lm_head is None: - raise RuntimeError("lm_head is required when tie_embeddings=False") - logits_proj = self.lm_head(x) - logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - per_token = F.cross_entropy(logits.float(), targets, reduction="none") - return per_token.view(input_ids.shape) - - -# ----------------------------- -# TRAINING -# ----------------------------- - -def main() -> None: - global zeropower_via_newtonschulz5 - - code = Path(__file__).read_text(encoding="utf-8") - args = Hyperparameters() - zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) - - # ----------------------------- - # DISTRIBUTED + CUDA SETUP - # ----------------------------- - - distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ - rank = int(os.environ.get("RANK", "0")) - world_size = int(os.environ.get("WORLD_SIZE", "1")) - local_rank = int(os.environ.get("LOCAL_RANK", "0")) - if world_size <= 0: - raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") - if 8 % world_size != 0: - raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") - grad_accum_steps = 8 // world_size - grad_scale = 1.0 / grad_accum_steps - if not torch.cuda.is_available(): - raise RuntimeError("CUDA is required") - device = torch.device("cuda", local_rank) - torch.cuda.set_device(device) - if distributed: - dist.init_process_group(backend="nccl", device_id=device) - dist.barrier() - master_process = rank == 0 - - # Fast math knobs - torch.backends.cuda.matmul.allow_tf32 = True - torch.backends.cudnn.allow_tf32 = True - from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp - - enable_cudnn_sdp(False) - enable_flash_sdp(True) - enable_mem_efficient_sdp(False) - enable_math_sdp(False) - - logfile = None - if master_process: - os.makedirs("logs", exist_ok=True) - logfile = f"logs/{args.run_id}.txt" - print(logfile) - - def log0(msg: str, console: bool = True) -> None: - if not master_process: - return - if console: - print(msg) - if logfile is not None: - with open(logfile, "a", encoding="utf-8") as f: - print(msg, file=f) - - log0(code, console=False) - log0("=" * 100, console=False) - log0(f"Running Python {sys.version}", console=False) - log0(f"Running PyTorch {torch.__version__}", console=False) - log0( - subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, - console=False, - ) - log0("=" * 100, console=False) - - # ----------------------------- - # TOKENIZER + VALIDATION METRIC SETUP - # ----------------------------- - - random.seed(args.seed) - np.random.seed(args.seed) - torch.manual_seed(args.seed) - torch.cuda.manual_seed_all(args.seed) - - if not args.tokenizer_path.endswith(".model"): - raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") - sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) - if int(sp.vocab_size()) != args.vocab_size: - raise ValueError( - f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" - ) - dataset_dir = Path(args.data_path).resolve() - actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) - val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) - base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( - sp, args.vocab_size, device - ) - log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") - log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") - log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") - - # ----------------------------- - # MODEL + OPTIMIZER SETUP - # ----------------------------- - - mlp_hidden = args.mlp_hidden if args.mlp_hidden > 0 else args.mlp_mult * args.model_dim - log0(f"mlp_hidden:{mlp_hidden}") - - base_model = GPT( - vocab_size=args.vocab_size, - num_layers=args.num_layers, - model_dim=args.model_dim, - num_heads=args.num_heads, - num_kv_heads=args.num_kv_heads, - mlp_hidden=mlp_hidden, - tie_embeddings=args.tie_embeddings, - tied_embed_init_std=args.tied_embed_init_std, - logit_softcap=args.logit_softcap, - rope_base=args.rope_base, - qk_gain_init=args.qk_gain_init, - ).to(device).bfloat16() - for module in base_model.modules(): - if isinstance(module, CastedLinear): - module.float() - restore_low_dim_params_to_fp32(base_model) - compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) - model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model - - # Optimizer split: - # - token embedding (Adam) uses EMBED_LR - # - untied lm_head (Adam) uses HEAD_LR - # - matrix params in transformer blocks use MATRIX_LR via Muon - # - vectors/scalars use SCALAR_LR via Adam - block_named_params = list(base_model.blocks.named_parameters()) - matrix_params = [ - p - for name, p in block_named_params - if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - scalar_params = [ - p - for name, p in block_named_params - if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - if base_model.skip_weights.numel() > 0: - scalar_params.append(base_model.skip_weights) - token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr - optimizer_tok = torch.optim.Adam( - [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizer_muon = Muon( - matrix_params, - lr=args.matrix_lr, - momentum=args.muon_momentum, - backend_steps=args.muon_backend_steps, - ) - for group in optimizer_muon.param_groups: - group["base_lr"] = args.matrix_lr - optimizer_scalar = torch.optim.Adam( - [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] - if base_model.lm_head is not None: - optimizer_head = torch.optim.Adam( - [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizers.insert(1, optimizer_head) - - n_params = sum(p.numel() for p in base_model.parameters()) - log0(f"model_params:{n_params}") - log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") - log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") - log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") - log0( - f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " - f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " - f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" - ) - log0( - f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " - f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " - f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" - ) - log0(f"seed:{args.seed}") - - # ----------------------------- - # DATA LOADER & MODEL WARMUP - # ----------------------------- - - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - def zero_grad_all() -> None: - for opt in optimizers: - opt.zero_grad(set_to_none=True) - - max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None - - def lr_mul(step: int, elapsed_ms: float) -> float: - if args.warmdown_iters <= 0: - return 1.0 - if max_wallclock_ms is None: - warmdown_start = max(args.iterations - args.warmdown_iters, 0) - return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 - step_ms = elapsed_ms / max(step, 1) - warmdown_ms = args.warmdown_iters * step_ms - remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) - return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 - - # Warmup primes the compiled forward/backward/optimizer paths, then we restore the - # initial weights/optimizer state so measured training starts from the true init. - if args.warmup_steps > 0: - initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} - initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] - model.train() - for warmup_step in range(args.warmup_steps): - zero_grad_all() - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - warmup_loss = model(x, y) - (warmup_loss * grad_scale).backward() - for opt in optimizers: - opt.step() - zero_grad_all() - if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: - log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") - base_model.load_state_dict(initial_model_state, strict=True) - for opt, state in zip(optimizers, initial_optimizer_states, strict=True): - opt.load_state_dict(state) - zero_grad_all() - if distributed: - model.require_backward_grad_sync = True - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - # ----------------------------- - # MAIN TRAINING LOOP - # ----------------------------- - - global _QAT_ENABLED - _QAT_ENABLED = USE_INT6 # Enable QAT during training if using int6 - log0(f"qat_enabled:{_QAT_ENABLED}") - - training_time_ms = 0.0 - stop_after_step: int | None = None - torch.cuda.synchronize() - t0 = time.perf_counter() - - step = 0 - while True: - last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) - - should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) - if should_validate: - torch.cuda.synchronize() - training_time_ms += 1000.0 * (time.perf_counter() - t0) - val_loss, val_bpb = eval_val( - args, - model, - rank, - world_size, - device, - grad_accum_steps, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, - ) - log0( - f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " - f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" - ) - torch.cuda.synchronize() - t0 = time.perf_counter() - - if last_step: - if stop_after_step is not None and step < args.iterations: - log0( - f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " - f"step:{step}/{args.iterations}" - ) - break - - elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - scale = lr_mul(step, elapsed_ms) - zero_grad_all() - train_loss = torch.zeros((), device=device) - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - loss = model(x, y) - train_loss += loss.detach() - (loss * grad_scale).backward() - train_loss /= grad_accum_steps - - frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 - muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum - for group in optimizer_muon.param_groups: - group["momentum"] = muon_momentum - - for opt in optimizers: - for group in opt.param_groups: - group["lr"] = group["base_lr"] * scale - - if args.grad_clip_norm > 0: - torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) - for opt in optimizers: - opt.step() - zero_grad_all() - - step += 1 - approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - should_log_train = ( - args.train_log_every > 0 - and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) - ) - if should_log_train: - log0( - f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " - f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" - ) - - # Needed to sync whether we've reached the wallclock cap. - reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms - if distributed and max_wallclock_ms is not None: - reached_cap_tensor = torch.tensor(int(reached_cap), device=device) - dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) - reached_cap = bool(reached_cap_tensor.item()) - if stop_after_step is None and reached_cap: - stop_after_step = step - - _QAT_ENABLED = False # Disable QAT for eval/serialization - log0( - f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " - f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" - ) - - # ----------------------------- - # SERIALIZATION + ROUNDTRIP VALIDATION - # ----------------------------- - # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce - # the compressed int8+zlib artifact and validate the round-tripped weights. - - if master_process: - torch.save(base_model.state_dict(), "final_model.pt") - model_bytes = os.path.getsize("final_model.pt") - code_bytes = len(code.encode("utf-8")) - log0(f"Serialized model: {model_bytes} bytes") - log0(f"Code size: {code_bytes} bytes") - log0(f"Total submission size: {model_bytes + code_bytes} bytes") - - quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) - quant_buf = io.BytesIO() - torch.save(quant_obj, quant_buf) - quant_raw = quant_buf.getvalue() - if USE_ZSTD: - cctx = zstd.ZstdCompressor(level=ZSTD_LEVEL) - quant_blob = cctx.compress(quant_raw) - else: - quant_blob = zlib.compress(quant_raw, level=9) - quant_raw_bytes = len(quant_raw) - compress_label = f"zstd-{ZSTD_LEVEL}" if USE_ZSTD else "zlib" - if master_process: - with open("final_model.int8.ptz", "wb") as f: - f.write(quant_blob) - quant_file_bytes = os.path.getsize("final_model.int8.ptz") - code_bytes = len(code.encode("utf-8")) - ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) - log0( - f"Serialized model int6+{compress_label}: {quant_file_bytes} bytes " - f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" - ) - log0(f"Total submission size int6+{compress_label}: {quant_file_bytes + code_bytes} bytes") - - if distributed: - dist.barrier() - with open("final_model.int8.ptz", "rb") as f: - quant_blob_disk = f.read() - if USE_ZSTD: - dctx = zstd.ZstdDecompressor() - quant_state = torch.load(io.BytesIO(dctx.decompress(quant_blob_disk)), map_location="cpu") - else: - quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") - base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) - torch.cuda.synchronize() - t_qeval = time.perf_counter() - q_val_loss, q_val_bpb = eval_val( - args, - model, - rank, - world_size, - device, - grad_accum_steps, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, - ) - torch.cuda.synchronize() - log0( - f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " - f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" - ) - log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") - - # Sliding window eval for final score (if stride > 0) - if args.eval_stride > 0: - eval_sl = args.eval_seq_len if args.eval_seq_len > 0 else args.train_seq_len - torch.cuda.synchronize() - t_slide = time.perf_counter() - sw_val_loss, sw_val_bpb = eval_val_sliding( - base_model, - rank, - world_size, - device, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, - eval_seq_len=eval_sl, - stride=args.eval_stride, - ) - torch.cuda.synchronize() - log0( - f"final_sliding_window val_loss:{sw_val_loss:.4f} val_bpb:{sw_val_bpb:.4f} " - f"stride:{args.eval_stride} eval_seq_len:{eval_sl} " - f"eval_time:{1000.0 * (time.perf_counter() - t_slide):.0f}ms" - ) - log0(f"final_sliding_window_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") - - if distributed: - dist.destroy_process_group() - - -if __name__ == "__main__": - main() - -==================================================================================================== -Running Python 3.12.12 | packaged by Anaconda, Inc. | (main, Oct 21 2025, 20:16:04) [GCC 11.2.0] -Running PyTorch 2.8.0+cu128 -Thu Mar 19 18:59:28 2026 -+-----------------------------------------------------------------------------------------+ -| NVIDIA-SMI 590.48.01 Driver Version: 590.48.01 CUDA Version: 13.1 | -+-----------------------------------------+------------------------+----------------------+ -| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | -| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | -| | | MIG M. | -|=========================================+========================+======================| -| 0 NVIDIA H100 80GB HBM3 On | 00000000:53:00.0 Off | 0 | -| N/A 30C P0 104W / 700W | 1542MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 1 NVIDIA H100 80GB HBM3 On | 00000000:64:00.0 Off | 0 | -| N/A 31C P0 103W / 700W | 1650MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 2 NVIDIA H100 80GB HBM3 On | 00000000:75:00.0 Off | 0 | -| N/A 31C P0 103W / 700W | 2165MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 3 NVIDIA H100 80GB HBM3 On | 00000000:86:00.0 Off | 0 | -| N/A 33C P0 104W / 700W | 1562MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 4 NVIDIA H100 80GB HBM3 On | 00000000:97:00.0 Off | 0 | -| N/A 33C P0 102W / 700W | 2279MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 5 NVIDIA H100 80GB HBM3 On | 00000000:A8:00.0 Off | 0 | -| N/A 29C P0 103W / 700W | 2181MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 6 NVIDIA H100 80GB HBM3 On | 00000000:B9:00.0 Off | 0 | -| N/A 32C P0 104W / 700W | 1576MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 7 NVIDIA H100 80GB HBM3 On | 00000000:CA:00.0 Off | 0 | -| N/A 31C P0 102W / 700W | 2207MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ - -+-----------------------------------------------------------------------------------------+ -| Processes: | -| GPU GI CI PID Type Process name GPU Memory | -| ID ID Usage | -|=========================================================================================| -| 0 N/A N/A 164653 C .../envs/fi-bench/bin/python3.12 1516MiB | -| 1 N/A N/A 164654 C .../envs/fi-bench/bin/python3.12 1516MiB | -| 2 N/A N/A 135953 C /opt/venv/bin/python 616MiB | -| 2 N/A N/A 164655 C .../envs/fi-bench/bin/python3.12 1516MiB | -| 3 N/A N/A 164656 C .../envs/fi-bench/bin/python3.12 1516MiB | -| 4 N/A N/A 129526 C /opt/venv/bin/python 630MiB | -| 4 N/A N/A 164657 C .../envs/fi-bench/bin/python3.12 1516MiB | -| 5 N/A N/A 122421 C /opt/venv/bin/python 616MiB | -| 5 N/A N/A 164658 C .../envs/fi-bench/bin/python3.12 1516MiB | -| 6 N/A N/A 164659 C .../envs/fi-bench/bin/python3.12 1516MiB | -| 7 N/A N/A 131633 C /opt/venv/bin/python 630MiB | -| 7 N/A N/A 164660 C .../envs/fi-bench/bin/python3.12 1516MiB | -+-----------------------------------------------------------------------------------------+ - -==================================================================================================== -val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model -train_loader:dataset:fineweb10B_sp1024 train_shards:80 -val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 -mlp_hidden:1344 -model_params:22174288 -world_size:8 grad_accum_steps:1 -sdp_backends:cudnn=False flash=True mem_efficient=False math=False -attention_mode:gqa num_heads:8 num_kv_heads:4 -tie_embeddings:True embed_lr:0.04 head_lr:0.0 matrix_lr:0.02 scalar_lr:0.02 -train_batch_tokens:524288 train_seq_len:2048 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 -seed:1337 -warmup_step:1/20 -warmup_step:2/20 -warmup_step:3/20 -warmup_step:4/20 -warmup_step:5/20 -warmup_step:6/20 -warmup_step:7/20 -warmup_step:8/20 -warmup_step:9/20 -warmup_step:10/20 -warmup_step:11/20 -warmup_step:12/20 -warmup_step:13/20 -warmup_step:14/20 -warmup_step:15/20 -warmup_step:16/20 -warmup_step:17/20 -warmup_step:18/20 -warmup_step:19/20 -warmup_step:20/20 -qat_enabled:True -step:1/20000 train_loss:6.9348 train_time:28804ms step_avg:28803.77ms -step:2/20000 train_loss:15.2455 train_time:28848ms step_avg:14423.83ms -step:3/20000 train_loss:15.1987 train_time:28915ms step_avg:9638.40ms -step:4/20000 train_loss:15.1339 train_time:28983ms step_avg:7245.75ms -step:5/20000 train_loss:14.9721 train_time:29051ms step_avg:5810.22ms -step:6/20000 train_loss:14.8893 train_time:29119ms step_avg:4853.22ms -step:7/20000 train_loss:14.5652 train_time:29188ms step_avg:4169.66ms -step:8/20000 train_loss:14.2584 train_time:29256ms step_avg:3656.96ms -step:9/20000 train_loss:13.8421 train_time:29324ms step_avg:3258.23ms -step:10/20000 train_loss:13.3369 train_time:29393ms step_avg:2939.34ms -step:100/20000 train_loss:4.2032 train_time:35511ms step_avg:355.11ms -step:200/20000 train_loss:3.0297 train_time:42469ms step_avg:212.35ms -step:300/20000 train_loss:2.4981 train_time:49265ms step_avg:164.22ms -step:400/20000 train_loss:2.3384 train_time:56229ms step_avg:140.57ms -step:500/20000 train_loss:2.4563 train_time:63009ms step_avg:126.02ms -step:600/20000 train_loss:2.4993 train_time:69948ms step_avg:116.58ms -step:700/20000 train_loss:2.3930 train_time:76720ms step_avg:109.60ms -step:800/20000 train_loss:2.2397 train_time:83665ms step_avg:104.58ms -step:900/20000 train_loss:2.2854 train_time:90456ms step_avg:100.51ms -step:1000/20000 train_loss:2.3331 train_time:97405ms step_avg:97.40ms -step:1100/20000 train_loss:2.1988 train_time:104196ms step_avg:94.72ms -step:1200/20000 train_loss:2.3500 train_time:111145ms step_avg:92.62ms -step:1300/20000 train_loss:2.3244 train_time:117935ms step_avg:90.72ms -step:1400/20000 train_loss:2.3809 train_time:124887ms step_avg:89.20ms -step:1500/20000 train_loss:2.1906 train_time:131676ms step_avg:87.78ms -step:1600/20000 train_loss:2.0530 train_time:138629ms step_avg:86.64ms -step:1700/20000 train_loss:2.1228 train_time:145419ms step_avg:85.54ms -step:1800/20000 train_loss:2.1613 train_time:152371ms step_avg:84.65ms -step:1900/20000 train_loss:2.1430 train_time:159160ms step_avg:83.77ms -step:2000/20000 train_loss:2.1961 train_time:166110ms step_avg:83.06ms -step:2100/20000 train_loss:2.2126 train_time:173052ms step_avg:82.41ms -step:2200/20000 train_loss:2.0101 train_time:179842ms step_avg:81.75ms -step:2300/20000 train_loss:2.3139 train_time:186768ms step_avg:81.20ms -step:2400/20000 train_loss:2.1392 train_time:193557ms step_avg:80.65ms -step:2500/20000 train_loss:2.0649 train_time:200492ms step_avg:80.20ms -step:2600/20000 train_loss:2.3548 train_time:207281ms step_avg:79.72ms -step:2700/20000 train_loss:2.0866 train_time:214226ms step_avg:79.34ms -step:2800/20000 train_loss:2.1706 train_time:221015ms step_avg:78.93ms -step:2900/20000 train_loss:2.1141 train_time:227949ms step_avg:78.60ms -step:3000/20000 train_loss:2.1548 train_time:234738ms step_avg:78.25ms -step:3100/20000 train_loss:2.1279 train_time:241682ms step_avg:77.96ms -step:3200/20000 train_loss:2.1197 train_time:248469ms step_avg:77.65ms -step:3300/20000 train_loss:2.1664 train_time:255401ms step_avg:77.39ms -step:3400/20000 train_loss:2.0896 train_time:262191ms step_avg:77.12ms -step:3500/20000 train_loss:2.1775 train_time:269133ms step_avg:76.90ms -step:3600/20000 train_loss:2.0271 train_time:275921ms step_avg:76.64ms -step:3700/20000 train_loss:2.0593 train_time:282854ms step_avg:76.45ms -step:3800/20000 train_loss:2.1358 train_time:289643ms step_avg:76.22ms -step:3900/20000 train_loss:1.9078 train_time:296575ms step_avg:76.04ms -step:4000/20000 train_loss:2.1005 train_time:303363ms step_avg:75.84ms -step:4100/20000 train_loss:2.1093 train_time:310296ms step_avg:75.68ms -step:4200/20000 train_loss:2.0977 train_time:317227ms step_avg:75.53ms -step:4300/20000 train_loss:1.9403 train_time:324015ms step_avg:75.35ms -step:4400/20000 train_loss:2.0306 train_time:330959ms step_avg:75.22ms -step:4500/20000 train_loss:2.1817 train_time:337747ms step_avg:75.05ms -step:4600/20000 train_loss:1.8991 train_time:344692ms step_avg:74.93ms -step:4700/20000 train_loss:2.1887 train_time:351481ms step_avg:74.78ms -step:4800/20000 train_loss:2.1740 train_time:358412ms step_avg:74.67ms -step:4900/20000 train_loss:2.0819 train_time:365200ms step_avg:74.53ms -step:5000/20000 train_loss:1.9305 train_time:372133ms step_avg:74.43ms -step:5100/20000 train_loss:1.9331 train_time:378920ms step_avg:74.30ms -step:5200/20000 train_loss:2.0837 train_time:385861ms step_avg:74.20ms -step:5300/20000 train_loss:2.1152 train_time:392649ms step_avg:74.08ms -step:5400/20000 train_loss:2.0970 train_time:399582ms step_avg:74.00ms -step:5500/20000 train_loss:2.0548 train_time:406370ms step_avg:73.89ms -step:5600/20000 train_loss:2.0806 train_time:413304ms step_avg:73.80ms -step:5700/20000 train_loss:2.0759 train_time:420093ms step_avg:73.70ms -step:5800/20000 train_loss:2.0368 train_time:427035ms step_avg:73.63ms -step:5900/20000 train_loss:2.0007 train_time:433825ms step_avg:73.53ms -step:6000/20000 train_loss:2.1213 train_time:440770ms step_avg:73.46ms -step:6100/20000 train_loss:2.0200 train_time:447559ms step_avg:73.37ms -step:6200/20000 train_loss:1.9816 train_time:454508ms step_avg:73.31ms -step:6300/20000 train_loss:1.9321 train_time:461443ms step_avg:73.24ms -step:6400/20000 train_loss:2.0603 train_time:468232ms step_avg:73.16ms -step:6500/20000 train_loss:1.9761 train_time:475164ms step_avg:73.10ms -step:6600/20000 train_loss:2.0127 train_time:481952ms step_avg:73.02ms -step:6700/20000 train_loss:2.0560 train_time:488885ms step_avg:72.97ms -step:6800/20000 train_loss:2.0777 train_time:495676ms step_avg:72.89ms -step:6900/20000 train_loss:1.9868 train_time:502606ms step_avg:72.84ms -step:7000/20000 train_loss:2.1159 train_time:509397ms step_avg:72.77ms -step:7100/20000 train_loss:1.9522 train_time:516340ms step_avg:72.72ms -step:7200/20000 train_loss:2.0863 train_time:523129ms step_avg:72.66ms -step:7300/20000 train_loss:1.9689 train_time:530074ms step_avg:72.61ms -step:7400/20000 train_loss:2.0103 train_time:536864ms step_avg:72.55ms -step:7500/20000 train_loss:2.0058 train_time:543811ms step_avg:72.51ms -step:7600/20000 train_loss:1.8901 train_time:550601ms step_avg:72.45ms -step:7700/20000 train_loss:1.9680 train_time:557552ms step_avg:72.41ms -step:7800/20000 train_loss:2.0326 train_time:564342ms step_avg:72.35ms -step:7900/20000 train_loss:2.0107 train_time:571279ms step_avg:72.31ms -step:8000/20000 train_loss:1.9981 train_time:578070ms step_avg:72.26ms -step:8100/20000 train_loss:2.0352 train_time:584997ms step_avg:72.22ms -step:8200/20000 train_loss:2.0720 train_time:591786ms step_avg:72.17ms -step:8300/20000 train_loss:1.9949 train_time:598726ms step_avg:72.14ms -step:8319/20000 val_loss:1.9958 val_bpb:1.1821 train_time:600042ms step_avg:72.13ms -stopping_early: wallclock_cap train_time:600042ms step:8319/20000 -peak memory allocated: 12108 MiB reserved: 12414 MiB -Serialized model: 87686115 bytes -Code size: 56005 bytes -Total submission size: 87742120 bytes -Serialized model int6+zstd-22: 15502314 bytes (payload:22835776 raw_torch:22885563 payload_ratio:3.84x) -Total submission size int6+zstd-22: 15558319 bytes -final_int8_zlib_roundtrip val_loss:1.9959 val_bpb:1.1821 eval_time:10467ms -final_int8_zlib_roundtrip_exact val_loss:1.99587778 val_bpb:1.18207124 -final_sliding_window val_loss:1.9603 val_bpb:1.1610 stride:64 eval_seq_len:2048 eval_time:370426ms -final_sliding_window_exact val_loss:1.96032898 val_bpb:1.16101576 diff --git a/records/track_10min_16mb/2026-03-19_Seq2048_FP16Emb_TunedLR/train_seed3.log b/records/track_10min_16mb/2026-03-19_Seq2048_FP16Emb_TunedLR/train_seed3.log deleted file mode 100644 index 04cc4aaad1..0000000000 --- a/records/track_10min_16mb/2026-03-19_Seq2048_FP16Emb_TunedLR/train_seed3.log +++ /dev/null @@ -1,1475 +0,0 @@ -""" -The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. - -Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. -""" - -from __future__ import annotations - -import copy -import glob -import io -import math -import os -import random -import subprocess -import sys -import time -import uuid -import zlib -import zstandard as zstd -from pathlib import Path - -import numpy as np -import sentencepiece as spm -import torch -import torch.distributed as dist -import torch.nn.functional as F -from torch import Tensor, nn -from torch.nn.parallel import DistributedDataParallel as DDP - -# ----------------------------- -# HYPERPARAMETERS -# ----------------------------- -# Default Simple Baseline run: -# - 9 transformer blocks at width 512 -# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion -# - vocab size 1024, sequence length 1024, tied embeddings -# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap - -class Hyperparameters: - # Data paths are shard globs produced by the existing preprocessing pipeline. - data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") - train_files = os.path.join(data_path, "fineweb_train_*.bin") - val_files = os.path.join(data_path, "fineweb_val_*.bin") - tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") - run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) - seed = int(os.environ.get("SEED", 1337)) - - # Validation cadence and batch size. Validation always uses the full fineweb_val split. - val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) - val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) - train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) - # Sliding window eval: stride=0 means standard (non-overlapping) eval. - eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) - eval_seq_len = int(os.environ.get("EVAL_SEQ_LEN", 0)) # 0 = use train_seq_len - - # Training length. - iterations = int(os.environ.get("ITERATIONS", 20000)) - warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3600)) - warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) - train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) - train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) - max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) - qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) - - # Model shape. - vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) - num_layers = int(os.environ.get("NUM_LAYERS", 10)) - num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) - model_dim = int(os.environ.get("MODEL_DIM", 512)) - num_heads = int(os.environ.get("NUM_HEADS", 8)) - mlp_mult = int(os.environ.get("MLP_MULT", 2)) - mlp_hidden = int(os.environ.get("MLP_HIDDEN", 1536)) # 3x model_dim, enabled by int6+zstd compression - tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) - rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) - logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) - - # Optimizer hyperparameters. - embed_lr = float(os.environ.get("EMBED_LR", 0.6)) - head_lr = float(os.environ.get("HEAD_LR", 0.008)) - tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.04)) - tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) - matrix_lr = float(os.environ.get("MATRIX_LR", 0.02)) - scalar_lr = float(os.environ.get("SCALAR_LR", 0.02)) - muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) - muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) - muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) - muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) - beta1 = float(os.environ.get("BETA1", 0.9)) - beta2 = float(os.environ.get("BETA2", 0.95)) - adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) - grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) - -# ----------------------------- -# MUON OPTIMIZER -# ----------------------------- -# -# As borrowed from modded-nanogpt -# Background on Muon: https://kellerjordan.github.io/posts/muon/ - -def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: - # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. - # Muon uses this to normalize matrix-shaped gradients before applying them. - a, b, c = (3.4445, -4.7750, 2.0315) - X = G.bfloat16() - X /= X.norm() + eps - transposed = G.size(0) > G.size(1) - if transposed: - X = X.T - for _ in range(steps): - A = X @ X.T - B = b * A + c * A @ A - X = a * X + B @ X - return X.T if transposed else X - - -class Muon(torch.optim.Optimizer): - def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): - super().__init__( - params, - dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), - ) - - @torch.no_grad() - def step(self, closure=None): - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - distributed = dist.is_available() and dist.is_initialized() - world_size = dist.get_world_size() if distributed else 1 - rank = dist.get_rank() if distributed else 0 - - for group in self.param_groups: - params = group["params"] - if not params: - continue - lr = group["lr"] - momentum = group["momentum"] - backend_steps = group["backend_steps"] - nesterov = group["nesterov"] - - total_params = sum(int(p.numel()) for p in params) - updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) - - curr = 0 - for i, p in enumerate(params): - if i % world_size == rank and p.grad is not None: - g = p.grad - state = self.state[p] - if "momentum_buffer" not in state: - state["momentum_buffer"] = torch.zeros_like(g) - buf = state["momentum_buffer"] - buf.mul_(momentum).add_(g) - if nesterov: - g = g.add(buf, alpha=momentum) - g = zeropower_via_newtonschulz5(g, steps=backend_steps) - # Scale correction from Muon reference implementations. - g *= max(1, g.size(0) / g.size(1)) ** 0.5 - updates_flat[curr : curr + p.numel()] = g.reshape(-1) - curr += p.numel() - - if distributed: - dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) - - curr = 0 - for p in params: - g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) - p.add_(g, alpha=-lr) - curr += p.numel() - - return loss - - -# ----------------------------- -# TOKENIZER-AGNOSTIC EVALUATION SETUP -# ----------------------------- -# -# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. -# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. -# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. -# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. - -def build_sentencepiece_luts( - sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device -) -> tuple[Tensor, Tensor, Tensor]: - sp_vocab_size = int(sp.vocab_size()) - table_size = max(sp_vocab_size, vocab_size) - base_bytes_np = np.zeros((table_size,), dtype=np.int16) - has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) - is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) - for token_id in range(sp_vocab_size): - if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): - continue - is_boundary_token_np[token_id] = False - if sp.is_byte(token_id): - base_bytes_np[token_id] = 1 - continue - piece = sp.id_to_piece(token_id) - if piece.startswith("▁"): - has_leading_space_np[token_id] = True - piece = piece[1:] - base_bytes_np[token_id] = len(piece.encode("utf-8")) - return ( - torch.tensor(base_bytes_np, dtype=torch.int16, device=device), - torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), - torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), - ) - - -def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: - files = [Path(p) for p in sorted(glob.glob(pattern))] - if not files: - raise FileNotFoundError(f"No files found for pattern: {pattern}") - # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. - tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() - usable = ((tokens.numel() - 1) // seq_len) * seq_len - if usable <= 0: - raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") - return tokens[: usable + 1] - - -def eval_val( - args: Hyperparameters, - model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - grad_accum_steps: int, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, -) -> tuple[float, float]: - # Validation computes two metrics: - # - val_loss: token cross-entropy (natural log) - # - val_bpb: tokenizer-agnostic compression metric used by the challenge - local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) - if local_batch_tokens < args.train_seq_len: - raise ValueError( - "VAL_BATCH_SIZE must provide at least one sequence per rank; " - f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " - f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" - ) - local_batch_seqs = local_batch_tokens // args.train_seq_len - total_seqs = (val_tokens.numel() - 1) // args.train_seq_len - seq_start = (total_seqs * rank) // world_size - seq_end = (total_seqs * (rank + 1)) // world_size - val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) - val_token_count = torch.zeros((), device=device, dtype=torch.float64) - val_byte_count = torch.zeros((), device=device, dtype=torch.float64) - - model.eval() - with torch.inference_mode(): - for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): - batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) - raw_start = batch_seq_start * args.train_seq_len - raw_end = batch_seq_end * args.train_seq_len + 1 - local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) - x = local[:-1].reshape(-1, args.train_seq_len) - y = local[1:].reshape(-1, args.train_seq_len) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - batch_loss = model(x, y).detach() - batch_token_count = float(y.numel()) - val_loss_sum += batch_loss.to(torch.float64) * batch_token_count - val_token_count += batch_token_count - prev_ids = x.reshape(-1) - tgt_ids = y.reshape(-1) - token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) - token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) - val_byte_count += token_bytes.to(torch.float64).sum() - - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) - - val_loss = val_loss_sum / val_token_count - bits_per_token = val_loss.item() / math.log(2.0) - tokens_per_byte = val_token_count.item() / val_byte_count.item() - model.train() - return float(val_loss.item()), float(bits_per_token * tokens_per_byte) - - -def eval_val_sliding( - base_model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, - eval_seq_len: int = 2048, - stride: int = 64, - batch_size: int = 4, -) -> tuple[float, float]: - """Sliding window eval: each token scored with max context.""" - N = val_tokens.numel() - # Distribute windows across ranks - window_starts: list[int] = [] - pos = 0 - while pos + eval_seq_len < N: - window_starts.append(pos) - if pos == 0: - pos = eval_seq_len - stride # after first full window - else: - pos += stride - - # Partition windows across ranks - per_rank = len(window_starts) // max(world_size, 1) - rank_start = rank * per_rank - rank_end = (rank + 1) * per_rank if rank < world_size - 1 else len(window_starts) - my_windows = window_starts[rank_start:rank_end] - - val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) - val_token_count = torch.zeros((), device=device, dtype=torch.float64) - val_byte_count = torch.zeros((), device=device, dtype=torch.float64) - - base_model.eval() - with torch.inference_mode(): - for batch_start in range(0, len(my_windows), batch_size): - batch_windows = my_windows[batch_start:batch_start + batch_size] - xs, ys = [], [] - for w in batch_windows: - xs.append(val_tokens[w:w + eval_seq_len]) - ys.append(val_tokens[w + 1:w + eval_seq_len + 1]) - x = torch.stack(xs).to(device=device, dtype=torch.int64) - y = torch.stack(ys).to(device=device, dtype=torch.int64) - - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - per_token_loss = base_model.forward_per_token_loss(x, y) - - for i, w in enumerate(batch_windows): - if w == 0: - # First window: score all positions - score_start = 0 - else: - # Subsequent windows: score only last `stride` positions - score_start = eval_seq_len - stride - - losses = per_token_loss[i, score_start:] - tgt_ids = y[i, score_start:] - prev_ids = x[i, score_start:] - n_scored = losses.numel() - - val_loss_sum += losses.to(torch.float64).sum() - val_token_count += n_scored - token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) - token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) - val_byte_count += token_bytes.to(torch.float64).sum() - - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) - - val_loss = val_loss_sum / val_token_count - bits_per_token = val_loss.item() / math.log(2.0) - tokens_per_byte = val_token_count.item() / val_byte_count.item() - base_model.train() - return float(val_loss.item()), float(bits_per_token * tokens_per_byte) - - -# ----------------------------- -# POST-TRAINING QUANTIZATION -# ----------------------------- -# -# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. -# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. -# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. - -CONTROL_TENSOR_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "CONTROL_TENSOR_NAME_PATTERNS", - "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", - ",".join(CONTROL_TENSOR_NAME_PATTERNS), - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 -# Large tensors to keep as fp16 passthrough instead of int8 quantization. -# Tied embeddings serve dual duty (input + output head) and are disproportionately -# sensitive to quantization noise (~0.007 BPB gap). -FP16_PASSTHROUGH_PATTERNS = tuple( - p for p in os.environ.get("FP16_PASSTHROUGH_PATTERNS", "tok_emb").split(",") if p -) -INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 -INT8_PER_ROW_SCALE_DTYPE = torch.float16 -INT8_CLIP_PERCENTILE = 99.99984 -INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 -# Full int6 quantization: all 2D block weights use [-31,31] range (63 levels) -# stored in int8 container. zstd-22 compresses the low-entropy data much better. -USE_INT6 = bool(int(os.environ.get("USE_INT6", "1"))) -USE_ZSTD = bool(int(os.environ.get("USE_ZSTD", "1"))) -ZSTD_LEVEL = int(os.environ.get("ZSTD_LEVEL", 22)) - -def tensor_nbytes(t: Tensor) -> int: - return int(t.numel()) * int(t.element_size()) - -def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: - if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): - return t.float().contiguous() - if t.dtype in {torch.float32, torch.bfloat16}: - passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") - return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() - return t - -def quantize_float_tensor(t: Tensor, use_int6: bool = False) -> tuple[Tensor, Tensor]: - t32 = t.float() - qmax = 31 if use_int6 else 127 - if t32.ndim == 2: - # Matrices get one scale per row, which usually tracks output-channel - # ranges much better than a single tensor-wide scale. - clip_abs = ( - torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) - if t32.numel() - else torch.empty((t32.shape[0],), dtype=torch.float32) - ) - clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) - scale = (clip_abs / float(qmax)).clamp_min(1.0 / float(qmax)) - q = torch.clamp(torch.round(clipped / scale[:, None]), -qmax, qmax).to(torch.int8).contiguous() - return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() - - # Vectors / scalars use a simpler per-tensor scale. - clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 - scale = torch.tensor(clip_abs / float(qmax) if clip_abs > 0 else 1.0, dtype=torch.float32) - q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -qmax, qmax).to(torch.int8).contiguous() - return q, scale - -def quantize_state_dict_int8(state_dict: dict[str, Tensor]): - # Export format: - # - per-row int6/int8 for 2D float tensors (int6 if USE_INT6) - # - per-tensor int8 for other float tensors - # - exact passthrough for non-floats - # - passthrough for small float tensors, stored as fp16 to save bytes - quantized: dict[str, Tensor] = {} - scales: dict[str, Tensor] = {} - dtypes: dict[str, str] = {} - passthrough: dict[str, Tensor] = {} - passthrough_orig_dtypes: dict[str, str] = {} - qmeta: dict[str, dict[str, object]] = {} - stats = dict.fromkeys( - ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), - 0, - ) - - for name, tensor in state_dict.items(): - t = tensor.detach().to("cpu").contiguous() - stats["param_count"] += int(t.numel()) - stats["num_tensors"] += 1 - stats["baseline_tensor_bytes"] += tensor_nbytes(t) - - if not t.is_floating_point(): - stats["num_nonfloat_tensors"] += 1 - passthrough[name] = t - stats["int8_payload_bytes"] += tensor_nbytes(t) - continue - - # Small float tensors are cheap enough to keep directly. We still downcast - # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. - if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: - kept = keep_float_tensor(name, t, passthrough_orig_dtypes) - passthrough[name] = kept - stats["int8_payload_bytes"] += tensor_nbytes(kept) - continue - - # Large tensors matching FP16_PASSTHROUGH_PATTERNS are kept as fp16 - # instead of int8 (e.g. tied embeddings that are quant-sensitive). - if any(pattern in name for pattern in FP16_PASSTHROUGH_PATTERNS): - kept = keep_float_tensor(name, t, passthrough_orig_dtypes) - passthrough[name] = kept - stats["int8_payload_bytes"] += tensor_nbytes(kept) - continue - - stats["num_float_tensors"] += 1 - q, s = quantize_float_tensor(t, use_int6=USE_INT6) - if s.ndim > 0: - qmeta[name] = {"scheme": "per_row", "axis": 0} - quantized[name] = q - scales[name] = s - dtypes[name] = str(t.dtype).removeprefix("torch.") - stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) - - obj: dict[str, object] = { - "__quant_format__": "int8_clean_per_row_v1", - "quantized": quantized, - "scales": scales, - "dtypes": dtypes, - "passthrough": passthrough, - } - if qmeta: - obj["qmeta"] = qmeta - if passthrough_orig_dtypes: - obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes - return obj, stats - -def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: - out: dict[str, Tensor] = {} - qmeta = obj.get("qmeta", {}) - passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) - for name, q in obj["quantized"].items(): - dtype = getattr(torch, obj["dtypes"][name]) - s = obj["scales"][name] - if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: - s = s.to(dtype=torch.float32) - # Broadcast the saved row scale back across trailing dimensions. - out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() - else: - scale = float(s.item()) - out[name] = (q.float() * scale).to(dtype=dtype).contiguous() - for name, t in obj["passthrough"].items(): - # Restore small tensors, undoing the temporary fp16 storage cast if needed. - out_t = t.detach().to("cpu").contiguous() - orig_dtype = passthrough_orig_dtypes.get(name) - if isinstance(orig_dtype, str): - out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() - out[name] = out_t - return out - - -# ----------------------------- -# DATA LOADING -# ----------------------------- - -def load_data_shard(file: Path) -> Tensor: - header_bytes = 256 * np.dtype(" None: - self.file_idx = (self.file_idx + 1) % len(self.files) - self.tokens = load_data_shard(self.files[self.file_idx]) - self.pos = 0 - - def take(self, n: int) -> Tensor: - chunks: list[Tensor] = [] - remaining = n - while remaining > 0: - avail = self.tokens.numel() - self.pos - if avail <= 0: - self._advance_file() - continue - k = min(remaining, avail) - chunks.append(self.tokens[self.pos : self.pos + k]) - self.pos += k - remaining -= k - return chunks[0] if len(chunks) == 1 else torch.cat(chunks) - - -class DistributedTokenLoader: - # Each call consumes a contiguous chunk from the shared token stream, then slices out - # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. - def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): - self.rank = rank - self.world_size = world_size - self.device = device - self.stream = TokenStream(pattern) - - def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: - local_tokens = global_tokens // (self.world_size * grad_accum_steps) - per_rank_span = local_tokens + 1 - chunk = self.stream.take(per_rank_span * self.world_size) - start = self.rank * per_rank_span - local = chunk[start : start + per_rank_span].to(dtype=torch.int64) - x = local[:-1].reshape(-1, seq_len) - y = local[1:].reshape(-1, seq_len) - return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) - -# ----------------------------- -# TRANSFORMER MODULES -# ----------------------------- - -class RMSNorm(nn.Module): - def __init__(self, eps: float | None = None): - super().__init__() - self.eps = eps - - def forward(self, x: Tensor) -> Tensor: - return F.rms_norm(x, (x.size(-1),), eps=self.eps) - - -def _fake_quantize_int6(w: Tensor) -> Tensor: - """STE fake int6 quantization: forward uses quantized values, backward passes through.""" - # Use amax instead of quantile — compile-friendly and O(n) per row - clip_abs = w.abs().amax(dim=1).clamp_min(1.0 / 31.0) - scale = clip_abs / 31.0 - q = (w / scale[:, None]).round().clamp(-31, 31) - dequant = q * scale[:, None] - return w + (dequant - w).detach() # STE: forward=quantized, backward=identity - -# Global flag to enable/disable QAT (disabled during eval) -_QAT_ENABLED = False - -class CastedLinear(nn.Linear): - # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. - def forward(self, x: Tensor) -> Tensor: - w = self.weight.to(x.dtype) - if _QAT_ENABLED and self.weight.ndim == 2: - w = _fake_quantize_int6(w) - bias = self.bias.to(x.dtype) if self.bias is not None else None - return F.linear(x, w, bias) - - -def restore_low_dim_params_to_fp32(module: nn.Module) -> None: - # Keep small/control parameters in fp32 even when the model body runs in bf16. - with torch.no_grad(): - for name, param in module.named_parameters(): - if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: - param.data = param.data.float() - - -class Rotary(nn.Module): - # Caches cos/sin tables per sequence length on the current device. - def __init__(self, dim: int, base: float = 10000.0): - super().__init__() - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self._seq_len_cached = 0 - self._cos_cached: Tensor | None = None - self._sin_cached: Tensor | None = None - - def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: - if ( - self._cos_cached is None - or self._sin_cached is None - or self._seq_len_cached != seq_len - or self._cos_cached.device != device - ): - t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) - freqs = torch.outer(t, self.inv_freq.to(device)) - self._cos_cached = freqs.cos()[None, None, :, :] - self._sin_cached = freqs.sin()[None, None, :, :] - self._seq_len_cached = seq_len - return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) - - -def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: - half = x.size(-1) // 2 - x1, x2 = x[..., :half], x[..., half:] - return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) - - -class CausalSelfAttention(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - rope_base: float, - qk_gain_init: float, - ): - super().__init__() - if dim % num_heads != 0: - raise ValueError("model_dim must be divisible by num_heads") - if num_heads % num_kv_heads != 0: - raise ValueError("num_heads must be divisible by num_kv_heads") - self.num_heads = num_heads - self.num_kv_heads = num_kv_heads - self.head_dim = dim // num_heads - if self.head_dim % 2 != 0: - raise ValueError("head_dim must be even for RoPE") - kv_dim = self.num_kv_heads * self.head_dim - self.c_q = CastedLinear(dim, dim, bias=False) - self.c_k = CastedLinear(dim, kv_dim, bias=False) - self.c_v = CastedLinear(dim, kv_dim, bias=False) - self.proj = CastedLinear(dim, dim, bias=False) - self.proj._zero_init = True - self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) - self.rotary = Rotary(self.head_dim, base=rope_base) - - def forward(self, x: Tensor) -> Tensor: - bsz, seqlen, dim = x.shape - q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) - k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) - v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) - q = F.rms_norm(q, (q.size(-1),)) - k = F.rms_norm(k, (k.size(-1),)) - cos, sin = self.rotary(seqlen, x.device, q.dtype) - q = apply_rotary_emb(q, cos, sin) - k = apply_rotary_emb(k, cos, sin) - q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] - y = F.scaled_dot_product_attention( - q, - k, - v, - attn_mask=None, - is_causal=True, - enable_gqa=(self.num_kv_heads != self.num_heads), - ) - y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) - return self.proj(y) - - -class MLP(nn.Module): - # relu^2 MLP from the original modded-nanogpt setup - def __init__(self, dim: int, hidden: int): - super().__init__() - self.fc = CastedLinear(dim, hidden, bias=False) - self.proj = CastedLinear(hidden, dim, bias=False) - self.proj._zero_init = True - - def forward(self, x: Tensor) -> Tensor: - x = torch.relu(self.fc(x)) - return self.proj(x.square()) - - -class Block(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - mlp_hidden: int, - rope_base: float, - qk_gain_init: float, - ): - super().__init__() - self.attn_norm = RMSNorm() - self.mlp_norm = RMSNorm() - self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) - self.mlp = MLP(dim, mlp_hidden) - self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) - - def forward(self, x: Tensor, x0: Tensor) -> Tensor: - mix = self.resid_mix.to(dtype=x.dtype) - x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 - attn_out = self.attn(self.attn_norm(x)) - x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out - x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) - return x - - -class GPT(nn.Module): - def __init__( - self, - vocab_size: int, - num_layers: int, - model_dim: int, - num_heads: int, - num_kv_heads: int, - mlp_hidden: int, - tie_embeddings: bool, - tied_embed_init_std: float, - logit_softcap: float, - rope_base: float, - qk_gain_init: float, - ): - super().__init__() - if logit_softcap <= 0.0: - raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") - self.tie_embeddings = tie_embeddings - self.tied_embed_init_std = tied_embed_init_std - self.logit_softcap = logit_softcap - self.tok_emb = nn.Embedding(vocab_size, model_dim) - self.num_encoder_layers = num_layers // 2 - self.num_decoder_layers = num_layers - self.num_encoder_layers - self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) - self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) - self.blocks = nn.ModuleList( - [ - Block( - model_dim, - num_heads, - num_kv_heads, - mlp_hidden, - rope_base, - qk_gain_init, - ) - for i in range(num_layers) - ] - ) - self.final_norm = RMSNorm() - self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) - if self.lm_head is not None: - self.lm_head._zero_init = True - self._init_weights() - - def _init_weights(self) -> None: - if self.tie_embeddings: - nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) - for module in self.modules(): - if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): - nn.init.zeros_(module.weight) - - def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: - x = self.tok_emb(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - x0 = x - skips: list[Tensor] = [] - - # First half stores skips; second half reuses them in reverse order. - for i in range(self.num_encoder_layers): - x = self.blocks[i](x, x0) - skips.append(x) - for i in range(self.num_decoder_layers): - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - x = self.blocks[self.num_encoder_layers + i](x, x0) - - x = self.final_norm(x).reshape(-1, x.size(-1)) - targets = target_ids.reshape(-1) - if self.tie_embeddings: - logits_proj = F.linear(x, self.tok_emb.weight) - else: - if self.lm_head is None: - raise RuntimeError("lm_head is required when tie_embeddings=False") - logits_proj = self.lm_head(x) - logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - return F.cross_entropy(logits.float(), targets, reduction="mean") - - def forward_per_token_loss(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: - """Return per-token losses shaped (batch, seq_len) for sliding window eval.""" - x = self.tok_emb(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - x0 = x - skips: list[Tensor] = [] - for i in range(self.num_encoder_layers): - x = self.blocks[i](x, x0) - skips.append(x) - for i in range(self.num_decoder_layers): - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - x = self.blocks[self.num_encoder_layers + i](x, x0) - x = self.final_norm(x).reshape(-1, x.size(-1)) - targets = target_ids.reshape(-1) - if self.tie_embeddings: - logits_proj = F.linear(x, self.tok_emb.weight) - else: - if self.lm_head is None: - raise RuntimeError("lm_head is required when tie_embeddings=False") - logits_proj = self.lm_head(x) - logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - per_token = F.cross_entropy(logits.float(), targets, reduction="none") - return per_token.view(input_ids.shape) - - -# ----------------------------- -# TRAINING -# ----------------------------- - -def main() -> None: - global zeropower_via_newtonschulz5 - - code = Path(__file__).read_text(encoding="utf-8") - args = Hyperparameters() - zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) - - # ----------------------------- - # DISTRIBUTED + CUDA SETUP - # ----------------------------- - - distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ - rank = int(os.environ.get("RANK", "0")) - world_size = int(os.environ.get("WORLD_SIZE", "1")) - local_rank = int(os.environ.get("LOCAL_RANK", "0")) - if world_size <= 0: - raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") - if 8 % world_size != 0: - raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") - grad_accum_steps = 8 // world_size - grad_scale = 1.0 / grad_accum_steps - if not torch.cuda.is_available(): - raise RuntimeError("CUDA is required") - device = torch.device("cuda", local_rank) - torch.cuda.set_device(device) - if distributed: - dist.init_process_group(backend="nccl", device_id=device) - dist.barrier() - master_process = rank == 0 - - # Fast math knobs - torch.backends.cuda.matmul.allow_tf32 = True - torch.backends.cudnn.allow_tf32 = True - from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp - - enable_cudnn_sdp(False) - enable_flash_sdp(True) - enable_mem_efficient_sdp(False) - enable_math_sdp(False) - - logfile = None - if master_process: - os.makedirs("logs", exist_ok=True) - logfile = f"logs/{args.run_id}.txt" - print(logfile) - - def log0(msg: str, console: bool = True) -> None: - if not master_process: - return - if console: - print(msg) - if logfile is not None: - with open(logfile, "a", encoding="utf-8") as f: - print(msg, file=f) - - log0(code, console=False) - log0("=" * 100, console=False) - log0(f"Running Python {sys.version}", console=False) - log0(f"Running PyTorch {torch.__version__}", console=False) - log0( - subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, - console=False, - ) - log0("=" * 100, console=False) - - # ----------------------------- - # TOKENIZER + VALIDATION METRIC SETUP - # ----------------------------- - - random.seed(args.seed) - np.random.seed(args.seed) - torch.manual_seed(args.seed) - torch.cuda.manual_seed_all(args.seed) - - if not args.tokenizer_path.endswith(".model"): - raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") - sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) - if int(sp.vocab_size()) != args.vocab_size: - raise ValueError( - f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" - ) - dataset_dir = Path(args.data_path).resolve() - actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) - val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) - base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( - sp, args.vocab_size, device - ) - log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") - log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") - log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") - - # ----------------------------- - # MODEL + OPTIMIZER SETUP - # ----------------------------- - - mlp_hidden = args.mlp_hidden if args.mlp_hidden > 0 else args.mlp_mult * args.model_dim - log0(f"mlp_hidden:{mlp_hidden}") - - base_model = GPT( - vocab_size=args.vocab_size, - num_layers=args.num_layers, - model_dim=args.model_dim, - num_heads=args.num_heads, - num_kv_heads=args.num_kv_heads, - mlp_hidden=mlp_hidden, - tie_embeddings=args.tie_embeddings, - tied_embed_init_std=args.tied_embed_init_std, - logit_softcap=args.logit_softcap, - rope_base=args.rope_base, - qk_gain_init=args.qk_gain_init, - ).to(device).bfloat16() - for module in base_model.modules(): - if isinstance(module, CastedLinear): - module.float() - restore_low_dim_params_to_fp32(base_model) - compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) - model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model - - # Optimizer split: - # - token embedding (Adam) uses EMBED_LR - # - untied lm_head (Adam) uses HEAD_LR - # - matrix params in transformer blocks use MATRIX_LR via Muon - # - vectors/scalars use SCALAR_LR via Adam - block_named_params = list(base_model.blocks.named_parameters()) - matrix_params = [ - p - for name, p in block_named_params - if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - scalar_params = [ - p - for name, p in block_named_params - if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - if base_model.skip_weights.numel() > 0: - scalar_params.append(base_model.skip_weights) - token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr - optimizer_tok = torch.optim.Adam( - [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizer_muon = Muon( - matrix_params, - lr=args.matrix_lr, - momentum=args.muon_momentum, - backend_steps=args.muon_backend_steps, - ) - for group in optimizer_muon.param_groups: - group["base_lr"] = args.matrix_lr - optimizer_scalar = torch.optim.Adam( - [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] - if base_model.lm_head is not None: - optimizer_head = torch.optim.Adam( - [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizers.insert(1, optimizer_head) - - n_params = sum(p.numel() for p in base_model.parameters()) - log0(f"model_params:{n_params}") - log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") - log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") - log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") - log0( - f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " - f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " - f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" - ) - log0( - f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " - f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " - f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" - ) - log0(f"seed:{args.seed}") - - # ----------------------------- - # DATA LOADER & MODEL WARMUP - # ----------------------------- - - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - def zero_grad_all() -> None: - for opt in optimizers: - opt.zero_grad(set_to_none=True) - - max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None - - def lr_mul(step: int, elapsed_ms: float) -> float: - if args.warmdown_iters <= 0: - return 1.0 - if max_wallclock_ms is None: - warmdown_start = max(args.iterations - args.warmdown_iters, 0) - return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 - step_ms = elapsed_ms / max(step, 1) - warmdown_ms = args.warmdown_iters * step_ms - remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) - return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 - - # Warmup primes the compiled forward/backward/optimizer paths, then we restore the - # initial weights/optimizer state so measured training starts from the true init. - if args.warmup_steps > 0: - initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} - initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] - model.train() - for warmup_step in range(args.warmup_steps): - zero_grad_all() - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - warmup_loss = model(x, y) - (warmup_loss * grad_scale).backward() - for opt in optimizers: - opt.step() - zero_grad_all() - if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: - log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") - base_model.load_state_dict(initial_model_state, strict=True) - for opt, state in zip(optimizers, initial_optimizer_states, strict=True): - opt.load_state_dict(state) - zero_grad_all() - if distributed: - model.require_backward_grad_sync = True - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - # ----------------------------- - # MAIN TRAINING LOOP - # ----------------------------- - - global _QAT_ENABLED - _QAT_ENABLED = USE_INT6 # Enable QAT during training if using int6 - log0(f"qat_enabled:{_QAT_ENABLED}") - - training_time_ms = 0.0 - stop_after_step: int | None = None - torch.cuda.synchronize() - t0 = time.perf_counter() - - step = 0 - while True: - last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) - - should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) - if should_validate: - torch.cuda.synchronize() - training_time_ms += 1000.0 * (time.perf_counter() - t0) - val_loss, val_bpb = eval_val( - args, - model, - rank, - world_size, - device, - grad_accum_steps, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, - ) - log0( - f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " - f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" - ) - torch.cuda.synchronize() - t0 = time.perf_counter() - - if last_step: - if stop_after_step is not None and step < args.iterations: - log0( - f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " - f"step:{step}/{args.iterations}" - ) - break - - elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - scale = lr_mul(step, elapsed_ms) - zero_grad_all() - train_loss = torch.zeros((), device=device) - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - loss = model(x, y) - train_loss += loss.detach() - (loss * grad_scale).backward() - train_loss /= grad_accum_steps - - frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 - muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum - for group in optimizer_muon.param_groups: - group["momentum"] = muon_momentum - - for opt in optimizers: - for group in opt.param_groups: - group["lr"] = group["base_lr"] * scale - - if args.grad_clip_norm > 0: - torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) - for opt in optimizers: - opt.step() - zero_grad_all() - - step += 1 - approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - should_log_train = ( - args.train_log_every > 0 - and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) - ) - if should_log_train: - log0( - f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " - f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" - ) - - # Needed to sync whether we've reached the wallclock cap. - reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms - if distributed and max_wallclock_ms is not None: - reached_cap_tensor = torch.tensor(int(reached_cap), device=device) - dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) - reached_cap = bool(reached_cap_tensor.item()) - if stop_after_step is None and reached_cap: - stop_after_step = step - - _QAT_ENABLED = False # Disable QAT for eval/serialization - log0( - f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " - f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" - ) - - # ----------------------------- - # SERIALIZATION + ROUNDTRIP VALIDATION - # ----------------------------- - # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce - # the compressed int8+zlib artifact and validate the round-tripped weights. - - if master_process: - torch.save(base_model.state_dict(), "final_model.pt") - model_bytes = os.path.getsize("final_model.pt") - code_bytes = len(code.encode("utf-8")) - log0(f"Serialized model: {model_bytes} bytes") - log0(f"Code size: {code_bytes} bytes") - log0(f"Total submission size: {model_bytes + code_bytes} bytes") - - quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) - quant_buf = io.BytesIO() - torch.save(quant_obj, quant_buf) - quant_raw = quant_buf.getvalue() - if USE_ZSTD: - cctx = zstd.ZstdCompressor(level=ZSTD_LEVEL) - quant_blob = cctx.compress(quant_raw) - else: - quant_blob = zlib.compress(quant_raw, level=9) - quant_raw_bytes = len(quant_raw) - compress_label = f"zstd-{ZSTD_LEVEL}" if USE_ZSTD else "zlib" - if master_process: - with open("final_model.int8.ptz", "wb") as f: - f.write(quant_blob) - quant_file_bytes = os.path.getsize("final_model.int8.ptz") - code_bytes = len(code.encode("utf-8")) - ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) - log0( - f"Serialized model int6+{compress_label}: {quant_file_bytes} bytes " - f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" - ) - log0(f"Total submission size int6+{compress_label}: {quant_file_bytes + code_bytes} bytes") - - if distributed: - dist.barrier() - with open("final_model.int8.ptz", "rb") as f: - quant_blob_disk = f.read() - if USE_ZSTD: - dctx = zstd.ZstdDecompressor() - quant_state = torch.load(io.BytesIO(dctx.decompress(quant_blob_disk)), map_location="cpu") - else: - quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") - base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) - torch.cuda.synchronize() - t_qeval = time.perf_counter() - q_val_loss, q_val_bpb = eval_val( - args, - model, - rank, - world_size, - device, - grad_accum_steps, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, - ) - torch.cuda.synchronize() - log0( - f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " - f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" - ) - log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") - - # Sliding window eval for final score (if stride > 0) - if args.eval_stride > 0: - eval_sl = args.eval_seq_len if args.eval_seq_len > 0 else args.train_seq_len - torch.cuda.synchronize() - t_slide = time.perf_counter() - sw_val_loss, sw_val_bpb = eval_val_sliding( - base_model, - rank, - world_size, - device, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, - eval_seq_len=eval_sl, - stride=args.eval_stride, - ) - torch.cuda.synchronize() - log0( - f"final_sliding_window val_loss:{sw_val_loss:.4f} val_bpb:{sw_val_bpb:.4f} " - f"stride:{args.eval_stride} eval_seq_len:{eval_sl} " - f"eval_time:{1000.0 * (time.perf_counter() - t_slide):.0f}ms" - ) - log0(f"final_sliding_window_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") - - if distributed: - dist.destroy_process_group() - - -if __name__ == "__main__": - main() - -==================================================================================================== -Running Python 3.12.12 | packaged by Anaconda, Inc. | (main, Oct 21 2025, 20:16:04) [GCC 11.2.0] -Running PyTorch 2.8.0+cu128 -Thu Mar 19 19:40:58 2026 -+-----------------------------------------------------------------------------------------+ -| NVIDIA-SMI 590.48.01 Driver Version: 590.48.01 CUDA Version: 13.1 | -+-----------------------------------------+------------------------+----------------------+ -| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | -| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | -| | | MIG M. | -|=========================================+========================+======================| -| 0 NVIDIA H100 80GB HBM3 On | 00000000:53:00.0 Off | 0 | -| N/A 29C P0 104W / 700W | 1542MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 1 NVIDIA H100 80GB HBM3 On | 00000000:64:00.0 Off | 0 | -| N/A 30C P0 102W / 700W | 2175MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 2 NVIDIA H100 80GB HBM3 On | 00000000:75:00.0 Off | 0 | -| N/A 30C P0 104W / 700W | 2165MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 3 NVIDIA H100 80GB HBM3 On | 00000000:86:00.0 Off | 0 | -| N/A 32C P0 104W / 700W | 1562MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 4 NVIDIA H100 80GB HBM3 On | 00000000:97:00.0 Off | 0 | -| N/A 32C P0 102W / 700W | 2279MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 5 NVIDIA H100 80GB HBM3 On | 00000000:A8:00.0 Off | 0 | -| N/A 29C P0 103W / 700W | 2181MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 6 NVIDIA H100 80GB HBM3 On | 00000000:B9:00.0 Off | 0 | -| N/A 31C P0 103W / 700W | 2101MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 7 NVIDIA H100 80GB HBM3 On | 00000000:CA:00.0 Off | 0 | -| N/A 30C P0 102W / 700W | 1572MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ - -+-----------------------------------------------------------------------------------------+ -| Processes: | -| GPU GI CI PID Type Process name GPU Memory | -| ID ID Usage | -|=========================================================================================| -| 0 N/A N/A 195729 C .../envs/fi-bench/bin/python3.12 1516MiB | -| 1 N/A N/A 192400 C /opt/venv/bin/python 522MiB | -| 1 N/A N/A 195730 C .../envs/fi-bench/bin/python3.12 1516MiB | -| 2 N/A N/A 135953 C /opt/venv/bin/python 616MiB | -| 2 N/A N/A 195731 C .../envs/fi-bench/bin/python3.12 1516MiB | -| 3 N/A N/A 195732 C .../envs/fi-bench/bin/python3.12 1516MiB | -| 4 N/A N/A 129526 C /opt/venv/bin/python 630MiB | -| 4 N/A N/A 195733 C .../envs/fi-bench/bin/python3.12 1516MiB | -| 5 N/A N/A 122421 C /opt/venv/bin/python 616MiB | -| 5 N/A N/A 195734 C .../envs/fi-bench/bin/python3.12 1516MiB | -| 6 N/A N/A 184508 C /opt/venv/bin/python 522MiB | -| 6 N/A N/A 195735 C .../envs/fi-bench/bin/python3.12 1516MiB | -| 7 N/A N/A 195736 C .../envs/fi-bench/bin/python3.12 1516MiB | -+-----------------------------------------------------------------------------------------+ - -==================================================================================================== -val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model -train_loader:dataset:fineweb10B_sp1024 train_shards:80 -val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 -mlp_hidden:1344 -model_params:22174288 -world_size:8 grad_accum_steps:1 -sdp_backends:cudnn=False flash=True mem_efficient=False math=False -attention_mode:gqa num_heads:8 num_kv_heads:4 -tie_embeddings:True embed_lr:0.04 head_lr:0.0 matrix_lr:0.02 scalar_lr:0.02 -train_batch_tokens:524288 train_seq_len:2048 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 -seed:3 -warmup_step:1/20 -warmup_step:2/20 -warmup_step:3/20 -warmup_step:4/20 -warmup_step:5/20 -warmup_step:6/20 -warmup_step:7/20 -warmup_step:8/20 -warmup_step:9/20 -warmup_step:10/20 -warmup_step:11/20 -warmup_step:12/20 -warmup_step:13/20 -warmup_step:14/20 -warmup_step:15/20 -warmup_step:16/20 -warmup_step:17/20 -warmup_step:18/20 -warmup_step:19/20 -warmup_step:20/20 -qat_enabled:True -step:1/20000 train_loss:6.9336 train_time:11500ms step_avg:11500.14ms -step:2/20000 train_loss:15.2595 train_time:11550ms step_avg:5775.05ms -step:3/20000 train_loss:15.1992 train_time:11617ms step_avg:3872.43ms -step:4/20000 train_loss:15.0556 train_time:11685ms step_avg:2921.22ms -step:5/20000 train_loss:14.6581 train_time:11753ms step_avg:2350.62ms -step:6/20000 train_loss:14.3336 train_time:11821ms step_avg:1970.19ms -step:7/20000 train_loss:13.5810 train_time:11889ms step_avg:1698.44ms -step:8/20000 train_loss:12.7839 train_time:11957ms step_avg:1494.62ms -step:9/20000 train_loss:11.8139 train_time:12025ms step_avg:1336.10ms -step:10/20000 train_loss:10.6918 train_time:12093ms step_avg:1209.28ms -step:200/20000 train_loss:2.8507 train_time:25150ms step_avg:125.75ms -step:400/20000 train_loss:2.2858 train_time:38880ms step_avg:97.20ms -step:600/20000 train_loss:2.4788 train_time:52613ms step_avg:87.69ms -step:800/20000 train_loss:2.2261 train_time:66322ms step_avg:82.90ms -step:1000/20000 train_loss:2.3220 train_time:80042ms step_avg:80.04ms -step:1200/20000 train_loss:2.3437 train_time:93782ms step_avg:78.15ms -step:1400/20000 train_loss:2.3756 train_time:107518ms step_avg:76.80ms -step:1600/20000 train_loss:2.0471 train_time:121257ms step_avg:75.79ms -step:1800/20000 train_loss:2.1595 train_time:134981ms step_avg:74.99ms -step:2000/20000 train_loss:2.1927 train_time:148710ms step_avg:74.35ms -step:2200/20000 train_loss:2.0077 train_time:162428ms step_avg:73.83ms -step:2400/20000 train_loss:2.1383 train_time:176141ms step_avg:73.39ms -step:2600/20000 train_loss:2.3492 train_time:189871ms step_avg:73.03ms -step:2800/20000 train_loss:2.1620 train_time:203609ms step_avg:72.72ms -step:3000/20000 train_loss:2.1581 train_time:217357ms step_avg:72.45ms -step:3200/20000 train_loss:2.1164 train_time:231090ms step_avg:72.22ms -step:3400/20000 train_loss:2.0825 train_time:244837ms step_avg:72.01ms -step:3600/20000 train_loss:2.0303 train_time:258566ms step_avg:71.82ms -step:3800/20000 train_loss:2.1319 train_time:272302ms step_avg:71.66ms -step:4000/20000 train_loss:2.0956 train_time:286063ms step_avg:71.52ms -step:4200/20000 train_loss:2.0911 train_time:299944ms step_avg:71.42ms -step:4400/20000 train_loss:2.0205 train_time:313684ms step_avg:71.29ms -step:4600/20000 train_loss:1.8936 train_time:327447ms step_avg:71.18ms -step:4800/20000 train_loss:2.1738 train_time:341199ms step_avg:71.08ms -step:5000/20000 train_loss:1.9280 train_time:354978ms step_avg:71.00ms -step:5200/20000 train_loss:2.0886 train_time:369244ms step_avg:71.01ms -step:5400/20000 train_loss:2.0986 train_time:382976ms step_avg:70.92ms -step:5600/20000 train_loss:2.0865 train_time:396706ms step_avg:70.84ms -step:5800/20000 train_loss:2.0414 train_time:410432ms step_avg:70.76ms -step:6000/20000 train_loss:2.1160 train_time:424156ms step_avg:70.69ms -step:6200/20000 train_loss:1.9865 train_time:437880ms step_avg:70.63ms -step:6400/20000 train_loss:2.0587 train_time:451607ms step_avg:70.56ms -step:6600/20000 train_loss:2.0143 train_time:465341ms step_avg:70.51ms -step:6800/20000 train_loss:2.0779 train_time:479087ms step_avg:70.45ms -step:7000/20000 train_loss:2.1154 train_time:492821ms step_avg:70.40ms -step:7200/20000 train_loss:2.0828 train_time:506533ms step_avg:70.35ms -step:7400/20000 train_loss:2.0102 train_time:520253ms step_avg:70.30ms -step:7600/20000 train_loss:1.8855 train_time:533980ms step_avg:70.26ms -step:7800/20000 train_loss:2.0290 train_time:547697ms step_avg:70.22ms -step:8000/20000 train_loss:1.9976 train_time:561412ms step_avg:70.18ms -step:8200/20000 train_loss:2.0731 train_time:575145ms step_avg:70.14ms -step:8400/20000 train_loss:2.0096 train_time:589014ms step_avg:70.12ms -step:8562/20000 val_loss:1.9924 val_bpb:1.1800 train_time:600036ms step_avg:70.08ms -stopping_early: wallclock_cap train_time:600036ms step:8562/20000 -peak memory allocated: 12108 MiB reserved: 12414 MiB -Serialized model: 87686115 bytes -Code size: 56005 bytes -Total submission size: 87742120 bytes -Serialized model int6+zstd-22: 15463854 bytes (payload:22835776 raw_torch:22885563 payload_ratio:3.84x) -Total submission size int6+zstd-22: 15519859 bytes -final_int8_zlib_roundtrip val_loss:1.9918 val_bpb:1.1797 eval_time:10291ms -final_int8_zlib_roundtrip_exact val_loss:1.99183056 val_bpb:1.17967425 -final_sliding_window val_loss:1.9563 val_bpb:1.1586 stride:64 eval_seq_len:2048 eval_time:356614ms -final_sliding_window_exact val_loss:1.95627871 val_bpb:1.15861696 diff --git a/records/track_10min_16mb/2026-03-19_Seq2048_FP16Emb_TunedLR/train_seed42.log b/records/track_10min_16mb/2026-03-19_Seq2048_FP16Emb_TunedLR/train_seed42.log deleted file mode 100644 index 349ab1d639..0000000000 --- a/records/track_10min_16mb/2026-03-19_Seq2048_FP16Emb_TunedLR/train_seed42.log +++ /dev/null @@ -1,1475 +0,0 @@ -""" -The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. - -Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. -""" - -from __future__ import annotations - -import copy -import glob -import io -import math -import os -import random -import subprocess -import sys -import time -import uuid -import zlib -import zstandard as zstd -from pathlib import Path - -import numpy as np -import sentencepiece as spm -import torch -import torch.distributed as dist -import torch.nn.functional as F -from torch import Tensor, nn -from torch.nn.parallel import DistributedDataParallel as DDP - -# ----------------------------- -# HYPERPARAMETERS -# ----------------------------- -# Default Simple Baseline run: -# - 9 transformer blocks at width 512 -# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion -# - vocab size 1024, sequence length 1024, tied embeddings -# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap - -class Hyperparameters: - # Data paths are shard globs produced by the existing preprocessing pipeline. - data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") - train_files = os.path.join(data_path, "fineweb_train_*.bin") - val_files = os.path.join(data_path, "fineweb_val_*.bin") - tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") - run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) - seed = int(os.environ.get("SEED", 1337)) - - # Validation cadence and batch size. Validation always uses the full fineweb_val split. - val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) - val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) - train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) - # Sliding window eval: stride=0 means standard (non-overlapping) eval. - eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) - eval_seq_len = int(os.environ.get("EVAL_SEQ_LEN", 0)) # 0 = use train_seq_len - - # Training length. - iterations = int(os.environ.get("ITERATIONS", 20000)) - warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3600)) - warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) - train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) - train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) - max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) - qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) - - # Model shape. - vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) - num_layers = int(os.environ.get("NUM_LAYERS", 10)) - num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) - model_dim = int(os.environ.get("MODEL_DIM", 512)) - num_heads = int(os.environ.get("NUM_HEADS", 8)) - mlp_mult = int(os.environ.get("MLP_MULT", 2)) - mlp_hidden = int(os.environ.get("MLP_HIDDEN", 1536)) # 3x model_dim, enabled by int6+zstd compression - tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) - rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) - logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) - - # Optimizer hyperparameters. - embed_lr = float(os.environ.get("EMBED_LR", 0.6)) - head_lr = float(os.environ.get("HEAD_LR", 0.008)) - tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.04)) - tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) - matrix_lr = float(os.environ.get("MATRIX_LR", 0.02)) - scalar_lr = float(os.environ.get("SCALAR_LR", 0.02)) - muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) - muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) - muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) - muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) - beta1 = float(os.environ.get("BETA1", 0.9)) - beta2 = float(os.environ.get("BETA2", 0.95)) - adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) - grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) - -# ----------------------------- -# MUON OPTIMIZER -# ----------------------------- -# -# As borrowed from modded-nanogpt -# Background on Muon: https://kellerjordan.github.io/posts/muon/ - -def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: - # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. - # Muon uses this to normalize matrix-shaped gradients before applying them. - a, b, c = (3.4445, -4.7750, 2.0315) - X = G.bfloat16() - X /= X.norm() + eps - transposed = G.size(0) > G.size(1) - if transposed: - X = X.T - for _ in range(steps): - A = X @ X.T - B = b * A + c * A @ A - X = a * X + B @ X - return X.T if transposed else X - - -class Muon(torch.optim.Optimizer): - def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): - super().__init__( - params, - dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), - ) - - @torch.no_grad() - def step(self, closure=None): - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - distributed = dist.is_available() and dist.is_initialized() - world_size = dist.get_world_size() if distributed else 1 - rank = dist.get_rank() if distributed else 0 - - for group in self.param_groups: - params = group["params"] - if not params: - continue - lr = group["lr"] - momentum = group["momentum"] - backend_steps = group["backend_steps"] - nesterov = group["nesterov"] - - total_params = sum(int(p.numel()) for p in params) - updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) - - curr = 0 - for i, p in enumerate(params): - if i % world_size == rank and p.grad is not None: - g = p.grad - state = self.state[p] - if "momentum_buffer" not in state: - state["momentum_buffer"] = torch.zeros_like(g) - buf = state["momentum_buffer"] - buf.mul_(momentum).add_(g) - if nesterov: - g = g.add(buf, alpha=momentum) - g = zeropower_via_newtonschulz5(g, steps=backend_steps) - # Scale correction from Muon reference implementations. - g *= max(1, g.size(0) / g.size(1)) ** 0.5 - updates_flat[curr : curr + p.numel()] = g.reshape(-1) - curr += p.numel() - - if distributed: - dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) - - curr = 0 - for p in params: - g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) - p.add_(g, alpha=-lr) - curr += p.numel() - - return loss - - -# ----------------------------- -# TOKENIZER-AGNOSTIC EVALUATION SETUP -# ----------------------------- -# -# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. -# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. -# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. -# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. - -def build_sentencepiece_luts( - sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device -) -> tuple[Tensor, Tensor, Tensor]: - sp_vocab_size = int(sp.vocab_size()) - table_size = max(sp_vocab_size, vocab_size) - base_bytes_np = np.zeros((table_size,), dtype=np.int16) - has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) - is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) - for token_id in range(sp_vocab_size): - if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): - continue - is_boundary_token_np[token_id] = False - if sp.is_byte(token_id): - base_bytes_np[token_id] = 1 - continue - piece = sp.id_to_piece(token_id) - if piece.startswith("▁"): - has_leading_space_np[token_id] = True - piece = piece[1:] - base_bytes_np[token_id] = len(piece.encode("utf-8")) - return ( - torch.tensor(base_bytes_np, dtype=torch.int16, device=device), - torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), - torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), - ) - - -def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: - files = [Path(p) for p in sorted(glob.glob(pattern))] - if not files: - raise FileNotFoundError(f"No files found for pattern: {pattern}") - # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. - tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() - usable = ((tokens.numel() - 1) // seq_len) * seq_len - if usable <= 0: - raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") - return tokens[: usable + 1] - - -def eval_val( - args: Hyperparameters, - model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - grad_accum_steps: int, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, -) -> tuple[float, float]: - # Validation computes two metrics: - # - val_loss: token cross-entropy (natural log) - # - val_bpb: tokenizer-agnostic compression metric used by the challenge - local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) - if local_batch_tokens < args.train_seq_len: - raise ValueError( - "VAL_BATCH_SIZE must provide at least one sequence per rank; " - f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " - f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" - ) - local_batch_seqs = local_batch_tokens // args.train_seq_len - total_seqs = (val_tokens.numel() - 1) // args.train_seq_len - seq_start = (total_seqs * rank) // world_size - seq_end = (total_seqs * (rank + 1)) // world_size - val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) - val_token_count = torch.zeros((), device=device, dtype=torch.float64) - val_byte_count = torch.zeros((), device=device, dtype=torch.float64) - - model.eval() - with torch.inference_mode(): - for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): - batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) - raw_start = batch_seq_start * args.train_seq_len - raw_end = batch_seq_end * args.train_seq_len + 1 - local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) - x = local[:-1].reshape(-1, args.train_seq_len) - y = local[1:].reshape(-1, args.train_seq_len) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - batch_loss = model(x, y).detach() - batch_token_count = float(y.numel()) - val_loss_sum += batch_loss.to(torch.float64) * batch_token_count - val_token_count += batch_token_count - prev_ids = x.reshape(-1) - tgt_ids = y.reshape(-1) - token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) - token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) - val_byte_count += token_bytes.to(torch.float64).sum() - - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) - - val_loss = val_loss_sum / val_token_count - bits_per_token = val_loss.item() / math.log(2.0) - tokens_per_byte = val_token_count.item() / val_byte_count.item() - model.train() - return float(val_loss.item()), float(bits_per_token * tokens_per_byte) - - -def eval_val_sliding( - base_model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, - eval_seq_len: int = 2048, - stride: int = 64, - batch_size: int = 4, -) -> tuple[float, float]: - """Sliding window eval: each token scored with max context.""" - N = val_tokens.numel() - # Distribute windows across ranks - window_starts: list[int] = [] - pos = 0 - while pos + eval_seq_len < N: - window_starts.append(pos) - if pos == 0: - pos = eval_seq_len - stride # after first full window - else: - pos += stride - - # Partition windows across ranks - per_rank = len(window_starts) // max(world_size, 1) - rank_start = rank * per_rank - rank_end = (rank + 1) * per_rank if rank < world_size - 1 else len(window_starts) - my_windows = window_starts[rank_start:rank_end] - - val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) - val_token_count = torch.zeros((), device=device, dtype=torch.float64) - val_byte_count = torch.zeros((), device=device, dtype=torch.float64) - - base_model.eval() - with torch.inference_mode(): - for batch_start in range(0, len(my_windows), batch_size): - batch_windows = my_windows[batch_start:batch_start + batch_size] - xs, ys = [], [] - for w in batch_windows: - xs.append(val_tokens[w:w + eval_seq_len]) - ys.append(val_tokens[w + 1:w + eval_seq_len + 1]) - x = torch.stack(xs).to(device=device, dtype=torch.int64) - y = torch.stack(ys).to(device=device, dtype=torch.int64) - - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - per_token_loss = base_model.forward_per_token_loss(x, y) - - for i, w in enumerate(batch_windows): - if w == 0: - # First window: score all positions - score_start = 0 - else: - # Subsequent windows: score only last `stride` positions - score_start = eval_seq_len - stride - - losses = per_token_loss[i, score_start:] - tgt_ids = y[i, score_start:] - prev_ids = x[i, score_start:] - n_scored = losses.numel() - - val_loss_sum += losses.to(torch.float64).sum() - val_token_count += n_scored - token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) - token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) - val_byte_count += token_bytes.to(torch.float64).sum() - - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) - - val_loss = val_loss_sum / val_token_count - bits_per_token = val_loss.item() / math.log(2.0) - tokens_per_byte = val_token_count.item() / val_byte_count.item() - base_model.train() - return float(val_loss.item()), float(bits_per_token * tokens_per_byte) - - -# ----------------------------- -# POST-TRAINING QUANTIZATION -# ----------------------------- -# -# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. -# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. -# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. - -CONTROL_TENSOR_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "CONTROL_TENSOR_NAME_PATTERNS", - "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", - ",".join(CONTROL_TENSOR_NAME_PATTERNS), - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 -# Large tensors to keep as fp16 passthrough instead of int8 quantization. -# Tied embeddings serve dual duty (input + output head) and are disproportionately -# sensitive to quantization noise (~0.007 BPB gap). -FP16_PASSTHROUGH_PATTERNS = tuple( - p for p in os.environ.get("FP16_PASSTHROUGH_PATTERNS", "tok_emb").split(",") if p -) -INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 -INT8_PER_ROW_SCALE_DTYPE = torch.float16 -INT8_CLIP_PERCENTILE = 99.99984 -INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 -# Full int6 quantization: all 2D block weights use [-31,31] range (63 levels) -# stored in int8 container. zstd-22 compresses the low-entropy data much better. -USE_INT6 = bool(int(os.environ.get("USE_INT6", "1"))) -USE_ZSTD = bool(int(os.environ.get("USE_ZSTD", "1"))) -ZSTD_LEVEL = int(os.environ.get("ZSTD_LEVEL", 22)) - -def tensor_nbytes(t: Tensor) -> int: - return int(t.numel()) * int(t.element_size()) - -def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: - if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): - return t.float().contiguous() - if t.dtype in {torch.float32, torch.bfloat16}: - passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") - return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() - return t - -def quantize_float_tensor(t: Tensor, use_int6: bool = False) -> tuple[Tensor, Tensor]: - t32 = t.float() - qmax = 31 if use_int6 else 127 - if t32.ndim == 2: - # Matrices get one scale per row, which usually tracks output-channel - # ranges much better than a single tensor-wide scale. - clip_abs = ( - torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) - if t32.numel() - else torch.empty((t32.shape[0],), dtype=torch.float32) - ) - clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) - scale = (clip_abs / float(qmax)).clamp_min(1.0 / float(qmax)) - q = torch.clamp(torch.round(clipped / scale[:, None]), -qmax, qmax).to(torch.int8).contiguous() - return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() - - # Vectors / scalars use a simpler per-tensor scale. - clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 - scale = torch.tensor(clip_abs / float(qmax) if clip_abs > 0 else 1.0, dtype=torch.float32) - q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -qmax, qmax).to(torch.int8).contiguous() - return q, scale - -def quantize_state_dict_int8(state_dict: dict[str, Tensor]): - # Export format: - # - per-row int6/int8 for 2D float tensors (int6 if USE_INT6) - # - per-tensor int8 for other float tensors - # - exact passthrough for non-floats - # - passthrough for small float tensors, stored as fp16 to save bytes - quantized: dict[str, Tensor] = {} - scales: dict[str, Tensor] = {} - dtypes: dict[str, str] = {} - passthrough: dict[str, Tensor] = {} - passthrough_orig_dtypes: dict[str, str] = {} - qmeta: dict[str, dict[str, object]] = {} - stats = dict.fromkeys( - ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), - 0, - ) - - for name, tensor in state_dict.items(): - t = tensor.detach().to("cpu").contiguous() - stats["param_count"] += int(t.numel()) - stats["num_tensors"] += 1 - stats["baseline_tensor_bytes"] += tensor_nbytes(t) - - if not t.is_floating_point(): - stats["num_nonfloat_tensors"] += 1 - passthrough[name] = t - stats["int8_payload_bytes"] += tensor_nbytes(t) - continue - - # Small float tensors are cheap enough to keep directly. We still downcast - # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. - if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: - kept = keep_float_tensor(name, t, passthrough_orig_dtypes) - passthrough[name] = kept - stats["int8_payload_bytes"] += tensor_nbytes(kept) - continue - - # Large tensors matching FP16_PASSTHROUGH_PATTERNS are kept as fp16 - # instead of int8 (e.g. tied embeddings that are quant-sensitive). - if any(pattern in name for pattern in FP16_PASSTHROUGH_PATTERNS): - kept = keep_float_tensor(name, t, passthrough_orig_dtypes) - passthrough[name] = kept - stats["int8_payload_bytes"] += tensor_nbytes(kept) - continue - - stats["num_float_tensors"] += 1 - q, s = quantize_float_tensor(t, use_int6=USE_INT6) - if s.ndim > 0: - qmeta[name] = {"scheme": "per_row", "axis": 0} - quantized[name] = q - scales[name] = s - dtypes[name] = str(t.dtype).removeprefix("torch.") - stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) - - obj: dict[str, object] = { - "__quant_format__": "int8_clean_per_row_v1", - "quantized": quantized, - "scales": scales, - "dtypes": dtypes, - "passthrough": passthrough, - } - if qmeta: - obj["qmeta"] = qmeta - if passthrough_orig_dtypes: - obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes - return obj, stats - -def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: - out: dict[str, Tensor] = {} - qmeta = obj.get("qmeta", {}) - passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) - for name, q in obj["quantized"].items(): - dtype = getattr(torch, obj["dtypes"][name]) - s = obj["scales"][name] - if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: - s = s.to(dtype=torch.float32) - # Broadcast the saved row scale back across trailing dimensions. - out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() - else: - scale = float(s.item()) - out[name] = (q.float() * scale).to(dtype=dtype).contiguous() - for name, t in obj["passthrough"].items(): - # Restore small tensors, undoing the temporary fp16 storage cast if needed. - out_t = t.detach().to("cpu").contiguous() - orig_dtype = passthrough_orig_dtypes.get(name) - if isinstance(orig_dtype, str): - out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() - out[name] = out_t - return out - - -# ----------------------------- -# DATA LOADING -# ----------------------------- - -def load_data_shard(file: Path) -> Tensor: - header_bytes = 256 * np.dtype(" None: - self.file_idx = (self.file_idx + 1) % len(self.files) - self.tokens = load_data_shard(self.files[self.file_idx]) - self.pos = 0 - - def take(self, n: int) -> Tensor: - chunks: list[Tensor] = [] - remaining = n - while remaining > 0: - avail = self.tokens.numel() - self.pos - if avail <= 0: - self._advance_file() - continue - k = min(remaining, avail) - chunks.append(self.tokens[self.pos : self.pos + k]) - self.pos += k - remaining -= k - return chunks[0] if len(chunks) == 1 else torch.cat(chunks) - - -class DistributedTokenLoader: - # Each call consumes a contiguous chunk from the shared token stream, then slices out - # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. - def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): - self.rank = rank - self.world_size = world_size - self.device = device - self.stream = TokenStream(pattern) - - def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: - local_tokens = global_tokens // (self.world_size * grad_accum_steps) - per_rank_span = local_tokens + 1 - chunk = self.stream.take(per_rank_span * self.world_size) - start = self.rank * per_rank_span - local = chunk[start : start + per_rank_span].to(dtype=torch.int64) - x = local[:-1].reshape(-1, seq_len) - y = local[1:].reshape(-1, seq_len) - return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) - -# ----------------------------- -# TRANSFORMER MODULES -# ----------------------------- - -class RMSNorm(nn.Module): - def __init__(self, eps: float | None = None): - super().__init__() - self.eps = eps - - def forward(self, x: Tensor) -> Tensor: - return F.rms_norm(x, (x.size(-1),), eps=self.eps) - - -def _fake_quantize_int6(w: Tensor) -> Tensor: - """STE fake int6 quantization: forward uses quantized values, backward passes through.""" - # Use amax instead of quantile — compile-friendly and O(n) per row - clip_abs = w.abs().amax(dim=1).clamp_min(1.0 / 31.0) - scale = clip_abs / 31.0 - q = (w / scale[:, None]).round().clamp(-31, 31) - dequant = q * scale[:, None] - return w + (dequant - w).detach() # STE: forward=quantized, backward=identity - -# Global flag to enable/disable QAT (disabled during eval) -_QAT_ENABLED = False - -class CastedLinear(nn.Linear): - # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. - def forward(self, x: Tensor) -> Tensor: - w = self.weight.to(x.dtype) - if _QAT_ENABLED and self.weight.ndim == 2: - w = _fake_quantize_int6(w) - bias = self.bias.to(x.dtype) if self.bias is not None else None - return F.linear(x, w, bias) - - -def restore_low_dim_params_to_fp32(module: nn.Module) -> None: - # Keep small/control parameters in fp32 even when the model body runs in bf16. - with torch.no_grad(): - for name, param in module.named_parameters(): - if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: - param.data = param.data.float() - - -class Rotary(nn.Module): - # Caches cos/sin tables per sequence length on the current device. - def __init__(self, dim: int, base: float = 10000.0): - super().__init__() - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self._seq_len_cached = 0 - self._cos_cached: Tensor | None = None - self._sin_cached: Tensor | None = None - - def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: - if ( - self._cos_cached is None - or self._sin_cached is None - or self._seq_len_cached != seq_len - or self._cos_cached.device != device - ): - t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) - freqs = torch.outer(t, self.inv_freq.to(device)) - self._cos_cached = freqs.cos()[None, None, :, :] - self._sin_cached = freqs.sin()[None, None, :, :] - self._seq_len_cached = seq_len - return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) - - -def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: - half = x.size(-1) // 2 - x1, x2 = x[..., :half], x[..., half:] - return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) - - -class CausalSelfAttention(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - rope_base: float, - qk_gain_init: float, - ): - super().__init__() - if dim % num_heads != 0: - raise ValueError("model_dim must be divisible by num_heads") - if num_heads % num_kv_heads != 0: - raise ValueError("num_heads must be divisible by num_kv_heads") - self.num_heads = num_heads - self.num_kv_heads = num_kv_heads - self.head_dim = dim // num_heads - if self.head_dim % 2 != 0: - raise ValueError("head_dim must be even for RoPE") - kv_dim = self.num_kv_heads * self.head_dim - self.c_q = CastedLinear(dim, dim, bias=False) - self.c_k = CastedLinear(dim, kv_dim, bias=False) - self.c_v = CastedLinear(dim, kv_dim, bias=False) - self.proj = CastedLinear(dim, dim, bias=False) - self.proj._zero_init = True - self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) - self.rotary = Rotary(self.head_dim, base=rope_base) - - def forward(self, x: Tensor) -> Tensor: - bsz, seqlen, dim = x.shape - q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) - k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) - v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) - q = F.rms_norm(q, (q.size(-1),)) - k = F.rms_norm(k, (k.size(-1),)) - cos, sin = self.rotary(seqlen, x.device, q.dtype) - q = apply_rotary_emb(q, cos, sin) - k = apply_rotary_emb(k, cos, sin) - q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] - y = F.scaled_dot_product_attention( - q, - k, - v, - attn_mask=None, - is_causal=True, - enable_gqa=(self.num_kv_heads != self.num_heads), - ) - y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) - return self.proj(y) - - -class MLP(nn.Module): - # relu^2 MLP from the original modded-nanogpt setup - def __init__(self, dim: int, hidden: int): - super().__init__() - self.fc = CastedLinear(dim, hidden, bias=False) - self.proj = CastedLinear(hidden, dim, bias=False) - self.proj._zero_init = True - - def forward(self, x: Tensor) -> Tensor: - x = torch.relu(self.fc(x)) - return self.proj(x.square()) - - -class Block(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - mlp_hidden: int, - rope_base: float, - qk_gain_init: float, - ): - super().__init__() - self.attn_norm = RMSNorm() - self.mlp_norm = RMSNorm() - self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) - self.mlp = MLP(dim, mlp_hidden) - self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) - - def forward(self, x: Tensor, x0: Tensor) -> Tensor: - mix = self.resid_mix.to(dtype=x.dtype) - x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 - attn_out = self.attn(self.attn_norm(x)) - x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out - x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) - return x - - -class GPT(nn.Module): - def __init__( - self, - vocab_size: int, - num_layers: int, - model_dim: int, - num_heads: int, - num_kv_heads: int, - mlp_hidden: int, - tie_embeddings: bool, - tied_embed_init_std: float, - logit_softcap: float, - rope_base: float, - qk_gain_init: float, - ): - super().__init__() - if logit_softcap <= 0.0: - raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") - self.tie_embeddings = tie_embeddings - self.tied_embed_init_std = tied_embed_init_std - self.logit_softcap = logit_softcap - self.tok_emb = nn.Embedding(vocab_size, model_dim) - self.num_encoder_layers = num_layers // 2 - self.num_decoder_layers = num_layers - self.num_encoder_layers - self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) - self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) - self.blocks = nn.ModuleList( - [ - Block( - model_dim, - num_heads, - num_kv_heads, - mlp_hidden, - rope_base, - qk_gain_init, - ) - for i in range(num_layers) - ] - ) - self.final_norm = RMSNorm() - self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) - if self.lm_head is not None: - self.lm_head._zero_init = True - self._init_weights() - - def _init_weights(self) -> None: - if self.tie_embeddings: - nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) - for module in self.modules(): - if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): - nn.init.zeros_(module.weight) - - def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: - x = self.tok_emb(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - x0 = x - skips: list[Tensor] = [] - - # First half stores skips; second half reuses them in reverse order. - for i in range(self.num_encoder_layers): - x = self.blocks[i](x, x0) - skips.append(x) - for i in range(self.num_decoder_layers): - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - x = self.blocks[self.num_encoder_layers + i](x, x0) - - x = self.final_norm(x).reshape(-1, x.size(-1)) - targets = target_ids.reshape(-1) - if self.tie_embeddings: - logits_proj = F.linear(x, self.tok_emb.weight) - else: - if self.lm_head is None: - raise RuntimeError("lm_head is required when tie_embeddings=False") - logits_proj = self.lm_head(x) - logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - return F.cross_entropy(logits.float(), targets, reduction="mean") - - def forward_per_token_loss(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: - """Return per-token losses shaped (batch, seq_len) for sliding window eval.""" - x = self.tok_emb(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - x0 = x - skips: list[Tensor] = [] - for i in range(self.num_encoder_layers): - x = self.blocks[i](x, x0) - skips.append(x) - for i in range(self.num_decoder_layers): - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - x = self.blocks[self.num_encoder_layers + i](x, x0) - x = self.final_norm(x).reshape(-1, x.size(-1)) - targets = target_ids.reshape(-1) - if self.tie_embeddings: - logits_proj = F.linear(x, self.tok_emb.weight) - else: - if self.lm_head is None: - raise RuntimeError("lm_head is required when tie_embeddings=False") - logits_proj = self.lm_head(x) - logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - per_token = F.cross_entropy(logits.float(), targets, reduction="none") - return per_token.view(input_ids.shape) - - -# ----------------------------- -# TRAINING -# ----------------------------- - -def main() -> None: - global zeropower_via_newtonschulz5 - - code = Path(__file__).read_text(encoding="utf-8") - args = Hyperparameters() - zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) - - # ----------------------------- - # DISTRIBUTED + CUDA SETUP - # ----------------------------- - - distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ - rank = int(os.environ.get("RANK", "0")) - world_size = int(os.environ.get("WORLD_SIZE", "1")) - local_rank = int(os.environ.get("LOCAL_RANK", "0")) - if world_size <= 0: - raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") - if 8 % world_size != 0: - raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") - grad_accum_steps = 8 // world_size - grad_scale = 1.0 / grad_accum_steps - if not torch.cuda.is_available(): - raise RuntimeError("CUDA is required") - device = torch.device("cuda", local_rank) - torch.cuda.set_device(device) - if distributed: - dist.init_process_group(backend="nccl", device_id=device) - dist.barrier() - master_process = rank == 0 - - # Fast math knobs - torch.backends.cuda.matmul.allow_tf32 = True - torch.backends.cudnn.allow_tf32 = True - from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp - - enable_cudnn_sdp(False) - enable_flash_sdp(True) - enable_mem_efficient_sdp(False) - enable_math_sdp(False) - - logfile = None - if master_process: - os.makedirs("logs", exist_ok=True) - logfile = f"logs/{args.run_id}.txt" - print(logfile) - - def log0(msg: str, console: bool = True) -> None: - if not master_process: - return - if console: - print(msg) - if logfile is not None: - with open(logfile, "a", encoding="utf-8") as f: - print(msg, file=f) - - log0(code, console=False) - log0("=" * 100, console=False) - log0(f"Running Python {sys.version}", console=False) - log0(f"Running PyTorch {torch.__version__}", console=False) - log0( - subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, - console=False, - ) - log0("=" * 100, console=False) - - # ----------------------------- - # TOKENIZER + VALIDATION METRIC SETUP - # ----------------------------- - - random.seed(args.seed) - np.random.seed(args.seed) - torch.manual_seed(args.seed) - torch.cuda.manual_seed_all(args.seed) - - if not args.tokenizer_path.endswith(".model"): - raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") - sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) - if int(sp.vocab_size()) != args.vocab_size: - raise ValueError( - f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" - ) - dataset_dir = Path(args.data_path).resolve() - actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) - val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) - base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( - sp, args.vocab_size, device - ) - log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") - log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") - log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") - - # ----------------------------- - # MODEL + OPTIMIZER SETUP - # ----------------------------- - - mlp_hidden = args.mlp_hidden if args.mlp_hidden > 0 else args.mlp_mult * args.model_dim - log0(f"mlp_hidden:{mlp_hidden}") - - base_model = GPT( - vocab_size=args.vocab_size, - num_layers=args.num_layers, - model_dim=args.model_dim, - num_heads=args.num_heads, - num_kv_heads=args.num_kv_heads, - mlp_hidden=mlp_hidden, - tie_embeddings=args.tie_embeddings, - tied_embed_init_std=args.tied_embed_init_std, - logit_softcap=args.logit_softcap, - rope_base=args.rope_base, - qk_gain_init=args.qk_gain_init, - ).to(device).bfloat16() - for module in base_model.modules(): - if isinstance(module, CastedLinear): - module.float() - restore_low_dim_params_to_fp32(base_model) - compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) - model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model - - # Optimizer split: - # - token embedding (Adam) uses EMBED_LR - # - untied lm_head (Adam) uses HEAD_LR - # - matrix params in transformer blocks use MATRIX_LR via Muon - # - vectors/scalars use SCALAR_LR via Adam - block_named_params = list(base_model.blocks.named_parameters()) - matrix_params = [ - p - for name, p in block_named_params - if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - scalar_params = [ - p - for name, p in block_named_params - if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - if base_model.skip_weights.numel() > 0: - scalar_params.append(base_model.skip_weights) - token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr - optimizer_tok = torch.optim.Adam( - [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizer_muon = Muon( - matrix_params, - lr=args.matrix_lr, - momentum=args.muon_momentum, - backend_steps=args.muon_backend_steps, - ) - for group in optimizer_muon.param_groups: - group["base_lr"] = args.matrix_lr - optimizer_scalar = torch.optim.Adam( - [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] - if base_model.lm_head is not None: - optimizer_head = torch.optim.Adam( - [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizers.insert(1, optimizer_head) - - n_params = sum(p.numel() for p in base_model.parameters()) - log0(f"model_params:{n_params}") - log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") - log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") - log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") - log0( - f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " - f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " - f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" - ) - log0( - f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " - f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " - f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" - ) - log0(f"seed:{args.seed}") - - # ----------------------------- - # DATA LOADER & MODEL WARMUP - # ----------------------------- - - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - def zero_grad_all() -> None: - for opt in optimizers: - opt.zero_grad(set_to_none=True) - - max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None - - def lr_mul(step: int, elapsed_ms: float) -> float: - if args.warmdown_iters <= 0: - return 1.0 - if max_wallclock_ms is None: - warmdown_start = max(args.iterations - args.warmdown_iters, 0) - return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 - step_ms = elapsed_ms / max(step, 1) - warmdown_ms = args.warmdown_iters * step_ms - remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) - return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 - - # Warmup primes the compiled forward/backward/optimizer paths, then we restore the - # initial weights/optimizer state so measured training starts from the true init. - if args.warmup_steps > 0: - initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} - initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] - model.train() - for warmup_step in range(args.warmup_steps): - zero_grad_all() - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - warmup_loss = model(x, y) - (warmup_loss * grad_scale).backward() - for opt in optimizers: - opt.step() - zero_grad_all() - if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: - log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") - base_model.load_state_dict(initial_model_state, strict=True) - for opt, state in zip(optimizers, initial_optimizer_states, strict=True): - opt.load_state_dict(state) - zero_grad_all() - if distributed: - model.require_backward_grad_sync = True - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - # ----------------------------- - # MAIN TRAINING LOOP - # ----------------------------- - - global _QAT_ENABLED - _QAT_ENABLED = USE_INT6 # Enable QAT during training if using int6 - log0(f"qat_enabled:{_QAT_ENABLED}") - - training_time_ms = 0.0 - stop_after_step: int | None = None - torch.cuda.synchronize() - t0 = time.perf_counter() - - step = 0 - while True: - last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) - - should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) - if should_validate: - torch.cuda.synchronize() - training_time_ms += 1000.0 * (time.perf_counter() - t0) - val_loss, val_bpb = eval_val( - args, - model, - rank, - world_size, - device, - grad_accum_steps, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, - ) - log0( - f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " - f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" - ) - torch.cuda.synchronize() - t0 = time.perf_counter() - - if last_step: - if stop_after_step is not None and step < args.iterations: - log0( - f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " - f"step:{step}/{args.iterations}" - ) - break - - elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - scale = lr_mul(step, elapsed_ms) - zero_grad_all() - train_loss = torch.zeros((), device=device) - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - loss = model(x, y) - train_loss += loss.detach() - (loss * grad_scale).backward() - train_loss /= grad_accum_steps - - frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 - muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum - for group in optimizer_muon.param_groups: - group["momentum"] = muon_momentum - - for opt in optimizers: - for group in opt.param_groups: - group["lr"] = group["base_lr"] * scale - - if args.grad_clip_norm > 0: - torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) - for opt in optimizers: - opt.step() - zero_grad_all() - - step += 1 - approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - should_log_train = ( - args.train_log_every > 0 - and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) - ) - if should_log_train: - log0( - f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " - f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" - ) - - # Needed to sync whether we've reached the wallclock cap. - reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms - if distributed and max_wallclock_ms is not None: - reached_cap_tensor = torch.tensor(int(reached_cap), device=device) - dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) - reached_cap = bool(reached_cap_tensor.item()) - if stop_after_step is None and reached_cap: - stop_after_step = step - - _QAT_ENABLED = False # Disable QAT for eval/serialization - log0( - f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " - f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" - ) - - # ----------------------------- - # SERIALIZATION + ROUNDTRIP VALIDATION - # ----------------------------- - # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce - # the compressed int8+zlib artifact and validate the round-tripped weights. - - if master_process: - torch.save(base_model.state_dict(), "final_model.pt") - model_bytes = os.path.getsize("final_model.pt") - code_bytes = len(code.encode("utf-8")) - log0(f"Serialized model: {model_bytes} bytes") - log0(f"Code size: {code_bytes} bytes") - log0(f"Total submission size: {model_bytes + code_bytes} bytes") - - quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) - quant_buf = io.BytesIO() - torch.save(quant_obj, quant_buf) - quant_raw = quant_buf.getvalue() - if USE_ZSTD: - cctx = zstd.ZstdCompressor(level=ZSTD_LEVEL) - quant_blob = cctx.compress(quant_raw) - else: - quant_blob = zlib.compress(quant_raw, level=9) - quant_raw_bytes = len(quant_raw) - compress_label = f"zstd-{ZSTD_LEVEL}" if USE_ZSTD else "zlib" - if master_process: - with open("final_model.int8.ptz", "wb") as f: - f.write(quant_blob) - quant_file_bytes = os.path.getsize("final_model.int8.ptz") - code_bytes = len(code.encode("utf-8")) - ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) - log0( - f"Serialized model int6+{compress_label}: {quant_file_bytes} bytes " - f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" - ) - log0(f"Total submission size int6+{compress_label}: {quant_file_bytes + code_bytes} bytes") - - if distributed: - dist.barrier() - with open("final_model.int8.ptz", "rb") as f: - quant_blob_disk = f.read() - if USE_ZSTD: - dctx = zstd.ZstdDecompressor() - quant_state = torch.load(io.BytesIO(dctx.decompress(quant_blob_disk)), map_location="cpu") - else: - quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") - base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) - torch.cuda.synchronize() - t_qeval = time.perf_counter() - q_val_loss, q_val_bpb = eval_val( - args, - model, - rank, - world_size, - device, - grad_accum_steps, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, - ) - torch.cuda.synchronize() - log0( - f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " - f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" - ) - log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") - - # Sliding window eval for final score (if stride > 0) - if args.eval_stride > 0: - eval_sl = args.eval_seq_len if args.eval_seq_len > 0 else args.train_seq_len - torch.cuda.synchronize() - t_slide = time.perf_counter() - sw_val_loss, sw_val_bpb = eval_val_sliding( - base_model, - rank, - world_size, - device, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, - eval_seq_len=eval_sl, - stride=args.eval_stride, - ) - torch.cuda.synchronize() - log0( - f"final_sliding_window val_loss:{sw_val_loss:.4f} val_bpb:{sw_val_bpb:.4f} " - f"stride:{args.eval_stride} eval_seq_len:{eval_sl} " - f"eval_time:{1000.0 * (time.perf_counter() - t_slide):.0f}ms" - ) - log0(f"final_sliding_window_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") - - if distributed: - dist.destroy_process_group() - - -if __name__ == "__main__": - main() - -==================================================================================================== -Running Python 3.12.12 | packaged by Anaconda, Inc. | (main, Oct 21 2025, 20:16:04) [GCC 11.2.0] -Running PyTorch 2.8.0+cu128 -Thu Mar 19 19:18:15 2026 -+-----------------------------------------------------------------------------------------+ -| NVIDIA-SMI 590.48.01 Driver Version: 590.48.01 CUDA Version: 13.1 | -+-----------------------------------------+------------------------+----------------------+ -| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | -| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | -| | | MIG M. | -|=========================================+========================+======================| -| 0 NVIDIA H100 80GB HBM3 On | 00000000:53:00.0 Off | 0 | -| N/A 31C P0 104W / 700W | 1542MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 1 NVIDIA H100 80GB HBM3 On | 00000000:64:00.0 Off | 0 | -| N/A 33C P0 104W / 700W | 1650MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 2 NVIDIA H100 80GB HBM3 On | 00000000:75:00.0 Off | 0 | -| N/A 32C P0 104W / 700W | 2165MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 3 NVIDIA H100 80GB HBM3 On | 00000000:86:00.0 Off | 0 | -| N/A 36C P0 106W / 700W | 1562MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 4 NVIDIA H100 80GB HBM3 On | 00000000:97:00.0 Off | 0 | -| N/A 35C P0 104W / 700W | 2279MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 5 NVIDIA H100 80GB HBM3 On | 00000000:A8:00.0 Off | 0 | -| N/A 30C P0 103W / 700W | 2181MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 6 NVIDIA H100 80GB HBM3 On | 00000000:B9:00.0 Off | 0 | -| N/A 34C P0 104W / 700W | 5175MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 7 NVIDIA H100 80GB HBM3 On | 00000000:CA:00.0 Off | 0 | -| N/A 32C P0 103W / 700W | 2207MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ - -+-----------------------------------------------------------------------------------------+ -| Processes: | -| GPU GI CI PID Type Process name GPU Memory | -| ID ID Usage | -|=========================================================================================| -| 0 N/A N/A 184849 C .../envs/fi-bench/bin/python3.12 1516MiB | -| 1 N/A N/A 184850 C .../envs/fi-bench/bin/python3.12 1516MiB | -| 2 N/A N/A 135953 C /opt/venv/bin/python 616MiB | -| 2 N/A N/A 184851 C .../envs/fi-bench/bin/python3.12 1516MiB | -| 3 N/A N/A 184852 C .../envs/fi-bench/bin/python3.12 1516MiB | -| 4 N/A N/A 129526 C /opt/venv/bin/python 630MiB | -| 4 N/A N/A 184853 C .../envs/fi-bench/bin/python3.12 1516MiB | -| 5 N/A N/A 122421 C /opt/venv/bin/python 616MiB | -| 5 N/A N/A 184854 C .../envs/fi-bench/bin/python3.12 1516MiB | -| 6 N/A N/A 184508 C /opt/venv/bin/python 3594MiB | -| 6 N/A N/A 184855 C .../envs/fi-bench/bin/python3.12 1516MiB | -| 7 N/A N/A 131633 C /opt/venv/bin/python 630MiB | -| 7 N/A N/A 184856 C .../envs/fi-bench/bin/python3.12 1516MiB | -+-----------------------------------------------------------------------------------------+ - -==================================================================================================== -val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model -train_loader:dataset:fineweb10B_sp1024 train_shards:80 -val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 -mlp_hidden:1344 -model_params:22174288 -world_size:8 grad_accum_steps:1 -sdp_backends:cudnn=False flash=True mem_efficient=False math=False -attention_mode:gqa num_heads:8 num_kv_heads:4 -tie_embeddings:True embed_lr:0.04 head_lr:0.0 matrix_lr:0.02 scalar_lr:0.02 -train_batch_tokens:524288 train_seq_len:2048 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 -seed:42 -warmup_step:1/20 -warmup_step:2/20 -warmup_step:3/20 -warmup_step:4/20 -warmup_step:5/20 -warmup_step:6/20 -warmup_step:7/20 -warmup_step:8/20 -warmup_step:9/20 -warmup_step:10/20 -warmup_step:11/20 -warmup_step:12/20 -warmup_step:13/20 -warmup_step:14/20 -warmup_step:15/20 -warmup_step:16/20 -warmup_step:17/20 -warmup_step:18/20 -warmup_step:19/20 -warmup_step:20/20 -qat_enabled:True -step:1/20000 train_loss:6.9351 train_time:11712ms step_avg:11712.12ms -step:2/20000 train_loss:15.3475 train_time:11761ms step_avg:5880.53ms -step:3/20000 train_loss:15.2938 train_time:11828ms step_avg:3942.80ms -step:4/20000 train_loss:15.1012 train_time:11896ms step_avg:2974.08ms -step:5/20000 train_loss:14.7546 train_time:11964ms step_avg:2392.89ms -step:6/20000 train_loss:14.4415 train_time:12033ms step_avg:2005.44ms -step:7/20000 train_loss:13.6915 train_time:12101ms step_avg:1728.78ms -step:8/20000 train_loss:12.9190 train_time:12170ms step_avg:1521.21ms -step:9/20000 train_loss:11.9571 train_time:12238ms step_avg:1359.74ms -step:10/20000 train_loss:10.8447 train_time:12305ms step_avg:1230.55ms -step:200/20000 train_loss:2.8531 train_time:25366ms step_avg:126.83ms -step:400/20000 train_loss:2.2860 train_time:39093ms step_avg:97.73ms -step:600/20000 train_loss:2.4757 train_time:52799ms step_avg:88.00ms -step:800/20000 train_loss:2.2231 train_time:66504ms step_avg:83.13ms -step:1000/20000 train_loss:2.3201 train_time:80209ms step_avg:80.21ms -step:1200/20000 train_loss:2.3436 train_time:93936ms step_avg:78.28ms -step:1400/20000 train_loss:2.3729 train_time:107671ms step_avg:76.91ms -step:1600/20000 train_loss:2.0468 train_time:121413ms step_avg:75.88ms -step:1800/20000 train_loss:2.1553 train_time:135151ms step_avg:75.08ms -step:2000/20000 train_loss:2.2011 train_time:148882ms step_avg:74.44ms -step:2200/20000 train_loss:2.0109 train_time:162617ms step_avg:73.92ms -step:2400/20000 train_loss:2.1412 train_time:176341ms step_avg:73.48ms -step:2600/20000 train_loss:2.3548 train_time:190071ms step_avg:73.10ms -step:2800/20000 train_loss:2.1670 train_time:203792ms step_avg:72.78ms -step:3000/20000 train_loss:2.1575 train_time:217511ms step_avg:72.50ms -step:3200/20000 train_loss:2.1165 train_time:231225ms step_avg:72.26ms -step:3400/20000 train_loss:2.0834 train_time:244940ms step_avg:72.04ms -step:3600/20000 train_loss:2.0278 train_time:258655ms step_avg:71.85ms -step:3800/20000 train_loss:2.1311 train_time:272367ms step_avg:71.68ms -step:4000/20000 train_loss:2.1007 train_time:286085ms step_avg:71.52ms -step:4200/20000 train_loss:2.0916 train_time:299952ms step_avg:71.42ms -step:4400/20000 train_loss:2.0264 train_time:313671ms step_avg:71.29ms -step:4600/20000 train_loss:1.8936 train_time:327387ms step_avg:71.17ms -step:4800/20000 train_loss:2.1812 train_time:341095ms step_avg:71.06ms -step:5000/20000 train_loss:1.9342 train_time:354842ms step_avg:70.97ms -step:5200/20000 train_loss:2.0896 train_time:368553ms step_avg:70.88ms -step:5400/20000 train_loss:2.1011 train_time:382286ms step_avg:70.79ms -step:5600/20000 train_loss:2.0844 train_time:396017ms step_avg:70.72ms -step:5800/20000 train_loss:2.0419 train_time:409733ms step_avg:70.64ms -step:6000/20000 train_loss:2.1222 train_time:423454ms step_avg:70.58ms -step:6200/20000 train_loss:1.9862 train_time:438681ms step_avg:70.75ms -step:6400/20000 train_loss:2.0657 train_time:452391ms step_avg:70.69ms -step:6600/20000 train_loss:2.0174 train_time:466115ms step_avg:70.62ms -step:6800/20000 train_loss:2.0753 train_time:479822ms step_avg:70.56ms -step:7000/20000 train_loss:2.1198 train_time:493541ms step_avg:70.51ms -step:7200/20000 train_loss:2.0827 train_time:507258ms step_avg:70.45ms -step:7400/20000 train_loss:2.0139 train_time:520984ms step_avg:70.40ms -step:7600/20000 train_loss:1.8885 train_time:534712ms step_avg:70.36ms -step:7800/20000 train_loss:2.0334 train_time:548439ms step_avg:70.31ms -step:8000/20000 train_loss:1.9978 train_time:562155ms step_avg:70.27ms -step:8200/20000 train_loss:2.0720 train_time:575879ms step_avg:70.23ms -step:8400/20000 train_loss:2.0147 train_time:589737ms step_avg:70.21ms -step:8552/20000 val_loss:1.9938 val_bpb:1.1808 train_time:600076ms step_avg:70.17ms -stopping_early: wallclock_cap train_time:600076ms step:8552/20000 -peak memory allocated: 12108 MiB reserved: 12414 MiB -Serialized model: 87686115 bytes -Code size: 56005 bytes -Total submission size: 87742120 bytes -Serialized model int6+zstd-22: 15342994 bytes (payload:22835776 raw_torch:22885563 payload_ratio:3.84x) -Total submission size int6+zstd-22: 15398999 bytes -final_int8_zlib_roundtrip val_loss:1.9938 val_bpb:1.1808 eval_time:10277ms -final_int8_zlib_roundtrip_exact val_loss:1.99381198 val_bpb:1.18084776 -final_sliding_window val_loss:1.9583 val_bpb:1.1598 stride:64 eval_seq_len:2048 eval_time:355951ms -final_sliding_window_exact val_loss:1.95825352 val_bpb:1.15978656 diff --git a/records/track_10min_16mb/2026-03-19_SlidingWindowEval/README.md b/records/track_10min_16mb/2026-03-19_SlidingWindowEval/README.md deleted file mode 100644 index e4c8ef6fd4..0000000000 --- a/records/track_10min_16mb/2026-03-19_SlidingWindowEval/README.md +++ /dev/null @@ -1,78 +0,0 @@ -This record implements sliding window evaluation, showing that eval strategies alone can provide significant improvements. - -**Note on `train_gpt.py`:** The included script contains some unused experimental code paths (QAT, looped architectures) that are **all disabled by default** and were not active during this run. Only the sliding window evaluation code (`eval_val_sliding`, `forward_logits`, `EVAL_STRIDE`, `EVAL_BATCH_SEQS`) is used. The command below shows the exact invocation. - -## Key Idea: Sliding Window Evaluation - -The baseline evaluates by chopping the validation set into non-overlapping 1024-token chunks. The problem is that the first token in each chunk has zero context. On average, each token gets ~512 tokens of context. - -Sliding window evaluation uses overlapping windows with a configurable stride. With `EVAL_STRIDE=64` and `TRAIN_SEQ_LEN=1024`, each window advances by 64 tokens, but only the rightmost 64 tokens (which have 960+ tokens of context) are scored. Every token in the validation set is scored exactly once, but with near-maximum context. - -## Results - -| Metric | Naive Baseline | This Submission | -|---|---|---| -| Pre-quant val_bpb | 1.2172 | 1.2196 | -| **Post-quant val_bpb** | **1.2244** | **1.1925** | -| **Improvement** | — | **-0.0319** | -| Training steps | 13,780 | 13,450 | -| Eval time (8xH100) | ~16s | 70s | -| Artifact size | 15,863,489 bytes | 15,874,829 bytes | - -The pre-quant BPB is nearly identical (training is the same). The 0.032 improvement comes entirely from scoring tokens with richer context during evaluation. - -## Configuration - -Architecture and training are identical to the Naive Baseline: -- Layout: `VOCAB_SIZE=1024 NUM_LAYERS=9 MODEL_DIM=512 NUM_HEADS=8 NUM_KV_HEADS=4 MLP_MULT=2` -- Tied output/input embeddings: `TIE_EMBEDDINGS=1` -- Tied embedding LR: `TIED_EMBED_LR=0.05` -- Batching: `TRAIN_BATCH_TOKENS=524288 TRAIN_SEQ_LEN=1024` - -Evaluation-specific parameters: -- `EVAL_STRIDE=64` (sliding window stride; baseline uses non-overlapping = stride 1024) -- `EVAL_BATCH_SEQS=1024` (number of windows per forward pass for GPU utilization) - -## Command - -```bash -RUN_ID=8xh100_slide64_v2 \ -DATA_PATH=./data/datasets/fineweb10B_sp1024/ \ -TOKENIZER_PATH=./data/tokenizers/fineweb_1024_bpe.model \ -VOCAB_SIZE=1024 \ -NUM_LOOPS=1 \ -LORA_RANK=0 \ -QAT=0 \ -EVAL_STRIDE=64 \ -EVAL_BATCH_SEQS=1024 \ -MAX_WALLCLOCK_SECONDS=600 \ -TRAIN_LOG_EVERY=200 \ -VAL_LOSS_EVERY=1000 \ -torchrun --standalone --nproc_per_node=8 train_gpt.py -``` - -The `NUM_LOOPS=1 LORA_RANK=0 QAT=0` flags explicitly disable all unused code paths (these are also the defaults). - -## Key Metrics (from `train.log`) - -- Timed training stopped at `13450/20000` steps due to the wallclock cap. -- Pre-quant eval at stop: `val_loss:2.0592`, `val_bpb:1.2196` -- Post-quant sliding window eval: `val_loss:2.0135`, `val_bpb:1.1925` -- Exact printed metric: `final_int8_zlib_roundtrip_exact val_bpb:1.19250007` -- Train time: `600028ms` (`step_avg:44.61ms`) -- Peak memory: `10119 MiB allocated`, `10294 MiB reserved` -- Eval time: `69881ms` (sliding window, stride=64, batch_seqs=1024) -- Serialized model int8+zlib: `15816489 bytes` -- Code size: `58340 bytes` -- Total submission size int8+zlib: `15874829 bytes` - -## Training Volume - -- Global batch: `524288` tokens/step -- Total train tokens seen: `7,055,769,600` - -## Included Files - -- `train_gpt.py` (code snapshot used for the run, includes `eval_val_sliding` function) -- `train.log` (exact remote training log) -- `submission.json` (leaderboard metadata) diff --git a/records/track_10min_16mb/2026-03-19_SlidingWindowEval/submission.json b/records/track_10min_16mb/2026-03-19_SlidingWindowEval/submission.json deleted file mode 100644 index d25b325f97..0000000000 --- a/records/track_10min_16mb/2026-03-19_SlidingWindowEval/submission.json +++ /dev/null @@ -1,17 +0,0 @@ -{ - "author": "Matthew Li", - "github_id": "mattqlf", - "name": "Sliding Window Eval (stride=64)", - "blurb": "Baseline 9x512 SP-1024 architecture with sliding window evaluation at stride=64. Each token is scored with 960+ tokens of context instead of the baseline's 0-1023. Training is identical to the naive baseline; the improvement comes entirely from the evaluation strategy. Post-quant int8+zlib roundtrip under the 16,000,000-byte cap.", - "date": "2026-03-19T04:48:00Z", - "val_loss": 2.01348383, - "val_bpb": 1.19250007, - "pre_quant_val_loss": 2.0592, - "pre_quant_val_bpb": 1.2196, - "step_stop": 13450, - "wallclock_seconds": 600.028, - "eval_time_seconds": 69.881, - "bytes_total": 15874829, - "bytes_model_int8_zlib": 15816489, - "bytes_code": 58340 -} \ No newline at end of file diff --git a/records/track_10min_16mb/2026-03-19_SlidingWindowEval/train.log b/records/track_10min_16mb/2026-03-19_SlidingWindowEval/train.log deleted file mode 100644 index 8bd9edc7d8..0000000000 --- a/records/track_10min_16mb/2026-03-19_SlidingWindowEval/train.log +++ /dev/null @@ -1,133 +0,0 @@ -val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model -train_loader:dataset:fineweb10B_sp1024 train_shards:80 -val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 -qat:False -model_params:17059912 (unique_layers:9 loops:1 effective_depth:9 lora_rank:0 lora_params:0) -world_size:8 grad_accum_steps:1 -sdp_backends:cudnn=False flash=True mem_efficient=False math=False -attention_mode:gqa num_heads:8 num_kv_heads:4 -tie_embeddings:True embed_lr:0.05 head_lr:0.0 matrix_lr:0.04 scalar_lr:0.04 -train_batch_tokens:524288 train_seq_len:1024 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 -seed:1337 -warmup_step:1/20 -warmup_step:2/20 -warmup_step:3/20 -warmup_step:4/20 -warmup_step:5/20 -warmup_step:6/20 -warmup_step:7/20 -warmup_step:8/20 -warmup_step:9/20 -warmup_step:10/20 -warmup_step:11/20 -warmup_step:12/20 -warmup_step:13/20 -warmup_step:14/20 -warmup_step:15/20 -warmup_step:16/20 -warmup_step:17/20 -warmup_step:18/20 -warmup_step:19/20 -warmup_step:20/20 -step:0/20000 val_loss:6.9357 val_bpb:4.1077 train_time:0ms step_avg:0.01ms -step:1/20000 train_loss:6.9370 train_time:24ms step_avg:23.88ms -step:2/20000 train_loss:16.8366 train_time:62ms step_avg:31.18ms -step:3/20000 train_loss:8.7608 train_time:105ms step_avg:35.06ms -step:4/20000 train_loss:6.6385 train_time:148ms step_avg:37.01ms -step:5/20000 train_loss:6.6114 train_time:191ms step_avg:38.20ms -step:6/20000 train_loss:7.4220 train_time:234ms step_avg:39.05ms -step:7/20000 train_loss:6.3508 train_time:277ms step_avg:39.61ms -step:8/20000 train_loss:6.1582 train_time:320ms step_avg:40.05ms -step:9/20000 train_loss:6.0678 train_time:364ms step_avg:40.39ms -step:10/20000 train_loss:5.9747 train_time:407ms step_avg:40.66ms -step:200/20000 train_loss:2.8545 train_time:8724ms step_avg:43.62ms -step:400/20000 train_loss:2.3579 train_time:17484ms step_avg:43.71ms -step:600/20000 train_loss:2.5468 train_time:26272ms step_avg:43.79ms -step:800/20000 train_loss:2.2933 train_time:35060ms step_avg:43.83ms -step:1000/20000 train_loss:2.3741 train_time:43870ms step_avg:43.87ms -step:1000/20000 val_loss:2.3339 val_bpb:1.3823 train_time:43898ms step_avg:43.90ms -step:1200/20000 train_loss:2.3859 train_time:52691ms step_avg:43.91ms -step:1400/20000 train_loss:2.4313 train_time:61546ms step_avg:43.96ms -step:1600/20000 train_loss:2.0990 train_time:70406ms step_avg:44.00ms -step:1800/20000 train_loss:2.1989 train_time:79268ms step_avg:44.04ms -step:2000/20000 train_loss:2.2537 train_time:88144ms step_avg:44.07ms -step:2000/20000 val_loss:2.2349 val_bpb:1.3236 train_time:88173ms step_avg:44.09ms -step:2200/20000 train_loss:2.0705 train_time:97057ms step_avg:44.12ms -step:2400/20000 train_loss:2.2003 train_time:105955ms step_avg:44.15ms -step:2600/20000 train_loss:2.4100 train_time:114848ms step_avg:44.17ms -step:2800/20000 train_loss:2.2339 train_time:123759ms step_avg:44.20ms -step:3000/20000 train_loss:2.2271 train_time:132674ms step_avg:44.22ms -step:3000/20000 val_loss:2.1940 val_bpb:1.2994 train_time:132702ms step_avg:44.23ms -step:3200/20000 train_loss:2.1853 train_time:141596ms step_avg:44.25ms -step:3400/20000 train_loss:2.1579 train_time:150510ms step_avg:44.27ms -step:3600/20000 train_loss:2.1150 train_time:159433ms step_avg:44.29ms -step:3800/20000 train_loss:2.2207 train_time:168353ms step_avg:44.30ms -step:4000/20000 train_loss:2.1629 train_time:177281ms step_avg:44.32ms -step:4000/20000 val_loss:2.1691 val_bpb:1.2846 train_time:177309ms step_avg:44.33ms -step:4200/20000 train_loss:2.1755 train_time:186254ms step_avg:44.35ms -step:4400/20000 train_loss:2.1075 train_time:195164ms step_avg:44.36ms -step:4600/20000 train_loss:1.9721 train_time:204095ms step_avg:44.37ms -step:4800/20000 train_loss:2.2620 train_time:213026ms step_avg:44.38ms -step:5000/20000 train_loss:2.0261 train_time:221961ms step_avg:44.39ms -step:5000/20000 val_loss:2.1527 val_bpb:1.2749 train_time:221991ms step_avg:44.40ms -step:5200/20000 train_loss:2.1734 train_time:230894ms step_avg:44.40ms -step:5400/20000 train_loss:2.1832 train_time:239840ms step_avg:44.41ms -step:5600/20000 train_loss:2.1834 train_time:248772ms step_avg:44.42ms -step:5800/20000 train_loss:2.1438 train_time:257705ms step_avg:44.43ms -step:6000/20000 train_loss:2.2213 train_time:266645ms step_avg:44.44ms -step:6000/20000 val_loss:2.1428 val_bpb:1.2691 train_time:266673ms step_avg:44.45ms -step:6200/20000 train_loss:2.0903 train_time:275590ms step_avg:44.45ms -step:6400/20000 train_loss:2.1614 train_time:284523ms step_avg:44.46ms -step:6600/20000 train_loss:2.1233 train_time:293461ms step_avg:44.46ms -step:6800/20000 train_loss:2.1883 train_time:302396ms step_avg:44.47ms -step:7000/20000 train_loss:2.2269 train_time:311350ms step_avg:44.48ms -step:7000/20000 val_loss:2.1319 val_bpb:1.2626 train_time:311378ms step_avg:44.48ms -step:7200/20000 train_loss:2.1985 train_time:320283ms step_avg:44.48ms -step:7400/20000 train_loss:2.1159 train_time:329218ms step_avg:44.49ms -step:7600/20000 train_loss:2.0015 train_time:338182ms step_avg:44.50ms -step:7800/20000 train_loss:2.1457 train_time:347121ms step_avg:44.50ms -step:8000/20000 train_loss:2.1162 train_time:356081ms step_avg:44.51ms -step:8000/20000 val_loss:2.1223 val_bpb:1.2570 train_time:356110ms step_avg:44.51ms -step:8200/20000 train_loss:2.1840 train_time:365027ms step_avg:44.52ms -step:8400/20000 train_loss:2.1384 train_time:374085ms step_avg:44.53ms -step:8600/20000 train_loss:2.1382 train_time:383022ms step_avg:44.54ms -step:8800/20000 train_loss:2.1010 train_time:391971ms step_avg:44.54ms -step:9000/20000 train_loss:2.0244 train_time:400928ms step_avg:44.55ms -step:9000/20000 val_loss:2.1174 val_bpb:1.2540 train_time:400957ms step_avg:44.55ms -step:9200/20000 train_loss:2.0847 train_time:409874ms step_avg:44.55ms -step:9400/20000 train_loss:2.1341 train_time:418805ms step_avg:44.55ms -step:9600/20000 train_loss:2.1481 train_time:427753ms step_avg:44.56ms -step:9800/20000 train_loss:2.0727 train_time:436682ms step_avg:44.56ms -step:10000/20000 train_loss:2.1143 train_time:445623ms step_avg:44.56ms -step:10000/20000 val_loss:2.1124 val_bpb:1.2511 train_time:445652ms step_avg:44.57ms -step:10200/20000 train_loss:2.0665 train_time:454563ms step_avg:44.57ms -step:10400/20000 train_loss:2.0990 train_time:463504ms step_avg:44.57ms -step:10600/20000 train_loss:1.9760 train_time:472458ms step_avg:44.57ms -step:10800/20000 train_loss:2.1863 train_time:481398ms step_avg:44.57ms -step:11000/20000 train_loss:2.1152 train_time:490335ms step_avg:44.58ms -step:11000/20000 val_loss:2.1058 val_bpb:1.2472 train_time:490363ms step_avg:44.58ms -step:11200/20000 train_loss:2.0681 train_time:499305ms step_avg:44.58ms -step:11400/20000 train_loss:2.0572 train_time:508232ms step_avg:44.58ms -step:11600/20000 train_loss:2.0625 train_time:517178ms step_avg:44.58ms -step:11800/20000 train_loss:2.0980 train_time:526122ms step_avg:44.59ms -step:12000/20000 train_loss:2.0710 train_time:535066ms step_avg:44.59ms -step:12000/20000 val_loss:2.1003 val_bpb:1.2439 train_time:535094ms step_avg:44.59ms -step:12200/20000 train_loss:2.2155 train_time:544026ms step_avg:44.59ms -step:12400/20000 train_loss:1.8595 train_time:553021ms step_avg:44.60ms -step:12600/20000 train_loss:2.0846 train_time:561982ms step_avg:44.60ms -step:12800/20000 train_loss:2.0964 train_time:570913ms step_avg:44.60ms -step:13000/20000 train_loss:2.1690 train_time:579870ms step_avg:44.61ms -step:13000/20000 val_loss:2.0744 val_bpb:1.2286 train_time:579898ms step_avg:44.61ms -step:13200/20000 train_loss:2.1741 train_time:588820ms step_avg:44.61ms -step:13400/20000 train_loss:2.0456 train_time:597778ms step_avg:44.61ms -step:13450/20000 val_loss:2.0592 val_bpb:1.2196 train_time:600028ms step_avg:44.61ms -stopping_early: wallclock_cap train_time:600028ms step:13450/20000 -peak memory allocated: 10119 MiB reserved: 10294 MiB -Serialized model: 67224983 bytes -Code size: 58340 bytes -Total submission size: 67283323 bytes -Serialized model int8+zlib: 15816489 bytes (payload:17178912 raw_torch:17224025 payload_ratio:3.91x) -Total submission size int8+zlib: 15874829 bytes -final_eval_mode:sliding_window stride:64 batch_seqs:1024 -final_int8_zlib_roundtrip val_loss:2.0135 val_bpb:1.1925 eval_time:69881ms -final_int8_zlib_roundtrip_exact val_loss:2.01348383 val_bpb:1.19250007 diff --git a/records/track_10min_16mb/2026-03-19_SlidingWindowEval/train_gpt.py b/records/track_10min_16mb/2026-03-19_SlidingWindowEval/train_gpt.py deleted file mode 100644 index 6a8fd84a81..0000000000 --- a/records/track_10min_16mb/2026-03-19_SlidingWindowEval/train_gpt.py +++ /dev/null @@ -1,1366 +0,0 @@ -""" -The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. - -Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. -""" - -from __future__ import annotations - -import copy -import glob -import io -import math -import os -import random -import subprocess -import sys -import time -import uuid -import zlib -from pathlib import Path - -import numpy as np -import sentencepiece as spm -import torch -import torch.distributed as dist -import torch.nn.functional as F -from torch import Tensor, nn -from torch.nn.parallel import DistributedDataParallel as DDP - -# ----------------------------- -# HYPERPARAMETERS -# ----------------------------- -# Default Simple Baseline run: -# - 9 transformer blocks at width 512 -# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion -# - vocab size 1024, sequence length 1024, tied embeddings -# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap - -class Hyperparameters: - # Data paths are shard globs produced by the existing preprocessing pipeline. - data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") - train_files = os.path.join(data_path, "fineweb_train_*.bin") - val_files = os.path.join(data_path, "fineweb_val_*.bin") - tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") - run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) - seed = int(os.environ.get("SEED", 1337)) - - # Validation cadence and batch size. Validation always uses the full fineweb_val split. - val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) - val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) - train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) - - # Training length. - iterations = int(os.environ.get("ITERATIONS", 20000)) - warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) - warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) - train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) - train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) - max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) - qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) - - # Model shape. - vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) - num_layers = int(os.environ.get("NUM_LAYERS", 9)) - num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) - model_dim = int(os.environ.get("MODEL_DIM", 512)) - num_heads = int(os.environ.get("NUM_HEADS", 8)) - mlp_mult = int(os.environ.get("MLP_MULT", 2)) - num_loops = int(os.environ.get("NUM_LOOPS", 1)) - lora_rank = int(os.environ.get("LORA_RANK", 0)) - qat = bool(int(os.environ.get("QAT", "1"))) - eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) - eval_batch_seqs = int(os.environ.get("EVAL_BATCH_SEQS", 32)) - tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) - rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) - logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) - - # Optimizer hyperparameters. - embed_lr = float(os.environ.get("EMBED_LR", 0.6)) - head_lr = float(os.environ.get("HEAD_LR", 0.008)) - tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) - tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) - matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) - scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) - muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) - muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) - muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) - muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) - beta1 = float(os.environ.get("BETA1", 0.9)) - beta2 = float(os.environ.get("BETA2", 0.95)) - adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) - lora_lr = float(os.environ.get("LORA_LR", 0.01)) - grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) - -# ----------------------------- -# MUON OPTIMIZER -# ----------------------------- -# -# As borrowed from modded-nanogpt -# Background on Muon: https://kellerjordan.github.io/posts/muon/ - -def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: - # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. - # Muon uses this to normalize matrix-shaped gradients before applying them. - a, b, c = (3.4445, -4.7750, 2.0315) - X = G.bfloat16() - X /= X.norm() + eps - transposed = G.size(0) > G.size(1) - if transposed: - X = X.T - for _ in range(steps): - A = X @ X.T - B = b * A + c * A @ A - X = a * X + B @ X - return X.T if transposed else X - - -class Muon(torch.optim.Optimizer): - def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): - super().__init__( - params, - dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), - ) - - @torch.no_grad() - def step(self, closure=None): - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - distributed = dist.is_available() and dist.is_initialized() - world_size = dist.get_world_size() if distributed else 1 - rank = dist.get_rank() if distributed else 0 - - for group in self.param_groups: - params = group["params"] - if not params: - continue - lr = group["lr"] - momentum = group["momentum"] - backend_steps = group["backend_steps"] - nesterov = group["nesterov"] - - total_params = sum(int(p.numel()) for p in params) - updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) - - curr = 0 - for i, p in enumerate(params): - if i % world_size == rank and p.grad is not None: - g = p.grad - state = self.state[p] - if "momentum_buffer" not in state: - state["momentum_buffer"] = torch.zeros_like(g) - buf = state["momentum_buffer"] - buf.mul_(momentum).add_(g) - if nesterov: - g = g.add(buf, alpha=momentum) - g = zeropower_via_newtonschulz5(g, steps=backend_steps) - # Scale correction from Muon reference implementations. - g *= max(1, g.size(0) / g.size(1)) ** 0.5 - updates_flat[curr : curr + p.numel()] = g.reshape(-1) - curr += p.numel() - - if distributed: - dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) - - curr = 0 - for p in params: - g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) - p.add_(g, alpha=-lr) - curr += p.numel() - - return loss - - -# ----------------------------- -# TOKENIZER-AGNOSTIC EVALUATION SETUP -# ----------------------------- -# -# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. -# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. -# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. -# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. - -def build_sentencepiece_luts( - sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device -) -> tuple[Tensor, Tensor, Tensor]: - sp_vocab_size = int(sp.vocab_size()) - table_size = max(sp_vocab_size, vocab_size) - base_bytes_np = np.zeros((table_size,), dtype=np.int16) - has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) - is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) - for token_id in range(sp_vocab_size): - if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): - continue - is_boundary_token_np[token_id] = False - if sp.is_byte(token_id): - base_bytes_np[token_id] = 1 - continue - piece = sp.id_to_piece(token_id) - if piece.startswith("▁"): - has_leading_space_np[token_id] = True - piece = piece[1:] - base_bytes_np[token_id] = len(piece.encode("utf-8")) - return ( - torch.tensor(base_bytes_np, dtype=torch.int16, device=device), - torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), - torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), - ) - - -def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: - files = [Path(p) for p in sorted(glob.glob(pattern))] - if not files: - raise FileNotFoundError(f"No files found for pattern: {pattern}") - # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. - tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() - usable = ((tokens.numel() - 1) // seq_len) * seq_len - if usable <= 0: - raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") - return tokens[: usable + 1] - - -def eval_val( - args: Hyperparameters, - model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - grad_accum_steps: int, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, -) -> tuple[float, float]: - # Validation computes two metrics: - # - val_loss: token cross-entropy (natural log) - # - val_bpb: tokenizer-agnostic compression metric used by the challenge - local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) - if local_batch_tokens < args.train_seq_len: - raise ValueError( - "VAL_BATCH_SIZE must provide at least one sequence per rank; " - f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " - f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" - ) - local_batch_seqs = local_batch_tokens // args.train_seq_len - total_seqs = (val_tokens.numel() - 1) // args.train_seq_len - seq_start = (total_seqs * rank) // world_size - seq_end = (total_seqs * (rank + 1)) // world_size - val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) - val_token_count = torch.zeros((), device=device, dtype=torch.float64) - val_byte_count = torch.zeros((), device=device, dtype=torch.float64) - - model.eval() - with torch.inference_mode(): - for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): - batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) - raw_start = batch_seq_start * args.train_seq_len - raw_end = batch_seq_end * args.train_seq_len + 1 - local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) - x = local[:-1].reshape(-1, args.train_seq_len) - y = local[1:].reshape(-1, args.train_seq_len) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - batch_loss = model(x, y).detach() - batch_token_count = float(y.numel()) - val_loss_sum += batch_loss.to(torch.float64) * batch_token_count - val_token_count += batch_token_count - prev_ids = x.reshape(-1) - tgt_ids = y.reshape(-1) - token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) - token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) - val_byte_count += token_bytes.to(torch.float64).sum() - - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) - - val_loss = val_loss_sum / val_token_count - bits_per_token = val_loss.item() / math.log(2.0) - tokens_per_byte = val_token_count.item() / val_byte_count.item() - model.train() - return float(val_loss.item()), float(bits_per_token * tokens_per_byte) - -# ----------------------------- -# POST-TRAINING QUANTIZATION -# ----------------------------- -# -# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. -# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. -# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. - -CONTROL_TENSOR_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "CONTROL_TENSOR_NAME_PATTERNS", - "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", - ",".join(CONTROL_TENSOR_NAME_PATTERNS), - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 -INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 -INT8_PER_ROW_SCALE_DTYPE = torch.float16 -INT8_CLIP_PERCENTILE = 99.99984 -INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 - -def tensor_nbytes(t: Tensor) -> int: - return int(t.numel()) * int(t.element_size()) - -def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: - if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): - return t.float().contiguous() - if t.dtype in {torch.float32, torch.bfloat16}: - passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") - return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() - return t - -def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: - t32 = t.float() - if t32.ndim == 2: - # Matrices get one scale per row, which usually tracks output-channel - # ranges much better than a single tensor-wide scale. - clip_abs = ( - torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) - if t32.numel() - else torch.empty((t32.shape[0],), dtype=torch.float32) - ) - clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) - scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) - q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() - return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() - - # Vectors / scalars use a simpler per-tensor scale. - clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 - scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) - q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() - return q, scale - -def quantize_state_dict_int8(state_dict: dict[str, Tensor]): - # Single supported clean-script export format: - # - per-row int8 for 2D float tensors - # - per-tensor int8 for other float tensors - # - exact passthrough for non-floats - # - passthrough for small float tensors, stored as fp16 to save bytes - quantized: dict[str, Tensor] = {} - scales: dict[str, Tensor] = {} - dtypes: dict[str, str] = {} - passthrough: dict[str, Tensor] = {} - passthrough_orig_dtypes: dict[str, str] = {} - qmeta: dict[str, dict[str, object]] = {} - stats = dict.fromkeys( - ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), - 0, - ) - - for name, tensor in state_dict.items(): - t = tensor.detach().to("cpu").contiguous() - stats["param_count"] += int(t.numel()) - stats["num_tensors"] += 1 - stats["baseline_tensor_bytes"] += tensor_nbytes(t) - - if not t.is_floating_point(): - stats["num_nonfloat_tensors"] += 1 - passthrough[name] = t - stats["int8_payload_bytes"] += tensor_nbytes(t) - continue - - # Small float tensors are cheap enough to keep directly. We still downcast - # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. - if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: - kept = keep_float_tensor(name, t, passthrough_orig_dtypes) - passthrough[name] = kept - stats["int8_payload_bytes"] += tensor_nbytes(kept) - continue - - stats["num_float_tensors"] += 1 - q, s = quantize_float_tensor(t) - if s.ndim > 0: - qmeta[name] = {"scheme": "per_row", "axis": 0} - quantized[name] = q - scales[name] = s - dtypes[name] = str(t.dtype).removeprefix("torch.") - stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) - - obj: dict[str, object] = { - "__quant_format__": "int8_clean_per_row_v1", - "quantized": quantized, - "scales": scales, - "dtypes": dtypes, - "passthrough": passthrough, - } - if qmeta: - obj["qmeta"] = qmeta - if passthrough_orig_dtypes: - obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes - return obj, stats - -def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: - out: dict[str, Tensor] = {} - qmeta = obj.get("qmeta", {}) - passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) - for name, q in obj["quantized"].items(): - dtype = getattr(torch, obj["dtypes"][name]) - s = obj["scales"][name] - if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: - s = s.to(dtype=torch.float32) - # Broadcast the saved row scale back across trailing dimensions. - out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() - else: - scale = float(s.item()) - out[name] = (q.float() * scale).to(dtype=dtype).contiguous() - for name, t in obj["passthrough"].items(): - # Restore small tensors, undoing the temporary fp16 storage cast if needed. - out_t = t.detach().to("cpu").contiguous() - orig_dtype = passthrough_orig_dtypes.get(name) - if isinstance(orig_dtype, str): - out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() - out[name] = out_t - return out - - -# ----------------------------- -# DATA LOADING -# ----------------------------- - -def load_data_shard(file: Path) -> Tensor: - header_bytes = 256 * np.dtype(" None: - self.file_idx = (self.file_idx + 1) % len(self.files) - self.tokens = load_data_shard(self.files[self.file_idx]) - self.pos = 0 - - def take(self, n: int) -> Tensor: - chunks: list[Tensor] = [] - remaining = n - while remaining > 0: - avail = self.tokens.numel() - self.pos - if avail <= 0: - self._advance_file() - continue - k = min(remaining, avail) - chunks.append(self.tokens[self.pos : self.pos + k]) - self.pos += k - remaining -= k - return chunks[0] if len(chunks) == 1 else torch.cat(chunks) - - -class DistributedTokenLoader: - # Each call consumes a contiguous chunk from the shared token stream, then slices out - # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. - def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): - self.rank = rank - self.world_size = world_size - self.device = device - self.stream = TokenStream(pattern) - - def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: - local_tokens = global_tokens // (self.world_size * grad_accum_steps) - per_rank_span = local_tokens + 1 - chunk = self.stream.take(per_rank_span * self.world_size) - start = self.rank * per_rank_span - local = chunk[start : start + per_rank_span].to(dtype=torch.int64) - x = local[:-1].reshape(-1, seq_len) - y = local[1:].reshape(-1, seq_len) - return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) - -# ----------------------------- -# TRANSFORMER MODULES -# ----------------------------- - -class RMSNorm(nn.Module): - def __init__(self, eps: float | None = None): - super().__init__() - self.eps = eps - - def forward(self, x: Tensor) -> Tensor: - return F.rms_norm(x, (x.size(-1),), eps=self.eps) - - -def fake_quantize_int8_per_row(w: Tensor) -> Tensor: - """Simulate per-row int8 quantization with straight-through estimator. - - Forward: uses quantized-then-dequantized weights (same rounding as post-training). - Backward: gradients pass through as if no quantization happened (STE). - """ - scale = w.detach().abs().amax(dim=-1, keepdim=True).div_(127.0).clamp_(min=1.0 / 127.0) - w_deq = (w / scale).round().clamp_(-127, 127) * scale - return w + (w_deq - w).detach() - - -class CastedLinear(nn.Linear): - # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. - _qat: bool = False - - def forward(self, x: Tensor) -> Tensor: - w = self.weight - if self._qat and self.training: - w = fake_quantize_int8_per_row(w) - bias = self.bias.to(x.dtype) if self.bias is not None else None - return F.linear(x, w.to(x.dtype), bias) - - -class AttentionLoRA(nn.Module): - """Per-iteration LoRA adapters for attention Q, K, V, and output projections. - - Initialized so that the LoRA contribution is zero at the start of training - (B matrices are zeros). During training, the optimizer learns per-iteration - specialization while the base attention weights remain shared across loops. - """ - def __init__(self, dim: int, kv_dim: int, rank: int): - super().__init__() - self.q_A = nn.Parameter(torch.empty(dim, rank)) - self.q_B = nn.Parameter(torch.zeros(rank, dim)) - self.k_A = nn.Parameter(torch.empty(dim, rank)) - self.k_B = nn.Parameter(torch.zeros(rank, kv_dim)) - self.v_A = nn.Parameter(torch.empty(dim, rank)) - self.v_B = nn.Parameter(torch.zeros(rank, kv_dim)) - self.proj_A = nn.Parameter(torch.empty(dim, rank)) - self.proj_B = nn.Parameter(torch.zeros(rank, dim)) - self._init_lora() - - def _init_lora(self) -> None: - for name in ("q_A", "k_A", "v_A", "proj_A"): - nn.init.kaiming_uniform_(getattr(self, name), a=math.sqrt(5)) - - -def restore_low_dim_params_to_fp32(module: nn.Module) -> None: - # Keep small/control parameters in fp32 even when the model body runs in bf16. - with torch.no_grad(): - for name, param in module.named_parameters(): - if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: - param.data = param.data.float() - - -class Rotary(nn.Module): - # Caches cos/sin tables per sequence length on the current device. - def __init__(self, dim: int, base: float = 10000.0): - super().__init__() - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self._seq_len_cached = 0 - self._cos_cached: Tensor | None = None - self._sin_cached: Tensor | None = None - - def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: - if ( - self._cos_cached is None - or self._sin_cached is None - or self._seq_len_cached != seq_len - or self._cos_cached.device != device - ): - t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) - freqs = torch.outer(t, self.inv_freq.to(device)) - self._cos_cached = freqs.cos()[None, None, :, :] - self._sin_cached = freqs.sin()[None, None, :, :] - self._seq_len_cached = seq_len - return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) - - -def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: - half = x.size(-1) // 2 - x1, x2 = x[..., :half], x[..., half:] - return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) - - -class CausalSelfAttention(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - rope_base: float, - qk_gain_init: float, - ): - super().__init__() - if dim % num_heads != 0: - raise ValueError("model_dim must be divisible by num_heads") - if num_heads % num_kv_heads != 0: - raise ValueError("num_heads must be divisible by num_kv_heads") - self.num_heads = num_heads - self.num_kv_heads = num_kv_heads - self.head_dim = dim // num_heads - if self.head_dim % 2 != 0: - raise ValueError("head_dim must be even for RoPE") - kv_dim = self.num_kv_heads * self.head_dim - self.c_q = CastedLinear(dim, dim, bias=False) - self.c_k = CastedLinear(dim, kv_dim, bias=False) - self.c_v = CastedLinear(dim, kv_dim, bias=False) - self.proj = CastedLinear(dim, dim, bias=False) - self.proj._zero_init = True - self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) - self.rotary = Rotary(self.head_dim, base=rope_base) - - def forward(self, x: Tensor, lora: AttentionLoRA | None = None) -> Tensor: - bsz, seqlen, dim = x.shape - q = self.c_q(x) - k = self.c_k(x) - v = self.c_v(x) - if lora is not None: - # LoRA delta: (bsz, seqlen, dim) @ (dim, rank) @ (rank, out_dim) - # autocast handles fp32->bf16 cast of LoRA params automatically - q = q + (x @ lora.q_A) @ lora.q_B - k = k + (x @ lora.k_A) @ lora.k_B - v = v + (x @ lora.v_A) @ lora.v_B - q = q.reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) - k = k.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) - v = v.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) - q = F.rms_norm(q, (q.size(-1),)) - k = F.rms_norm(k, (k.size(-1),)) - cos, sin = self.rotary(seqlen, x.device, q.dtype) - q = apply_rotary_emb(q, cos, sin) - k = apply_rotary_emb(k, cos, sin) - q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] - y = F.scaled_dot_product_attention( - q, - k, - v, - attn_mask=None, - is_causal=True, - enable_gqa=(self.num_kv_heads != self.num_heads), - ) - y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) - out = self.proj(y) - if lora is not None: - out = out + (y @ lora.proj_A) @ lora.proj_B - return out - - -class MLP(nn.Module): - # relu^2 MLP from the original modded-nanogpt setup - def __init__(self, dim: int, mlp_mult: int): - super().__init__() - hidden = mlp_mult * dim - self.fc = CastedLinear(dim, hidden, bias=False) - self.proj = CastedLinear(hidden, dim, bias=False) - self.proj._zero_init = True - - def forward(self, x: Tensor) -> Tensor: - x = torch.relu(self.fc(x)) - return self.proj(x.square()) - - -class Block(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: int, - rope_base: float, - qk_gain_init: float, - ): - super().__init__() - self.attn_norm = RMSNorm() - self.mlp_norm = RMSNorm() - self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) - self.mlp = MLP(dim, mlp_mult) - self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) - - def forward(self, x: Tensor, x0: Tensor, lora: AttentionLoRA | None = None) -> Tensor: - mix = self.resid_mix.to(dtype=x.dtype) - x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 - attn_out = self.attn(self.attn_norm(x), lora=lora) - x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out - x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) - return x - - -class GPT(nn.Module): - def __init__( - self, - vocab_size: int, - num_layers: int, - model_dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: int, - tie_embeddings: bool, - tied_embed_init_std: float, - logit_softcap: float, - rope_base: float, - qk_gain_init: float, - num_loops: int = 1, - lora_rank: int = 0, - ): - super().__init__() - if logit_softcap <= 0.0: - raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") - self.tie_embeddings = tie_embeddings - self.tied_embed_init_std = tied_embed_init_std - self.logit_softcap = logit_softcap - self.num_unique_layers = num_layers - self.num_loops = num_loops - effective_depth = num_layers * num_loops - self.tok_emb = nn.Embedding(vocab_size, model_dim) - self.num_encoder_layers = effective_depth // 2 - self.num_decoder_layers = effective_depth - self.num_encoder_layers - self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) - self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) - self.blocks = nn.ModuleList( - [ - Block( - model_dim, - num_heads, - num_kv_heads, - mlp_mult, - rope_base, - qk_gain_init, - ) - for i in range(num_layers) - ] - ) - # Per-(loop, block) LoRA adapters for attention projections. - # Only created when num_loops > 1 and lora_rank > 0. - kv_dim = num_kv_heads * (model_dim // num_heads) - if lora_rank > 0 and num_loops > 1: - self.lora_adapters = nn.ModuleList( - [ - nn.ModuleList( - [AttentionLoRA(model_dim, kv_dim, lora_rank) for _ in range(num_layers)] - ) - for _ in range(num_loops) - ] - ) - else: - self.lora_adapters = None - self.final_norm = RMSNorm() - self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) - if self.lm_head is not None: - self.lm_head._zero_init = True - self._init_weights() - - def _init_weights(self) -> None: - if self.tie_embeddings: - nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) - for module in self.modules(): - if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): - nn.init.zeros_(module.weight) - - def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: - x = self.tok_emb(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - x0 = x - skips: list[Tensor] = [] - - # Iterate through effective layers: each unique block is reused across loops. - # First half (encoder) stores skip connections; second half (decoder) pops them. - eff_idx = 0 - for loop_idx in range(self.num_loops): - for block_idx in range(self.num_unique_layers): - lora = self.lora_adapters[loop_idx][block_idx] if self.lora_adapters is not None else None - if eff_idx < self.num_encoder_layers: - x = self.blocks[block_idx](x, x0, lora=lora) - skips.append(x) - else: - dec_idx = eff_idx - self.num_encoder_layers - if dec_idx < self.num_skip_weights and skips: - x = x + self.skip_weights[dec_idx].to(dtype=x.dtype)[None, None, :] * skips.pop() - x = self.blocks[block_idx](x, x0, lora=lora) - eff_idx += 1 - - x = self.final_norm(x).reshape(-1, x.size(-1)) - targets = target_ids.reshape(-1) - if self.tie_embeddings: - logits_proj = F.linear(x, self.tok_emb.weight) - else: - if self.lm_head is None: - raise RuntimeError("lm_head is required when tie_embeddings=False") - logits_proj = self.lm_head(x) - logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - return F.cross_entropy(logits.float(), targets, reduction="mean") - - def forward_logits(self, input_ids: Tensor) -> Tensor: - """Return logits (bsz, seq_len, vocab) without computing loss.""" - x = self.tok_emb(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - x0 = x - skips: list[Tensor] = [] - eff_idx = 0 - for loop_idx in range(self.num_loops): - for block_idx in range(self.num_unique_layers): - lora = self.lora_adapters[loop_idx][block_idx] if self.lora_adapters is not None else None - if eff_idx < self.num_encoder_layers: - x = self.blocks[block_idx](x, x0, lora=lora) - skips.append(x) - else: - dec_idx = eff_idx - self.num_encoder_layers - if dec_idx < self.num_skip_weights and skips: - x = x + self.skip_weights[dec_idx].to(dtype=x.dtype)[None, None, :] * skips.pop() - x = self.blocks[block_idx](x, x0, lora=lora) - eff_idx += 1 - x = self.final_norm(x) - if self.tie_embeddings: - logits_proj = F.linear(x, self.tok_emb.weight) - else: - logits_proj = self.lm_head(x) - return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - - -def eval_val_sliding( - args: Hyperparameters, - base_model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, - stride: int, - batch_seqs: int = 32, -) -> tuple[float, float]: - """Sliding window evaluation: each token scored with maximum context. - - Windows of train_seq_len advance by `stride`. Only the last `stride` tokens - per window contribute to the score (first window scores all). Windows are - batched and distributed across ranks. - """ - seq_len = args.train_seq_len - total_tokens = val_tokens.numel() - 1 - - # Build windows; include final partial window if it has at least 1 token - window_starts = [ws for ws in range(0, total_tokens, stride) - if min(ws + seq_len, total_tokens) - ws >= 1] - total_windows = len(window_starts) - - # Distribute across ranks - my_s = (total_windows * rank) // world_size - my_e = (total_windows * (rank + 1)) // world_size - my_windows = window_starts[my_s:my_e] - - loss_sum = torch.zeros((), device=device, dtype=torch.float64) - token_count = torch.zeros((), device=device, dtype=torch.float64) - byte_count = torch.zeros((), device=device, dtype=torch.float64) - - base_model.eval() - with torch.inference_mode(): - for bi in range(0, len(my_windows), batch_seqs): - batch_ws = my_windows[bi:bi + batch_seqs] - bsz = len(batch_ws) - - x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) - y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) - wlens: list[int] = [] - - for i, ws in enumerate(batch_ws): - end = min(ws + seq_len, total_tokens) - wlen = end - ws - wlens.append(wlen) - chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) - x_batch[i, :wlen] = chunk[:-1] - y_batch[i, :wlen] = chunk[1:] - - with torch.autocast(device_type="cuda", dtype=torch.bfloat16): - logits = base_model.forward_logits(x_batch) - - nll = F.cross_entropy( - logits.reshape(-1, logits.size(-1)).float(), - y_batch.reshape(-1), - reduction="none", - ).reshape(bsz, seq_len) - - for i, ws in enumerate(batch_ws): - wlen = wlens[i] - s = 0 if ws == 0 else max(wlen - stride, 0) - scored_nll = nll[i, s:wlen].to(torch.float64) - loss_sum += scored_nll.sum() - token_count += float(wlen - s) - tgt = y_batch[i, s:wlen] - prev = x_batch[i, s:wlen] - tb = base_bytes_lut[tgt].to(torch.float64) - tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) - byte_count += tb.sum() - - # Progress (rank 0 only) - if rank == 0 and (bi // batch_seqs) % 50 == 0: - done = min(bi + batch_seqs, len(my_windows)) - pct = done / len(my_windows) * 100 - running_bpb = 0.0 - if token_count.item() > 0: - rl = (loss_sum / token_count).item() - running_bpb = rl / math.log(2.0) * (token_count.item() / byte_count.item()) - print(f" sliding_eval [{pct:5.1f}%] {done}/{len(my_windows)} windows running_bpb={running_bpb:.6f}", flush=True) - - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) - - val_loss = (loss_sum / token_count).item() - bits_per_token = val_loss / math.log(2.0) - tokens_per_byte = token_count.item() / byte_count.item() - base_model.train() - return val_loss, bits_per_token * tokens_per_byte - - -# ----------------------------- -# TRAINING -# ----------------------------- - -def main() -> None: - global zeropower_via_newtonschulz5 - - code = Path(__file__).read_text(encoding="utf-8") - args = Hyperparameters() - zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) - - # ----------------------------- - # DISTRIBUTED + CUDA SETUP - # ----------------------------- - - distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ - rank = int(os.environ.get("RANK", "0")) - world_size = int(os.environ.get("WORLD_SIZE", "1")) - local_rank = int(os.environ.get("LOCAL_RANK", "0")) - if world_size <= 0: - raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") - if 8 % world_size != 0: - raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") - grad_accum_steps = 8 // world_size - grad_scale = 1.0 / grad_accum_steps - if not torch.cuda.is_available(): - raise RuntimeError("CUDA is required") - device = torch.device("cuda", local_rank) - torch.cuda.set_device(device) - if distributed: - dist.init_process_group(backend="nccl", device_id=device) - dist.barrier() - master_process = rank == 0 - - # Fast math knobs - torch.backends.cuda.matmul.allow_tf32 = True - torch.backends.cudnn.allow_tf32 = True - from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp - - enable_cudnn_sdp(False) - enable_flash_sdp(True) - enable_mem_efficient_sdp(False) - enable_math_sdp(False) - - logfile = None - if master_process: - os.makedirs("logs", exist_ok=True) - logfile = f"logs/{args.run_id}.txt" - print(logfile) - - def log0(msg: str, console: bool = True) -> None: - if not master_process: - return - if console: - print(msg) - if logfile is not None: - with open(logfile, "a", encoding="utf-8") as f: - print(msg, file=f) - - log0(code, console=False) - log0("=" * 100, console=False) - log0(f"Running Python {sys.version}", console=False) - log0(f"Running PyTorch {torch.__version__}", console=False) - log0( - subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, - console=False, - ) - log0("=" * 100, console=False) - - # ----------------------------- - # TOKENIZER + VALIDATION METRIC SETUP - # ----------------------------- - - random.seed(args.seed) - np.random.seed(args.seed) - torch.manual_seed(args.seed) - torch.cuda.manual_seed_all(args.seed) - - if not args.tokenizer_path.endswith(".model"): - raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") - sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) - if int(sp.vocab_size()) != args.vocab_size: - raise ValueError( - f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" - ) - dataset_dir = Path(args.data_path).resolve() - actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) - val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) - base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( - sp, args.vocab_size, device - ) - log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") - log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") - log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") - - # ----------------------------- - # MODEL + OPTIMIZER SETUP - # ----------------------------- - - base_model = GPT( - vocab_size=args.vocab_size, - num_layers=args.num_layers, - model_dim=args.model_dim, - num_heads=args.num_heads, - num_kv_heads=args.num_kv_heads, - mlp_mult=args.mlp_mult, - tie_embeddings=args.tie_embeddings, - tied_embed_init_std=args.tied_embed_init_std, - logit_softcap=args.logit_softcap, - rope_base=args.rope_base, - qk_gain_init=args.qk_gain_init, - num_loops=args.num_loops, - lora_rank=args.lora_rank, - ).to(device).bfloat16() - for module in base_model.modules(): - if isinstance(module, (CastedLinear, AttentionLoRA)): - module.float() - if isinstance(module, CastedLinear) and args.qat: - module._qat = True - restore_low_dim_params_to_fp32(base_model) - log0(f"qat:{args.qat}") - compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) - model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model - - # Optimizer split: - # - token embedding (Adam) uses EMBED_LR - # - untied lm_head (Adam) uses HEAD_LR - # - matrix params in transformer blocks use MATRIX_LR via Muon - # - vectors/scalars use SCALAR_LR via Adam - block_named_params = list(base_model.blocks.named_parameters()) - matrix_params = [ - p - for name, p in block_named_params - if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - scalar_params = [ - p - for name, p in block_named_params - if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - if base_model.skip_weights.numel() > 0: - scalar_params.append(base_model.skip_weights) - token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr - optimizer_tok = torch.optim.Adam( - [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizer_muon = Muon( - matrix_params, - lr=args.matrix_lr, - momentum=args.muon_momentum, - backend_steps=args.muon_backend_steps, - ) - for group in optimizer_muon.param_groups: - group["base_lr"] = args.matrix_lr - optimizer_scalar = torch.optim.Adam( - [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] - if base_model.lora_adapters is not None: - lora_params = list(base_model.lora_adapters.parameters()) - optimizer_lora = torch.optim.Adam( - [{"params": lora_params, "lr": args.lora_lr, "base_lr": args.lora_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizers.append(optimizer_lora) - if base_model.lm_head is not None: - optimizer_head = torch.optim.Adam( - [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizers.insert(1, optimizer_head) - - n_params = sum(p.numel() for p in base_model.parameters()) - n_lora = sum(p.numel() for p in base_model.lora_adapters.parameters()) if base_model.lora_adapters is not None else 0 - effective_depth = args.num_layers * args.num_loops - log0(f"model_params:{n_params} (unique_layers:{args.num_layers} loops:{args.num_loops} effective_depth:{effective_depth} lora_rank:{args.lora_rank} lora_params:{n_lora})") - log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") - log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") - log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") - log0( - f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " - f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " - f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" - ) - log0( - f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " - f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " - f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" - ) - log0(f"seed:{args.seed}") - - # ----------------------------- - # DATA LOADER & MODEL WARMUP - # ----------------------------- - - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - def zero_grad_all() -> None: - for opt in optimizers: - opt.zero_grad(set_to_none=True) - - max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None - - def lr_mul(step: int, elapsed_ms: float) -> float: - if args.warmdown_iters <= 0: - return 1.0 - if max_wallclock_ms is None: - warmdown_start = max(args.iterations - args.warmdown_iters, 0) - return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 - step_ms = elapsed_ms / max(step, 1) - warmdown_ms = args.warmdown_iters * step_ms - remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) - return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 - - # Warmup primes the compiled forward/backward/optimizer paths, then we restore the - # initial weights/optimizer state so measured training starts from the true init. - if args.warmup_steps > 0: - initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} - initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] - model.train() - for warmup_step in range(args.warmup_steps): - zero_grad_all() - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - warmup_loss = model(x, y) - (warmup_loss * grad_scale).backward() - for opt in optimizers: - opt.step() - zero_grad_all() - if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: - log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") - base_model.load_state_dict(initial_model_state, strict=True) - for opt, state in zip(optimizers, initial_optimizer_states, strict=True): - opt.load_state_dict(state) - zero_grad_all() - if distributed: - model.require_backward_grad_sync = True - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - # ----------------------------- - # MAIN TRAINING LOOP - # ----------------------------- - - training_time_ms = 0.0 - stop_after_step: int | None = None - torch.cuda.synchronize() - t0 = time.perf_counter() - - step = 0 - while True: - last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) - - should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) - if should_validate: - torch.cuda.synchronize() - training_time_ms += 1000.0 * (time.perf_counter() - t0) - val_loss, val_bpb = eval_val( - args, - model, - rank, - world_size, - device, - grad_accum_steps, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, - ) - log0( - f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " - f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" - ) - torch.cuda.synchronize() - t0 = time.perf_counter() - - if last_step: - if stop_after_step is not None and step < args.iterations: - log0( - f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " - f"step:{step}/{args.iterations}" - ) - break - - elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - scale = lr_mul(step, elapsed_ms) - zero_grad_all() - train_loss = torch.zeros((), device=device) - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - loss = model(x, y) - train_loss += loss.detach() - (loss * grad_scale).backward() - train_loss /= grad_accum_steps - - frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 - muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum - for group in optimizer_muon.param_groups: - group["momentum"] = muon_momentum - - for opt in optimizers: - for group in opt.param_groups: - group["lr"] = group["base_lr"] * scale - - if args.grad_clip_norm > 0: - torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) - for opt in optimizers: - opt.step() - zero_grad_all() - - step += 1 - approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - should_log_train = ( - args.train_log_every > 0 - and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) - ) - if should_log_train: - log0( - f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " - f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" - ) - - # Needed to sync whether we've reached the wallclock cap. - reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms - if distributed and max_wallclock_ms is not None: - reached_cap_tensor = torch.tensor(int(reached_cap), device=device) - dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) - reached_cap = bool(reached_cap_tensor.item()) - if stop_after_step is None and reached_cap: - stop_after_step = step - - log0( - f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " - f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" - ) - - # ----------------------------- - # SERIALIZATION + ROUNDTRIP VALIDATION - # ----------------------------- - # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce - # the compressed int8+zlib artifact and validate the round-tripped weights. - - if master_process: - torch.save(base_model.state_dict(), "final_model.pt") - model_bytes = os.path.getsize("final_model.pt") - code_bytes = len(code.encode("utf-8")) - log0(f"Serialized model: {model_bytes} bytes") - log0(f"Code size: {code_bytes} bytes") - log0(f"Total submission size: {model_bytes + code_bytes} bytes") - - quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) - quant_buf = io.BytesIO() - torch.save(quant_obj, quant_buf) - quant_raw = quant_buf.getvalue() - quant_blob = zlib.compress(quant_raw, level=9) - quant_raw_bytes = len(quant_raw) - if master_process: - with open("final_model.int8.ptz", "wb") as f: - f.write(quant_blob) - quant_file_bytes = os.path.getsize("final_model.int8.ptz") - code_bytes = len(code.encode("utf-8")) - ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) - log0( - f"Serialized model int8+zlib: {quant_file_bytes} bytes " - f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" - ) - log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") - - if distributed: - dist.barrier() - with open("final_model.int8.ptz", "rb") as f: - quant_blob_disk = f.read() - quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") - base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) - torch.cuda.synchronize() - t_qeval = time.perf_counter() - if args.eval_stride > 0 and args.eval_stride < args.train_seq_len: - log0(f"final_eval_mode:sliding_window stride:{args.eval_stride} batch_seqs:{args.eval_batch_seqs}") - q_val_loss, q_val_bpb = eval_val_sliding( - args, - base_model, - rank, - world_size, - device, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, - stride=args.eval_stride, - batch_seqs=args.eval_batch_seqs, - ) - else: - log0("final_eval_mode:standard") - q_val_loss, q_val_bpb = eval_val( - args, - model, - rank, - world_size, - device, - grad_accum_steps, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, - ) - torch.cuda.synchronize() - log0( - f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " - f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" - ) - log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") - - if distributed: - dist.destroy_process_group() - - -if __name__ == "__main__": - main() diff --git a/records/track_10min_16mb/2026-03-19_SlidingWindow_FP16Emb_10L_MuonWD_OvertoneInit/README.md b/records/track_10min_16mb/2026-03-19_SlidingWindow_FP16Emb_10L_MuonWD_OvertoneInit/README.md deleted file mode 100644 index d873eb5a60..0000000000 --- a/records/track_10min_16mb/2026-03-19_SlidingWindow_FP16Emb_10L_MuonWD_OvertoneInit/README.md +++ /dev/null @@ -1,28 +0,0 @@ -# Sliding Window + FP16 Embed + 10L + Muon WD + Overtone Init - -**Mean val_bpb: 1.1748** (3 seeds, p<0.001) - -## Key Techniques - -1. **Sliding window evaluation** (stride=64, seq_len=1024): Every token scored with 960+ context instead of 0-1023 average. Compiled `forward_logits` method for efficient batch inference. - -2. **FP16 tied embedding export**: Keep `tok_emb.weight` in fp16 — int8 errors compound through both input and output paths. - -3. **10 transformer layers** (up from 9): Muon weight decay compresses enough to fit the extra layer. - -4. **Decoupled weight decay for Muon optimizer** (0.02): Improves generalization and quantization robustness. - -5. **Overtone spectral embedding init**: SVD power-law spectrum shaping (`S_k ~ k^{-0.5}`). - -6. **Phase-transition residual mixing**: Sigmoid-scheduled `resid_mix` initialization. - -## Results - -| Seed | val_loss | val_bpb | Steps | ms/step | -|------|----------|---------|-------|---------| -| 1337 | 1.9849 | 1.1756 | 10424 | 57.55 | -| 42 | 1.9827 | 1.1742 | 10710 | 56.06 | -| 7 | 1.9830 | 1.1744 | 10498 | 57.18 | -| **Mean** | **1.9835** | **1.1748** | | | - -Artifact: ~14.7 MB | Eval time: ~162s (sliding window) diff --git a/records/track_10min_16mb/2026-03-19_SlidingWindow_FP16Emb_10L_MuonWD_OvertoneInit/submission.json b/records/track_10min_16mb/2026-03-19_SlidingWindow_FP16Emb_10L_MuonWD_OvertoneInit/submission.json deleted file mode 100644 index 968ccf57e1..0000000000 --- a/records/track_10min_16mb/2026-03-19_SlidingWindow_FP16Emb_10L_MuonWD_OvertoneInit/submission.json +++ /dev/null @@ -1,16 +0,0 @@ -{ - "track": "10min_16mb", - "date": "2026-03-19", - "name": "Sliding Window + FP16 Embed + 10L + Muon WD + Overtone Init", - "author": "notapplica", - "seed_results": { - "1337": {"val_loss": 1.98492632, "val_bpb": 1.17558517, "steps": 10424, "ms_per_step": 57.55}, - "42": {"val_loss": 1.98265459, "val_bpb": 1.17423973, "steps": 10710, "ms_per_step": 56.06}, - "7": {"val_loss": 1.98298356, "val_bpb": 1.17443456, "steps": 10498, "ms_per_step": 57.18} - }, - "mean_val_loss": 1.98352149, - "mean_val_bpb": 1.17475315, - "p_value": 0.0001, - "artifact_bytes": 15374243, - "code_bytes": 50651 -} diff --git a/records/track_10min_16mb/2026-03-19_SlidingWindow_FP16Emb_10L_MuonWD_OvertoneInit/train_gpt.py b/records/track_10min_16mb/2026-03-19_SlidingWindow_FP16Emb_10L_MuonWD_OvertoneInit/train_gpt.py deleted file mode 100644 index a631fc7c8b..0000000000 --- a/records/track_10min_16mb/2026-03-19_SlidingWindow_FP16Emb_10L_MuonWD_OvertoneInit/train_gpt.py +++ /dev/null @@ -1,1298 +0,0 @@ -""" -The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. - -Hard stop: `train_gpt.py` and `train_gpt_mlx.py` must never be longer than 1500 lines. -""" - -from __future__ import annotations - -import copy -import glob -import io -import math -import os -import random -import subprocess -import sys -import time -import uuid -import zlib -from pathlib import Path - -import numpy as np -import sentencepiece as spm -import torch -import torch.distributed as dist -import torch.nn.functional as F -from torch import Tensor, nn -from torch.nn.parallel import DistributedDataParallel as DDP - -# ----------------------------- -# HYPERPARAMETERS -# ----------------------------- -# Default Simple Baseline run: -# - 9 transformer blocks at width 512 -# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion -# - vocab size 1024, sequence length 1024, tied embeddings -# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap - -class Hyperparameters: - # Data paths are shard globs produced by the existing preprocessing pipeline. - data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") - train_files = os.path.join(data_path, "fineweb_train_*.bin") - val_files = os.path.join(data_path, "fineweb_val_*.bin") - tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") - run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) - seed = int(os.environ.get("SEED", 1337)) - - # Validation cadence and batch size. Validation always uses the full fineweb_val split. - val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) - val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) - train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) - - # Training length. - iterations = int(os.environ.get("ITERATIONS", 20000)) - warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 2500)) - warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) - train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) - train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) - eval_seq_len = int(os.environ.get("EVAL_SEQ_LEN", 0)) # 0 = same as train_seq_len - eval_stride = int(os.environ.get("EVAL_STRIDE", 0)) # 0 = non-overlapping; >0 = sliding window - max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) - qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) - - # Model shape. - vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) - num_layers = int(os.environ.get("NUM_LAYERS", 10)) - num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) - model_dim = int(os.environ.get("MODEL_DIM", 512)) - num_heads = int(os.environ.get("NUM_HEADS", 8)) - mlp_mult = int(os.environ.get("MLP_MULT", 2)) - tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) - rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) - logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) - - # Optimizer hyperparameters. - embed_lr = float(os.environ.get("EMBED_LR", 0.6)) - head_lr = float(os.environ.get("HEAD_LR", 0.008)) - tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.10)) - tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) - matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) - scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) - muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) - muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) - muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) - muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) - beta1 = float(os.environ.get("BETA1", 0.9)) - beta2 = float(os.environ.get("BETA2", 0.95)) - adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) - grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) - -# ----------------------------- -# MUON OPTIMIZER -# ----------------------------- -# -# As borrowed from modded-nanogpt -# Background on Muon: https://kellerjordan.github.io/posts/muon/ - -def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: - # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. - # Muon uses this to normalize matrix-shaped gradients before applying them. - a, b, c = (3.4445, -4.7750, 2.0315) - X = G.bfloat16() - X /= X.norm() + eps - transposed = G.size(0) > G.size(1) - if transposed: - X = X.T - for _ in range(steps): - A = X @ X.T - B = b * A + c * A @ A - X = a * X + B @ X - return X.T if transposed else X - - -class Muon(torch.optim.Optimizer): - def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): - super().__init__( - params, - dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), - ) - - @torch.no_grad() - def step(self, closure=None): - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - distributed = dist.is_available() and dist.is_initialized() - world_size = dist.get_world_size() if distributed else 1 - rank = dist.get_rank() if distributed else 0 - - for group in self.param_groups: - params = group["params"] - if not params: - continue - lr = group["lr"] - momentum = group["momentum"] - backend_steps = group["backend_steps"] - nesterov = group["nesterov"] - - total_params = sum(int(p.numel()) for p in params) - updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) - - curr = 0 - for i, p in enumerate(params): - if i % world_size == rank and p.grad is not None: - g = p.grad - state = self.state[p] - if "momentum_buffer" not in state: - state["momentum_buffer"] = torch.zeros_like(g) - buf = state["momentum_buffer"] - buf.mul_(momentum).add_(g) - if nesterov: - g = g.add(buf, alpha=momentum) - g = zeropower_via_newtonschulz5(g, steps=backend_steps) - # Scale correction from Muon reference implementations. - g *= max(1, g.size(0) / g.size(1)) ** 0.5 - updates_flat[curr : curr + p.numel()] = g.reshape(-1) - curr += p.numel() - - if distributed: - dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) - - curr = 0 - for p in params: - g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) - p.add_(g, alpha=-lr) - curr += p.numel() - - return loss - - -# ----------------------------- -# TOKENIZER-AGNOSTIC EVALUATION SETUP -# ----------------------------- -# -# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. -# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. -# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. -# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. - -def build_sentencepiece_luts( - sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device -) -> tuple[Tensor, Tensor, Tensor]: - sp_vocab_size = int(sp.vocab_size()) - table_size = max(sp_vocab_size, vocab_size) - base_bytes_np = np.zeros((table_size,), dtype=np.int16) - has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) - is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) - for token_id in range(sp_vocab_size): - if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): - continue - is_boundary_token_np[token_id] = False - if sp.is_byte(token_id): - base_bytes_np[token_id] = 1 - continue - piece = sp.id_to_piece(token_id) - if piece.startswith("▁"): - has_leading_space_np[token_id] = True - piece = piece[1:] - base_bytes_np[token_id] = len(piece.encode("utf-8")) - return ( - torch.tensor(base_bytes_np, dtype=torch.int16, device=device), - torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), - torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), - ) - - -def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: - files = [Path(p) for p in sorted(glob.glob(pattern))] - if not files: - raise FileNotFoundError(f"No files found for pattern: {pattern}") - # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. - tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() - usable = ((tokens.numel() - 1) // seq_len) * seq_len - if usable <= 0: - raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") - return tokens[: usable + 1] - - -def eval_val( - args: Hyperparameters, - model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - grad_accum_steps: int, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, - seq_len_override: int = 0, -) -> tuple[float, float]: - # Validation computes two metrics: - # - val_loss: token cross-entropy (natural log) - # - val_bpb: tokenizer-agnostic compression metric used by the challenge - seq_len = seq_len_override if seq_len_override > 0 else args.train_seq_len - local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) - if local_batch_tokens < seq_len: - raise ValueError( - "VAL_BATCH_SIZE must provide at least one sequence per rank; " - f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " - f"GRAD_ACCUM_STEPS={grad_accum_steps}, seq_len={seq_len}" - ) - local_batch_seqs = local_batch_tokens // seq_len - total_seqs = (val_tokens.numel() - 1) // seq_len - seq_start = (total_seqs * rank) // world_size - seq_end = (total_seqs * (rank + 1)) // world_size - val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) - val_token_count = torch.zeros((), device=device, dtype=torch.float64) - val_byte_count = torch.zeros((), device=device, dtype=torch.float64) - - model.eval() - with torch.inference_mode(): - for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): - batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) - raw_start = batch_seq_start * seq_len - raw_end = batch_seq_end * seq_len + 1 - local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) - x = local[:-1].reshape(-1, seq_len) - y = local[1:].reshape(-1, seq_len) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - batch_loss = model(x, y).detach() - batch_token_count = float(y.numel()) - val_loss_sum += batch_loss.to(torch.float64) * batch_token_count - val_token_count += batch_token_count - prev_ids = x.reshape(-1) - tgt_ids = y.reshape(-1) - token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) - token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) - val_byte_count += token_bytes.to(torch.float64).sum() - - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) - - val_loss = val_loss_sum / val_token_count - bits_per_token = val_loss.item() / math.log(2.0) - tokens_per_byte = val_token_count.item() / val_byte_count.item() - model.train() - return float(val_loss.item()), float(bits_per_token * tokens_per_byte) - -# ----------------------------- -# POST-TRAINING QUANTIZATION -# ----------------------------- -# -# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. -# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. -# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. - -CONTROL_TENSOR_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "CONTROL_TENSOR_NAME_PATTERNS", - "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", - ",".join(CONTROL_TENSOR_NAME_PATTERNS), - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 -INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 -INT8_PER_ROW_SCALE_DTYPE = torch.float16 -INT8_CLIP_PERCENTILE = 99.99984 -INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 - -def tensor_nbytes(t: Tensor) -> int: - return int(t.numel()) * int(t.element_size()) - -def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: - if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): - return t.float().contiguous() - if t.dtype in {torch.float32, torch.bfloat16}: - passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") - return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() - return t - -def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: - t32 = t.float() - if t32.ndim == 2: - # Matrices get one scale per row, which usually tracks output-channel - # ranges much better than a single tensor-wide scale. - clip_abs = ( - torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) - if t32.numel() - else torch.empty((t32.shape[0],), dtype=torch.float32) - ) - clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) - scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) - q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() - return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() - - # Vectors / scalars use a simpler per-tensor scale. - clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 - scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) - q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() - return q, scale - -def quantize_state_dict_int8(state_dict: dict[str, Tensor]): - # Single supported clean-script export format: - # - per-row int8 for 2D float tensors - # - per-tensor int8 for other float tensors - # - exact passthrough for non-floats - # - passthrough for small float tensors, stored as fp16 to save bytes - quantized: dict[str, Tensor] = {} - scales: dict[str, Tensor] = {} - dtypes: dict[str, str] = {} - passthrough: dict[str, Tensor] = {} - passthrough_orig_dtypes: dict[str, str] = {} - qmeta: dict[str, dict[str, object]] = {} - stats = dict.fromkeys( - ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), - 0, - ) - - for name, tensor in state_dict.items(): - t = tensor.detach().to("cpu").contiguous() - stats["param_count"] += int(t.numel()) - stats["num_tensors"] += 1 - stats["baseline_tensor_bytes"] += tensor_nbytes(t) - - if not t.is_floating_point(): - stats["num_nonfloat_tensors"] += 1 - passthrough[name] = t - stats["int8_payload_bytes"] += tensor_nbytes(t) - continue - - # Keep tied embedding in fp16 (int8 quantization hurts both input and output paths) - if "tok_emb" in name: - kept = t.to(dtype=torch.float16).contiguous() - passthrough[name] = kept - passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") - stats["int8_payload_bytes"] += tensor_nbytes(kept) - continue - - # Small float tensors are cheap enough to keep directly. We still downcast - # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. - if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: - kept = keep_float_tensor(name, t, passthrough_orig_dtypes) - passthrough[name] = kept - stats["int8_payload_bytes"] += tensor_nbytes(kept) - continue - - stats["num_float_tensors"] += 1 - q, s = quantize_float_tensor(t) - if s.ndim > 0: - qmeta[name] = {"scheme": "per_row", "axis": 0} - quantized[name] = q - scales[name] = s - dtypes[name] = str(t.dtype).removeprefix("torch.") - stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) - - obj: dict[str, object] = { - "__quant_format__": "int8_clean_per_row_v1", - "quantized": quantized, - "scales": scales, - "dtypes": dtypes, - "passthrough": passthrough, - } - if qmeta: - obj["qmeta"] = qmeta - if passthrough_orig_dtypes: - obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes - return obj, stats - -def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: - out: dict[str, Tensor] = {} - qmeta = obj.get("qmeta", {}) - passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) - for name, q in obj["quantized"].items(): - dtype = getattr(torch, obj["dtypes"][name]) - s = obj["scales"][name] - if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: - s = s.to(dtype=torch.float32) - # Broadcast the saved row scale back across trailing dimensions. - out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() - else: - scale = float(s.item()) - out[name] = (q.float() * scale).to(dtype=dtype).contiguous() - for name, t in obj["passthrough"].items(): - # Restore small tensors, undoing the temporary fp16 storage cast if needed. - out_t = t.detach().to("cpu").contiguous() - orig_dtype = passthrough_orig_dtypes.get(name) - if isinstance(orig_dtype, str): - out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() - out[name] = out_t - return out - - -# ----------------------------- -# DATA LOADING -# ----------------------------- - -def load_data_shard(file: Path) -> Tensor: - header_bytes = 256 * np.dtype(" None: - self.file_idx = (self.file_idx + 1) % len(self.files) - self.tokens = load_data_shard(self.files[self.file_idx]) - self.pos = 0 - - def take(self, n: int) -> Tensor: - chunks: list[Tensor] = [] - remaining = n - while remaining > 0: - avail = self.tokens.numel() - self.pos - if avail <= 0: - self._advance_file() - continue - k = min(remaining, avail) - chunks.append(self.tokens[self.pos : self.pos + k]) - self.pos += k - remaining -= k - return chunks[0] if len(chunks) == 1 else torch.cat(chunks) - - -class DistributedTokenLoader: - # Each call consumes a contiguous chunk from the shared token stream, then slices out - # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. - def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): - self.rank = rank - self.world_size = world_size - self.device = device - self.stream = TokenStream(pattern) - - def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: - local_tokens = global_tokens // (self.world_size * grad_accum_steps) - per_rank_span = local_tokens + 1 - chunk = self.stream.take(per_rank_span * self.world_size) - start = self.rank * per_rank_span - local = chunk[start : start + per_rank_span].to(dtype=torch.int64) - x = local[:-1].reshape(-1, seq_len) - y = local[1:].reshape(-1, seq_len) - return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) - -# ----------------------------- -# TRANSFORMER MODULES -# ----------------------------- - -class RMSNorm(nn.Module): - def __init__(self, eps: float | None = None): - super().__init__() - self.eps = eps - - def forward(self, x: Tensor) -> Tensor: - return F.rms_norm(x, (x.size(-1),), eps=self.eps) - - -class CastedLinear(nn.Linear): - # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. - def forward(self, x: Tensor) -> Tensor: - bias = self.bias.to(x.dtype) if self.bias is not None else None - return F.linear(x, self.weight.to(x.dtype), bias) - - -def restore_low_dim_params_to_fp32(module: nn.Module) -> None: - # Keep small/control parameters in fp32 even when the model body runs in bf16. - with torch.no_grad(): - for name, param in module.named_parameters(): - if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: - param.data = param.data.float() - - -class Rotary(nn.Module): - # Caches cos/sin tables per sequence length on the current device. - # Supports NTK-aware dynamic scaling for eval at longer sequence lengths. - def __init__(self, dim: int, base: float = 10000.0, train_seq_len: int = 1024): - super().__init__() - self.dim = dim - self.base = base - self.train_seq_len = train_seq_len - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self._seq_len_cached = 0 - self._cos_cached: Tensor | None = None - self._sin_cached: Tensor | None = None - - def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: - if ( - self._cos_cached is None - or self._sin_cached is None - or self._seq_len_cached != seq_len - or self._cos_cached.device != device - ): - # NTK-aware dynamic scaling when eval seq_len > train seq_len - if seq_len > self.train_seq_len: - scale = seq_len / self.train_seq_len - adjusted_base = self.base * (scale ** (self.dim / (self.dim - 2))) - inv_freq = 1.0 / (adjusted_base ** (torch.arange(0, self.dim, 2, dtype=torch.float32, device=device) / self.dim)) - else: - inv_freq = self.inv_freq.to(device) - t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) - freqs = torch.outer(t, inv_freq) - self._cos_cached = freqs.cos()[None, None, :, :] - self._sin_cached = freqs.sin()[None, None, :, :] - self._seq_len_cached = seq_len - return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) - - -def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: - half = x.size(-1) // 2 - x1, x2 = x[..., :half], x[..., half:] - return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) - - -class CausalSelfAttention(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - rope_base: float, - qk_gain_init: float, - train_seq_len: int = 1024, - ): - super().__init__() - if dim % num_heads != 0: - raise ValueError("model_dim must be divisible by num_heads") - if num_heads % num_kv_heads != 0: - raise ValueError("num_heads must be divisible by num_kv_heads") - self.num_heads = num_heads - self.num_kv_heads = num_kv_heads - self.head_dim = dim // num_heads - if self.head_dim % 2 != 0: - raise ValueError("head_dim must be even for RoPE") - kv_dim = self.num_kv_heads * self.head_dim - self.c_q = CastedLinear(dim, dim, bias=False) - self.c_k = CastedLinear(dim, kv_dim, bias=False) - self.c_v = CastedLinear(dim, kv_dim, bias=False) - self.proj = CastedLinear(dim, dim, bias=False) - self.proj._zero_init = True - self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) - self.rotary = Rotary(self.head_dim, base=rope_base, train_seq_len=train_seq_len) - - def forward(self, x: Tensor) -> Tensor: - bsz, seqlen, dim = x.shape - q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) - k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) - v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) - q = F.rms_norm(q, (q.size(-1),)) - k = F.rms_norm(k, (k.size(-1),)) - cos, sin = self.rotary(seqlen, x.device, q.dtype) - q = apply_rotary_emb(q, cos, sin) - k = apply_rotary_emb(k, cos, sin) - q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] - y = F.scaled_dot_product_attention( - q, - k, - v, - attn_mask=None, - is_causal=True, - enable_gqa=(self.num_kv_heads != self.num_heads), - ) - y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) - return self.proj(y) - - -class MLP(nn.Module): - # relu^2 MLP from the original modded-nanogpt setup - def __init__(self, dim: int, mlp_mult: int): - super().__init__() - hidden = mlp_mult * dim - self.fc = CastedLinear(dim, hidden, bias=False) - self.proj = CastedLinear(hidden, dim, bias=False) - self.proj._zero_init = True - - def forward(self, x: Tensor) -> Tensor: - x = torch.relu(self.fc(x)) - return self.proj(x.square()) - - -class Block(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: int, - rope_base: float, - qk_gain_init: float, - train_seq_len: int = 1024, - ): - super().__init__() - self.attn_norm = RMSNorm() - self.mlp_norm = RMSNorm() - self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init, train_seq_len=train_seq_len) - self.mlp = MLP(dim, mlp_mult) - self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) - - def forward(self, x: Tensor, x0: Tensor) -> Tensor: - mix = self.resid_mix.to(dtype=x.dtype) - x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 - attn_out = self.attn(self.attn_norm(x)) - x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out - x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) - return x - - -class GPT(nn.Module): - def __init__( - self, - vocab_size: int, - num_layers: int, - model_dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: int, - tie_embeddings: bool, - tied_embed_init_std: float, - logit_softcap: float, - rope_base: float, - qk_gain_init: float, - train_seq_len: int = 1024, - ): - super().__init__() - if logit_softcap <= 0.0: - raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") - self.tie_embeddings = tie_embeddings - self.tied_embed_init_std = tied_embed_init_std - self.logit_softcap = logit_softcap - self.tok_emb = nn.Embedding(vocab_size, model_dim) - self.num_encoder_layers = num_layers // 2 - self.num_decoder_layers = num_layers - self.num_encoder_layers - self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) - self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) - self.blocks = nn.ModuleList( - [ - Block( - model_dim, - num_heads, - num_kv_heads, - mlp_mult, - rope_base, - qk_gain_init, - train_seq_len=train_seq_len, - ) - for i in range(num_layers) - ] - ) - self.final_norm = RMSNorm() - self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) - if self.lm_head is not None: - self.lm_head._zero_init = True - self._init_weights() - - def _init_weights(self) -> None: - if self.tie_embeddings: - nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) - # Overtone init: shape embedding spectrum to power-law decay (like guitar harmonics) - with torch.no_grad(): - U, S, V = torch.linalg.svd(self.tok_emb.weight.data, full_matrices=False) - target_S = S[0] * (1.0 / torch.arange(1, S.shape[0] + 1, dtype=S.dtype)) ** 0.5 - self.tok_emb.weight.data = (U * target_S[None, :]) @ V - for module in self.modules(): - if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): - nn.init.zeros_(module.weight) - # Phase-transition resid_mix: early layers trust x0 more, late layers trust residual - num_layers = len(self.blocks) - for i, block in enumerate(self.blocks): - with torch.no_grad(): - phase = torch.sigmoid(torch.tensor(3.0 * (i / max(num_layers - 1, 1) - 0.5))) - block.resid_mix.data[0] = phase * torch.ones(block.resid_mix.shape[1]) - block.resid_mix.data[1] = (1 - phase) * torch.ones(block.resid_mix.shape[1]) - - def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: - x = self.tok_emb(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - x0 = x - skips: list[Tensor] = [] - - # First half stores skips; second half reuses them in reverse order. - for i in range(self.num_encoder_layers): - x = self.blocks[i](x, x0) - skips.append(x) - for i in range(self.num_decoder_layers): - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - x = self.blocks[self.num_encoder_layers + i](x, x0) - - x_norm = self.final_norm(x) - x_flat = x_norm.reshape(-1, x_norm.size(-1)) - targets = target_ids.reshape(-1) - if self.tie_embeddings: - logits_proj = F.linear(x_flat, self.tok_emb.weight) - else: - if self.lm_head is None: - raise RuntimeError("lm_head is required when tie_embeddings=False") - logits_proj = self.lm_head(x_flat) - logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - loss = F.cross_entropy(logits.float(), targets, reduction="mean") - return loss - - def forward_logits(self, input_ids: Tensor) -> Tensor: - """Forward pass returning logits [batch, seq_len, vocab]. For sliding window eval.""" - x = self.tok_emb(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - x0 = x - skips: list[Tensor] = [] - for i in range(self.num_encoder_layers): - x = self.blocks[i](x, x0) - skips.append(x) - for i in range(self.num_decoder_layers): - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - x = self.blocks[self.num_encoder_layers + i](x, x0) - x = self.final_norm(x) - if self.tie_embeddings: - logits = F.linear(x, self.tok_emb.weight.to(x.dtype)) - else: - logits = self.lm_head(x) - return self.logit_softcap * torch.tanh(logits / self.logit_softcap) - - -def eval_val_sliding( - logits_fn, - rank: int, - world_size: int, - device: torch.device, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, - seq_len: int, - stride: int, - eval_batch_seqs: int = 256, -) -> tuple[float, float]: - """Sliding window evaluation: each token scored with near-full context.""" - total = val_tokens.numel() - 1 - - # Build windows: (start_pos, score_offset) - windows: list[tuple[int, int]] = [] - p = 0 - while p + seq_len <= total: - s = 0 if p == 0 else (seq_len - stride) - windows.append((p, s)) - p += stride - - # Distribute across ranks - n = len(windows) - per_rank = (n + world_size - 1) // world_size - my_start = rank * per_rank - my_end = min(my_start + per_rank, n) - my_windows = windows[my_start:my_end] - - loss_sum = torch.zeros((), device=device, dtype=torch.float64) - tok_count = torch.zeros((), device=device, dtype=torch.float64) - byte_count = torch.zeros((), device=device, dtype=torch.float64) - - with torch.inference_mode(): - for i in range(0, len(my_windows), eval_batch_seqs): - batch = my_windows[i : i + eval_batch_seqs] - bs = len(batch) - - # Pad to eval_batch_seqs to avoid recompilation - x_list = [val_tokens[w : w + seq_len] for w, _ in batch] - y_list = [val_tokens[w + 1 : w + seq_len + 1] for w, _ in batch] - pad = eval_batch_seqs - bs - if pad > 0: - x_list.extend([x_list[-1]] * pad) - y_list.extend([y_list[-1]] * pad) - - x = torch.stack(x_list).to(device=device, dtype=torch.int64) - y = torch.stack(y_list).to(device=device, dtype=torch.int64) - - with torch.autocast(device_type="cuda", dtype=torch.bfloat16): - logits = logits_fn(x) - - for b in range(bs): - s = batch[b][1] - scored_logits = logits[b, s:] - scored_targets = y[b, s:] - - loss = F.cross_entropy(scored_logits.float(), scored_targets, reduction="sum") - loss_sum += loss.to(torch.float64) - ns = scored_targets.numel() - tok_count += ns - - prev = x[b, s : s + ns] - tgt = scored_targets - tb = base_bytes_lut[tgt].to(torch.int16) - tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.int16) - byte_count += tb.to(torch.float64).sum() - - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(tok_count, op=dist.ReduceOp.SUM) - dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) - - val_loss = (loss_sum / tok_count).item() - bpb = val_loss / math.log(2.0) * (tok_count.item() / byte_count.item()) - return val_loss, bpb - - -# ----------------------------- -# TRAINING -# ----------------------------- - -def main() -> None: - global zeropower_via_newtonschulz5 - - code = Path(__file__).read_text(encoding="utf-8") - args = Hyperparameters() - zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) - - # ----------------------------- - # DISTRIBUTED + CUDA SETUP - # ----------------------------- - - distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ - rank = int(os.environ.get("RANK", "0")) - world_size = int(os.environ.get("WORLD_SIZE", "1")) - local_rank = int(os.environ.get("LOCAL_RANK", "0")) - if world_size <= 0: - raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") - if 8 % world_size != 0: - raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") - grad_accum_steps = 8 // world_size - grad_scale = 1.0 / grad_accum_steps - if not torch.cuda.is_available(): - raise RuntimeError("CUDA is required") - device = torch.device("cuda", local_rank) - torch.cuda.set_device(device) - if distributed: - dist.init_process_group(backend="nccl", device_id=device) - dist.barrier() - master_process = rank == 0 - - # Fast math knobs - torch.backends.cuda.matmul.allow_tf32 = True - torch.backends.cudnn.allow_tf32 = True - from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp - - enable_cudnn_sdp(False) - enable_flash_sdp(True) - enable_mem_efficient_sdp(False) - enable_math_sdp(False) - - logfile = None - if master_process: - os.makedirs("logs", exist_ok=True) - logfile = f"logs/{args.run_id}.txt" - print(logfile) - - def log0(msg: str, console: bool = True) -> None: - if not master_process: - return - if console: - print(msg) - if logfile is not None: - with open(logfile, "a", encoding="utf-8") as f: - print(msg, file=f) - - log0(code, console=False) - log0("=" * 100, console=False) - log0(f"Running Python {sys.version}", console=False) - log0(f"Running PyTorch {torch.__version__}", console=False) - log0( - subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, - console=False, - ) - log0("=" * 100, console=False) - - # ----------------------------- - # TOKENIZER + VALIDATION METRIC SETUP - # ----------------------------- - - random.seed(args.seed) - np.random.seed(args.seed) - torch.manual_seed(args.seed) - torch.cuda.manual_seed_all(args.seed) - - if not args.tokenizer_path.endswith(".model"): - raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") - sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) - if int(sp.vocab_size()) != args.vocab_size: - raise ValueError( - f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" - ) - dataset_dir = Path(args.data_path).resolve() - actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) - val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) - base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( - sp, args.vocab_size, device - ) - log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") - log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") - log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") - - # ----------------------------- - # MODEL + OPTIMIZER SETUP - # ----------------------------- - - base_model = GPT( - vocab_size=args.vocab_size, - num_layers=args.num_layers, - model_dim=args.model_dim, - num_heads=args.num_heads, - num_kv_heads=args.num_kv_heads, - mlp_mult=args.mlp_mult, - tie_embeddings=args.tie_embeddings, - tied_embed_init_std=args.tied_embed_init_std, - logit_softcap=args.logit_softcap, - rope_base=args.rope_base, - qk_gain_init=args.qk_gain_init, - train_seq_len=args.train_seq_len, - ).to(device).bfloat16() - for module in base_model.modules(): - if isinstance(module, CastedLinear): - module.float() - restore_low_dim_params_to_fp32(base_model) - compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) - model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model - - # Optimizer split: - # - token embedding (Adam) uses EMBED_LR - # - untied lm_head (Adam) uses HEAD_LR - # - matrix params in transformer blocks use MATRIX_LR via Muon - # - vectors/scalars use SCALAR_LR via Adam - block_named_params = list(base_model.blocks.named_parameters()) - matrix_params = [ - p - for name, p in block_named_params - if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - scalar_params = [ - p - for name, p in block_named_params - if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - if base_model.skip_weights.numel() > 0: - scalar_params.append(base_model.skip_weights) - token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr - optimizer_tok = torch.optim.AdamW( - [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - weight_decay=0.01, - fused=True, - ) - optimizer_muon = Muon( - matrix_params, - lr=args.matrix_lr, - momentum=args.muon_momentum, - backend_steps=args.muon_backend_steps, - ) - for group in optimizer_muon.param_groups: - group["base_lr"] = args.matrix_lr - optimizer_scalar = torch.optim.AdamW( - [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - weight_decay=0.01, - fused=True, - ) - optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] - if base_model.lm_head is not None: - optimizer_head = torch.optim.Adam( - [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizers.insert(1, optimizer_head) - - n_params = sum(p.numel() for p in base_model.parameters()) - log0(f"model_params:{n_params}") - log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") - log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") - log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") - log0( - f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " - f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " - f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" - ) - log0( - f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " - f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " - f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" - ) - log0(f"seed:{args.seed}") - - # ----------------------------- - # DATA LOADER & MODEL WARMUP - # ----------------------------- - - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - def zero_grad_all() -> None: - for opt in optimizers: - opt.zero_grad(set_to_none=True) - - max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None - - def lr_mul(step: int, elapsed_ms: float) -> float: - if args.warmdown_iters <= 0: - return 1.0 - if max_wallclock_ms is None: - warmdown_start = max(args.iterations - args.warmdown_iters, 0) - return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 - step_ms = elapsed_ms / max(step, 1) - warmdown_ms = args.warmdown_iters * step_ms - remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) - return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 - - # Warmup primes the compiled forward/backward/optimizer paths, then we restore the - # initial weights/optimizer state so measured training starts from the true init. - if args.warmup_steps > 0: - initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} - initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] - model.train() - for warmup_step in range(args.warmup_steps): - zero_grad_all() - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - warmup_loss = model(x, y) - (warmup_loss * grad_scale).backward() - for opt in optimizers: - opt.step() - zero_grad_all() - if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: - log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") - base_model.load_state_dict(initial_model_state, strict=True) - for opt, state in zip(optimizers, initial_optimizer_states, strict=True): - opt.load_state_dict(state) - zero_grad_all() - if distributed: - model.require_backward_grad_sync = True - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - # ----------------------------- - # MAIN TRAINING LOOP - # ----------------------------- - - training_time_ms = 0.0 - stop_after_step: int | None = None - torch.cuda.synchronize() - t0 = time.perf_counter() - - step = 0 - while True: - last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) - - should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) - if should_validate: - torch.cuda.synchronize() - training_time_ms += 1000.0 * (time.perf_counter() - t0) - val_loss, val_bpb = eval_val( - args, - model, - rank, - world_size, - device, - grad_accum_steps, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, - ) - log0( - f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " - f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" - ) - torch.cuda.synchronize() - t0 = time.perf_counter() - - if last_step: - if stop_after_step is not None and step < args.iterations: - log0( - f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " - f"step:{step}/{args.iterations}" - ) - break - - elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - scale = lr_mul(step, elapsed_ms) - zero_grad_all() - train_loss = torch.zeros((), device=device) - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - loss = model(x, y) - train_loss += loss.detach() - (loss * grad_scale).backward() - train_loss /= grad_accum_steps - - frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 - muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum - for group in optimizer_muon.param_groups: - group["momentum"] = muon_momentum - - for opt in optimizers: - for group in opt.param_groups: - group["lr"] = group["base_lr"] * scale - - if args.grad_clip_norm > 0: - torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) - for opt in optimizers: - opt.step() - # Decoupled weight decay for Muon-optimized matrix params (not built into Muon) - with torch.no_grad(): - for p in matrix_params: - p.mul_(1.0 - 0.02 * optimizer_muon.param_groups[0]["lr"]) - zero_grad_all() - - step += 1 - approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - should_log_train = ( - args.train_log_every > 0 - and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) - ) - if should_log_train: - log0( - f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " - f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" - ) - - # Needed to sync whether we've reached the wallclock cap. - reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms - if distributed and max_wallclock_ms is not None: - reached_cap_tensor = torch.tensor(int(reached_cap), device=device) - dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) - reached_cap = bool(reached_cap_tensor.item()) - if stop_after_step is None and reached_cap: - stop_after_step = step - - log0( - f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " - f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" - ) - - # ----------------------------- - # SERIALIZATION + ROUNDTRIP VALIDATION - # ----------------------------- - # Save the raw state, then produce the compressed int8+zlib artifact. - export_state = dict(base_model.state_dict()) - - if master_process: - torch.save(export_state, "final_model.pt") - model_bytes = os.path.getsize("final_model.pt") - code_bytes = len(code.encode("utf-8")) - log0(f"Serialized model: {model_bytes} bytes") - log0(f"Code size: {code_bytes} bytes") - log0(f"Total submission size: {model_bytes + code_bytes} bytes") - - quant_obj, quant_stats = quantize_state_dict_int8(export_state) - quant_buf = io.BytesIO() - torch.save(quant_obj, quant_buf) - quant_raw = quant_buf.getvalue() - quant_blob = zlib.compress(quant_raw, level=9) - quant_raw_bytes = len(quant_raw) - if master_process: - with open("final_model.int8.ptz", "wb") as f: - f.write(quant_blob) - quant_file_bytes = os.path.getsize("final_model.int8.ptz") - code_bytes = len(code.encode("utf-8")) - ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) - log0( - f"Serialized model int8+zlib: {quant_file_bytes} bytes " - f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" - ) - log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") - - if distributed: - dist.barrier() - with open("final_model.int8.ptz", "rb") as f: - quant_blob_disk = f.read() - quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") - base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=False) - - # Determine eval configuration - eval_sl = args.eval_seq_len if args.eval_seq_len > 0 else args.train_seq_len - eval_stride = args.eval_stride - if eval_sl != args.train_seq_len: - val_tokens_eval = load_validation_tokens(args.val_files, eval_sl) - log0(f"eval_seq_len:{eval_sl} (NTK-scaled RoPE, {val_tokens_eval.numel()} tokens)") - else: - val_tokens_eval = val_tokens - - torch.cuda.synchronize() - t_qeval = time.perf_counter() - - if eval_stride > 0: - # Sliding window eval: compile forward_logits for efficiency - log0(f"Compiling forward_logits for sliding window eval (stride={eval_stride}, seq_len={eval_sl})...") - compiled_logits = torch.compile(base_model.forward_logits, dynamic=False) - # Warmup compilation - eval_batch_seqs = 256 - warmup_x = torch.zeros(eval_batch_seqs, eval_sl, dtype=torch.int64, device=device) - base_model.eval() - with torch.inference_mode(), torch.autocast(device_type="cuda", dtype=torch.bfloat16): - _ = compiled_logits(warmup_x) - log0("Compilation done, starting sliding window eval...") - - q_val_loss, q_val_bpb = eval_val_sliding( - compiled_logits, rank, world_size, device, - val_tokens_eval, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - eval_sl, eval_stride, eval_batch_seqs=eval_batch_seqs, - ) - base_model.train() - else: - q_val_loss, q_val_bpb = eval_val( - args, model, rank, world_size, device, grad_accum_steps, - val_tokens_eval, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - seq_len_override=eval_sl if eval_sl != args.train_seq_len else 0, - ) - - torch.cuda.synchronize() - log0( - f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " - f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" - ) - log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") - - if distributed: - dist.destroy_process_group() - - -if __name__ == "__main__": - main() diff --git a/records/track_10min_16mb/2026-03-19_SlidingWindow_FP16Emb_10L_MuonWD_OvertoneInit/train_seed1337.log b/records/track_10min_16mb/2026-03-19_SlidingWindow_FP16Emb_10L_MuonWD_OvertoneInit/train_seed1337.log deleted file mode 100644 index 7aec40e480..0000000000 --- a/records/track_10min_16mb/2026-03-19_SlidingWindow_FP16Emb_10L_MuonWD_OvertoneInit/train_seed1337.log +++ /dev/null @@ -1,111 +0,0 @@ -Uploading train_gpt.py (55483 bytes)... -[modal] Launching torchrun on 8xH100... -[modal] Script size: 55483 bytes -logs/autoresearch_seed1337.txt -val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=/data/tokenizers/fineweb_1024_bpe.model -train_loader:dataset:fineweb10B_sp1024 train_shards:80 -val_loader:shards pattern=/data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 -model_params:18897488 -world_size:8 grad_accum_steps:1 -sdp_backends:cudnn=False flash=True mem_efficient=False math=False -attention_mode:gqa num_heads:8 num_kv_heads:4 -tie_embeddings:True embed_lr:0.1 head_lr:0.0 matrix_lr:0.04 scalar_lr:0.04 -train_batch_tokens:524288 train_seq_len:1024 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 -seed:1337 -warmup_step:1/20 -warmup_step:2/20 -warmup_step:3/20 -warmup_step:4/20 -warmup_step:5/20 -warmup_step:6/20 -warmup_step:7/20 -warmup_step:8/20 -warmup_step:9/20 -warmup_step:10/20 -warmup_step:11/20 -warmup_step:12/20 -warmup_step:13/20 -warmup_step:14/20 -warmup_step:15/20 -warmup_step:16/20 -warmup_step:17/20 -warmup_step:18/20 -warmup_step:19/20 -warmup_step:20/20 -step:1/20000 train_loss:6.9338 train_time:67ms step_avg:67.02ms -step:2/20000 train_loss:23.6125 train_time:102ms step_avg:50.94ms -step:3/20000 train_loss:9.4294 train_time:160ms step_avg:53.35ms -step:4/20000 train_loss:6.3084 train_time:216ms step_avg:54.01ms -step:5/20000 train_loss:6.2337 train_time:274ms step_avg:54.76ms -step:6/20000 train_loss:7.4515 train_time:329ms step_avg:54.88ms -step:7/20000 train_loss:6.4829 train_time:386ms step_avg:55.12ms -step:8/20000 train_loss:6.4743 train_time:442ms step_avg:55.29ms -step:9/20000 train_loss:6.4677 train_time:503ms step_avg:55.89ms -step:10/20000 train_loss:6.3990 train_time:564ms step_avg:56.40ms -step:200/20000 train_loss:2.9388 train_time:11527ms step_avg:57.63ms -step:400/20000 train_loss:2.3863 train_time:23039ms step_avg:57.60ms -step:600/20000 train_loss:2.5645 train_time:34540ms step_avg:57.57ms -step:800/20000 train_loss:2.3133 train_time:46054ms step_avg:57.57ms -step:1000/20000 train_loss:2.3843 train_time:57601ms step_avg:57.60ms -step:1200/20000 train_loss:2.4015 train_time:69091ms step_avg:57.58ms -step:1400/20000 train_loss:2.4502 train_time:80586ms step_avg:57.56ms -step:1600/20000 train_loss:2.1239 train_time:92177ms step_avg:57.61ms -step:1800/20000 train_loss:2.2221 train_time:103711ms step_avg:57.62ms -step:2000/20000 train_loss:2.2741 train_time:115209ms step_avg:57.60ms -step:2200/20000 train_loss:2.1040 train_time:126685ms step_avg:57.58ms -step:2400/20000 train_loss:2.2265 train_time:138204ms step_avg:57.58ms -step:2600/20000 train_loss:2.4465 train_time:149723ms step_avg:57.59ms -step:2800/20000 train_loss:2.2744 train_time:161225ms step_avg:57.58ms -step:3000/20000 train_loss:2.2649 train_time:172784ms step_avg:57.59ms -step:3200/20000 train_loss:2.2328 train_time:184333ms step_avg:57.60ms -step:3400/20000 train_loss:2.2007 train_time:195830ms step_avg:57.60ms -step:3600/20000 train_loss:2.1673 train_time:207350ms step_avg:57.60ms -step:3800/20000 train_loss:2.2638 train_time:218858ms step_avg:57.59ms -step:4000/20000 train_loss:2.2112 train_time:230370ms step_avg:57.59ms -step:4200/20000 train_loss:2.2167 train_time:242053ms step_avg:57.63ms -step:4400/20000 train_loss:2.1627 train_time:253568ms step_avg:57.63ms -step:4600/20000 train_loss:2.0236 train_time:265062ms step_avg:57.62ms -step:4800/20000 train_loss:2.3148 train_time:276549ms step_avg:57.61ms -step:5000/20000 train_loss:2.0844 train_time:288028ms step_avg:57.61ms -step:5200/20000 train_loss:2.2262 train_time:299533ms step_avg:57.60ms -step:5400/20000 train_loss:2.2461 train_time:311015ms step_avg:57.60ms -step:5600/20000 train_loss:2.2471 train_time:322523ms step_avg:57.59ms -step:5800/20000 train_loss:2.2060 train_time:334020ms step_avg:57.59ms -step:6000/20000 train_loss:2.2712 train_time:345531ms step_avg:57.59ms -step:6200/20000 train_loss:2.1508 train_time:357039ms step_avg:57.59ms -step:6400/20000 train_loss:2.2255 train_time:368529ms step_avg:57.58ms -step:6600/20000 train_loss:2.1887 train_time:380044ms step_avg:57.58ms -step:6800/20000 train_loss:2.2524 train_time:391539ms step_avg:57.58ms -step:7000/20000 train_loss:2.2864 train_time:403034ms step_avg:57.58ms -step:7200/20000 train_loss:2.2706 train_time:414528ms step_avg:57.57ms -step:7400/20000 train_loss:2.1733 train_time:426028ms step_avg:57.57ms -step:7600/20000 train_loss:2.0536 train_time:437506ms step_avg:57.57ms -step:7800/20000 train_loss:2.2108 train_time:449001ms step_avg:57.56ms -step:8000/20000 train_loss:2.1741 train_time:460485ms step_avg:57.56ms -step:8200/20000 train_loss:2.2311 train_time:471969ms step_avg:57.56ms -step:8400/20000 train_loss:2.1656 train_time:483615ms step_avg:57.57ms -step:8600/20000 train_loss:2.1694 train_time:495096ms step_avg:57.57ms -step:8800/20000 train_loss:2.1230 train_time:506588ms step_avg:57.57ms -step:9000/20000 train_loss:2.0405 train_time:518062ms step_avg:57.56ms -step:9200/20000 train_loss:2.0900 train_time:529570ms step_avg:57.56ms -step:9400/20000 train_loss:2.1280 train_time:541073ms step_avg:57.56ms -step:9600/20000 train_loss:2.1301 train_time:552602ms step_avg:57.56ms -step:9800/20000 train_loss:2.0449 train_time:564107ms step_avg:57.56ms -step:10000/20000 train_loss:2.0685 train_time:575592ms step_avg:57.56ms -step:10200/20000 train_loss:2.0080 train_time:587087ms step_avg:57.56ms -step:10400/20000 train_loss:2.0274 train_time:598561ms step_avg:57.55ms -step:10424/20000 val_loss:2.0418 val_bpb:1.2092 train_time:599952ms step_avg:57.55ms -stopping_early: wallclock_cap train_time:599952ms step:10424/20000 -peak memory allocated: 11329 MiB reserved: 11514 MiB -Serialized model: 74573987 bytes -Code size: 55483 bytes -Total submission size: 74629470 bytes -Serialized model int8+zlib: 15318760 bytes (payload:19552576 raw_torch:19602363 payload_ratio:3.81x) -Total submission size int8+zlib: 15374243 bytes -Compiling forward_logits for sliding window eval (stride=64, seq_len=1024)... -Compilation done, starting sliding window eval... -final_int8_zlib_roundtrip val_loss:1.9849 val_bpb:1.1756 eval_time:161632ms -final_int8_zlib_roundtrip_exact val_loss:1.98492632 val_bpb:1.17558517 -[modal] Exit code: 0 - -[local] Done. Log saved to run_seed1337.log diff --git a/records/track_10min_16mb/2026-03-19_SlidingWindow_FP16Emb_10L_MuonWD_OvertoneInit/train_seed42.log b/records/track_10min_16mb/2026-03-19_SlidingWindow_FP16Emb_10L_MuonWD_OvertoneInit/train_seed42.log deleted file mode 100644 index de12d73ed0..0000000000 --- a/records/track_10min_16mb/2026-03-19_SlidingWindow_FP16Emb_10L_MuonWD_OvertoneInit/train_seed42.log +++ /dev/null @@ -1,112 +0,0 @@ -Uploading train_gpt.py (55483 bytes)... -[modal] Launching torchrun on 8xH100... -[modal] Script size: 55483 bytes -logs/autoresearch_seed42.txt -val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=/data/tokenizers/fineweb_1024_bpe.model -train_loader:dataset:fineweb10B_sp1024 train_shards:80 -val_loader:shards pattern=/data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 -model_params:18897488 -world_size:8 grad_accum_steps:1 -sdp_backends:cudnn=False flash=True mem_efficient=False math=False -attention_mode:gqa num_heads:8 num_kv_heads:4 -tie_embeddings:True embed_lr:0.1 head_lr:0.0 matrix_lr:0.04 scalar_lr:0.04 -train_batch_tokens:524288 train_seq_len:1024 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 -seed:42 -warmup_step:1/20 -warmup_step:2/20 -warmup_step:3/20 -warmup_step:4/20 -warmup_step:5/20 -warmup_step:6/20 -warmup_step:7/20 -warmup_step:8/20 -warmup_step:9/20 -warmup_step:10/20 -warmup_step:11/20 -warmup_step:12/20 -warmup_step:13/20 -warmup_step:14/20 -warmup_step:15/20 -warmup_step:16/20 -warmup_step:17/20 -warmup_step:18/20 -warmup_step:19/20 -warmup_step:20/20 -step:1/20000 train_loss:6.9323 train_time:62ms step_avg:61.65ms -step:2/20000 train_loss:23.6476 train_time:93ms step_avg:46.55ms -step:3/20000 train_loss:9.6205 train_time:155ms step_avg:51.71ms -step:4/20000 train_loss:6.3404 train_time:209ms step_avg:52.28ms -step:5/20000 train_loss:6.2318 train_time:266ms step_avg:53.20ms -step:6/20000 train_loss:7.4480 train_time:321ms step_avg:53.52ms -step:7/20000 train_loss:6.4688 train_time:377ms step_avg:53.82ms -step:8/20000 train_loss:6.4513 train_time:437ms step_avg:54.64ms -step:9/20000 train_loss:6.4109 train_time:492ms step_avg:54.72ms -step:10/20000 train_loss:6.3441 train_time:547ms step_avg:54.73ms -step:200/20000 train_loss:2.9265 train_time:11165ms step_avg:55.82ms -step:400/20000 train_loss:2.3613 train_time:22355ms step_avg:55.89ms -step:600/20000 train_loss:2.5526 train_time:33575ms step_avg:55.96ms -step:800/20000 train_loss:2.3010 train_time:44782ms step_avg:55.98ms -step:1000/20000 train_loss:2.3803 train_time:55974ms step_avg:55.97ms -step:1200/20000 train_loss:2.3977 train_time:67386ms step_avg:56.16ms -step:1400/20000 train_loss:2.4409 train_time:78537ms step_avg:56.10ms -step:1600/20000 train_loss:2.1152 train_time:89803ms step_avg:56.13ms -step:1800/20000 train_loss:2.2157 train_time:101111ms step_avg:56.17ms -step:2000/20000 train_loss:2.2767 train_time:112316ms step_avg:56.16ms -step:2200/20000 train_loss:2.1024 train_time:123486ms step_avg:56.13ms -step:2400/20000 train_loss:2.2208 train_time:134684ms step_avg:56.12ms -step:2600/20000 train_loss:2.4388 train_time:145876ms step_avg:56.11ms -step:2800/20000 train_loss:2.2708 train_time:157065ms step_avg:56.09ms -step:3000/20000 train_loss:2.2585 train_time:168280ms step_avg:56.09ms -step:3200/20000 train_loss:2.2229 train_time:179520ms step_avg:56.10ms -step:3400/20000 train_loss:2.1937 train_time:190683ms step_avg:56.08ms -step:3600/20000 train_loss:2.1559 train_time:201869ms step_avg:56.07ms -step:3800/20000 train_loss:2.2586 train_time:213099ms step_avg:56.08ms -step:4000/20000 train_loss:2.2032 train_time:224296ms step_avg:56.07ms -step:4200/20000 train_loss:2.2127 train_time:235671ms step_avg:56.11ms -step:4400/20000 train_loss:2.1618 train_time:246847ms step_avg:56.10ms -step:4600/20000 train_loss:2.0127 train_time:258038ms step_avg:56.10ms -step:4800/20000 train_loss:2.3044 train_time:269211ms step_avg:56.09ms -step:5000/20000 train_loss:2.0753 train_time:280381ms step_avg:56.08ms -step:5200/20000 train_loss:2.2253 train_time:291546ms step_avg:56.07ms -step:5400/20000 train_loss:2.2388 train_time:302775ms step_avg:56.07ms -step:5600/20000 train_loss:2.2436 train_time:313948ms step_avg:56.06ms -step:5800/20000 train_loss:2.2000 train_time:325136ms step_avg:56.06ms -step:6000/20000 train_loss:2.2707 train_time:336332ms step_avg:56.06ms -step:6200/20000 train_loss:2.1477 train_time:347529ms step_avg:56.05ms -step:6400/20000 train_loss:2.2215 train_time:358754ms step_avg:56.06ms -step:6600/20000 train_loss:2.1809 train_time:369968ms step_avg:56.06ms -step:6800/20000 train_loss:2.2446 train_time:381154ms step_avg:56.05ms -step:7000/20000 train_loss:2.2792 train_time:392335ms step_avg:56.05ms -step:7200/20000 train_loss:2.2634 train_time:403511ms step_avg:56.04ms -step:7400/20000 train_loss:2.1739 train_time:414699ms step_avg:56.04ms -step:7600/20000 train_loss:2.0521 train_time:425838ms step_avg:56.03ms -step:7800/20000 train_loss:2.2032 train_time:437058ms step_avg:56.03ms -step:8000/20000 train_loss:2.1726 train_time:448259ms step_avg:56.03ms -step:8200/20000 train_loss:2.2406 train_time:459444ms step_avg:56.03ms -step:8400/20000 train_loss:2.1804 train_time:470727ms step_avg:56.04ms -step:8600/20000 train_loss:2.1752 train_time:481924ms step_avg:56.04ms -step:8800/20000 train_loss:2.1352 train_time:493074ms step_avg:56.03ms -step:9000/20000 train_loss:2.0513 train_time:504249ms step_avg:56.03ms -step:9200/20000 train_loss:2.1035 train_time:515413ms step_avg:56.02ms -step:9400/20000 train_loss:2.1378 train_time:526585ms step_avg:56.02ms -step:9600/20000 train_loss:2.1488 train_time:537753ms step_avg:56.02ms -step:9800/20000 train_loss:2.0549 train_time:548947ms step_avg:56.01ms -step:10000/20000 train_loss:2.0850 train_time:560135ms step_avg:56.01ms -step:10200/20000 train_loss:2.0299 train_time:571310ms step_avg:56.01ms -step:10400/20000 train_loss:2.0442 train_time:582450ms step_avg:56.00ms -step:10600/20000 train_loss:1.9142 train_time:594182ms step_avg:56.05ms -step:10695/20000 val_loss:2.0400 val_bpb:1.2082 train_time:600069ms step_avg:56.11ms -stopping_early: wallclock_cap train_time:600069ms step:10695/20000 -peak memory allocated: 11329 MiB reserved: 11514 MiB -Serialized model: 74573987 bytes -Code size: 55483 bytes -Total submission size: 74629470 bytes -Serialized model int8+zlib: 15294888 bytes (payload:19552576 raw_torch:19602363 payload_ratio:3.81x) -Total submission size int8+zlib: 15350371 bytes -Compiling forward_logits for sliding window eval (stride=64, seq_len=1024)... -Compilation done, starting sliding window eval... -final_int8_zlib_roundtrip val_loss:1.9827 val_bpb:1.1742 eval_time:159776ms -final_int8_zlib_roundtrip_exact val_loss:1.98265459 val_bpb:1.17423973 -[modal] Exit code: 0 - -[local] Done. Log saved to run_seed42.log diff --git a/records/track_10min_16mb/2026-03-19_SlidingWindow_FP16Emb_10L_MuonWD_OvertoneInit/train_seed7.log b/records/track_10min_16mb/2026-03-19_SlidingWindow_FP16Emb_10L_MuonWD_OvertoneInit/train_seed7.log deleted file mode 100644 index 5075123074..0000000000 --- a/records/track_10min_16mb/2026-03-19_SlidingWindow_FP16Emb_10L_MuonWD_OvertoneInit/train_seed7.log +++ /dev/null @@ -1,111 +0,0 @@ -Uploading train_gpt.py (55483 bytes)... -[modal] Launching torchrun on 8xH100... -[modal] Script size: 55483 bytes -logs/autoresearch_seed7.txt -val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=/data/tokenizers/fineweb_1024_bpe.model -train_loader:dataset:fineweb10B_sp1024 train_shards:80 -val_loader:shards pattern=/data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 -model_params:18897488 -world_size:8 grad_accum_steps:1 -sdp_backends:cudnn=False flash=True mem_efficient=False math=False -attention_mode:gqa num_heads:8 num_kv_heads:4 -tie_embeddings:True embed_lr:0.1 head_lr:0.0 matrix_lr:0.04 scalar_lr:0.04 -train_batch_tokens:524288 train_seq_len:1024 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 -seed:7 -warmup_step:1/20 -warmup_step:2/20 -warmup_step:3/20 -warmup_step:4/20 -warmup_step:5/20 -warmup_step:6/20 -warmup_step:7/20 -warmup_step:8/20 -warmup_step:9/20 -warmup_step:10/20 -warmup_step:11/20 -warmup_step:12/20 -warmup_step:13/20 -warmup_step:14/20 -warmup_step:15/20 -warmup_step:16/20 -warmup_step:17/20 -warmup_step:18/20 -warmup_step:19/20 -warmup_step:20/20 -step:1/20000 train_loss:6.9356 train_time:66ms step_avg:65.76ms -step:2/20000 train_loss:24.6999 train_time:104ms step_avg:51.85ms -step:3/20000 train_loss:10.1890 train_time:160ms step_avg:53.34ms -step:4/20000 train_loss:6.3509 train_time:220ms step_avg:54.89ms -step:5/20000 train_loss:6.1613 train_time:276ms step_avg:55.13ms -step:6/20000 train_loss:7.3943 train_time:335ms step_avg:55.86ms -step:7/20000 train_loss:6.4498 train_time:393ms step_avg:56.11ms -step:8/20000 train_loss:6.4479 train_time:450ms step_avg:56.25ms -step:9/20000 train_loss:6.4421 train_time:507ms step_avg:56.36ms -step:10/20000 train_loss:6.3911 train_time:565ms step_avg:56.53ms -step:200/20000 train_loss:2.9848 train_time:11388ms step_avg:56.94ms -step:400/20000 train_loss:2.3865 train_time:22780ms step_avg:56.95ms -step:600/20000 train_loss:2.5612 train_time:34174ms step_avg:56.96ms -step:800/20000 train_loss:2.3107 train_time:45574ms step_avg:56.97ms -step:1000/20000 train_loss:2.3844 train_time:56982ms step_avg:56.98ms -step:1200/20000 train_loss:2.3985 train_time:68369ms step_avg:56.97ms -step:1400/20000 train_loss:2.4504 train_time:79938ms step_avg:57.10ms -step:1600/20000 train_loss:2.1221 train_time:91344ms step_avg:57.09ms -step:1800/20000 train_loss:2.2191 train_time:102745ms step_avg:57.08ms -step:2000/20000 train_loss:2.2744 train_time:114167ms step_avg:57.08ms -step:2200/20000 train_loss:2.1048 train_time:125614ms step_avg:57.10ms -step:2400/20000 train_loss:2.2290 train_time:137051ms step_avg:57.10ms -step:2600/20000 train_loss:2.4426 train_time:148501ms step_avg:57.12ms -step:2800/20000 train_loss:2.2674 train_time:159925ms step_avg:57.12ms -step:3000/20000 train_loss:2.2556 train_time:171387ms step_avg:57.13ms -step:3200/20000 train_loss:2.2267 train_time:182854ms step_avg:57.14ms -step:3400/20000 train_loss:2.1966 train_time:194278ms step_avg:57.14ms -step:3600/20000 train_loss:2.1604 train_time:205735ms step_avg:57.15ms -step:3800/20000 train_loss:2.2604 train_time:217152ms step_avg:57.15ms -step:4000/20000 train_loss:2.2059 train_time:228561ms step_avg:57.14ms -step:4200/20000 train_loss:2.2209 train_time:240156ms step_avg:57.18ms -step:4400/20000 train_loss:2.1617 train_time:251590ms step_avg:57.18ms -step:4600/20000 train_loss:2.0158 train_time:263068ms step_avg:57.19ms -step:4800/20000 train_loss:2.3160 train_time:274498ms step_avg:57.19ms -step:5000/20000 train_loss:2.0824 train_time:285934ms step_avg:57.19ms -step:5200/20000 train_loss:2.2244 train_time:297389ms step_avg:57.19ms -step:5400/20000 train_loss:2.2446 train_time:308787ms step_avg:57.18ms -step:5600/20000 train_loss:2.2422 train_time:320213ms step_avg:57.18ms -step:5800/20000 train_loss:2.2005 train_time:331642ms step_avg:57.18ms -step:6000/20000 train_loss:2.2750 train_time:343140ms step_avg:57.19ms -step:6200/20000 train_loss:2.1433 train_time:354535ms step_avg:57.18ms -step:6400/20000 train_loss:2.2240 train_time:365957ms step_avg:57.18ms -step:6600/20000 train_loss:2.1820 train_time:377379ms step_avg:57.18ms -step:6800/20000 train_loss:2.2432 train_time:388815ms step_avg:57.18ms -step:7000/20000 train_loss:2.2803 train_time:400338ms step_avg:57.19ms -step:7200/20000 train_loss:2.2588 train_time:411733ms step_avg:57.19ms -step:7400/20000 train_loss:2.1762 train_time:423167ms step_avg:57.18ms -step:7600/20000 train_loss:2.0536 train_time:434610ms step_avg:57.19ms -step:7800/20000 train_loss:2.2062 train_time:446028ms step_avg:57.18ms -step:8000/20000 train_loss:2.1741 train_time:457471ms step_avg:57.18ms -step:8200/20000 train_loss:2.2289 train_time:468895ms step_avg:57.18ms -step:8400/20000 train_loss:2.1774 train_time:480489ms step_avg:57.20ms -step:8600/20000 train_loss:2.1666 train_time:491904ms step_avg:57.20ms -step:8800/20000 train_loss:2.1271 train_time:503351ms step_avg:57.20ms -step:9000/20000 train_loss:2.0412 train_time:514781ms step_avg:57.20ms -step:9200/20000 train_loss:2.0934 train_time:526175ms step_avg:57.19ms -step:9400/20000 train_loss:2.1279 train_time:537582ms step_avg:57.19ms -step:9600/20000 train_loss:2.1309 train_time:548991ms step_avg:57.19ms -step:9800/20000 train_loss:2.0441 train_time:560484ms step_avg:57.19ms -step:10000/20000 train_loss:2.0728 train_time:571896ms step_avg:57.19ms -step:10200/20000 train_loss:2.0129 train_time:583355ms step_avg:57.19ms -step:10400/20000 train_loss:2.0294 train_time:594780ms step_avg:57.19ms -step:10491/20000 val_loss:2.0395 val_bpb:1.2079 train_time:600073ms step_avg:57.20ms -stopping_early: wallclock_cap train_time:600073ms step:10491/20000 -peak memory allocated: 11329 MiB reserved: 11514 MiB -Serialized model: 74573987 bytes -Code size: 55483 bytes -Total submission size: 74629470 bytes -Serialized model int8+zlib: 15341602 bytes (payload:19552576 raw_torch:19602363 payload_ratio:3.81x) -Total submission size int8+zlib: 15397085 bytes -Compiling forward_logits for sliding window eval (stride=64, seq_len=1024)... -Compilation done, starting sliding window eval... -final_int8_zlib_roundtrip val_loss:1.9830 val_bpb:1.1744 eval_time:134729ms -final_int8_zlib_roundtrip_exact val_loss:1.98298356 val_bpb:1.17443456 -[modal] Exit code: 0 - -[local] Done. Log saved to run_seed7.log diff --git a/records/track_10min_16mb/2026-03-19_TrainingOptSeq4096/README.md b/records/track_10min_16mb/2026-03-19_TrainingOptSeq4096/README.md deleted file mode 100644 index b6ea35074d..0000000000 --- a/records/track_10min_16mb/2026-03-19_TrainingOptSeq4096/README.md +++ /dev/null @@ -1,76 +0,0 @@ -This record submission is called `Training Opt Seq4096 v1`. - -Configuration: -- Layout: `VOCAB_SIZE=1024 NUM_LAYERS=9 MODEL_DIM=512 NUM_HEADS=8 NUM_KV_HEADS=4 MLP_MULT=2` -- Tied output/input embeddings: `TIE_EMBEDDINGS=1` -- Sequence length: `TRAIN_SEQ_LEN=4096` -- Batching: `TRAIN_BATCH_TOKENS=393216` (3/4 batch) -- Learning rates: `TIED_EMBED_LR=0.030 MATRIX_LR=0.020 SCALAR_LR=0.020` -- Muon optimizer: `MUON_MOMENTUM=0.99 MUON_MOMENTUM_WARMUP_STEPS=1500 MUON_MOMENTUM_WARMUP_START=0.92` -- Schedule: `WARMDOWN_ITERS=3000` - -Command: -```bash -RUN_ID=training_opt_seq4096_v1 \ -DATA_PATH=./data/datasets/fineweb10B_sp1024 \ -TOKENIZER_PATH=./data/tokenizers/fineweb_1024_bpe.model \ -VOCAB_SIZE=1024 \ -MAX_WALLCLOCK_SECONDS=600 \ -torchrun --standalone --nproc_per_node=8 \ - records/track_10min_16mb/2026-03-19_TrainingOptSeq4096/train_gpt.py -``` - -Key metrics (from the standalone record run): -- Timed training stopped at `8394/20000` steps due to the wallclock cap. -- Pre-quant eval at stop: `val_loss:2.0227`, `val_bpb:1.1980` -- Post-quant roundtrip eval: `val_loss:2.0286`, `val_bpb:1.2014` -- Exact printed metric: `final_int8_zlib_roundtrip_exact val_bpb:1.20143417` -- Train time: `599921ms` (`step_avg:71.47ms`) -- Peak memory: `7748 MiB allocated`, `8070 MiB reserved` -- Serialized model int8+zlib: `15820684 bytes` -- Code size for this standalone record script: `47759 bytes` -- Total submission size int8+zlib: `15868326 bytes` - -Approach: -This submission combines two independent improvements over the naive baseline: - -1. **Longer training context (seq_len=4096):** Each training sequence sees 4x more context than the 1024-token baseline, giving the autoregressive model much better signal per token. This costs ~71ms/step (vs ~43ms at seq_len=1024), but the quality improvement far outweighs the fewer total steps. - -2. **Aggressive Muon optimizer tuning:** - - **Higher momentum (0.99 vs 0.95):** Provides stronger gradient smoothing, leading to better convergence. - - **Lower learning rates (0.020 vs 0.04):** Dramatically reduces int8 quantization loss (0.0034 BPB quant penalty vs 0.007+ at default LR) while maintaining similar pre-quant quality. - - **3/4 batch (393K vs 524K tokens):** More optimizer updates per wallclock second. - - **Extended momentum warmup (1500 steps from 0.92):** Prevents early instability with the higher momentum. - - **Longer warmdown (3000 steps):** Proportionally longer LR decay for the ~8400-step run. - -The net effect is a **0.023 BPB improvement** over the naive baseline (1.2014 vs 1.2244), and a **0.015 BPB improvement** over the previous best entry (Long Context Seq2048 v2 at 1.2162). - -Additional full-run reproducibility logs included in this folder: -- `train.log`: canonical standalone run, `SEED=1337`, `val_bpb=1.20143417` -- `train_seed1338.log`: full rerun, `SEED=1338`, `val_bpb=1.19945102` -- `train_seed1339.log`: full rerun, `SEED=1339`, `val_bpb=1.20319508` - -Record-track significance note: -- The current SOTA is `Long Context Seq2048 v2` at `1.21613611`. -- The challenge requires beating `1.21113611` (SOTA - 0.005) at p < 0.01. -- All three included full runs clear that threshold: - - `SEED=1337`: `1.20143417` - - `SEED=1338`: `1.19945102` - - `SEED=1339`: `1.20319508` -- Sample mean across the three runs: `1.20136009` -- Sample standard deviation: `0.00187` -- One-sided one-sample t-test against `1.21113611`: `t=9.06` with `df=2`, which gives `p=0.006` - -Hardware: 8x NVIDIA H100 80GB HBM3 (SXM, NVLink NV18 all-to-all), PyTorch 2.8.0+cu128. - -Why this folder is standalone: -- `train_gpt.py` compiles from inside this record folder and was used for the canonical run whose output is saved as `train.log`. -- No extra Python source files are required for the training path. -- The only inputs expected at runtime are the cached dataset and tokenizer paths described in the main repo README. - -Included files: -- `train_gpt.py` (standalone winning recipe with defaults baked in) -- `README.md` (this file) -- `submission.json` (leaderboard metadata) -- `train.log` (canonical full log from the standalone record script) -- `train_seed1338.log`, `train_seed1339.log` (extra full reruns for reproducibility) diff --git a/records/track_10min_16mb/2026-03-19_TrainingOptSeq4096/submission.json b/records/track_10min_16mb/2026-03-19_TrainingOptSeq4096/submission.json deleted file mode 100644 index 71d9851168..0000000000 --- a/records/track_10min_16mb/2026-03-19_TrainingOptSeq4096/submission.json +++ /dev/null @@ -1,11 +0,0 @@ -{ - "author": "Spokane Way", - "github_id": "spokane-way", - "name": "Training Opt Seq4096 v1", - "blurb": "SP-1024 9x512 KV4 run at TRAIN_SEQ_LEN=4096 with aggressively tuned Muon optimizer: momentum 0.99, lower LR (0.020/0.020/0.030), 3/4 batch (393K tokens), warmdown 3000 steps, and extended momentum warmup (1500 steps from 0.92). Combines long-context training with training optimization to beat the naive baseline by 0.023 BPB.", - "date": "2026-03-19T04:28:00Z", - "val_loss": 2.02857127, - "val_bpb": 1.20143417, - "bytes_total": 15868326, - "bytes_code": 47759 -} diff --git a/records/track_10min_16mb/2026-03-19_TrainingOptSeq4096/train.log b/records/track_10min_16mb/2026-03-19_TrainingOptSeq4096/train.log deleted file mode 100644 index 947f7a096e..0000000000 --- a/records/track_10min_16mb/2026-03-19_TrainingOptSeq4096/train.log +++ /dev/null @@ -1,1246 +0,0 @@ -""" -The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. - -Hard stop: `train_gpt.py` and `train_gpt_mlx.py` must never be longer than 1500 lines. -""" - -from __future__ import annotations - -import copy -import glob -import io -import math -import os -import random -import subprocess -import sys -import time -import uuid -import zlib -from pathlib import Path - -import numpy as np -import sentencepiece as spm -import torch -import torch.distributed as dist -import torch.nn.functional as F -from torch import Tensor, nn -from torch.nn.parallel import DistributedDataParallel as DDP - -# ----------------------------- -# HYPERPARAMETERS -# ----------------------------- -# Default Simple Baseline run: -# - 9 transformer blocks at width 512 -# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion -# - vocab size 1024, sequence length 1024, tied embeddings -# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap - -class Hyperparameters: - # Data paths are shard globs produced by the existing preprocessing pipeline. - data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") - train_files = os.path.join(data_path, "fineweb_train_*.bin") - val_files = os.path.join(data_path, "fineweb_val_*.bin") - tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") - run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) - seed = int(os.environ.get("SEED", 1337)) - - # Validation cadence and batch size. Validation always uses the full fineweb_val split. - val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) - val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 20000)) - train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) - - # Training length. - iterations = int(os.environ.get("ITERATIONS", 20000)) - warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 800)) - warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) - train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) - train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 4096)) - max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) - qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) - - # Model shape. - vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) - num_layers = int(os.environ.get("NUM_LAYERS", 9)) - num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) - model_dim = int(os.environ.get("MODEL_DIM", 512)) - num_heads = int(os.environ.get("NUM_HEADS", 8)) - mlp_mult = int(os.environ.get("MLP_MULT", 2)) - tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) - rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) - logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) - - # Optimizer hyperparameters. - embed_lr = float(os.environ.get("EMBED_LR", 0.6)) - head_lr = float(os.environ.get("HEAD_LR", 0.008)) - tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) - tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) - matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) - scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) - muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) - muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) - muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) - muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) - beta1 = float(os.environ.get("BETA1", 0.9)) - beta2 = float(os.environ.get("BETA2", 0.95)) - adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) - grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) - -# ----------------------------- -# MUON OPTIMIZER -# ----------------------------- -# -# As borrowed from modded-nanogpt -# Background on Muon: https://kellerjordan.github.io/posts/muon/ - -def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: - # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. - # Muon uses this to normalize matrix-shaped gradients before applying them. - a, b, c = (3.4445, -4.7750, 2.0315) - X = G.bfloat16() - X /= X.norm() + eps - transposed = G.size(0) > G.size(1) - if transposed: - X = X.T - for _ in range(steps): - A = X @ X.T - B = b * A + c * A @ A - X = a * X + B @ X - return X.T if transposed else X - - -class Muon(torch.optim.Optimizer): - def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): - super().__init__( - params, - dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), - ) - - @torch.no_grad() - def step(self, closure=None): - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - distributed = dist.is_available() and dist.is_initialized() - world_size = dist.get_world_size() if distributed else 1 - rank = dist.get_rank() if distributed else 0 - - for group in self.param_groups: - params = group["params"] - if not params: - continue - lr = group["lr"] - momentum = group["momentum"] - backend_steps = group["backend_steps"] - nesterov = group["nesterov"] - - total_params = sum(int(p.numel()) for p in params) - updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) - - curr = 0 - for i, p in enumerate(params): - if i % world_size == rank and p.grad is not None: - g = p.grad - state = self.state[p] - if "momentum_buffer" not in state: - state["momentum_buffer"] = torch.zeros_like(g) - buf = state["momentum_buffer"] - buf.mul_(momentum).add_(g) - if nesterov: - g = g.add(buf, alpha=momentum) - g = zeropower_via_newtonschulz5(g, steps=backend_steps) - # Scale correction from Muon reference implementations. - g *= max(1, g.size(0) / g.size(1)) ** 0.5 - updates_flat[curr : curr + p.numel()] = g.reshape(-1) - curr += p.numel() - - if distributed: - dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) - - curr = 0 - for p in params: - g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) - p.add_(g, alpha=-lr) - curr += p.numel() - - return loss - - -# ----------------------------- -# TOKENIZER-AGNOSTIC EVALUATION SETUP -# ----------------------------- -# -# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. -# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. -# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. -# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. - -def build_sentencepiece_luts( - sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device -) -> tuple[Tensor, Tensor, Tensor]: - sp_vocab_size = int(sp.vocab_size()) - table_size = max(sp_vocab_size, vocab_size) - base_bytes_np = np.zeros((table_size,), dtype=np.int16) - has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) - is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) - for token_id in range(sp_vocab_size): - if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): - continue - is_boundary_token_np[token_id] = False - if sp.is_byte(token_id): - base_bytes_np[token_id] = 1 - continue - piece = sp.id_to_piece(token_id) - if piece.startswith("▁"): - has_leading_space_np[token_id] = True - piece = piece[1:] - base_bytes_np[token_id] = len(piece.encode("utf-8")) - return ( - torch.tensor(base_bytes_np, dtype=torch.int16, device=device), - torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), - torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), - ) - - -def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: - files = [Path(p) for p in sorted(glob.glob(pattern))] - if not files: - raise FileNotFoundError(f"No files found for pattern: {pattern}") - # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. - tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() - usable = ((tokens.numel() - 1) // seq_len) * seq_len - if usable <= 0: - raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") - return tokens[: usable + 1] - - -def eval_val( - args: Hyperparameters, - model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - grad_accum_steps: int, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, -) -> tuple[float, float]: - # Validation computes two metrics: - # - val_loss: token cross-entropy (natural log) - # - val_bpb: tokenizer-agnostic compression metric used by the challenge - local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) - if local_batch_tokens < args.train_seq_len: - raise ValueError( - "VAL_BATCH_SIZE must provide at least one sequence per rank; " - f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " - f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" - ) - local_batch_seqs = local_batch_tokens // args.train_seq_len - total_seqs = (val_tokens.numel() - 1) // args.train_seq_len - seq_start = (total_seqs * rank) // world_size - seq_end = (total_seqs * (rank + 1)) // world_size - val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) - val_token_count = torch.zeros((), device=device, dtype=torch.float64) - val_byte_count = torch.zeros((), device=device, dtype=torch.float64) - - model.eval() - with torch.inference_mode(): - for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): - batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) - raw_start = batch_seq_start * args.train_seq_len - raw_end = batch_seq_end * args.train_seq_len + 1 - local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) - x = local[:-1].reshape(-1, args.train_seq_len) - y = local[1:].reshape(-1, args.train_seq_len) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - batch_loss = model(x, y).detach() - batch_token_count = float(y.numel()) - val_loss_sum += batch_loss.to(torch.float64) * batch_token_count - val_token_count += batch_token_count - prev_ids = x.reshape(-1) - tgt_ids = y.reshape(-1) - token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) - token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) - val_byte_count += token_bytes.to(torch.float64).sum() - - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) - - val_loss = val_loss_sum / val_token_count - bits_per_token = val_loss.item() / math.log(2.0) - tokens_per_byte = val_token_count.item() / val_byte_count.item() - model.train() - return float(val_loss.item()), float(bits_per_token * tokens_per_byte) - -# ----------------------------- -# POST-TRAINING QUANTIZATION -# ----------------------------- -# -# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. -# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. -# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. - -CONTROL_TENSOR_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "CONTROL_TENSOR_NAME_PATTERNS", - "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", - ",".join(CONTROL_TENSOR_NAME_PATTERNS), - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 -INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 -INT8_PER_ROW_SCALE_DTYPE = torch.float16 -INT8_CLIP_PERCENTILE = 99.99984 -INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 - -def tensor_nbytes(t: Tensor) -> int: - return int(t.numel()) * int(t.element_size()) - -def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: - if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): - return t.float().contiguous() - if t.dtype in {torch.float32, torch.bfloat16}: - passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") - return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() - return t - -def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: - t32 = t.float() - if t32.ndim == 2: - # Matrices get one scale per row, which usually tracks output-channel - # ranges much better than a single tensor-wide scale. - clip_abs = ( - torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) - if t32.numel() - else torch.empty((t32.shape[0],), dtype=torch.float32) - ) - clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) - scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) - q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() - return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() - - # Vectors / scalars use a simpler per-tensor scale. - clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 - scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) - q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() - return q, scale - -def quantize_state_dict_int8(state_dict: dict[str, Tensor]): - # Single supported clean-script export format: - # - per-row int8 for 2D float tensors - # - per-tensor int8 for other float tensors - # - exact passthrough for non-floats - # - passthrough for small float tensors, stored as fp16 to save bytes - quantized: dict[str, Tensor] = {} - scales: dict[str, Tensor] = {} - dtypes: dict[str, str] = {} - passthrough: dict[str, Tensor] = {} - passthrough_orig_dtypes: dict[str, str] = {} - qmeta: dict[str, dict[str, object]] = {} - stats = dict.fromkeys( - ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), - 0, - ) - - for name, tensor in state_dict.items(): - t = tensor.detach().to("cpu").contiguous() - stats["param_count"] += int(t.numel()) - stats["num_tensors"] += 1 - stats["baseline_tensor_bytes"] += tensor_nbytes(t) - - if not t.is_floating_point(): - stats["num_nonfloat_tensors"] += 1 - passthrough[name] = t - stats["int8_payload_bytes"] += tensor_nbytes(t) - continue - - # Small float tensors are cheap enough to keep directly. We still downcast - # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. - if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: - kept = keep_float_tensor(name, t, passthrough_orig_dtypes) - passthrough[name] = kept - stats["int8_payload_bytes"] += tensor_nbytes(kept) - continue - - stats["num_float_tensors"] += 1 - q, s = quantize_float_tensor(t) - if s.ndim > 0: - qmeta[name] = {"scheme": "per_row", "axis": 0} - quantized[name] = q - scales[name] = s - dtypes[name] = str(t.dtype).removeprefix("torch.") - stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) - - obj: dict[str, object] = { - "__quant_format__": "int8_clean_per_row_v1", - "quantized": quantized, - "scales": scales, - "dtypes": dtypes, - "passthrough": passthrough, - } - if qmeta: - obj["qmeta"] = qmeta - if passthrough_orig_dtypes: - obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes - return obj, stats - -def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: - out: dict[str, Tensor] = {} - qmeta = obj.get("qmeta", {}) - passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) - for name, q in obj["quantized"].items(): - dtype = getattr(torch, obj["dtypes"][name]) - s = obj["scales"][name] - if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: - s = s.to(dtype=torch.float32) - # Broadcast the saved row scale back across trailing dimensions. - out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() - else: - scale = float(s.item()) - out[name] = (q.float() * scale).to(dtype=dtype).contiguous() - for name, t in obj["passthrough"].items(): - # Restore small tensors, undoing the temporary fp16 storage cast if needed. - out_t = t.detach().to("cpu").contiguous() - orig_dtype = passthrough_orig_dtypes.get(name) - if isinstance(orig_dtype, str): - out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() - out[name] = out_t - return out - - -# ----------------------------- -# DATA LOADING -# ----------------------------- - -def load_data_shard(file: Path) -> Tensor: - header_bytes = 256 * np.dtype(" None: - self.file_idx = (self.file_idx + 1) % len(self.files) - self.tokens = load_data_shard(self.files[self.file_idx]) - self.pos = 0 - - def take(self, n: int) -> Tensor: - chunks: list[Tensor] = [] - remaining = n - while remaining > 0: - avail = self.tokens.numel() - self.pos - if avail <= 0: - self._advance_file() - continue - k = min(remaining, avail) - chunks.append(self.tokens[self.pos : self.pos + k]) - self.pos += k - remaining -= k - return chunks[0] if len(chunks) == 1 else torch.cat(chunks) - - -class DistributedTokenLoader: - # Each call consumes a contiguous chunk from the shared token stream, then slices out - # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. - def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): - self.rank = rank - self.world_size = world_size - self.device = device - self.stream = TokenStream(pattern) - - def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: - local_tokens = global_tokens // (self.world_size * grad_accum_steps) - per_rank_span = local_tokens + 1 - chunk = self.stream.take(per_rank_span * self.world_size) - start = self.rank * per_rank_span - local = chunk[start : start + per_rank_span].to(dtype=torch.int64) - x = local[:-1].reshape(-1, seq_len) - y = local[1:].reshape(-1, seq_len) - return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) - -# ----------------------------- -# TRANSFORMER MODULES -# ----------------------------- - -class RMSNorm(nn.Module): - def __init__(self, eps: float | None = None): - super().__init__() - self.eps = eps - - def forward(self, x: Tensor) -> Tensor: - return F.rms_norm(x, (x.size(-1),), eps=self.eps) - - -class CastedLinear(nn.Linear): - # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. - def forward(self, x: Tensor) -> Tensor: - bias = self.bias.to(x.dtype) if self.bias is not None else None - return F.linear(x, self.weight.to(x.dtype), bias) - - -def restore_low_dim_params_to_fp32(module: nn.Module) -> None: - # Keep small/control parameters in fp32 even when the model body runs in bf16. - with torch.no_grad(): - for name, param in module.named_parameters(): - if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: - param.data = param.data.float() - - -class Rotary(nn.Module): - # Caches cos/sin tables per sequence length on the current device. - def __init__(self, dim: int, base: float = 10000.0): - super().__init__() - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self._seq_len_cached = 0 - self._cos_cached: Tensor | None = None - self._sin_cached: Tensor | None = None - - def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: - if ( - self._cos_cached is None - or self._sin_cached is None - or self._seq_len_cached != seq_len - or self._cos_cached.device != device - ): - t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) - freqs = torch.outer(t, self.inv_freq.to(device)) - self._cos_cached = freqs.cos()[None, None, :, :] - self._sin_cached = freqs.sin()[None, None, :, :] - self._seq_len_cached = seq_len - return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) - - -def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: - half = x.size(-1) // 2 - x1, x2 = x[..., :half], x[..., half:] - return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) - - -class CausalSelfAttention(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - rope_base: float, - qk_gain_init: float, - ): - super().__init__() - if dim % num_heads != 0: - raise ValueError("model_dim must be divisible by num_heads") - if num_heads % num_kv_heads != 0: - raise ValueError("num_heads must be divisible by num_kv_heads") - self.num_heads = num_heads - self.num_kv_heads = num_kv_heads - self.head_dim = dim // num_heads - if self.head_dim % 2 != 0: - raise ValueError("head_dim must be even for RoPE") - kv_dim = self.num_kv_heads * self.head_dim - self.c_q = CastedLinear(dim, dim, bias=False) - self.c_k = CastedLinear(dim, kv_dim, bias=False) - self.c_v = CastedLinear(dim, kv_dim, bias=False) - self.proj = CastedLinear(dim, dim, bias=False) - self.proj._zero_init = True - self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) - self.rotary = Rotary(self.head_dim, base=rope_base) - - def forward(self, x: Tensor) -> Tensor: - bsz, seqlen, dim = x.shape - q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) - k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) - v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) - q = F.rms_norm(q, (q.size(-1),)) - k = F.rms_norm(k, (k.size(-1),)) - cos, sin = self.rotary(seqlen, x.device, q.dtype) - q = apply_rotary_emb(q, cos, sin) - k = apply_rotary_emb(k, cos, sin) - q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] - y = F.scaled_dot_product_attention( - q, - k, - v, - attn_mask=None, - is_causal=True, - enable_gqa=(self.num_kv_heads != self.num_heads), - ) - y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) - return self.proj(y) - - -class MLP(nn.Module): - # relu^2 MLP from the original modded-nanogpt setup - def __init__(self, dim: int, mlp_mult: int): - super().__init__() - hidden = mlp_mult * dim - self.fc = CastedLinear(dim, hidden, bias=False) - self.proj = CastedLinear(hidden, dim, bias=False) - self.proj._zero_init = True - - def forward(self, x: Tensor) -> Tensor: - x = torch.relu(self.fc(x)) - return self.proj(x.square()) - - -class Block(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: int, - rope_base: float, - qk_gain_init: float, - ): - super().__init__() - self.attn_norm = RMSNorm() - self.mlp_norm = RMSNorm() - self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) - self.mlp = MLP(dim, mlp_mult) - self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) - - def forward(self, x: Tensor, x0: Tensor) -> Tensor: - mix = self.resid_mix.to(dtype=x.dtype) - x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 - attn_out = self.attn(self.attn_norm(x)) - x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out - x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) - return x - - -class GPT(nn.Module): - def __init__( - self, - vocab_size: int, - num_layers: int, - model_dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: int, - tie_embeddings: bool, - tied_embed_init_std: float, - logit_softcap: float, - rope_base: float, - qk_gain_init: float, - ): - super().__init__() - if logit_softcap <= 0.0: - raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") - self.tie_embeddings = tie_embeddings - self.tied_embed_init_std = tied_embed_init_std - self.logit_softcap = logit_softcap - self.tok_emb = nn.Embedding(vocab_size, model_dim) - self.num_encoder_layers = num_layers // 2 - self.num_decoder_layers = num_layers - self.num_encoder_layers - self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) - self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) - self.blocks = nn.ModuleList( - [ - Block( - model_dim, - num_heads, - num_kv_heads, - mlp_mult, - rope_base, - qk_gain_init, - ) - for i in range(num_layers) - ] - ) - self.final_norm = RMSNorm() - self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) - if self.lm_head is not None: - self.lm_head._zero_init = True - self._init_weights() - - def _init_weights(self) -> None: - if self.tie_embeddings: - nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) - for module in self.modules(): - if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): - nn.init.zeros_(module.weight) - - def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: - x = self.tok_emb(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - x0 = x - skips: list[Tensor] = [] - - # First half stores skips; second half reuses them in reverse order. - for i in range(self.num_encoder_layers): - x = self.blocks[i](x, x0) - skips.append(x) - for i in range(self.num_decoder_layers): - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - x = self.blocks[self.num_encoder_layers + i](x, x0) - - x = self.final_norm(x).reshape(-1, x.size(-1)) - targets = target_ids.reshape(-1) - if self.tie_embeddings: - logits_proj = F.linear(x, self.tok_emb.weight) - else: - if self.lm_head is None: - raise RuntimeError("lm_head is required when tie_embeddings=False") - logits_proj = self.lm_head(x) - logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - return F.cross_entropy(logits.float(), targets, reduction="mean") - - -# ----------------------------- -# TRAINING -# ----------------------------- - -def main() -> None: - global zeropower_via_newtonschulz5 - - code = Path(__file__).read_text(encoding="utf-8") - args = Hyperparameters() - zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) - - # ----------------------------- - # DISTRIBUTED + CUDA SETUP - # ----------------------------- - - distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ - rank = int(os.environ.get("RANK", "0")) - world_size = int(os.environ.get("WORLD_SIZE", "1")) - local_rank = int(os.environ.get("LOCAL_RANK", "0")) - if world_size <= 0: - raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") - if 8 % world_size != 0: - raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") - grad_accum_steps = 8 // world_size - grad_scale = 1.0 / grad_accum_steps - if not torch.cuda.is_available(): - raise RuntimeError("CUDA is required") - device = torch.device("cuda", local_rank) - torch.cuda.set_device(device) - if distributed: - dist.init_process_group(backend="nccl", device_id=device) - dist.barrier() - master_process = rank == 0 - - # Fast math knobs - torch.backends.cuda.matmul.allow_tf32 = True - torch.backends.cudnn.allow_tf32 = True - from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp - - enable_cudnn_sdp(False) - enable_flash_sdp(True) - enable_mem_efficient_sdp(False) - enable_math_sdp(False) - - logfile = None - if master_process: - os.makedirs("logs", exist_ok=True) - logfile = f"logs/{args.run_id}.txt" - print(logfile) - - def log0(msg: str, console: bool = True) -> None: - if not master_process: - return - if console: - print(msg) - if logfile is not None: - with open(logfile, "a", encoding="utf-8") as f: - print(msg, file=f) - - log0(code, console=False) - log0("=" * 100, console=False) - log0(f"Running Python {sys.version}", console=False) - log0(f"Running PyTorch {torch.__version__}", console=False) - log0( - subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, - console=False, - ) - log0("=" * 100, console=False) - - # ----------------------------- - # TOKENIZER + VALIDATION METRIC SETUP - # ----------------------------- - - random.seed(args.seed) - np.random.seed(args.seed) - torch.manual_seed(args.seed) - torch.cuda.manual_seed_all(args.seed) - - if not args.tokenizer_path.endswith(".model"): - raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") - sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) - if int(sp.vocab_size()) != args.vocab_size: - raise ValueError( - f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" - ) - dataset_dir = Path(args.data_path).resolve() - actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) - val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) - base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( - sp, args.vocab_size, device - ) - log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") - log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") - log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") - - # ----------------------------- - # MODEL + OPTIMIZER SETUP - # ----------------------------- - - base_model = GPT( - vocab_size=args.vocab_size, - num_layers=args.num_layers, - model_dim=args.model_dim, - num_heads=args.num_heads, - num_kv_heads=args.num_kv_heads, - mlp_mult=args.mlp_mult, - tie_embeddings=args.tie_embeddings, - tied_embed_init_std=args.tied_embed_init_std, - logit_softcap=args.logit_softcap, - rope_base=args.rope_base, - qk_gain_init=args.qk_gain_init, - ).to(device).bfloat16() - for module in base_model.modules(): - if isinstance(module, CastedLinear): - module.float() - restore_low_dim_params_to_fp32(base_model) - compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) - model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model - - # Optimizer split: - # - token embedding (Adam) uses EMBED_LR - # - untied lm_head (Adam) uses HEAD_LR - # - matrix params in transformer blocks use MATRIX_LR via Muon - # - vectors/scalars use SCALAR_LR via Adam - block_named_params = list(base_model.blocks.named_parameters()) - matrix_params = [ - p - for name, p in block_named_params - if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - scalar_params = [ - p - for name, p in block_named_params - if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - if base_model.skip_weights.numel() > 0: - scalar_params.append(base_model.skip_weights) - token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr - optimizer_tok = torch.optim.Adam( - [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizer_muon = Muon( - matrix_params, - lr=args.matrix_lr, - momentum=args.muon_momentum, - backend_steps=args.muon_backend_steps, - ) - for group in optimizer_muon.param_groups: - group["base_lr"] = args.matrix_lr - optimizer_scalar = torch.optim.Adam( - [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] - if base_model.lm_head is not None: - optimizer_head = torch.optim.Adam( - [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizers.insert(1, optimizer_head) - - n_params = sum(p.numel() for p in base_model.parameters()) - log0(f"model_params:{n_params}") - log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") - log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") - log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") - log0( - f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " - f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " - f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" - ) - log0( - f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " - f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " - f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" - ) - log0(f"seed:{args.seed}") - - # ----------------------------- - # DATA LOADER & MODEL WARMUP - # ----------------------------- - - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - def zero_grad_all() -> None: - for opt in optimizers: - opt.zero_grad(set_to_none=True) - - max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None - - def lr_mul(step: int, elapsed_ms: float) -> float: - if args.warmdown_iters <= 0: - return 1.0 - if max_wallclock_ms is None: - warmdown_start = max(args.iterations - args.warmdown_iters, 0) - return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 - step_ms = elapsed_ms / max(step, 1) - warmdown_ms = args.warmdown_iters * step_ms - remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) - return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 - - # Warmup primes the compiled forward/backward/optimizer paths, then we restore the - # initial weights/optimizer state so measured training starts from the true init. - if args.warmup_steps > 0: - initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} - initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] - model.train() - for warmup_step in range(args.warmup_steps): - zero_grad_all() - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - warmup_loss = model(x, y) - (warmup_loss * grad_scale).backward() - for opt in optimizers: - opt.step() - zero_grad_all() - if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: - log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") - base_model.load_state_dict(initial_model_state, strict=True) - for opt, state in zip(optimizers, initial_optimizer_states, strict=True): - opt.load_state_dict(state) - zero_grad_all() - if distributed: - model.require_backward_grad_sync = True - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - # ----------------------------- - # MAIN TRAINING LOOP - # ----------------------------- - - training_time_ms = 0.0 - stop_after_step: int | None = None - torch.cuda.synchronize() - t0 = time.perf_counter() - - step = 0 - while True: - last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) - - should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) - if should_validate: - torch.cuda.synchronize() - training_time_ms += 1000.0 * (time.perf_counter() - t0) - val_loss, val_bpb = eval_val( - args, - model, - rank, - world_size, - device, - grad_accum_steps, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, - ) - log0( - f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " - f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" - ) - torch.cuda.synchronize() - t0 = time.perf_counter() - - if last_step: - if stop_after_step is not None and step < args.iterations: - log0( - f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " - f"step:{step}/{args.iterations}" - ) - break - - elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - scale = lr_mul(step, elapsed_ms) - zero_grad_all() - train_loss = torch.zeros((), device=device) - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - loss = model(x, y) - train_loss += loss.detach() - (loss * grad_scale).backward() - train_loss /= grad_accum_steps - - frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 - muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum - for group in optimizer_muon.param_groups: - group["momentum"] = muon_momentum - - for opt in optimizers: - for group in opt.param_groups: - group["lr"] = group["base_lr"] * scale - - if args.grad_clip_norm > 0: - torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) - for opt in optimizers: - opt.step() - zero_grad_all() - - step += 1 - approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - should_log_train = ( - args.train_log_every > 0 - and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) - ) - if should_log_train: - log0( - f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " - f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" - ) - - # Needed to sync whether we've reached the wallclock cap. - reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms - if distributed and max_wallclock_ms is not None: - reached_cap_tensor = torch.tensor(int(reached_cap), device=device) - dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) - reached_cap = bool(reached_cap_tensor.item()) - if stop_after_step is None and reached_cap: - stop_after_step = step - - log0( - f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " - f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" - ) - - # ----------------------------- - # SERIALIZATION + ROUNDTRIP VALIDATION - # ----------------------------- - # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce - # the compressed int8+zlib artifact and validate the round-tripped weights. - - if master_process: - torch.save(base_model.state_dict(), "final_model.pt") - model_bytes = os.path.getsize("final_model.pt") - code_bytes = len(code.encode("utf-8")) - log0(f"Serialized model: {model_bytes} bytes") - log0(f"Code size: {code_bytes} bytes") - log0(f"Total submission size: {model_bytes + code_bytes} bytes") - - quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) - quant_buf = io.BytesIO() - torch.save(quant_obj, quant_buf) - quant_raw = quant_buf.getvalue() - quant_blob = zlib.compress(quant_raw, level=9) - quant_raw_bytes = len(quant_raw) - if master_process: - with open("final_model.int8.ptz", "wb") as f: - f.write(quant_blob) - quant_file_bytes = os.path.getsize("final_model.int8.ptz") - code_bytes = len(code.encode("utf-8")) - ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) - log0( - f"Serialized model int8+zlib: {quant_file_bytes} bytes " - f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" - ) - log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") - - if distributed: - dist.barrier() - with open("final_model.int8.ptz", "rb") as f: - quant_blob_disk = f.read() - quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") - base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) - torch.cuda.synchronize() - t_qeval = time.perf_counter() - q_val_loss, q_val_bpb = eval_val( - args, - model, - rank, - world_size, - device, - grad_accum_steps, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, - ) - torch.cuda.synchronize() - log0( - f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " - f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" - ) - log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") - - if distributed: - dist.destroy_process_group() - - -if __name__ == "__main__": - main() - -==================================================================================================== -Running Python 3.12.3 (main, Aug 14 2025, 17:47:21) [GCC 13.3.0] -Running PyTorch 2.8.0+cu128 -Thu Mar 19 04:28:27 2026 -+-----------------------------------------------------------------------------------------+ -| NVIDIA-SMI 570.211.01 Driver Version: 570.211.01 CUDA Version: 12.8 | -|-----------------------------------------+------------------------+----------------------+ -| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | -| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | -| | | MIG M. | -|=========================================+========================+======================| -| 0 NVIDIA H100 80GB HBM3 On | 00000000:18:00.0 Off | 0 | -| N/A 25C P0 115W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 1 NVIDIA H100 80GB HBM3 On | 00000000:2A:00.0 Off | 0 | -| N/A 25C P0 114W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 2 NVIDIA H100 80GB HBM3 On | 00000000:3A:00.0 Off | 0 | -| N/A 27C P0 121W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 3 NVIDIA H100 80GB HBM3 On | 00000000:5D:00.0 Off | 0 | -| N/A 24C P0 114W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 4 NVIDIA H100 80GB HBM3 On | 00000000:9A:00.0 Off | 0 | -| N/A 24C P0 112W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 5 NVIDIA H100 80GB HBM3 On | 00000000:AB:00.0 Off | 0 | -| N/A 26C P0 120W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 6 NVIDIA H100 80GB HBM3 On | 00000000:BA:00.0 Off | 0 | -| N/A 25C P0 116W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 7 NVIDIA H100 80GB HBM3 On | 00000000:DB:00.0 Off | 0 | -| N/A 23C P0 117W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ - -+-----------------------------------------------------------------------------------------+ -| Processes: | -| GPU GI CI PID Type Process name GPU Memory | -| ID ID Usage | -|=========================================================================================| -| 0 N/A N/A 187171 C /usr/local/bin/python 1510MiB | -| 1 N/A N/A 187172 C /usr/local/bin/python 1510MiB | -| 2 N/A N/A 187173 C /usr/local/bin/python 1510MiB | -| 3 N/A N/A 187174 C /usr/local/bin/python 1510MiB | -| 4 N/A N/A 187175 C /usr/local/bin/python 1510MiB | -| 5 N/A N/A 187176 C /usr/local/bin/python 1510MiB | -| 6 N/A N/A 187177 C /usr/local/bin/python 1510MiB | -| 7 N/A N/A 187178 C /usr/local/bin/python 1510MiB | -+-----------------------------------------------------------------------------------------+ - -==================================================================================================== -val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model -train_loader:dataset:fineweb10B_sp1024 train_shards:80 -val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 -model_params:17059912 -world_size:8 grad_accum_steps:1 -sdp_backends:cudnn=False flash=True mem_efficient=False math=False -attention_mode:gqa num_heads:8 num_kv_heads:4 -tie_embeddings:True embed_lr:0.03 head_lr:0.0 matrix_lr:0.02 scalar_lr:0.02 -train_batch_tokens:393216 train_seq_len:4096 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 -seed:1337 -warmup_step:1/20 -warmup_step:2/20 -warmup_step:3/20 -warmup_step:4/20 -warmup_step:5/20 -warmup_step:6/20 -warmup_step:7/20 -warmup_step:8/20 -warmup_step:9/20 -warmup_step:10/20 -warmup_step:11/20 -warmup_step:12/20 -warmup_step:13/20 -warmup_step:14/20 -warmup_step:15/20 -warmup_step:16/20 -warmup_step:17/20 -warmup_step:18/20 -warmup_step:19/20 -warmup_step:20/20 -step:0/20000 val_loss:6.9357 val_bpb:4.1077 train_time:0ms step_avg:0.01ms -step:1/20000 train_loss:6.9376 train_time:34ms step_avg:33.93ms -step:2/20000 train_loss:12.1429 train_time:102ms step_avg:50.88ms -step:3/20000 train_loss:7.4535 train_time:177ms step_avg:59.05ms -step:4/20000 train_loss:6.3258 train_time:239ms step_avg:59.81ms -step:5/20000 train_loss:6.9181 train_time:303ms step_avg:60.56ms -step:6/20000 train_loss:6.9134 train_time:386ms step_avg:64.37ms -step:7/20000 train_loss:6.7000 train_time:450ms step_avg:64.32ms -step:8/20000 train_loss:6.6610 train_time:515ms step_avg:64.35ms -step:9/20000 train_loss:6.3090 train_time:606ms step_avg:67.28ms -step:10/20000 train_loss:6.2067 train_time:671ms step_avg:67.11ms -step:2000/20000 train_loss:2.3125 train_time:128919ms step_avg:64.46ms -step:2000/20000 val_loss:2.2111 val_bpb:1.3095 train_time:128959ms step_avg:64.48ms -step:4000/20000 train_loss:1.9264 train_time:277104ms step_avg:69.28ms -step:4000/20000 val_loss:2.1202 val_bpb:1.2557 train_time:277140ms step_avg:69.28ms -step:6000/20000 train_loss:1.9978 train_time:428170ms step_avg:71.36ms -step:6000/20000 val_loss:2.0783 val_bpb:1.2309 train_time:428206ms step_avg:71.37ms -step:8000/20000 train_loss:2.0986 train_time:570391ms step_avg:71.30ms -step:8000/20000 val_loss:2.0286 val_bpb:1.2015 train_time:570423ms step_avg:71.30ms -step:8394/20000 val_loss:2.0227 val_bpb:1.1980 train_time:599921ms step_avg:71.47ms -stopping_early: wallclock_cap train_time:599921ms step:8394/20000 -peak memory allocated: 7748 MiB reserved: 8070 MiB -Serialized model: 67224983 bytes -Code size: 47642 bytes -Total submission size: 67272625 bytes -Serialized model int8+zlib: 15820684 bytes (payload:17178912 raw_torch:17224025 payload_ratio:3.91x) -Total submission size int8+zlib: 15868326 bytes -final_int8_zlib_roundtrip val_loss:2.0286 val_bpb:1.2014 eval_time:2107ms -final_int8_zlib_roundtrip_exact val_loss:2.02857127 val_bpb:1.20143417 diff --git a/records/track_10min_16mb/2026-03-19_TrainingOptSeq4096/train_gpt.py b/records/track_10min_16mb/2026-03-19_TrainingOptSeq4096/train_gpt.py deleted file mode 100644 index 8b85823c2a..0000000000 --- a/records/track_10min_16mb/2026-03-19_TrainingOptSeq4096/train_gpt.py +++ /dev/null @@ -1,1127 +0,0 @@ -""" -The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. - -Hard stop: `train_gpt.py` and `train_gpt_mlx.py` must never be longer than 1500 lines. -""" - -from __future__ import annotations - -import copy -import glob -import io -import math -import os -import random -import subprocess -import sys -import time -import uuid -import zlib -from pathlib import Path - -import numpy as np -import sentencepiece as spm -import torch -import torch.distributed as dist -import torch.nn.functional as F -from torch import Tensor, nn -from torch.nn.parallel import DistributedDataParallel as DDP - -# ----------------------------- -# HYPERPARAMETERS -# ----------------------------- -# Default Training Opt Seq4096 v1 run: -# - 9 transformer blocks at width 512 -# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion -# - vocab size 1024, sequence length 4096, tied embeddings -# - 393,216 train tokens per step (3/4 batch) for 20,000 iterations with a ~10 minute cap -# - tuned Muon optimizer: momentum 0.99, lower LR (0.020/0.020/0.030), warmdown 3000 - -class Hyperparameters: - # Data paths are shard globs produced by the existing preprocessing pipeline. - data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") - train_files = os.path.join(data_path, "fineweb_train_*.bin") - val_files = os.path.join(data_path, "fineweb_val_*.bin") - tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") - run_id = os.environ.get("RUN_ID", "training_opt_seq4096_v1") - seed = int(os.environ.get("SEED", 1337)) - - # Validation cadence and batch size. Validation always uses the full fineweb_val split. - val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) - val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) - train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) - - # Training length. - iterations = int(os.environ.get("ITERATIONS", 20000)) - warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3000)) - warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) - train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 393_216)) - train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 4096)) - max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) - qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) - - # Model shape. - vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) - num_layers = int(os.environ.get("NUM_LAYERS", 9)) - num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) - model_dim = int(os.environ.get("MODEL_DIM", 512)) - num_heads = int(os.environ.get("NUM_HEADS", 8)) - mlp_mult = int(os.environ.get("MLP_MULT", 2)) - tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) - rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) - logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) - - # Optimizer hyperparameters. - embed_lr = float(os.environ.get("EMBED_LR", 0.6)) - head_lr = float(os.environ.get("HEAD_LR", 0.008)) - tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.030)) - tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) - matrix_lr = float(os.environ.get("MATRIX_LR", 0.020)) - scalar_lr = float(os.environ.get("SCALAR_LR", 0.020)) - muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) - muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) - muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) - muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) - beta1 = float(os.environ.get("BETA1", 0.9)) - beta2 = float(os.environ.get("BETA2", 0.95)) - adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) - grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) - -# ----------------------------- -# MUON OPTIMIZER -# ----------------------------- -# -# As borrowed from modded-nanogpt -# Background on Muon: https://kellerjordan.github.io/posts/muon/ - -def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: - # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. - # Muon uses this to normalize matrix-shaped gradients before applying them. - a, b, c = (3.4445, -4.7750, 2.0315) - X = G.bfloat16() - X /= X.norm() + eps - transposed = G.size(0) > G.size(1) - if transposed: - X = X.T - for _ in range(steps): - A = X @ X.T - B = b * A + c * A @ A - X = a * X + B @ X - return X.T if transposed else X - - -class Muon(torch.optim.Optimizer): - def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): - super().__init__( - params, - dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), - ) - - @torch.no_grad() - def step(self, closure=None): - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - distributed = dist.is_available() and dist.is_initialized() - world_size = dist.get_world_size() if distributed else 1 - rank = dist.get_rank() if distributed else 0 - - for group in self.param_groups: - params = group["params"] - if not params: - continue - lr = group["lr"] - momentum = group["momentum"] - backend_steps = group["backend_steps"] - nesterov = group["nesterov"] - - total_params = sum(int(p.numel()) for p in params) - updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) - - curr = 0 - for i, p in enumerate(params): - if i % world_size == rank and p.grad is not None: - g = p.grad - state = self.state[p] - if "momentum_buffer" not in state: - state["momentum_buffer"] = torch.zeros_like(g) - buf = state["momentum_buffer"] - buf.mul_(momentum).add_(g) - if nesterov: - g = g.add(buf, alpha=momentum) - g = zeropower_via_newtonschulz5(g, steps=backend_steps) - # Scale correction from Muon reference implementations. - g *= max(1, g.size(0) / g.size(1)) ** 0.5 - updates_flat[curr : curr + p.numel()] = g.reshape(-1) - curr += p.numel() - - if distributed: - dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) - - curr = 0 - for p in params: - g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) - p.add_(g, alpha=-lr) - curr += p.numel() - - return loss - - -# ----------------------------- -# TOKENIZER-AGNOSTIC EVALUATION SETUP -# ----------------------------- -# -# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. -# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. -# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. -# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. - -def build_sentencepiece_luts( - sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device -) -> tuple[Tensor, Tensor, Tensor]: - sp_vocab_size = int(sp.vocab_size()) - table_size = max(sp_vocab_size, vocab_size) - base_bytes_np = np.zeros((table_size,), dtype=np.int16) - has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) - is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) - for token_id in range(sp_vocab_size): - if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): - continue - is_boundary_token_np[token_id] = False - if sp.is_byte(token_id): - base_bytes_np[token_id] = 1 - continue - piece = sp.id_to_piece(token_id) - if piece.startswith("▁"): - has_leading_space_np[token_id] = True - piece = piece[1:] - base_bytes_np[token_id] = len(piece.encode("utf-8")) - return ( - torch.tensor(base_bytes_np, dtype=torch.int16, device=device), - torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), - torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), - ) - - -def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: - files = [Path(p) for p in sorted(glob.glob(pattern))] - if not files: - raise FileNotFoundError(f"No files found for pattern: {pattern}") - # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. - tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() - usable = ((tokens.numel() - 1) // seq_len) * seq_len - if usable <= 0: - raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") - return tokens[: usable + 1] - - -def eval_val( - args: Hyperparameters, - model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - grad_accum_steps: int, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, -) -> tuple[float, float]: - # Validation computes two metrics: - # - val_loss: token cross-entropy (natural log) - # - val_bpb: tokenizer-agnostic compression metric used by the challenge - local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) - if local_batch_tokens < args.train_seq_len: - raise ValueError( - "VAL_BATCH_SIZE must provide at least one sequence per rank; " - f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " - f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" - ) - local_batch_seqs = local_batch_tokens // args.train_seq_len - total_seqs = (val_tokens.numel() - 1) // args.train_seq_len - seq_start = (total_seqs * rank) // world_size - seq_end = (total_seqs * (rank + 1)) // world_size - val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) - val_token_count = torch.zeros((), device=device, dtype=torch.float64) - val_byte_count = torch.zeros((), device=device, dtype=torch.float64) - - model.eval() - with torch.inference_mode(): - for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): - batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) - raw_start = batch_seq_start * args.train_seq_len - raw_end = batch_seq_end * args.train_seq_len + 1 - local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) - x = local[:-1].reshape(-1, args.train_seq_len) - y = local[1:].reshape(-1, args.train_seq_len) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - batch_loss = model(x, y).detach() - batch_token_count = float(y.numel()) - val_loss_sum += batch_loss.to(torch.float64) * batch_token_count - val_token_count += batch_token_count - prev_ids = x.reshape(-1) - tgt_ids = y.reshape(-1) - token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) - token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) - val_byte_count += token_bytes.to(torch.float64).sum() - - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) - - val_loss = val_loss_sum / val_token_count - bits_per_token = val_loss.item() / math.log(2.0) - tokens_per_byte = val_token_count.item() / val_byte_count.item() - model.train() - return float(val_loss.item()), float(bits_per_token * tokens_per_byte) - -# ----------------------------- -# POST-TRAINING QUANTIZATION -# ----------------------------- -# -# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. -# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. -# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. - -CONTROL_TENSOR_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "CONTROL_TENSOR_NAME_PATTERNS", - "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", - ",".join(CONTROL_TENSOR_NAME_PATTERNS), - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 -INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 -INT8_PER_ROW_SCALE_DTYPE = torch.float16 -INT8_CLIP_PERCENTILE = 99.99984 -INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 - -def tensor_nbytes(t: Tensor) -> int: - return int(t.numel()) * int(t.element_size()) - -def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: - if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): - return t.float().contiguous() - if t.dtype in {torch.float32, torch.bfloat16}: - passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") - return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() - return t - -def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: - t32 = t.float() - if t32.ndim == 2: - # Matrices get one scale per row, which usually tracks output-channel - # ranges much better than a single tensor-wide scale. - clip_abs = ( - torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) - if t32.numel() - else torch.empty((t32.shape[0],), dtype=torch.float32) - ) - clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) - scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) - q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() - return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() - - # Vectors / scalars use a simpler per-tensor scale. - clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 - scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) - q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() - return q, scale - -def quantize_state_dict_int8(state_dict: dict[str, Tensor]): - # Single supported clean-script export format: - # - per-row int8 for 2D float tensors - # - per-tensor int8 for other float tensors - # - exact passthrough for non-floats - # - passthrough for small float tensors, stored as fp16 to save bytes - quantized: dict[str, Tensor] = {} - scales: dict[str, Tensor] = {} - dtypes: dict[str, str] = {} - passthrough: dict[str, Tensor] = {} - passthrough_orig_dtypes: dict[str, str] = {} - qmeta: dict[str, dict[str, object]] = {} - stats = dict.fromkeys( - ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), - 0, - ) - - for name, tensor in state_dict.items(): - t = tensor.detach().to("cpu").contiguous() - stats["param_count"] += int(t.numel()) - stats["num_tensors"] += 1 - stats["baseline_tensor_bytes"] += tensor_nbytes(t) - - if not t.is_floating_point(): - stats["num_nonfloat_tensors"] += 1 - passthrough[name] = t - stats["int8_payload_bytes"] += tensor_nbytes(t) - continue - - # Small float tensors are cheap enough to keep directly. We still downcast - # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. - if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: - kept = keep_float_tensor(name, t, passthrough_orig_dtypes) - passthrough[name] = kept - stats["int8_payload_bytes"] += tensor_nbytes(kept) - continue - - stats["num_float_tensors"] += 1 - q, s = quantize_float_tensor(t) - if s.ndim > 0: - qmeta[name] = {"scheme": "per_row", "axis": 0} - quantized[name] = q - scales[name] = s - dtypes[name] = str(t.dtype).removeprefix("torch.") - stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) - - obj: dict[str, object] = { - "__quant_format__": "int8_clean_per_row_v1", - "quantized": quantized, - "scales": scales, - "dtypes": dtypes, - "passthrough": passthrough, - } - if qmeta: - obj["qmeta"] = qmeta - if passthrough_orig_dtypes: - obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes - return obj, stats - -def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: - out: dict[str, Tensor] = {} - qmeta = obj.get("qmeta", {}) - passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) - for name, q in obj["quantized"].items(): - dtype = getattr(torch, obj["dtypes"][name]) - s = obj["scales"][name] - if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: - s = s.to(dtype=torch.float32) - # Broadcast the saved row scale back across trailing dimensions. - out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() - else: - scale = float(s.item()) - out[name] = (q.float() * scale).to(dtype=dtype).contiguous() - for name, t in obj["passthrough"].items(): - # Restore small tensors, undoing the temporary fp16 storage cast if needed. - out_t = t.detach().to("cpu").contiguous() - orig_dtype = passthrough_orig_dtypes.get(name) - if isinstance(orig_dtype, str): - out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() - out[name] = out_t - return out - - -# ----------------------------- -# DATA LOADING -# ----------------------------- - -def load_data_shard(file: Path) -> Tensor: - header_bytes = 256 * np.dtype(" None: - self.file_idx = (self.file_idx + 1) % len(self.files) - self.tokens = load_data_shard(self.files[self.file_idx]) - self.pos = 0 - - def take(self, n: int) -> Tensor: - chunks: list[Tensor] = [] - remaining = n - while remaining > 0: - avail = self.tokens.numel() - self.pos - if avail <= 0: - self._advance_file() - continue - k = min(remaining, avail) - chunks.append(self.tokens[self.pos : self.pos + k]) - self.pos += k - remaining -= k - return chunks[0] if len(chunks) == 1 else torch.cat(chunks) - - -class DistributedTokenLoader: - # Each call consumes a contiguous chunk from the shared token stream, then slices out - # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. - def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): - self.rank = rank - self.world_size = world_size - self.device = device - self.stream = TokenStream(pattern) - - def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: - local_tokens = global_tokens // (self.world_size * grad_accum_steps) - per_rank_span = local_tokens + 1 - chunk = self.stream.take(per_rank_span * self.world_size) - start = self.rank * per_rank_span - local = chunk[start : start + per_rank_span].to(dtype=torch.int64) - x = local[:-1].reshape(-1, seq_len) - y = local[1:].reshape(-1, seq_len) - return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) - -# ----------------------------- -# TRANSFORMER MODULES -# ----------------------------- - -class RMSNorm(nn.Module): - def __init__(self, eps: float | None = None): - super().__init__() - self.eps = eps - - def forward(self, x: Tensor) -> Tensor: - return F.rms_norm(x, (x.size(-1),), eps=self.eps) - - -class CastedLinear(nn.Linear): - # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. - def forward(self, x: Tensor) -> Tensor: - bias = self.bias.to(x.dtype) if self.bias is not None else None - return F.linear(x, self.weight.to(x.dtype), bias) - - -def restore_low_dim_params_to_fp32(module: nn.Module) -> None: - # Keep small/control parameters in fp32 even when the model body runs in bf16. - with torch.no_grad(): - for name, param in module.named_parameters(): - if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: - param.data = param.data.float() - - -class Rotary(nn.Module): - # Caches cos/sin tables per sequence length on the current device. - def __init__(self, dim: int, base: float = 10000.0): - super().__init__() - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self._seq_len_cached = 0 - self._cos_cached: Tensor | None = None - self._sin_cached: Tensor | None = None - - def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: - if ( - self._cos_cached is None - or self._sin_cached is None - or self._seq_len_cached != seq_len - or self._cos_cached.device != device - ): - t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) - freqs = torch.outer(t, self.inv_freq.to(device)) - self._cos_cached = freqs.cos()[None, None, :, :] - self._sin_cached = freqs.sin()[None, None, :, :] - self._seq_len_cached = seq_len - return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) - - -def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: - half = x.size(-1) // 2 - x1, x2 = x[..., :half], x[..., half:] - return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) - - -class CausalSelfAttention(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - rope_base: float, - qk_gain_init: float, - ): - super().__init__() - if dim % num_heads != 0: - raise ValueError("model_dim must be divisible by num_heads") - if num_heads % num_kv_heads != 0: - raise ValueError("num_heads must be divisible by num_kv_heads") - self.num_heads = num_heads - self.num_kv_heads = num_kv_heads - self.head_dim = dim // num_heads - if self.head_dim % 2 != 0: - raise ValueError("head_dim must be even for RoPE") - kv_dim = self.num_kv_heads * self.head_dim - self.c_q = CastedLinear(dim, dim, bias=False) - self.c_k = CastedLinear(dim, kv_dim, bias=False) - self.c_v = CastedLinear(dim, kv_dim, bias=False) - self.proj = CastedLinear(dim, dim, bias=False) - self.proj._zero_init = True - self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) - self.rotary = Rotary(self.head_dim, base=rope_base) - - def forward(self, x: Tensor) -> Tensor: - bsz, seqlen, dim = x.shape - q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) - k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) - v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) - q = F.rms_norm(q, (q.size(-1),)) - k = F.rms_norm(k, (k.size(-1),)) - cos, sin = self.rotary(seqlen, x.device, q.dtype) - q = apply_rotary_emb(q, cos, sin) - k = apply_rotary_emb(k, cos, sin) - q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] - y = F.scaled_dot_product_attention( - q, - k, - v, - attn_mask=None, - is_causal=True, - enable_gqa=(self.num_kv_heads != self.num_heads), - ) - y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) - return self.proj(y) - - -class MLP(nn.Module): - # relu^2 MLP from the original modded-nanogpt setup - def __init__(self, dim: int, mlp_mult: int): - super().__init__() - hidden = mlp_mult * dim - self.fc = CastedLinear(dim, hidden, bias=False) - self.proj = CastedLinear(hidden, dim, bias=False) - self.proj._zero_init = True - - def forward(self, x: Tensor) -> Tensor: - x = torch.relu(self.fc(x)) - return self.proj(x.square()) - - -class Block(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: int, - rope_base: float, - qk_gain_init: float, - ): - super().__init__() - self.attn_norm = RMSNorm() - self.mlp_norm = RMSNorm() - self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) - self.mlp = MLP(dim, mlp_mult) - self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) - - def forward(self, x: Tensor, x0: Tensor) -> Tensor: - mix = self.resid_mix.to(dtype=x.dtype) - x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 - attn_out = self.attn(self.attn_norm(x)) - x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out - x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) - return x - - -class GPT(nn.Module): - def __init__( - self, - vocab_size: int, - num_layers: int, - model_dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: int, - tie_embeddings: bool, - tied_embed_init_std: float, - logit_softcap: float, - rope_base: float, - qk_gain_init: float, - ): - super().__init__() - if logit_softcap <= 0.0: - raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") - self.tie_embeddings = tie_embeddings - self.tied_embed_init_std = tied_embed_init_std - self.logit_softcap = logit_softcap - self.tok_emb = nn.Embedding(vocab_size, model_dim) - self.num_encoder_layers = num_layers // 2 - self.num_decoder_layers = num_layers - self.num_encoder_layers - self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) - self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) - self.blocks = nn.ModuleList( - [ - Block( - model_dim, - num_heads, - num_kv_heads, - mlp_mult, - rope_base, - qk_gain_init, - ) - for i in range(num_layers) - ] - ) - self.final_norm = RMSNorm() - self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) - if self.lm_head is not None: - self.lm_head._zero_init = True - self._init_weights() - - def _init_weights(self) -> None: - if self.tie_embeddings: - nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) - for module in self.modules(): - if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): - nn.init.zeros_(module.weight) - - def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: - x = self.tok_emb(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - x0 = x - skips: list[Tensor] = [] - - # First half stores skips; second half reuses them in reverse order. - for i in range(self.num_encoder_layers): - x = self.blocks[i](x, x0) - skips.append(x) - for i in range(self.num_decoder_layers): - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - x = self.blocks[self.num_encoder_layers + i](x, x0) - - x = self.final_norm(x).reshape(-1, x.size(-1)) - targets = target_ids.reshape(-1) - if self.tie_embeddings: - logits_proj = F.linear(x, self.tok_emb.weight) - else: - if self.lm_head is None: - raise RuntimeError("lm_head is required when tie_embeddings=False") - logits_proj = self.lm_head(x) - logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - return F.cross_entropy(logits.float(), targets, reduction="mean") - - -# ----------------------------- -# TRAINING -# ----------------------------- - -def main() -> None: - global zeropower_via_newtonschulz5 - - code = Path(__file__).read_text(encoding="utf-8") - args = Hyperparameters() - zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) - - # ----------------------------- - # DISTRIBUTED + CUDA SETUP - # ----------------------------- - - distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ - rank = int(os.environ.get("RANK", "0")) - world_size = int(os.environ.get("WORLD_SIZE", "1")) - local_rank = int(os.environ.get("LOCAL_RANK", "0")) - if world_size <= 0: - raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") - if 8 % world_size != 0: - raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") - grad_accum_steps = 8 // world_size - grad_scale = 1.0 / grad_accum_steps - if not torch.cuda.is_available(): - raise RuntimeError("CUDA is required") - device = torch.device("cuda", local_rank) - torch.cuda.set_device(device) - if distributed: - dist.init_process_group(backend="nccl", device_id=device) - dist.barrier() - master_process = rank == 0 - - # Fast math knobs - torch.backends.cuda.matmul.allow_tf32 = True - torch.backends.cudnn.allow_tf32 = True - from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp - - enable_cudnn_sdp(False) - enable_flash_sdp(True) - enable_mem_efficient_sdp(False) - enable_math_sdp(False) - - logfile = None - if master_process: - os.makedirs("logs", exist_ok=True) - logfile = f"logs/{args.run_id}.txt" - print(logfile) - - def log0(msg: str, console: bool = True) -> None: - if not master_process: - return - if console: - print(msg) - if logfile is not None: - with open(logfile, "a", encoding="utf-8") as f: - print(msg, file=f) - - log0(code, console=False) - log0("=" * 100, console=False) - log0(f"Running Python {sys.version}", console=False) - log0(f"Running PyTorch {torch.__version__}", console=False) - log0( - subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, - console=False, - ) - log0("=" * 100, console=False) - - # ----------------------------- - # TOKENIZER + VALIDATION METRIC SETUP - # ----------------------------- - - random.seed(args.seed) - np.random.seed(args.seed) - torch.manual_seed(args.seed) - torch.cuda.manual_seed_all(args.seed) - - if not args.tokenizer_path.endswith(".model"): - raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") - sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) - if int(sp.vocab_size()) != args.vocab_size: - raise ValueError( - f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" - ) - dataset_dir = Path(args.data_path).resolve() - actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) - val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) - base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( - sp, args.vocab_size, device - ) - log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") - log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") - log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") - - # ----------------------------- - # MODEL + OPTIMIZER SETUP - # ----------------------------- - - base_model = GPT( - vocab_size=args.vocab_size, - num_layers=args.num_layers, - model_dim=args.model_dim, - num_heads=args.num_heads, - num_kv_heads=args.num_kv_heads, - mlp_mult=args.mlp_mult, - tie_embeddings=args.tie_embeddings, - tied_embed_init_std=args.tied_embed_init_std, - logit_softcap=args.logit_softcap, - rope_base=args.rope_base, - qk_gain_init=args.qk_gain_init, - ).to(device).bfloat16() - for module in base_model.modules(): - if isinstance(module, CastedLinear): - module.float() - restore_low_dim_params_to_fp32(base_model) - compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) - model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model - - # Optimizer split: - # - token embedding (Adam) uses EMBED_LR - # - untied lm_head (Adam) uses HEAD_LR - # - matrix params in transformer blocks use MATRIX_LR via Muon - # - vectors/scalars use SCALAR_LR via Adam - block_named_params = list(base_model.blocks.named_parameters()) - matrix_params = [ - p - for name, p in block_named_params - if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - scalar_params = [ - p - for name, p in block_named_params - if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - if base_model.skip_weights.numel() > 0: - scalar_params.append(base_model.skip_weights) - token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr - optimizer_tok = torch.optim.Adam( - [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizer_muon = Muon( - matrix_params, - lr=args.matrix_lr, - momentum=args.muon_momentum, - backend_steps=args.muon_backend_steps, - ) - for group in optimizer_muon.param_groups: - group["base_lr"] = args.matrix_lr - optimizer_scalar = torch.optim.Adam( - [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] - if base_model.lm_head is not None: - optimizer_head = torch.optim.Adam( - [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizers.insert(1, optimizer_head) - - n_params = sum(p.numel() for p in base_model.parameters()) - log0(f"model_params:{n_params}") - log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") - log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") - log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") - log0( - f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " - f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " - f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" - ) - log0( - f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " - f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " - f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" - ) - log0(f"seed:{args.seed}") - - # ----------------------------- - # DATA LOADER & MODEL WARMUP - # ----------------------------- - - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - def zero_grad_all() -> None: - for opt in optimizers: - opt.zero_grad(set_to_none=True) - - max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None - - def lr_mul(step: int, elapsed_ms: float) -> float: - if args.warmdown_iters <= 0: - return 1.0 - if max_wallclock_ms is None: - warmdown_start = max(args.iterations - args.warmdown_iters, 0) - return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 - step_ms = elapsed_ms / max(step, 1) - warmdown_ms = args.warmdown_iters * step_ms - remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) - return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 - - # Warmup primes the compiled forward/backward/optimizer paths, then we restore the - # initial weights/optimizer state so measured training starts from the true init. - if args.warmup_steps > 0: - initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} - initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] - model.train() - for warmup_step in range(args.warmup_steps): - zero_grad_all() - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - warmup_loss = model(x, y) - (warmup_loss * grad_scale).backward() - for opt in optimizers: - opt.step() - zero_grad_all() - if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: - log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") - base_model.load_state_dict(initial_model_state, strict=True) - for opt, state in zip(optimizers, initial_optimizer_states, strict=True): - opt.load_state_dict(state) - zero_grad_all() - if distributed: - model.require_backward_grad_sync = True - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - # ----------------------------- - # MAIN TRAINING LOOP - # ----------------------------- - - training_time_ms = 0.0 - stop_after_step: int | None = None - torch.cuda.synchronize() - t0 = time.perf_counter() - - step = 0 - while True: - last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) - - should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) - if should_validate: - torch.cuda.synchronize() - training_time_ms += 1000.0 * (time.perf_counter() - t0) - val_loss, val_bpb = eval_val( - args, - model, - rank, - world_size, - device, - grad_accum_steps, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, - ) - log0( - f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " - f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" - ) - torch.cuda.synchronize() - t0 = time.perf_counter() - - if last_step: - if stop_after_step is not None and step < args.iterations: - log0( - f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " - f"step:{step}/{args.iterations}" - ) - break - - elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - scale = lr_mul(step, elapsed_ms) - zero_grad_all() - train_loss = torch.zeros((), device=device) - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - loss = model(x, y) - train_loss += loss.detach() - (loss * grad_scale).backward() - train_loss /= grad_accum_steps - - frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 - muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum - for group in optimizer_muon.param_groups: - group["momentum"] = muon_momentum - - for opt in optimizers: - for group in opt.param_groups: - group["lr"] = group["base_lr"] * scale - - if args.grad_clip_norm > 0: - torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) - for opt in optimizers: - opt.step() - zero_grad_all() - - step += 1 - approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - should_log_train = ( - args.train_log_every > 0 - and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) - ) - if should_log_train: - log0( - f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " - f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" - ) - - # Needed to sync whether we've reached the wallclock cap. - reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms - if distributed and max_wallclock_ms is not None: - reached_cap_tensor = torch.tensor(int(reached_cap), device=device) - dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) - reached_cap = bool(reached_cap_tensor.item()) - if stop_after_step is None and reached_cap: - stop_after_step = step - - log0( - f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " - f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" - ) - - # ----------------------------- - # SERIALIZATION + ROUNDTRIP VALIDATION - # ----------------------------- - # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce - # the compressed int8+zlib artifact and validate the round-tripped weights. - - if master_process: - torch.save(base_model.state_dict(), "final_model.pt") - model_bytes = os.path.getsize("final_model.pt") - code_bytes = len(code.encode("utf-8")) - log0(f"Serialized model: {model_bytes} bytes") - log0(f"Code size: {code_bytes} bytes") - log0(f"Total submission size: {model_bytes + code_bytes} bytes") - - quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) - quant_buf = io.BytesIO() - torch.save(quant_obj, quant_buf) - quant_raw = quant_buf.getvalue() - quant_blob = zlib.compress(quant_raw, level=9) - quant_raw_bytes = len(quant_raw) - if master_process: - with open("final_model.int8.ptz", "wb") as f: - f.write(quant_blob) - quant_file_bytes = os.path.getsize("final_model.int8.ptz") - code_bytes = len(code.encode("utf-8")) - ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) - log0( - f"Serialized model int8+zlib: {quant_file_bytes} bytes " - f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" - ) - log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") - - if distributed: - dist.barrier() - with open("final_model.int8.ptz", "rb") as f: - quant_blob_disk = f.read() - quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") - base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) - torch.cuda.synchronize() - t_qeval = time.perf_counter() - q_val_loss, q_val_bpb = eval_val( - args, - model, - rank, - world_size, - device, - grad_accum_steps, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, - ) - torch.cuda.synchronize() - log0( - f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " - f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" - ) - log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") - - if distributed: - dist.destroy_process_group() - - -if __name__ == "__main__": - main() diff --git a/records/track_10min_16mb/2026-03-19_TrainingOptSeq4096/train_seed1338.log b/records/track_10min_16mb/2026-03-19_TrainingOptSeq4096/train_seed1338.log deleted file mode 100644 index 9f42fe3f0f..0000000000 --- a/records/track_10min_16mb/2026-03-19_TrainingOptSeq4096/train_seed1338.log +++ /dev/null @@ -1,1293 +0,0 @@ -""" -The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. - -Hard stop: `train_gpt.py` and `train_gpt_mlx.py` must never be longer than 1500 lines. -""" - -from __future__ import annotations - -import copy -import glob -import io -import math -import os -import random -import subprocess -import sys -import time -import uuid -import zlib -from pathlib import Path - -import numpy as np -import sentencepiece as spm -import torch -import torch.distributed as dist -import torch.nn.functional as F -from torch import Tensor, nn -from torch.nn.parallel import DistributedDataParallel as DDP - -# ----------------------------- -# HYPERPARAMETERS -# ----------------------------- -# Default Training Opt Seq4096 v1 run: -# - 9 transformer blocks at width 512 -# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion -# - vocab size 1024, sequence length 4096, tied embeddings -# - 393,216 train tokens per step (3/4 batch) for 20,000 iterations with a ~10 minute cap -# - tuned Muon optimizer: momentum 0.99, lower LR (0.020/0.020/0.030), warmdown 3000 - -class Hyperparameters: - # Data paths are shard globs produced by the existing preprocessing pipeline. - data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") - train_files = os.path.join(data_path, "fineweb_train_*.bin") - val_files = os.path.join(data_path, "fineweb_val_*.bin") - tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") - run_id = os.environ.get("RUN_ID", "training_opt_seq4096_v1") - seed = int(os.environ.get("SEED", 1337)) - - # Validation cadence and batch size. Validation always uses the full fineweb_val split. - val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) - val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) - train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) - - # Training length. - iterations = int(os.environ.get("ITERATIONS", 20000)) - warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3000)) - warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) - train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 393_216)) - train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 4096)) - max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) - qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) - - # Model shape. - vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) - num_layers = int(os.environ.get("NUM_LAYERS", 9)) - num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) - model_dim = int(os.environ.get("MODEL_DIM", 512)) - num_heads = int(os.environ.get("NUM_HEADS", 8)) - mlp_mult = int(os.environ.get("MLP_MULT", 2)) - tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) - rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) - logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) - - # Optimizer hyperparameters. - embed_lr = float(os.environ.get("EMBED_LR", 0.6)) - head_lr = float(os.environ.get("HEAD_LR", 0.008)) - tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.030)) - tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) - matrix_lr = float(os.environ.get("MATRIX_LR", 0.020)) - scalar_lr = float(os.environ.get("SCALAR_LR", 0.020)) - muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) - muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) - muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) - muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) - beta1 = float(os.environ.get("BETA1", 0.9)) - beta2 = float(os.environ.get("BETA2", 0.95)) - adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) - grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) - -# ----------------------------- -# MUON OPTIMIZER -# ----------------------------- -# -# As borrowed from modded-nanogpt -# Background on Muon: https://kellerjordan.github.io/posts/muon/ - -def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: - # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. - # Muon uses this to normalize matrix-shaped gradients before applying them. - a, b, c = (3.4445, -4.7750, 2.0315) - X = G.bfloat16() - X /= X.norm() + eps - transposed = G.size(0) > G.size(1) - if transposed: - X = X.T - for _ in range(steps): - A = X @ X.T - B = b * A + c * A @ A - X = a * X + B @ X - return X.T if transposed else X - - -class Muon(torch.optim.Optimizer): - def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): - super().__init__( - params, - dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), - ) - - @torch.no_grad() - def step(self, closure=None): - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - distributed = dist.is_available() and dist.is_initialized() - world_size = dist.get_world_size() if distributed else 1 - rank = dist.get_rank() if distributed else 0 - - for group in self.param_groups: - params = group["params"] - if not params: - continue - lr = group["lr"] - momentum = group["momentum"] - backend_steps = group["backend_steps"] - nesterov = group["nesterov"] - - total_params = sum(int(p.numel()) for p in params) - updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) - - curr = 0 - for i, p in enumerate(params): - if i % world_size == rank and p.grad is not None: - g = p.grad - state = self.state[p] - if "momentum_buffer" not in state: - state["momentum_buffer"] = torch.zeros_like(g) - buf = state["momentum_buffer"] - buf.mul_(momentum).add_(g) - if nesterov: - g = g.add(buf, alpha=momentum) - g = zeropower_via_newtonschulz5(g, steps=backend_steps) - # Scale correction from Muon reference implementations. - g *= max(1, g.size(0) / g.size(1)) ** 0.5 - updates_flat[curr : curr + p.numel()] = g.reshape(-1) - curr += p.numel() - - if distributed: - dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) - - curr = 0 - for p in params: - g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) - p.add_(g, alpha=-lr) - curr += p.numel() - - return loss - - -# ----------------------------- -# TOKENIZER-AGNOSTIC EVALUATION SETUP -# ----------------------------- -# -# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. -# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. -# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. -# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. - -def build_sentencepiece_luts( - sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device -) -> tuple[Tensor, Tensor, Tensor]: - sp_vocab_size = int(sp.vocab_size()) - table_size = max(sp_vocab_size, vocab_size) - base_bytes_np = np.zeros((table_size,), dtype=np.int16) - has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) - is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) - for token_id in range(sp_vocab_size): - if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): - continue - is_boundary_token_np[token_id] = False - if sp.is_byte(token_id): - base_bytes_np[token_id] = 1 - continue - piece = sp.id_to_piece(token_id) - if piece.startswith("▁"): - has_leading_space_np[token_id] = True - piece = piece[1:] - base_bytes_np[token_id] = len(piece.encode("utf-8")) - return ( - torch.tensor(base_bytes_np, dtype=torch.int16, device=device), - torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), - torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), - ) - - -def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: - files = [Path(p) for p in sorted(glob.glob(pattern))] - if not files: - raise FileNotFoundError(f"No files found for pattern: {pattern}") - # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. - tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() - usable = ((tokens.numel() - 1) // seq_len) * seq_len - if usable <= 0: - raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") - return tokens[: usable + 1] - - -def eval_val( - args: Hyperparameters, - model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - grad_accum_steps: int, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, -) -> tuple[float, float]: - # Validation computes two metrics: - # - val_loss: token cross-entropy (natural log) - # - val_bpb: tokenizer-agnostic compression metric used by the challenge - local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) - if local_batch_tokens < args.train_seq_len: - raise ValueError( - "VAL_BATCH_SIZE must provide at least one sequence per rank; " - f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " - f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" - ) - local_batch_seqs = local_batch_tokens // args.train_seq_len - total_seqs = (val_tokens.numel() - 1) // args.train_seq_len - seq_start = (total_seqs * rank) // world_size - seq_end = (total_seqs * (rank + 1)) // world_size - val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) - val_token_count = torch.zeros((), device=device, dtype=torch.float64) - val_byte_count = torch.zeros((), device=device, dtype=torch.float64) - - model.eval() - with torch.inference_mode(): - for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): - batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) - raw_start = batch_seq_start * args.train_seq_len - raw_end = batch_seq_end * args.train_seq_len + 1 - local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) - x = local[:-1].reshape(-1, args.train_seq_len) - y = local[1:].reshape(-1, args.train_seq_len) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - batch_loss = model(x, y).detach() - batch_token_count = float(y.numel()) - val_loss_sum += batch_loss.to(torch.float64) * batch_token_count - val_token_count += batch_token_count - prev_ids = x.reshape(-1) - tgt_ids = y.reshape(-1) - token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) - token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) - val_byte_count += token_bytes.to(torch.float64).sum() - - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) - - val_loss = val_loss_sum / val_token_count - bits_per_token = val_loss.item() / math.log(2.0) - tokens_per_byte = val_token_count.item() / val_byte_count.item() - model.train() - return float(val_loss.item()), float(bits_per_token * tokens_per_byte) - -# ----------------------------- -# POST-TRAINING QUANTIZATION -# ----------------------------- -# -# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. -# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. -# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. - -CONTROL_TENSOR_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "CONTROL_TENSOR_NAME_PATTERNS", - "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", - ",".join(CONTROL_TENSOR_NAME_PATTERNS), - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 -INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 -INT8_PER_ROW_SCALE_DTYPE = torch.float16 -INT8_CLIP_PERCENTILE = 99.99984 -INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 - -def tensor_nbytes(t: Tensor) -> int: - return int(t.numel()) * int(t.element_size()) - -def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: - if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): - return t.float().contiguous() - if t.dtype in {torch.float32, torch.bfloat16}: - passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") - return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() - return t - -def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: - t32 = t.float() - if t32.ndim == 2: - # Matrices get one scale per row, which usually tracks output-channel - # ranges much better than a single tensor-wide scale. - clip_abs = ( - torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) - if t32.numel() - else torch.empty((t32.shape[0],), dtype=torch.float32) - ) - clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) - scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) - q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() - return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() - - # Vectors / scalars use a simpler per-tensor scale. - clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 - scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) - q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() - return q, scale - -def quantize_state_dict_int8(state_dict: dict[str, Tensor]): - # Single supported clean-script export format: - # - per-row int8 for 2D float tensors - # - per-tensor int8 for other float tensors - # - exact passthrough for non-floats - # - passthrough for small float tensors, stored as fp16 to save bytes - quantized: dict[str, Tensor] = {} - scales: dict[str, Tensor] = {} - dtypes: dict[str, str] = {} - passthrough: dict[str, Tensor] = {} - passthrough_orig_dtypes: dict[str, str] = {} - qmeta: dict[str, dict[str, object]] = {} - stats = dict.fromkeys( - ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), - 0, - ) - - for name, tensor in state_dict.items(): - t = tensor.detach().to("cpu").contiguous() - stats["param_count"] += int(t.numel()) - stats["num_tensors"] += 1 - stats["baseline_tensor_bytes"] += tensor_nbytes(t) - - if not t.is_floating_point(): - stats["num_nonfloat_tensors"] += 1 - passthrough[name] = t - stats["int8_payload_bytes"] += tensor_nbytes(t) - continue - - # Small float tensors are cheap enough to keep directly. We still downcast - # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. - if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: - kept = keep_float_tensor(name, t, passthrough_orig_dtypes) - passthrough[name] = kept - stats["int8_payload_bytes"] += tensor_nbytes(kept) - continue - - stats["num_float_tensors"] += 1 - q, s = quantize_float_tensor(t) - if s.ndim > 0: - qmeta[name] = {"scheme": "per_row", "axis": 0} - quantized[name] = q - scales[name] = s - dtypes[name] = str(t.dtype).removeprefix("torch.") - stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) - - obj: dict[str, object] = { - "__quant_format__": "int8_clean_per_row_v1", - "quantized": quantized, - "scales": scales, - "dtypes": dtypes, - "passthrough": passthrough, - } - if qmeta: - obj["qmeta"] = qmeta - if passthrough_orig_dtypes: - obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes - return obj, stats - -def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: - out: dict[str, Tensor] = {} - qmeta = obj.get("qmeta", {}) - passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) - for name, q in obj["quantized"].items(): - dtype = getattr(torch, obj["dtypes"][name]) - s = obj["scales"][name] - if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: - s = s.to(dtype=torch.float32) - # Broadcast the saved row scale back across trailing dimensions. - out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() - else: - scale = float(s.item()) - out[name] = (q.float() * scale).to(dtype=dtype).contiguous() - for name, t in obj["passthrough"].items(): - # Restore small tensors, undoing the temporary fp16 storage cast if needed. - out_t = t.detach().to("cpu").contiguous() - orig_dtype = passthrough_orig_dtypes.get(name) - if isinstance(orig_dtype, str): - out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() - out[name] = out_t - return out - - -# ----------------------------- -# DATA LOADING -# ----------------------------- - -def load_data_shard(file: Path) -> Tensor: - header_bytes = 256 * np.dtype(" None: - self.file_idx = (self.file_idx + 1) % len(self.files) - self.tokens = load_data_shard(self.files[self.file_idx]) - self.pos = 0 - - def take(self, n: int) -> Tensor: - chunks: list[Tensor] = [] - remaining = n - while remaining > 0: - avail = self.tokens.numel() - self.pos - if avail <= 0: - self._advance_file() - continue - k = min(remaining, avail) - chunks.append(self.tokens[self.pos : self.pos + k]) - self.pos += k - remaining -= k - return chunks[0] if len(chunks) == 1 else torch.cat(chunks) - - -class DistributedTokenLoader: - # Each call consumes a contiguous chunk from the shared token stream, then slices out - # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. - def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): - self.rank = rank - self.world_size = world_size - self.device = device - self.stream = TokenStream(pattern) - - def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: - local_tokens = global_tokens // (self.world_size * grad_accum_steps) - per_rank_span = local_tokens + 1 - chunk = self.stream.take(per_rank_span * self.world_size) - start = self.rank * per_rank_span - local = chunk[start : start + per_rank_span].to(dtype=torch.int64) - x = local[:-1].reshape(-1, seq_len) - y = local[1:].reshape(-1, seq_len) - return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) - -# ----------------------------- -# TRANSFORMER MODULES -# ----------------------------- - -class RMSNorm(nn.Module): - def __init__(self, eps: float | None = None): - super().__init__() - self.eps = eps - - def forward(self, x: Tensor) -> Tensor: - return F.rms_norm(x, (x.size(-1),), eps=self.eps) - - -class CastedLinear(nn.Linear): - # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. - def forward(self, x: Tensor) -> Tensor: - bias = self.bias.to(x.dtype) if self.bias is not None else None - return F.linear(x, self.weight.to(x.dtype), bias) - - -def restore_low_dim_params_to_fp32(module: nn.Module) -> None: - # Keep small/control parameters in fp32 even when the model body runs in bf16. - with torch.no_grad(): - for name, param in module.named_parameters(): - if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: - param.data = param.data.float() - - -class Rotary(nn.Module): - # Caches cos/sin tables per sequence length on the current device. - def __init__(self, dim: int, base: float = 10000.0): - super().__init__() - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self._seq_len_cached = 0 - self._cos_cached: Tensor | None = None - self._sin_cached: Tensor | None = None - - def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: - if ( - self._cos_cached is None - or self._sin_cached is None - or self._seq_len_cached != seq_len - or self._cos_cached.device != device - ): - t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) - freqs = torch.outer(t, self.inv_freq.to(device)) - self._cos_cached = freqs.cos()[None, None, :, :] - self._sin_cached = freqs.sin()[None, None, :, :] - self._seq_len_cached = seq_len - return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) - - -def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: - half = x.size(-1) // 2 - x1, x2 = x[..., :half], x[..., half:] - return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) - - -class CausalSelfAttention(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - rope_base: float, - qk_gain_init: float, - ): - super().__init__() - if dim % num_heads != 0: - raise ValueError("model_dim must be divisible by num_heads") - if num_heads % num_kv_heads != 0: - raise ValueError("num_heads must be divisible by num_kv_heads") - self.num_heads = num_heads - self.num_kv_heads = num_kv_heads - self.head_dim = dim // num_heads - if self.head_dim % 2 != 0: - raise ValueError("head_dim must be even for RoPE") - kv_dim = self.num_kv_heads * self.head_dim - self.c_q = CastedLinear(dim, dim, bias=False) - self.c_k = CastedLinear(dim, kv_dim, bias=False) - self.c_v = CastedLinear(dim, kv_dim, bias=False) - self.proj = CastedLinear(dim, dim, bias=False) - self.proj._zero_init = True - self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) - self.rotary = Rotary(self.head_dim, base=rope_base) - - def forward(self, x: Tensor) -> Tensor: - bsz, seqlen, dim = x.shape - q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) - k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) - v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) - q = F.rms_norm(q, (q.size(-1),)) - k = F.rms_norm(k, (k.size(-1),)) - cos, sin = self.rotary(seqlen, x.device, q.dtype) - q = apply_rotary_emb(q, cos, sin) - k = apply_rotary_emb(k, cos, sin) - q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] - y = F.scaled_dot_product_attention( - q, - k, - v, - attn_mask=None, - is_causal=True, - enable_gqa=(self.num_kv_heads != self.num_heads), - ) - y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) - return self.proj(y) - - -class MLP(nn.Module): - # relu^2 MLP from the original modded-nanogpt setup - def __init__(self, dim: int, mlp_mult: int): - super().__init__() - hidden = mlp_mult * dim - self.fc = CastedLinear(dim, hidden, bias=False) - self.proj = CastedLinear(hidden, dim, bias=False) - self.proj._zero_init = True - - def forward(self, x: Tensor) -> Tensor: - x = torch.relu(self.fc(x)) - return self.proj(x.square()) - - -class Block(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: int, - rope_base: float, - qk_gain_init: float, - ): - super().__init__() - self.attn_norm = RMSNorm() - self.mlp_norm = RMSNorm() - self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) - self.mlp = MLP(dim, mlp_mult) - self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) - - def forward(self, x: Tensor, x0: Tensor) -> Tensor: - mix = self.resid_mix.to(dtype=x.dtype) - x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 - attn_out = self.attn(self.attn_norm(x)) - x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out - x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) - return x - - -class GPT(nn.Module): - def __init__( - self, - vocab_size: int, - num_layers: int, - model_dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: int, - tie_embeddings: bool, - tied_embed_init_std: float, - logit_softcap: float, - rope_base: float, - qk_gain_init: float, - ): - super().__init__() - if logit_softcap <= 0.0: - raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") - self.tie_embeddings = tie_embeddings - self.tied_embed_init_std = tied_embed_init_std - self.logit_softcap = logit_softcap - self.tok_emb = nn.Embedding(vocab_size, model_dim) - self.num_encoder_layers = num_layers // 2 - self.num_decoder_layers = num_layers - self.num_encoder_layers - self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) - self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) - self.blocks = nn.ModuleList( - [ - Block( - model_dim, - num_heads, - num_kv_heads, - mlp_mult, - rope_base, - qk_gain_init, - ) - for i in range(num_layers) - ] - ) - self.final_norm = RMSNorm() - self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) - if self.lm_head is not None: - self.lm_head._zero_init = True - self._init_weights() - - def _init_weights(self) -> None: - if self.tie_embeddings: - nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) - for module in self.modules(): - if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): - nn.init.zeros_(module.weight) - - def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: - x = self.tok_emb(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - x0 = x - skips: list[Tensor] = [] - - # First half stores skips; second half reuses them in reverse order. - for i in range(self.num_encoder_layers): - x = self.blocks[i](x, x0) - skips.append(x) - for i in range(self.num_decoder_layers): - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - x = self.blocks[self.num_encoder_layers + i](x, x0) - - x = self.final_norm(x).reshape(-1, x.size(-1)) - targets = target_ids.reshape(-1) - if self.tie_embeddings: - logits_proj = F.linear(x, self.tok_emb.weight) - else: - if self.lm_head is None: - raise RuntimeError("lm_head is required when tie_embeddings=False") - logits_proj = self.lm_head(x) - logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - return F.cross_entropy(logits.float(), targets, reduction="mean") - - -# ----------------------------- -# TRAINING -# ----------------------------- - -def main() -> None: - global zeropower_via_newtonschulz5 - - code = Path(__file__).read_text(encoding="utf-8") - args = Hyperparameters() - zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) - - # ----------------------------- - # DISTRIBUTED + CUDA SETUP - # ----------------------------- - - distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ - rank = int(os.environ.get("RANK", "0")) - world_size = int(os.environ.get("WORLD_SIZE", "1")) - local_rank = int(os.environ.get("LOCAL_RANK", "0")) - if world_size <= 0: - raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") - if 8 % world_size != 0: - raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") - grad_accum_steps = 8 // world_size - grad_scale = 1.0 / grad_accum_steps - if not torch.cuda.is_available(): - raise RuntimeError("CUDA is required") - device = torch.device("cuda", local_rank) - torch.cuda.set_device(device) - if distributed: - dist.init_process_group(backend="nccl", device_id=device) - dist.barrier() - master_process = rank == 0 - - # Fast math knobs - torch.backends.cuda.matmul.allow_tf32 = True - torch.backends.cudnn.allow_tf32 = True - from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp - - enable_cudnn_sdp(False) - enable_flash_sdp(True) - enable_mem_efficient_sdp(False) - enable_math_sdp(False) - - logfile = None - if master_process: - os.makedirs("logs", exist_ok=True) - logfile = f"logs/{args.run_id}.txt" - print(logfile) - - def log0(msg: str, console: bool = True) -> None: - if not master_process: - return - if console: - print(msg) - if logfile is not None: - with open(logfile, "a", encoding="utf-8") as f: - print(msg, file=f) - - log0(code, console=False) - log0("=" * 100, console=False) - log0(f"Running Python {sys.version}", console=False) - log0(f"Running PyTorch {torch.__version__}", console=False) - log0( - subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, - console=False, - ) - log0("=" * 100, console=False) - - # ----------------------------- - # TOKENIZER + VALIDATION METRIC SETUP - # ----------------------------- - - random.seed(args.seed) - np.random.seed(args.seed) - torch.manual_seed(args.seed) - torch.cuda.manual_seed_all(args.seed) - - if not args.tokenizer_path.endswith(".model"): - raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") - sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) - if int(sp.vocab_size()) != args.vocab_size: - raise ValueError( - f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" - ) - dataset_dir = Path(args.data_path).resolve() - actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) - val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) - base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( - sp, args.vocab_size, device - ) - log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") - log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") - log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") - - # ----------------------------- - # MODEL + OPTIMIZER SETUP - # ----------------------------- - - base_model = GPT( - vocab_size=args.vocab_size, - num_layers=args.num_layers, - model_dim=args.model_dim, - num_heads=args.num_heads, - num_kv_heads=args.num_kv_heads, - mlp_mult=args.mlp_mult, - tie_embeddings=args.tie_embeddings, - tied_embed_init_std=args.tied_embed_init_std, - logit_softcap=args.logit_softcap, - rope_base=args.rope_base, - qk_gain_init=args.qk_gain_init, - ).to(device).bfloat16() - for module in base_model.modules(): - if isinstance(module, CastedLinear): - module.float() - restore_low_dim_params_to_fp32(base_model) - compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) - model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model - - # Optimizer split: - # - token embedding (Adam) uses EMBED_LR - # - untied lm_head (Adam) uses HEAD_LR - # - matrix params in transformer blocks use MATRIX_LR via Muon - # - vectors/scalars use SCALAR_LR via Adam - block_named_params = list(base_model.blocks.named_parameters()) - matrix_params = [ - p - for name, p in block_named_params - if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - scalar_params = [ - p - for name, p in block_named_params - if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - if base_model.skip_weights.numel() > 0: - scalar_params.append(base_model.skip_weights) - token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr - optimizer_tok = torch.optim.Adam( - [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizer_muon = Muon( - matrix_params, - lr=args.matrix_lr, - momentum=args.muon_momentum, - backend_steps=args.muon_backend_steps, - ) - for group in optimizer_muon.param_groups: - group["base_lr"] = args.matrix_lr - optimizer_scalar = torch.optim.Adam( - [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] - if base_model.lm_head is not None: - optimizer_head = torch.optim.Adam( - [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizers.insert(1, optimizer_head) - - n_params = sum(p.numel() for p in base_model.parameters()) - log0(f"model_params:{n_params}") - log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") - log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") - log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") - log0( - f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " - f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " - f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" - ) - log0( - f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " - f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " - f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" - ) - log0(f"seed:{args.seed}") - - # ----------------------------- - # DATA LOADER & MODEL WARMUP - # ----------------------------- - - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - def zero_grad_all() -> None: - for opt in optimizers: - opt.zero_grad(set_to_none=True) - - max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None - - def lr_mul(step: int, elapsed_ms: float) -> float: - if args.warmdown_iters <= 0: - return 1.0 - if max_wallclock_ms is None: - warmdown_start = max(args.iterations - args.warmdown_iters, 0) - return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 - step_ms = elapsed_ms / max(step, 1) - warmdown_ms = args.warmdown_iters * step_ms - remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) - return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 - - # Warmup primes the compiled forward/backward/optimizer paths, then we restore the - # initial weights/optimizer state so measured training starts from the true init. - if args.warmup_steps > 0: - initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} - initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] - model.train() - for warmup_step in range(args.warmup_steps): - zero_grad_all() - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - warmup_loss = model(x, y) - (warmup_loss * grad_scale).backward() - for opt in optimizers: - opt.step() - zero_grad_all() - if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: - log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") - base_model.load_state_dict(initial_model_state, strict=True) - for opt, state in zip(optimizers, initial_optimizer_states, strict=True): - opt.load_state_dict(state) - zero_grad_all() - if distributed: - model.require_backward_grad_sync = True - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - # ----------------------------- - # MAIN TRAINING LOOP - # ----------------------------- - - training_time_ms = 0.0 - stop_after_step: int | None = None - torch.cuda.synchronize() - t0 = time.perf_counter() - - step = 0 - while True: - last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) - - should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) - if should_validate: - torch.cuda.synchronize() - training_time_ms += 1000.0 * (time.perf_counter() - t0) - val_loss, val_bpb = eval_val( - args, - model, - rank, - world_size, - device, - grad_accum_steps, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, - ) - log0( - f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " - f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" - ) - torch.cuda.synchronize() - t0 = time.perf_counter() - - if last_step: - if stop_after_step is not None and step < args.iterations: - log0( - f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " - f"step:{step}/{args.iterations}" - ) - break - - elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - scale = lr_mul(step, elapsed_ms) - zero_grad_all() - train_loss = torch.zeros((), device=device) - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - loss = model(x, y) - train_loss += loss.detach() - (loss * grad_scale).backward() - train_loss /= grad_accum_steps - - frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 - muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum - for group in optimizer_muon.param_groups: - group["momentum"] = muon_momentum - - for opt in optimizers: - for group in opt.param_groups: - group["lr"] = group["base_lr"] * scale - - if args.grad_clip_norm > 0: - torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) - for opt in optimizers: - opt.step() - zero_grad_all() - - step += 1 - approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - should_log_train = ( - args.train_log_every > 0 - and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) - ) - if should_log_train: - log0( - f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " - f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" - ) - - # Needed to sync whether we've reached the wallclock cap. - reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms - if distributed and max_wallclock_ms is not None: - reached_cap_tensor = torch.tensor(int(reached_cap), device=device) - dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) - reached_cap = bool(reached_cap_tensor.item()) - if stop_after_step is None and reached_cap: - stop_after_step = step - - log0( - f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " - f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" - ) - - # ----------------------------- - # SERIALIZATION + ROUNDTRIP VALIDATION - # ----------------------------- - # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce - # the compressed int8+zlib artifact and validate the round-tripped weights. - - if master_process: - torch.save(base_model.state_dict(), "final_model.pt") - model_bytes = os.path.getsize("final_model.pt") - code_bytes = len(code.encode("utf-8")) - log0(f"Serialized model: {model_bytes} bytes") - log0(f"Code size: {code_bytes} bytes") - log0(f"Total submission size: {model_bytes + code_bytes} bytes") - - quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) - quant_buf = io.BytesIO() - torch.save(quant_obj, quant_buf) - quant_raw = quant_buf.getvalue() - quant_blob = zlib.compress(quant_raw, level=9) - quant_raw_bytes = len(quant_raw) - if master_process: - with open("final_model.int8.ptz", "wb") as f: - f.write(quant_blob) - quant_file_bytes = os.path.getsize("final_model.int8.ptz") - code_bytes = len(code.encode("utf-8")) - ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) - log0( - f"Serialized model int8+zlib: {quant_file_bytes} bytes " - f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" - ) - log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") - - if distributed: - dist.barrier() - with open("final_model.int8.ptz", "rb") as f: - quant_blob_disk = f.read() - quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") - base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) - torch.cuda.synchronize() - t_qeval = time.perf_counter() - q_val_loss, q_val_bpb = eval_val( - args, - model, - rank, - world_size, - device, - grad_accum_steps, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, - ) - torch.cuda.synchronize() - log0( - f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " - f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" - ) - log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") - - if distributed: - dist.destroy_process_group() - - -if __name__ == "__main__": - main() - -==================================================================================================== -Running Python 3.12.3 (main, Aug 14 2025, 17:47:21) [GCC 13.3.0] -Running PyTorch 2.8.0+cu128 -Thu Mar 19 04:52:27 2026 -+-----------------------------------------------------------------------------------------+ -| NVIDIA-SMI 570.211.01 Driver Version: 570.211.01 CUDA Version: 12.8 | -|-----------------------------------------+------------------------+----------------------+ -| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | -| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | -| | | MIG M. | -|=========================================+========================+======================| -| 0 NVIDIA H100 80GB HBM3 On | 00000000:18:00.0 Off | 0 | -| N/A 24C P0 115W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 1 NVIDIA H100 80GB HBM3 On | 00000000:2A:00.0 Off | 0 | -| N/A 23C P0 113W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 2 NVIDIA H100 80GB HBM3 On | 00000000:3A:00.0 Off | 0 | -| N/A 25C P0 120W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 3 NVIDIA H100 80GB HBM3 On | 00000000:5D:00.0 Off | 0 | -| N/A 23C P0 116W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 4 NVIDIA H100 80GB HBM3 On | 00000000:9A:00.0 Off | 0 | -| N/A 22C P0 113W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 5 NVIDIA H100 80GB HBM3 On | 00000000:AB:00.0 Off | 0 | -| N/A 24C P0 117W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 6 NVIDIA H100 80GB HBM3 On | 00000000:BA:00.0 Off | 0 | -| N/A 24C P0 116W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 7 NVIDIA H100 80GB HBM3 On | 00000000:DB:00.0 Off | 0 | -| N/A 22C P0 117W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ - -+-----------------------------------------------------------------------------------------+ -| Processes: | -| GPU GI CI PID Type Process name GPU Memory | -| ID ID Usage | -|=========================================================================================| -| 0 N/A N/A 201417 C /usr/local/bin/python 1510MiB | -| 1 N/A N/A 201418 C /usr/local/bin/python 1510MiB | -| 2 N/A N/A 201419 C /usr/local/bin/python 1510MiB | -| 3 N/A N/A 201420 C /usr/local/bin/python 1510MiB | -| 4 N/A N/A 201421 C /usr/local/bin/python 1510MiB | -| 5 N/A N/A 201422 C /usr/local/bin/python 1510MiB | -| 6 N/A N/A 201423 C /usr/local/bin/python 1510MiB | -| 7 N/A N/A 201424 C /usr/local/bin/python 1510MiB | -+-----------------------------------------------------------------------------------------+ - -==================================================================================================== -val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model -train_loader:dataset:fineweb10B_sp1024 train_shards:80 -val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 -model_params:17059912 -world_size:8 grad_accum_steps:1 -sdp_backends:cudnn=False flash=True mem_efficient=False math=False -attention_mode:gqa num_heads:8 num_kv_heads:4 -tie_embeddings:True embed_lr:0.03 head_lr:0.0 matrix_lr:0.02 scalar_lr:0.02 -train_batch_tokens:393216 train_seq_len:4096 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 -seed:1338 -warmup_step:1/20 -warmup_step:2/20 -warmup_step:3/20 -warmup_step:4/20 -warmup_step:5/20 -warmup_step:6/20 -warmup_step:7/20 -warmup_step:8/20 -warmup_step:9/20 -warmup_step:10/20 -warmup_step:11/20 -warmup_step:12/20 -warmup_step:13/20 -warmup_step:14/20 -warmup_step:15/20 -warmup_step:16/20 -warmup_step:17/20 -warmup_step:18/20 -warmup_step:19/20 -warmup_step:20/20 -step:0/20000 val_loss:6.9373 val_bpb:4.1086 train_time:0ms step_avg:0.01ms -step:1/20000 train_loss:6.9374 train_time:37ms step_avg:36.55ms -step:2/20000 train_loss:12.1524 train_time:93ms step_avg:46.39ms -step:3/20000 train_loss:7.4424 train_time:174ms step_avg:57.84ms -step:4/20000 train_loss:6.3677 train_time:241ms step_avg:60.18ms -step:5/20000 train_loss:6.8789 train_time:304ms step_avg:60.75ms -step:6/20000 train_loss:6.9780 train_time:380ms step_avg:63.36ms -step:7/20000 train_loss:6.7945 train_time:446ms step_avg:63.64ms -step:8/20000 train_loss:6.6421 train_time:514ms step_avg:64.20ms -step:9/20000 train_loss:6.2938 train_time:588ms step_avg:65.31ms -step:10/20000 train_loss:6.1579 train_time:657ms step_avg:65.71ms -step:200/20000 train_loss:2.7689 train_time:10508ms step_avg:52.54ms -step:400/20000 train_loss:2.4187 train_time:24147ms step_avg:60.37ms -step:600/20000 train_loss:2.2876 train_time:37544ms step_avg:62.57ms -step:800/20000 train_loss:2.3653 train_time:50654ms step_avg:63.32ms -step:1000/20000 train_loss:2.3382 train_time:61075ms step_avg:61.08ms -step:1000/20000 val_loss:2.3005 val_bpb:1.3625 train_time:61107ms step_avg:61.11ms -step:1200/20000 train_loss:2.3484 train_time:74640ms step_avg:62.20ms -step:1400/20000 train_loss:2.3044 train_time:87762ms step_avg:62.69ms -step:1600/20000 train_loss:2.2646 train_time:101043ms step_avg:63.15ms -step:1800/20000 train_loss:2.0415 train_time:114134ms step_avg:63.41ms -step:2000/20000 train_loss:2.3040 train_time:124556ms step_avg:62.28ms -step:2000/20000 val_loss:2.2097 val_bpb:1.3087 train_time:124590ms step_avg:62.29ms -step:2200/20000 train_loss:2.0588 train_time:138202ms step_avg:62.82ms -step:2400/20000 train_loss:2.2112 train_time:151415ms step_avg:63.09ms -step:2600/20000 train_loss:2.3260 train_time:164733ms step_avg:63.36ms -step:2800/20000 train_loss:2.3926 train_time:178719ms step_avg:63.83ms -step:3000/20000 train_loss:2.1690 train_time:189109ms step_avg:63.04ms -step:3000/20000 val_loss:2.1535 val_bpb:1.2754 train_time:189143ms step_avg:63.05ms -step:3200/20000 train_loss:2.1964 train_time:202418ms step_avg:63.26ms -step:3400/20000 train_loss:2.2008 train_time:217452ms step_avg:63.96ms -step:3600/20000 train_loss:2.0782 train_time:231303ms step_avg:64.25ms -step:3800/20000 train_loss:2.0957 train_time:241700ms step_avg:63.61ms -step:4000/20000 train_loss:1.9226 train_time:255880ms step_avg:63.97ms -step:4000/20000 val_loss:2.1207 val_bpb:1.2560 train_time:255915ms step_avg:63.98ms -step:4200/20000 train_loss:2.0065 train_time:269315ms step_avg:64.12ms -step:4400/20000 train_loss:2.1154 train_time:283214ms step_avg:64.37ms -step:4600/20000 train_loss:2.0427 train_time:297518ms step_avg:64.68ms -step:4800/20000 train_loss:2.0627 train_time:307934ms step_avg:64.15ms -step:5000/20000 train_loss:2.1377 train_time:323094ms step_avg:64.62ms -step:5000/20000 val_loss:2.1016 val_bpb:1.2447 train_time:323128ms step_avg:64.63ms -step:5200/20000 train_loss:2.2018 train_time:337037ms step_avg:64.81ms -step:5400/20000 train_loss:1.8994 train_time:352259ms step_avg:65.23ms -step:5600/20000 train_loss:2.2430 train_time:367523ms step_avg:65.63ms -step:5800/20000 train_loss:2.1658 train_time:377940ms step_avg:65.16ms -step:6000/20000 train_loss:2.0026 train_time:392924ms step_avg:65.49ms -step:6000/20000 val_loss:2.0875 val_bpb:1.2363 train_time:392954ms step_avg:65.49ms -step:6200/20000 train_loss:2.1532 train_time:406879ms step_avg:65.63ms -step:6400/20000 train_loss:2.0957 train_time:420885ms step_avg:65.76ms -step:6600/20000 train_loss:2.0872 train_time:431290ms step_avg:65.35ms -step:6800/20000 train_loss:2.1602 train_time:445363ms step_avg:65.49ms -step:7000/20000 train_loss:2.1407 train_time:459048ms step_avg:65.58ms -step:7000/20000 val_loss:2.0648 val_bpb:1.2229 train_time:459081ms step_avg:65.58ms -step:7200/20000 train_loss:2.0602 train_time:473721ms step_avg:65.79ms -step:7400/20000 train_loss:2.0651 train_time:488975ms step_avg:66.08ms -step:7600/20000 train_loss:2.0228 train_time:499329ms step_avg:65.70ms -step:7800/20000 train_loss:2.0542 train_time:514168ms step_avg:65.92ms -step:8000/20000 train_loss:2.1130 train_time:528649ms step_avg:66.08ms -step:8000/20000 val_loss:2.0394 val_bpb:1.2079 train_time:528684ms step_avg:66.09ms -step:8200/20000 train_loss:2.1092 train_time:542310ms step_avg:66.14ms -step:8400/20000 train_loss:1.9927 train_time:556504ms step_avg:66.25ms -step:8600/20000 train_loss:2.0483 train_time:566945ms step_avg:65.92ms -step:8800/20000 train_loss:1.9529 train_time:580744ms step_avg:65.99ms -step:9000/20000 train_loss:2.0696 train_time:595074ms step_avg:66.12ms -step:9000/20000 val_loss:2.0198 val_bpb:1.1962 train_time:595109ms step_avg:66.12ms -step:9092/20000 val_loss:2.0193 val_bpb:1.1959 train_time:599903ms step_avg:65.98ms -stopping_early: wallclock_cap train_time:599903ms step:9092/20000 -peak memory allocated: 7748 MiB reserved: 8066 MiB -Serialized model: 67224983 bytes -Code size: 47759 bytes -Total submission size: 67272742 bytes -Serialized model int8+zlib: 15835395 bytes (payload:17178912 raw_torch:17224025 payload_ratio:3.91x) -Total submission size int8+zlib: 15883154 bytes -final_int8_zlib_roundtrip val_loss:2.0252 val_bpb:1.1995 eval_time:2235ms -final_int8_zlib_roundtrip_exact val_loss:2.02522281 val_bpb:1.19945102 diff --git a/records/track_10min_16mb/2026-03-19_TrainingOptSeq4096/train_seed1339.log b/records/track_10min_16mb/2026-03-19_TrainingOptSeq4096/train_seed1339.log deleted file mode 100644 index 7223868bb4..0000000000 --- a/records/track_10min_16mb/2026-03-19_TrainingOptSeq4096/train_seed1339.log +++ /dev/null @@ -1,1289 +0,0 @@ -""" -The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. - -Hard stop: `train_gpt.py` and `train_gpt_mlx.py` must never be longer than 1500 lines. -""" - -from __future__ import annotations - -import copy -import glob -import io -import math -import os -import random -import subprocess -import sys -import time -import uuid -import zlib -from pathlib import Path - -import numpy as np -import sentencepiece as spm -import torch -import torch.distributed as dist -import torch.nn.functional as F -from torch import Tensor, nn -from torch.nn.parallel import DistributedDataParallel as DDP - -# ----------------------------- -# HYPERPARAMETERS -# ----------------------------- -# Default Training Opt Seq4096 v1 run: -# - 9 transformer blocks at width 512 -# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion -# - vocab size 1024, sequence length 4096, tied embeddings -# - 393,216 train tokens per step (3/4 batch) for 20,000 iterations with a ~10 minute cap -# - tuned Muon optimizer: momentum 0.99, lower LR (0.020/0.020/0.030), warmdown 3000 - -class Hyperparameters: - # Data paths are shard globs produced by the existing preprocessing pipeline. - data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") - train_files = os.path.join(data_path, "fineweb_train_*.bin") - val_files = os.path.join(data_path, "fineweb_val_*.bin") - tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") - run_id = os.environ.get("RUN_ID", "training_opt_seq4096_v1") - seed = int(os.environ.get("SEED", 1337)) - - # Validation cadence and batch size. Validation always uses the full fineweb_val split. - val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) - val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) - train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) - - # Training length. - iterations = int(os.environ.get("ITERATIONS", 20000)) - warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3000)) - warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) - train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 393_216)) - train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 4096)) - max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) - qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) - - # Model shape. - vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) - num_layers = int(os.environ.get("NUM_LAYERS", 9)) - num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) - model_dim = int(os.environ.get("MODEL_DIM", 512)) - num_heads = int(os.environ.get("NUM_HEADS", 8)) - mlp_mult = int(os.environ.get("MLP_MULT", 2)) - tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) - rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) - logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) - - # Optimizer hyperparameters. - embed_lr = float(os.environ.get("EMBED_LR", 0.6)) - head_lr = float(os.environ.get("HEAD_LR", 0.008)) - tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.030)) - tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) - matrix_lr = float(os.environ.get("MATRIX_LR", 0.020)) - scalar_lr = float(os.environ.get("SCALAR_LR", 0.020)) - muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) - muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) - muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) - muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) - beta1 = float(os.environ.get("BETA1", 0.9)) - beta2 = float(os.environ.get("BETA2", 0.95)) - adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) - grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) - -# ----------------------------- -# MUON OPTIMIZER -# ----------------------------- -# -# As borrowed from modded-nanogpt -# Background on Muon: https://kellerjordan.github.io/posts/muon/ - -def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: - # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. - # Muon uses this to normalize matrix-shaped gradients before applying them. - a, b, c = (3.4445, -4.7750, 2.0315) - X = G.bfloat16() - X /= X.norm() + eps - transposed = G.size(0) > G.size(1) - if transposed: - X = X.T - for _ in range(steps): - A = X @ X.T - B = b * A + c * A @ A - X = a * X + B @ X - return X.T if transposed else X - - -class Muon(torch.optim.Optimizer): - def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): - super().__init__( - params, - dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), - ) - - @torch.no_grad() - def step(self, closure=None): - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - distributed = dist.is_available() and dist.is_initialized() - world_size = dist.get_world_size() if distributed else 1 - rank = dist.get_rank() if distributed else 0 - - for group in self.param_groups: - params = group["params"] - if not params: - continue - lr = group["lr"] - momentum = group["momentum"] - backend_steps = group["backend_steps"] - nesterov = group["nesterov"] - - total_params = sum(int(p.numel()) for p in params) - updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) - - curr = 0 - for i, p in enumerate(params): - if i % world_size == rank and p.grad is not None: - g = p.grad - state = self.state[p] - if "momentum_buffer" not in state: - state["momentum_buffer"] = torch.zeros_like(g) - buf = state["momentum_buffer"] - buf.mul_(momentum).add_(g) - if nesterov: - g = g.add(buf, alpha=momentum) - g = zeropower_via_newtonschulz5(g, steps=backend_steps) - # Scale correction from Muon reference implementations. - g *= max(1, g.size(0) / g.size(1)) ** 0.5 - updates_flat[curr : curr + p.numel()] = g.reshape(-1) - curr += p.numel() - - if distributed: - dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) - - curr = 0 - for p in params: - g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) - p.add_(g, alpha=-lr) - curr += p.numel() - - return loss - - -# ----------------------------- -# TOKENIZER-AGNOSTIC EVALUATION SETUP -# ----------------------------- -# -# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. -# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. -# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. -# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. - -def build_sentencepiece_luts( - sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device -) -> tuple[Tensor, Tensor, Tensor]: - sp_vocab_size = int(sp.vocab_size()) - table_size = max(sp_vocab_size, vocab_size) - base_bytes_np = np.zeros((table_size,), dtype=np.int16) - has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) - is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) - for token_id in range(sp_vocab_size): - if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): - continue - is_boundary_token_np[token_id] = False - if sp.is_byte(token_id): - base_bytes_np[token_id] = 1 - continue - piece = sp.id_to_piece(token_id) - if piece.startswith("▁"): - has_leading_space_np[token_id] = True - piece = piece[1:] - base_bytes_np[token_id] = len(piece.encode("utf-8")) - return ( - torch.tensor(base_bytes_np, dtype=torch.int16, device=device), - torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), - torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), - ) - - -def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: - files = [Path(p) for p in sorted(glob.glob(pattern))] - if not files: - raise FileNotFoundError(f"No files found for pattern: {pattern}") - # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. - tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() - usable = ((tokens.numel() - 1) // seq_len) * seq_len - if usable <= 0: - raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") - return tokens[: usable + 1] - - -def eval_val( - args: Hyperparameters, - model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - grad_accum_steps: int, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, -) -> tuple[float, float]: - # Validation computes two metrics: - # - val_loss: token cross-entropy (natural log) - # - val_bpb: tokenizer-agnostic compression metric used by the challenge - local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) - if local_batch_tokens < args.train_seq_len: - raise ValueError( - "VAL_BATCH_SIZE must provide at least one sequence per rank; " - f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " - f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" - ) - local_batch_seqs = local_batch_tokens // args.train_seq_len - total_seqs = (val_tokens.numel() - 1) // args.train_seq_len - seq_start = (total_seqs * rank) // world_size - seq_end = (total_seqs * (rank + 1)) // world_size - val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) - val_token_count = torch.zeros((), device=device, dtype=torch.float64) - val_byte_count = torch.zeros((), device=device, dtype=torch.float64) - - model.eval() - with torch.inference_mode(): - for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): - batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) - raw_start = batch_seq_start * args.train_seq_len - raw_end = batch_seq_end * args.train_seq_len + 1 - local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) - x = local[:-1].reshape(-1, args.train_seq_len) - y = local[1:].reshape(-1, args.train_seq_len) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - batch_loss = model(x, y).detach() - batch_token_count = float(y.numel()) - val_loss_sum += batch_loss.to(torch.float64) * batch_token_count - val_token_count += batch_token_count - prev_ids = x.reshape(-1) - tgt_ids = y.reshape(-1) - token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) - token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) - val_byte_count += token_bytes.to(torch.float64).sum() - - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) - - val_loss = val_loss_sum / val_token_count - bits_per_token = val_loss.item() / math.log(2.0) - tokens_per_byte = val_token_count.item() / val_byte_count.item() - model.train() - return float(val_loss.item()), float(bits_per_token * tokens_per_byte) - -# ----------------------------- -# POST-TRAINING QUANTIZATION -# ----------------------------- -# -# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. -# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. -# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. - -CONTROL_TENSOR_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "CONTROL_TENSOR_NAME_PATTERNS", - "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", - ",".join(CONTROL_TENSOR_NAME_PATTERNS), - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 -INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 -INT8_PER_ROW_SCALE_DTYPE = torch.float16 -INT8_CLIP_PERCENTILE = 99.99984 -INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 - -def tensor_nbytes(t: Tensor) -> int: - return int(t.numel()) * int(t.element_size()) - -def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: - if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): - return t.float().contiguous() - if t.dtype in {torch.float32, torch.bfloat16}: - passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") - return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() - return t - -def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: - t32 = t.float() - if t32.ndim == 2: - # Matrices get one scale per row, which usually tracks output-channel - # ranges much better than a single tensor-wide scale. - clip_abs = ( - torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) - if t32.numel() - else torch.empty((t32.shape[0],), dtype=torch.float32) - ) - clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) - scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) - q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() - return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() - - # Vectors / scalars use a simpler per-tensor scale. - clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 - scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) - q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() - return q, scale - -def quantize_state_dict_int8(state_dict: dict[str, Tensor]): - # Single supported clean-script export format: - # - per-row int8 for 2D float tensors - # - per-tensor int8 for other float tensors - # - exact passthrough for non-floats - # - passthrough for small float tensors, stored as fp16 to save bytes - quantized: dict[str, Tensor] = {} - scales: dict[str, Tensor] = {} - dtypes: dict[str, str] = {} - passthrough: dict[str, Tensor] = {} - passthrough_orig_dtypes: dict[str, str] = {} - qmeta: dict[str, dict[str, object]] = {} - stats = dict.fromkeys( - ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), - 0, - ) - - for name, tensor in state_dict.items(): - t = tensor.detach().to("cpu").contiguous() - stats["param_count"] += int(t.numel()) - stats["num_tensors"] += 1 - stats["baseline_tensor_bytes"] += tensor_nbytes(t) - - if not t.is_floating_point(): - stats["num_nonfloat_tensors"] += 1 - passthrough[name] = t - stats["int8_payload_bytes"] += tensor_nbytes(t) - continue - - # Small float tensors are cheap enough to keep directly. We still downcast - # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. - if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: - kept = keep_float_tensor(name, t, passthrough_orig_dtypes) - passthrough[name] = kept - stats["int8_payload_bytes"] += tensor_nbytes(kept) - continue - - stats["num_float_tensors"] += 1 - q, s = quantize_float_tensor(t) - if s.ndim > 0: - qmeta[name] = {"scheme": "per_row", "axis": 0} - quantized[name] = q - scales[name] = s - dtypes[name] = str(t.dtype).removeprefix("torch.") - stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) - - obj: dict[str, object] = { - "__quant_format__": "int8_clean_per_row_v1", - "quantized": quantized, - "scales": scales, - "dtypes": dtypes, - "passthrough": passthrough, - } - if qmeta: - obj["qmeta"] = qmeta - if passthrough_orig_dtypes: - obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes - return obj, stats - -def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: - out: dict[str, Tensor] = {} - qmeta = obj.get("qmeta", {}) - passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) - for name, q in obj["quantized"].items(): - dtype = getattr(torch, obj["dtypes"][name]) - s = obj["scales"][name] - if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: - s = s.to(dtype=torch.float32) - # Broadcast the saved row scale back across trailing dimensions. - out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() - else: - scale = float(s.item()) - out[name] = (q.float() * scale).to(dtype=dtype).contiguous() - for name, t in obj["passthrough"].items(): - # Restore small tensors, undoing the temporary fp16 storage cast if needed. - out_t = t.detach().to("cpu").contiguous() - orig_dtype = passthrough_orig_dtypes.get(name) - if isinstance(orig_dtype, str): - out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() - out[name] = out_t - return out - - -# ----------------------------- -# DATA LOADING -# ----------------------------- - -def load_data_shard(file: Path) -> Tensor: - header_bytes = 256 * np.dtype(" None: - self.file_idx = (self.file_idx + 1) % len(self.files) - self.tokens = load_data_shard(self.files[self.file_idx]) - self.pos = 0 - - def take(self, n: int) -> Tensor: - chunks: list[Tensor] = [] - remaining = n - while remaining > 0: - avail = self.tokens.numel() - self.pos - if avail <= 0: - self._advance_file() - continue - k = min(remaining, avail) - chunks.append(self.tokens[self.pos : self.pos + k]) - self.pos += k - remaining -= k - return chunks[0] if len(chunks) == 1 else torch.cat(chunks) - - -class DistributedTokenLoader: - # Each call consumes a contiguous chunk from the shared token stream, then slices out - # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. - def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): - self.rank = rank - self.world_size = world_size - self.device = device - self.stream = TokenStream(pattern) - - def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: - local_tokens = global_tokens // (self.world_size * grad_accum_steps) - per_rank_span = local_tokens + 1 - chunk = self.stream.take(per_rank_span * self.world_size) - start = self.rank * per_rank_span - local = chunk[start : start + per_rank_span].to(dtype=torch.int64) - x = local[:-1].reshape(-1, seq_len) - y = local[1:].reshape(-1, seq_len) - return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) - -# ----------------------------- -# TRANSFORMER MODULES -# ----------------------------- - -class RMSNorm(nn.Module): - def __init__(self, eps: float | None = None): - super().__init__() - self.eps = eps - - def forward(self, x: Tensor) -> Tensor: - return F.rms_norm(x, (x.size(-1),), eps=self.eps) - - -class CastedLinear(nn.Linear): - # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. - def forward(self, x: Tensor) -> Tensor: - bias = self.bias.to(x.dtype) if self.bias is not None else None - return F.linear(x, self.weight.to(x.dtype), bias) - - -def restore_low_dim_params_to_fp32(module: nn.Module) -> None: - # Keep small/control parameters in fp32 even when the model body runs in bf16. - with torch.no_grad(): - for name, param in module.named_parameters(): - if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: - param.data = param.data.float() - - -class Rotary(nn.Module): - # Caches cos/sin tables per sequence length on the current device. - def __init__(self, dim: int, base: float = 10000.0): - super().__init__() - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self._seq_len_cached = 0 - self._cos_cached: Tensor | None = None - self._sin_cached: Tensor | None = None - - def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: - if ( - self._cos_cached is None - or self._sin_cached is None - or self._seq_len_cached != seq_len - or self._cos_cached.device != device - ): - t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) - freqs = torch.outer(t, self.inv_freq.to(device)) - self._cos_cached = freqs.cos()[None, None, :, :] - self._sin_cached = freqs.sin()[None, None, :, :] - self._seq_len_cached = seq_len - return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) - - -def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: - half = x.size(-1) // 2 - x1, x2 = x[..., :half], x[..., half:] - return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) - - -class CausalSelfAttention(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - rope_base: float, - qk_gain_init: float, - ): - super().__init__() - if dim % num_heads != 0: - raise ValueError("model_dim must be divisible by num_heads") - if num_heads % num_kv_heads != 0: - raise ValueError("num_heads must be divisible by num_kv_heads") - self.num_heads = num_heads - self.num_kv_heads = num_kv_heads - self.head_dim = dim // num_heads - if self.head_dim % 2 != 0: - raise ValueError("head_dim must be even for RoPE") - kv_dim = self.num_kv_heads * self.head_dim - self.c_q = CastedLinear(dim, dim, bias=False) - self.c_k = CastedLinear(dim, kv_dim, bias=False) - self.c_v = CastedLinear(dim, kv_dim, bias=False) - self.proj = CastedLinear(dim, dim, bias=False) - self.proj._zero_init = True - self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) - self.rotary = Rotary(self.head_dim, base=rope_base) - - def forward(self, x: Tensor) -> Tensor: - bsz, seqlen, dim = x.shape - q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) - k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) - v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) - q = F.rms_norm(q, (q.size(-1),)) - k = F.rms_norm(k, (k.size(-1),)) - cos, sin = self.rotary(seqlen, x.device, q.dtype) - q = apply_rotary_emb(q, cos, sin) - k = apply_rotary_emb(k, cos, sin) - q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] - y = F.scaled_dot_product_attention( - q, - k, - v, - attn_mask=None, - is_causal=True, - enable_gqa=(self.num_kv_heads != self.num_heads), - ) - y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) - return self.proj(y) - - -class MLP(nn.Module): - # relu^2 MLP from the original modded-nanogpt setup - def __init__(self, dim: int, mlp_mult: int): - super().__init__() - hidden = mlp_mult * dim - self.fc = CastedLinear(dim, hidden, bias=False) - self.proj = CastedLinear(hidden, dim, bias=False) - self.proj._zero_init = True - - def forward(self, x: Tensor) -> Tensor: - x = torch.relu(self.fc(x)) - return self.proj(x.square()) - - -class Block(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: int, - rope_base: float, - qk_gain_init: float, - ): - super().__init__() - self.attn_norm = RMSNorm() - self.mlp_norm = RMSNorm() - self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) - self.mlp = MLP(dim, mlp_mult) - self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) - - def forward(self, x: Tensor, x0: Tensor) -> Tensor: - mix = self.resid_mix.to(dtype=x.dtype) - x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 - attn_out = self.attn(self.attn_norm(x)) - x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out - x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) - return x - - -class GPT(nn.Module): - def __init__( - self, - vocab_size: int, - num_layers: int, - model_dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: int, - tie_embeddings: bool, - tied_embed_init_std: float, - logit_softcap: float, - rope_base: float, - qk_gain_init: float, - ): - super().__init__() - if logit_softcap <= 0.0: - raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") - self.tie_embeddings = tie_embeddings - self.tied_embed_init_std = tied_embed_init_std - self.logit_softcap = logit_softcap - self.tok_emb = nn.Embedding(vocab_size, model_dim) - self.num_encoder_layers = num_layers // 2 - self.num_decoder_layers = num_layers - self.num_encoder_layers - self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) - self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) - self.blocks = nn.ModuleList( - [ - Block( - model_dim, - num_heads, - num_kv_heads, - mlp_mult, - rope_base, - qk_gain_init, - ) - for i in range(num_layers) - ] - ) - self.final_norm = RMSNorm() - self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) - if self.lm_head is not None: - self.lm_head._zero_init = True - self._init_weights() - - def _init_weights(self) -> None: - if self.tie_embeddings: - nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) - for module in self.modules(): - if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): - nn.init.zeros_(module.weight) - - def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: - x = self.tok_emb(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - x0 = x - skips: list[Tensor] = [] - - # First half stores skips; second half reuses them in reverse order. - for i in range(self.num_encoder_layers): - x = self.blocks[i](x, x0) - skips.append(x) - for i in range(self.num_decoder_layers): - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - x = self.blocks[self.num_encoder_layers + i](x, x0) - - x = self.final_norm(x).reshape(-1, x.size(-1)) - targets = target_ids.reshape(-1) - if self.tie_embeddings: - logits_proj = F.linear(x, self.tok_emb.weight) - else: - if self.lm_head is None: - raise RuntimeError("lm_head is required when tie_embeddings=False") - logits_proj = self.lm_head(x) - logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - return F.cross_entropy(logits.float(), targets, reduction="mean") - - -# ----------------------------- -# TRAINING -# ----------------------------- - -def main() -> None: - global zeropower_via_newtonschulz5 - - code = Path(__file__).read_text(encoding="utf-8") - args = Hyperparameters() - zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) - - # ----------------------------- - # DISTRIBUTED + CUDA SETUP - # ----------------------------- - - distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ - rank = int(os.environ.get("RANK", "0")) - world_size = int(os.environ.get("WORLD_SIZE", "1")) - local_rank = int(os.environ.get("LOCAL_RANK", "0")) - if world_size <= 0: - raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") - if 8 % world_size != 0: - raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") - grad_accum_steps = 8 // world_size - grad_scale = 1.0 / grad_accum_steps - if not torch.cuda.is_available(): - raise RuntimeError("CUDA is required") - device = torch.device("cuda", local_rank) - torch.cuda.set_device(device) - if distributed: - dist.init_process_group(backend="nccl", device_id=device) - dist.barrier() - master_process = rank == 0 - - # Fast math knobs - torch.backends.cuda.matmul.allow_tf32 = True - torch.backends.cudnn.allow_tf32 = True - from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp - - enable_cudnn_sdp(False) - enable_flash_sdp(True) - enable_mem_efficient_sdp(False) - enable_math_sdp(False) - - logfile = None - if master_process: - os.makedirs("logs", exist_ok=True) - logfile = f"logs/{args.run_id}.txt" - print(logfile) - - def log0(msg: str, console: bool = True) -> None: - if not master_process: - return - if console: - print(msg) - if logfile is not None: - with open(logfile, "a", encoding="utf-8") as f: - print(msg, file=f) - - log0(code, console=False) - log0("=" * 100, console=False) - log0(f"Running Python {sys.version}", console=False) - log0(f"Running PyTorch {torch.__version__}", console=False) - log0( - subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, - console=False, - ) - log0("=" * 100, console=False) - - # ----------------------------- - # TOKENIZER + VALIDATION METRIC SETUP - # ----------------------------- - - random.seed(args.seed) - np.random.seed(args.seed) - torch.manual_seed(args.seed) - torch.cuda.manual_seed_all(args.seed) - - if not args.tokenizer_path.endswith(".model"): - raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") - sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) - if int(sp.vocab_size()) != args.vocab_size: - raise ValueError( - f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" - ) - dataset_dir = Path(args.data_path).resolve() - actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) - val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) - base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( - sp, args.vocab_size, device - ) - log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") - log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") - log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") - - # ----------------------------- - # MODEL + OPTIMIZER SETUP - # ----------------------------- - - base_model = GPT( - vocab_size=args.vocab_size, - num_layers=args.num_layers, - model_dim=args.model_dim, - num_heads=args.num_heads, - num_kv_heads=args.num_kv_heads, - mlp_mult=args.mlp_mult, - tie_embeddings=args.tie_embeddings, - tied_embed_init_std=args.tied_embed_init_std, - logit_softcap=args.logit_softcap, - rope_base=args.rope_base, - qk_gain_init=args.qk_gain_init, - ).to(device).bfloat16() - for module in base_model.modules(): - if isinstance(module, CastedLinear): - module.float() - restore_low_dim_params_to_fp32(base_model) - compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) - model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model - - # Optimizer split: - # - token embedding (Adam) uses EMBED_LR - # - untied lm_head (Adam) uses HEAD_LR - # - matrix params in transformer blocks use MATRIX_LR via Muon - # - vectors/scalars use SCALAR_LR via Adam - block_named_params = list(base_model.blocks.named_parameters()) - matrix_params = [ - p - for name, p in block_named_params - if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - scalar_params = [ - p - for name, p in block_named_params - if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - if base_model.skip_weights.numel() > 0: - scalar_params.append(base_model.skip_weights) - token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr - optimizer_tok = torch.optim.Adam( - [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizer_muon = Muon( - matrix_params, - lr=args.matrix_lr, - momentum=args.muon_momentum, - backend_steps=args.muon_backend_steps, - ) - for group in optimizer_muon.param_groups: - group["base_lr"] = args.matrix_lr - optimizer_scalar = torch.optim.Adam( - [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] - if base_model.lm_head is not None: - optimizer_head = torch.optim.Adam( - [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizers.insert(1, optimizer_head) - - n_params = sum(p.numel() for p in base_model.parameters()) - log0(f"model_params:{n_params}") - log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") - log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") - log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") - log0( - f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " - f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " - f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" - ) - log0( - f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " - f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " - f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" - ) - log0(f"seed:{args.seed}") - - # ----------------------------- - # DATA LOADER & MODEL WARMUP - # ----------------------------- - - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - def zero_grad_all() -> None: - for opt in optimizers: - opt.zero_grad(set_to_none=True) - - max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None - - def lr_mul(step: int, elapsed_ms: float) -> float: - if args.warmdown_iters <= 0: - return 1.0 - if max_wallclock_ms is None: - warmdown_start = max(args.iterations - args.warmdown_iters, 0) - return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 - step_ms = elapsed_ms / max(step, 1) - warmdown_ms = args.warmdown_iters * step_ms - remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) - return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 - - # Warmup primes the compiled forward/backward/optimizer paths, then we restore the - # initial weights/optimizer state so measured training starts from the true init. - if args.warmup_steps > 0: - initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} - initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] - model.train() - for warmup_step in range(args.warmup_steps): - zero_grad_all() - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - warmup_loss = model(x, y) - (warmup_loss * grad_scale).backward() - for opt in optimizers: - opt.step() - zero_grad_all() - if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: - log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") - base_model.load_state_dict(initial_model_state, strict=True) - for opt, state in zip(optimizers, initial_optimizer_states, strict=True): - opt.load_state_dict(state) - zero_grad_all() - if distributed: - model.require_backward_grad_sync = True - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - # ----------------------------- - # MAIN TRAINING LOOP - # ----------------------------- - - training_time_ms = 0.0 - stop_after_step: int | None = None - torch.cuda.synchronize() - t0 = time.perf_counter() - - step = 0 - while True: - last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) - - should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) - if should_validate: - torch.cuda.synchronize() - training_time_ms += 1000.0 * (time.perf_counter() - t0) - val_loss, val_bpb = eval_val( - args, - model, - rank, - world_size, - device, - grad_accum_steps, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, - ) - log0( - f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " - f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" - ) - torch.cuda.synchronize() - t0 = time.perf_counter() - - if last_step: - if stop_after_step is not None and step < args.iterations: - log0( - f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " - f"step:{step}/{args.iterations}" - ) - break - - elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - scale = lr_mul(step, elapsed_ms) - zero_grad_all() - train_loss = torch.zeros((), device=device) - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - loss = model(x, y) - train_loss += loss.detach() - (loss * grad_scale).backward() - train_loss /= grad_accum_steps - - frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 - muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum - for group in optimizer_muon.param_groups: - group["momentum"] = muon_momentum - - for opt in optimizers: - for group in opt.param_groups: - group["lr"] = group["base_lr"] * scale - - if args.grad_clip_norm > 0: - torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) - for opt in optimizers: - opt.step() - zero_grad_all() - - step += 1 - approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - should_log_train = ( - args.train_log_every > 0 - and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) - ) - if should_log_train: - log0( - f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " - f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" - ) - - # Needed to sync whether we've reached the wallclock cap. - reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms - if distributed and max_wallclock_ms is not None: - reached_cap_tensor = torch.tensor(int(reached_cap), device=device) - dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) - reached_cap = bool(reached_cap_tensor.item()) - if stop_after_step is None and reached_cap: - stop_after_step = step - - log0( - f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " - f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" - ) - - # ----------------------------- - # SERIALIZATION + ROUNDTRIP VALIDATION - # ----------------------------- - # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce - # the compressed int8+zlib artifact and validate the round-tripped weights. - - if master_process: - torch.save(base_model.state_dict(), "final_model.pt") - model_bytes = os.path.getsize("final_model.pt") - code_bytes = len(code.encode("utf-8")) - log0(f"Serialized model: {model_bytes} bytes") - log0(f"Code size: {code_bytes} bytes") - log0(f"Total submission size: {model_bytes + code_bytes} bytes") - - quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) - quant_buf = io.BytesIO() - torch.save(quant_obj, quant_buf) - quant_raw = quant_buf.getvalue() - quant_blob = zlib.compress(quant_raw, level=9) - quant_raw_bytes = len(quant_raw) - if master_process: - with open("final_model.int8.ptz", "wb") as f: - f.write(quant_blob) - quant_file_bytes = os.path.getsize("final_model.int8.ptz") - code_bytes = len(code.encode("utf-8")) - ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) - log0( - f"Serialized model int8+zlib: {quant_file_bytes} bytes " - f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" - ) - log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") - - if distributed: - dist.barrier() - with open("final_model.int8.ptz", "rb") as f: - quant_blob_disk = f.read() - quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") - base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) - torch.cuda.synchronize() - t_qeval = time.perf_counter() - q_val_loss, q_val_bpb = eval_val( - args, - model, - rank, - world_size, - device, - grad_accum_steps, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, - ) - torch.cuda.synchronize() - log0( - f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " - f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" - ) - log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") - - if distributed: - dist.destroy_process_group() - - -if __name__ == "__main__": - main() - -==================================================================================================== -Running Python 3.12.3 (main, Aug 14 2025, 17:47:21) [GCC 13.3.0] -Running PyTorch 2.8.0+cu128 -Thu Mar 19 05:03:59 2026 -+-----------------------------------------------------------------------------------------+ -| NVIDIA-SMI 570.211.01 Driver Version: 570.211.01 CUDA Version: 12.8 | -|-----------------------------------------+------------------------+----------------------+ -| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | -| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | -| | | MIG M. | -|=========================================+========================+======================| -| 0 NVIDIA H100 80GB HBM3 On | 00000000:18:00.0 Off | 0 | -| N/A 28C P0 117W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 1 NVIDIA H100 80GB HBM3 On | 00000000:2A:00.0 Off | 0 | -| N/A 30C P0 117W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 2 NVIDIA H100 80GB HBM3 On | 00000000:3A:00.0 Off | 0 | -| N/A 32C P0 123W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 3 NVIDIA H100 80GB HBM3 On | 00000000:5D:00.0 Off | 0 | -| N/A 27C P0 117W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 4 NVIDIA H100 80GB HBM3 On | 00000000:9A:00.0 Off | 0 | -| N/A 26C P0 114W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 5 NVIDIA H100 80GB HBM3 On | 00000000:AB:00.0 Off | 0 | -| N/A 31C P0 121W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 6 NVIDIA H100 80GB HBM3 On | 00000000:BA:00.0 Off | 0 | -| N/A 30C P0 117W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 7 NVIDIA H100 80GB HBM3 On | 00000000:DB:00.0 Off | 0 | -| N/A 26C P0 119W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ - -+-----------------------------------------------------------------------------------------+ -| Processes: | -| GPU GI CI PID Type Process name GPU Memory | -| ID ID Usage | -|=========================================================================================| -| 0 N/A N/A 202262 C /usr/local/bin/python 1510MiB | -| 1 N/A N/A 202263 C /usr/local/bin/python 1510MiB | -| 2 N/A N/A 202264 C /usr/local/bin/python 1510MiB | -| 3 N/A N/A 202265 C /usr/local/bin/python 1510MiB | -| 4 N/A N/A 202266 C /usr/local/bin/python 1510MiB | -| 5 N/A N/A 202267 C /usr/local/bin/python 1510MiB | -| 6 N/A N/A 202268 C /usr/local/bin/python 1510MiB | -| 7 N/A N/A 202269 C /usr/local/bin/python 1510MiB | -+-----------------------------------------------------------------------------------------+ - -==================================================================================================== -val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model -train_loader:dataset:fineweb10B_sp1024 train_shards:80 -val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 -model_params:17059912 -world_size:8 grad_accum_steps:1 -sdp_backends:cudnn=False flash=True mem_efficient=False math=False -attention_mode:gqa num_heads:8 num_kv_heads:4 -tie_embeddings:True embed_lr:0.03 head_lr:0.0 matrix_lr:0.02 scalar_lr:0.02 -train_batch_tokens:393216 train_seq_len:4096 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 -seed:1339 -warmup_step:1/20 -warmup_step:2/20 -warmup_step:3/20 -warmup_step:4/20 -warmup_step:5/20 -warmup_step:6/20 -warmup_step:7/20 -warmup_step:8/20 -warmup_step:9/20 -warmup_step:10/20 -warmup_step:11/20 -warmup_step:12/20 -warmup_step:13/20 -warmup_step:14/20 -warmup_step:15/20 -warmup_step:16/20 -warmup_step:17/20 -warmup_step:18/20 -warmup_step:19/20 -warmup_step:20/20 -step:0/20000 val_loss:6.9372 val_bpb:4.1086 train_time:0ms step_avg:0.02ms -step:1/20000 train_loss:6.9377 train_time:40ms step_avg:40.37ms -step:2/20000 train_loss:12.1131 train_time:102ms step_avg:50.86ms -step:3/20000 train_loss:7.3629 train_time:175ms step_avg:58.28ms -step:4/20000 train_loss:6.3140 train_time:245ms step_avg:61.16ms -step:5/20000 train_loss:6.8601 train_time:316ms step_avg:63.13ms -step:6/20000 train_loss:6.9431 train_time:378ms step_avg:62.98ms -step:7/20000 train_loss:6.7660 train_time:457ms step_avg:65.26ms -step:8/20000 train_loss:6.4793 train_time:527ms step_avg:65.87ms -step:9/20000 train_loss:6.3069 train_time:591ms step_avg:65.66ms -step:10/20000 train_loss:6.2345 train_time:654ms step_avg:65.37ms -step:200/20000 train_loss:2.7717 train_time:10524ms step_avg:52.62ms -step:400/20000 train_loss:2.4119 train_time:25089ms step_avg:62.72ms -step:600/20000 train_loss:2.2870 train_time:39146ms step_avg:65.24ms -step:800/20000 train_loss:2.3682 train_time:53332ms step_avg:66.67ms -step:1000/20000 train_loss:2.3258 train_time:63778ms step_avg:63.78ms -step:1000/20000 val_loss:2.3036 val_bpb:1.3643 train_time:63809ms step_avg:63.81ms -step:1200/20000 train_loss:2.3500 train_time:79075ms step_avg:65.90ms -step:1400/20000 train_loss:2.3032 train_time:93065ms step_avg:66.48ms -step:1600/20000 train_loss:2.2669 train_time:106365ms step_avg:66.48ms -step:1800/20000 train_loss:2.0491 train_time:119794ms step_avg:66.55ms -step:2000/20000 train_loss:2.3109 train_time:130239ms step_avg:65.12ms -step:2000/20000 val_loss:2.2128 val_bpb:1.3106 train_time:130273ms step_avg:65.14ms -step:2200/20000 train_loss:2.0611 train_time:144178ms step_avg:65.54ms -step:2400/20000 train_loss:2.2093 train_time:157566ms step_avg:65.65ms -step:2600/20000 train_loss:2.3236 train_time:170759ms step_avg:65.68ms -step:2800/20000 train_loss:2.3997 train_time:185694ms step_avg:66.32ms -step:3000/20000 train_loss:2.1766 train_time:196104ms step_avg:65.37ms -step:3000/20000 val_loss:2.1566 val_bpb:1.2773 train_time:196141ms step_avg:65.38ms -step:3200/20000 train_loss:2.1982 train_time:209488ms step_avg:65.46ms -step:3400/20000 train_loss:2.2025 train_time:225460ms step_avg:66.31ms -step:3600/20000 train_loss:2.0799 train_time:240943ms step_avg:66.93ms -step:3800/20000 train_loss:2.0977 train_time:251427ms step_avg:66.16ms -step:4000/20000 train_loss:1.9322 train_time:266529ms step_avg:66.63ms -step:4000/20000 val_loss:2.1237 val_bpb:1.2578 train_time:266559ms step_avg:66.64ms -step:4200/20000 train_loss:2.0105 train_time:281249ms step_avg:66.96ms -step:4400/20000 train_loss:2.1162 train_time:297764ms step_avg:67.67ms -step:4600/20000 train_loss:2.0382 train_time:312481ms step_avg:67.93ms -step:4800/20000 train_loss:2.0747 train_time:322932ms step_avg:67.28ms -step:5000/20000 train_loss:2.1439 train_time:338935ms step_avg:67.79ms -step:5000/20000 val_loss:2.1047 val_bpb:1.2465 train_time:338964ms step_avg:67.79ms -step:5200/20000 train_loss:2.2006 train_time:354715ms step_avg:68.21ms -step:5400/20000 train_loss:1.9001 train_time:371359ms step_avg:68.77ms -step:5600/20000 train_loss:2.2526 train_time:386064ms step_avg:68.94ms -step:5800/20000 train_loss:2.1734 train_time:396507ms step_avg:68.36ms -step:6000/20000 train_loss:1.9981 train_time:413445ms step_avg:68.91ms -step:6000/20000 val_loss:2.0865 val_bpb:1.2358 train_time:413479ms step_avg:68.91ms -step:6200/20000 train_loss:2.1487 train_time:429320ms step_avg:69.25ms -step:6400/20000 train_loss:2.0919 train_time:444877ms step_avg:69.51ms -step:6600/20000 train_loss:2.0824 train_time:455307ms step_avg:68.99ms -step:6800/20000 train_loss:2.1572 train_time:472340ms step_avg:69.46ms -step:7000/20000 train_loss:2.1331 train_time:488662ms step_avg:69.81ms -step:7000/20000 val_loss:2.0582 val_bpb:1.2190 train_time:488692ms step_avg:69.81ms -step:7200/20000 train_loss:2.0539 train_time:503944ms step_avg:69.99ms -step:7400/20000 train_loss:2.0565 train_time:520195ms step_avg:70.30ms -step:7600/20000 train_loss:2.0145 train_time:530612ms step_avg:69.82ms -step:7800/20000 train_loss:2.0521 train_time:546496ms step_avg:70.06ms -step:8000/20000 train_loss:2.1062 train_time:562870ms step_avg:70.36ms -step:8000/20000 val_loss:2.0335 val_bpb:1.2043 train_time:562900ms step_avg:70.36ms -step:8200/20000 train_loss:2.1051 train_time:578448ms step_avg:70.54ms -step:8400/20000 train_loss:1.9911 train_time:594270ms step_avg:70.75ms -step:8508/20000 val_loss:2.0255 val_bpb:1.1996 train_time:599920ms step_avg:70.51ms -stopping_early: wallclock_cap train_time:599920ms step:8508/20000 -peak memory allocated: 7748 MiB reserved: 8066 MiB -Serialized model: 67224983 bytes -Code size: 47759 bytes -Total submission size: 67272742 bytes -Serialized model int8+zlib: 15839491 bytes (payload:17178912 raw_torch:17224025 payload_ratio:3.91x) -Total submission size int8+zlib: 15887250 bytes -final_int8_zlib_roundtrip val_loss:2.0315 val_bpb:1.2032 eval_time:2098ms -final_int8_zlib_roundtrip_exact val_loss:2.03154449 val_bpb:1.20319508 diff --git a/records/track_10min_16mb/2026-03-19_WarmdownQuantization/README.md b/records/track_10min_16mb/2026-03-19_WarmdownQuantization/README.md deleted file mode 100644 index 86a3e5d610..0000000000 --- a/records/track_10min_16mb/2026-03-19_WarmdownQuantization/README.md +++ /dev/null @@ -1,55 +0,0 @@ -# Warmdown-Quantization: Training for Compression - -## Score -**val_bpb = 1.2154** (baseline: 1.2244, improvement: 0.009 BPB / 0.017 nats) - -## Key Insight - -On 8xH100, the dominant bottleneck isn't model quality — it's quantization quality. The post-training int8 quantization penalty (0.014 BPB with default settings) is larger than most hyperparameter improvements combined. We attack this bottleneck from multiple angles. - -## Novel Contributions - -### 1. Always-Decaying Learning Rate Schedule (WARMDOWN_ITERS=20000) - -Setting WARMDOWN_ITERS far beyond the actual training steps (~12,200) produces dramatically better post-quantization quality. The LR decays linearly from 61% of peak at step 0 to near-zero at the final step. - -Aggressive LR decay produces tighter weight distributions with fewer outliers. Since int8 quantization error is proportional to the weight range per row, smoother weights map to the int8 grid with much less damage. - -Post-quant penalty drops from 0.014 BPB (WD=1200 default) to 0.005 BPB (WD=20000). We mapped the full curve across 10 warmdown values, finding the sweet spot at WD=20000 where the entire training run is in the decay phase. WD=30000 overshoots — too little high-LR learning. - -### 2. FP16 Tied Embeddings - -The tied embedding matrix (tok_emb.weight) serves dual roles as input lookup and output projection. Int8 quantization causes disproportionate damage because small errors affect both input representation quality AND output logit accuracy. Keeping it in fp16 during quantization reduces the remaining post-quant penalty from 0.005 to ~0.001 BPB at a cost of ~500KB (offset by reducing MLP hidden from 1024 to 992). - -### 3. Optimal NTK-RoPE Extrapolation for Well-Trained Models - -The optimal eval sequence length depends on training convergence: -- Undertrained models (1xH100, ~1,600 steps): eval@2048 gives +0.048 BPB -- Well-trained models (8xH100, ~12,200 steps): eval@2048 is neutral-to-negative; eval@1408 (1.375x) is optimal (+0.007 BPB) - -Well-trained models develop precise position-dependent patterns that aggressive NTK extrapolation distorts. Moderate extrapolation provides useful extra context without excessive distortion. - -### 4. Optimizer-Warmdown Interaction - -MUON_BACKEND_STEPS=5 outperforms 7 when combined with aggressive warmdown (WD=20000), despite 7 outperforming 5 at default warmdown (WD=2400). When warmdown already produces smooth weights, more training steps are more valuable than better per-step gradient quality. - -## Configuration - -``` -WARMDOWN_ITERS=20000 MATRIX_LR=0.06 TIED_EMBED_LR=0.07 SCALAR_LR=0.06 -GRAD_CLIP_NORM=1.0 MUON_BACKEND_STEPS=5 EVAL_SEQ_LEN=1408 -``` -- FP16 tied embedding (tok_emb.weight kept in fp16 during int8 export) -- MLP_HIDDEN=992 (offset FP16 embedding overhead) - -## Reproduction - -```bash -WARMDOWN_ITERS=20000 MATRIX_LR=0.06 TIED_EMBED_LR=0.07 SCALAR_LR=0.06 \ -GRAD_CLIP_NORM=1.0 MUON_BACKEND_STEPS=5 EVAL_SEQ_LEN=1408 MLP_HIDDEN=992 \ -torchrun --standalone --nproc_per_node=8 train_gpt.py -``` - -## Hardware Note - -Results obtained on RunPod 8xH100 SXM (47-48ms/step vs baseline's 43.5ms/step). Scores should improve when re-evaluated on OpenAI's faster hardware due to additional training steps within the 10-minute window. diff --git a/records/track_10min_16mb/2026-03-19_WarmdownQuantization/submission.json b/records/track_10min_16mb/2026-03-19_WarmdownQuantization/submission.json deleted file mode 100644 index 07c58c39c2..0000000000 --- a/records/track_10min_16mb/2026-03-19_WarmdownQuantization/submission.json +++ /dev/null @@ -1,11 +0,0 @@ -{ - "author": "samuellarson", - "github_id": "samuellarson", - "name": "Int6 MLP3x Sliding Window", - "blurb": "Int6 post-training quantization enables 3x MLP expansion (21.8M params in 16MB). Combined with train@2048 + sliding window eval + FP16 tied embeddings + Late-K passthrough.", - "date": "2026-03-20", - "val_loss": 1.95428963, - "val_bpb": 1.15744040, - "bytes_total": 15977717, - "bytes_code": 51200 -} diff --git a/records/track_10min_16mb/2026-03-19_WarmdownQuantization/train.log b/records/track_10min_16mb/2026-03-19_WarmdownQuantization/train.log deleted file mode 100644 index c10e30c24b..0000000000 --- a/records/track_10min_16mb/2026-03-19_WarmdownQuantization/train.log +++ /dev/null @@ -1,1403 +0,0 @@ -""" -The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. - -Hard stop: `train_gpt.py` and `train_gpt_mlx.py` must never be longer than 1500 lines. -""" - -from __future__ import annotations - -import copy -import glob -import io -import math -import os -import random -import subprocess -import sys -import time -import uuid -import zlib -from pathlib import Path - -import numpy as np -import sentencepiece as spm -import torch -import torch.distributed as dist -import torch.nn.functional as F -from torch import Tensor, nn -from torch.nn.parallel import DistributedDataParallel as DDP - -# ----------------------------- -# HYPERPARAMETERS -# ----------------------------- -# Default Simple Baseline run: -# - 9 transformer blocks at width 512 -# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion -# - vocab size 1024, sequence length 1024, tied embeddings -# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap - -class Hyperparameters: - # Data paths are shard globs produced by the existing preprocessing pipeline. - data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") - train_files = os.path.join(data_path, "fineweb_train_*.bin") - val_files = os.path.join(data_path, "fineweb_val_*.bin") - tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") - run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) - seed = int(os.environ.get("SEED", 1337)) - - # Validation cadence and batch size. Validation always uses the full fineweb_val split. - val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) - val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) - train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) - - # Training length. - iterations = int(os.environ.get("ITERATIONS", 20000)) - warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) - warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) - train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) - train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) - eval_seq_len = int(os.environ.get("EVAL_SEQ_LEN", 2048)) - eval_stride = int(os.environ.get("EVAL_STRIDE", 0)) - max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) - qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) - - # Model shape. - vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) - num_layers = int(os.environ.get("NUM_LAYERS", 9)) - num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) - model_dim = int(os.environ.get("MODEL_DIM", 512)) - num_heads = int(os.environ.get("NUM_HEADS", 8)) - mlp_mult = int(os.environ.get("MLP_MULT", 2)) - mlp_hidden = int(os.environ.get("MLP_HIDDEN", 0)) - tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) - rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) - logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) - - # Optimizer hyperparameters. - embed_lr = float(os.environ.get("EMBED_LR", 0.6)) - head_lr = float(os.environ.get("HEAD_LR", 0.008)) - tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) - tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) - matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) - scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) - muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) - muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) - muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) - muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) - beta1 = float(os.environ.get("BETA1", 0.9)) - beta2 = float(os.environ.get("BETA2", 0.95)) - adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) - grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) - -# ----------------------------- -# MUON OPTIMIZER -# ----------------------------- -# -# As borrowed from modded-nanogpt -# Background on Muon: https://kellerjordan.github.io/posts/muon/ - -def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: - # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. - # Muon uses this to normalize matrix-shaped gradients before applying them. - a, b, c = (3.4445, -4.7750, 2.0315) - X = G.bfloat16() - X /= X.norm() + eps - transposed = G.size(0) > G.size(1) - if transposed: - X = X.T - for _ in range(steps): - A = X @ X.T - B = b * A + c * A @ A - X = a * X + B @ X - return X.T if transposed else X - - -class Muon(torch.optim.Optimizer): - def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): - super().__init__( - params, - dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), - ) - - @torch.no_grad() - def step(self, closure=None): - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - distributed = dist.is_available() and dist.is_initialized() - world_size = dist.get_world_size() if distributed else 1 - rank = dist.get_rank() if distributed else 0 - - for group in self.param_groups: - params = group["params"] - if not params: - continue - lr = group["lr"] - momentum = group["momentum"] - backend_steps = group["backend_steps"] - nesterov = group["nesterov"] - - total_params = sum(int(p.numel()) for p in params) - updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) - - curr = 0 - for i, p in enumerate(params): - if i % world_size == rank and p.grad is not None: - g = p.grad - state = self.state[p] - if "momentum_buffer" not in state: - state["momentum_buffer"] = torch.zeros_like(g) - buf = state["momentum_buffer"] - buf.mul_(momentum).add_(g) - if nesterov: - g = g.add(buf, alpha=momentum) - g = zeropower_via_newtonschulz5(g, steps=backend_steps) - # Scale correction from Muon reference implementations. - g *= max(1, g.size(0) / g.size(1)) ** 0.5 - updates_flat[curr : curr + p.numel()] = g.reshape(-1) - curr += p.numel() - - if distributed: - dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) - - curr = 0 - for p in params: - g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) - p.add_(g, alpha=-lr) - curr += p.numel() - - return loss - - -# ----------------------------- -# TOKENIZER-AGNOSTIC EVALUATION SETUP -# ----------------------------- -# -# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. -# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. -# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. -# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. - -def build_sentencepiece_luts( - sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device -) -> tuple[Tensor, Tensor, Tensor]: - sp_vocab_size = int(sp.vocab_size()) - table_size = max(sp_vocab_size, vocab_size) - base_bytes_np = np.zeros((table_size,), dtype=np.int16) - has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) - is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) - for token_id in range(sp_vocab_size): - if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): - continue - is_boundary_token_np[token_id] = False - if sp.is_byte(token_id): - base_bytes_np[token_id] = 1 - continue - piece = sp.id_to_piece(token_id) - if piece.startswith("▁"): - has_leading_space_np[token_id] = True - piece = piece[1:] - base_bytes_np[token_id] = len(piece.encode("utf-8")) - return ( - torch.tensor(base_bytes_np, dtype=torch.int16, device=device), - torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), - torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), - ) - - -def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: - files = [Path(p) for p in sorted(glob.glob(pattern))] - if not files: - raise FileNotFoundError(f"No files found for pattern: {pattern}") - # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. - tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() - usable = ((tokens.numel() - 1) // seq_len) * seq_len - if usable <= 0: - raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") - return tokens[: usable + 1] - - -def eval_val( - args: Hyperparameters, - model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - grad_accum_steps: int, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, - eval_seq_len: int | None = None, -) -> tuple[float, float]: - # Validation computes two metrics: - # - val_loss: token cross-entropy (natural log) - # - val_bpb: tokenizer-agnostic compression metric used by the challenge - seq_len = eval_seq_len or args.train_seq_len - local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) - if local_batch_tokens < seq_len: - raise ValueError( - "VAL_BATCH_SIZE must provide at least one sequence per rank; " - f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " - f"GRAD_ACCUM_STEPS={grad_accum_steps}, seq_len={seq_len}" - ) - local_batch_seqs = local_batch_tokens // seq_len - total_seqs = (val_tokens.numel() - 1) // seq_len - seq_start = (total_seqs * rank) // world_size - seq_end = (total_seqs * (rank + 1)) // world_size - val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) - val_token_count = torch.zeros((), device=device, dtype=torch.float64) - val_byte_count = torch.zeros((), device=device, dtype=torch.float64) - - model.eval() - with torch.inference_mode(): - for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): - batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) - raw_start = batch_seq_start * seq_len - raw_end = batch_seq_end * seq_len + 1 - local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) - x = local[:-1].reshape(-1, seq_len) - y = local[1:].reshape(-1, seq_len) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - batch_loss = model(x, y).detach() - batch_token_count = float(y.numel()) - val_loss_sum += batch_loss.to(torch.float64) * batch_token_count - val_token_count += batch_token_count - prev_ids = x.reshape(-1) - tgt_ids = y.reshape(-1) - token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) - token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) - val_byte_count += token_bytes.to(torch.float64).sum() - - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) - - val_loss = val_loss_sum / val_token_count - bits_per_token = val_loss.item() / math.log(2.0) - tokens_per_byte = val_token_count.item() / val_byte_count.item() - model.train() - return float(val_loss.item()), float(bits_per_token * tokens_per_byte) - -# ----------------------------- -# POST-TRAINING QUANTIZATION -# ----------------------------- -# -# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. -# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. -# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. - -CONTROL_TENSOR_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "CONTROL_TENSOR_NAME_PATTERNS", - "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", - ",".join(CONTROL_TENSOR_NAME_PATTERNS), - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 -INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 -INT8_PER_ROW_SCALE_DTYPE = torch.float16 -INT8_CLIP_PERCENTILE = 99.99984 -INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 - -def tensor_nbytes(t: Tensor) -> int: - return int(t.numel()) * int(t.element_size()) - -def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: - if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): - return t.float().contiguous() - if t.dtype in {torch.float32, torch.bfloat16}: - passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") - return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() - return t - -def quantize_float_tensor(t: Tensor, bits: int = 8) -> tuple[Tensor, Tensor]: - max_val = 127 if bits == 8 else (2 ** (bits - 1)) - 1 # int6: 31, int8: 127 - t32 = t.float() - if t32.ndim == 2: - clip_abs = ( - torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) - if t32.numel() - else torch.empty((t32.shape[0],), dtype=torch.float32) - ) - clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) - scale = (clip_abs / float(max_val)).clamp_min(1.0 / float(max_val)) - q = torch.clamp(torch.round(clipped / scale[:, None]), -max_val, max_val).to(torch.int8).contiguous() - return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() - - clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 - scale = torch.tensor(clip_abs / float(max_val) if clip_abs > 0 else 1.0, dtype=torch.float32) - q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -max_val, max_val).to(torch.int8).contiguous() - return q, scale - -def quantize_state_dict_int8(state_dict: dict[str, Tensor]): - # Single supported clean-script export format: - # - per-row int8 for 2D float tensors - # - per-tensor int8 for other float tensors - # - exact passthrough for non-floats - # - passthrough for small float tensors, stored as fp16 to save bytes - quantized: dict[str, Tensor] = {} - scales: dict[str, Tensor] = {} - dtypes: dict[str, str] = {} - passthrough: dict[str, Tensor] = {} - passthrough_orig_dtypes: dict[str, str] = {} - qmeta: dict[str, dict[str, object]] = {} - stats = dict.fromkeys( - ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), - 0, - ) - - for name, tensor in state_dict.items(): - t = tensor.detach().to("cpu").contiguous() - stats["param_count"] += int(t.numel()) - stats["num_tensors"] += 1 - stats["baseline_tensor_bytes"] += tensor_nbytes(t) - - if not t.is_floating_point(): - stats["num_nonfloat_tensors"] += 1 - passthrough[name] = t - stats["int8_payload_bytes"] += tensor_nbytes(t) - continue - - # FP16 passthrough for tied embedding (our trick) - if name == "tok_emb.weight": - kept = t.to(dtype=torch.float16).contiguous() - passthrough[name] = kept - passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") - stats["int8_payload_bytes"] += tensor_nbytes(kept) - continue - - # Late-K passthrough: keep last 2 layers' key weights in fp16 (PR #99's trick) - num_layers_total = max((int(k.split(".")[1]) for k in state_dict if k.startswith("blocks.")), default=0) + 1 - if name.endswith("c_k.weight") and any(f"blocks.{i}." in name for i in range(num_layers_total - 2, num_layers_total)): - kept = t.to(dtype=torch.float16).contiguous() - passthrough[name] = kept - passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") - stats["int8_payload_bytes"] += tensor_nbytes(kept) - continue - - # Small float tensors are cheap enough to keep directly. - if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: - kept = keep_float_tensor(name, t, passthrough_orig_dtypes) - passthrough[name] = kept - stats["int8_payload_bytes"] += tensor_nbytes(kept) - continue - - # Everything else: int6 quantization (saves ~25% vs int8) - stats["num_float_tensors"] += 1 - q, s = quantize_float_tensor(t, bits=6) - if s.ndim > 0: - qmeta[name] = {"scheme": "per_row", "axis": 0} - quantized[name] = q - scales[name] = s - dtypes[name] = str(t.dtype).removeprefix("torch.") - stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) - - obj: dict[str, object] = { - "__quant_format__": "int8_clean_per_row_v1", - "quantized": quantized, - "scales": scales, - "dtypes": dtypes, - "passthrough": passthrough, - } - if qmeta: - obj["qmeta"] = qmeta - if passthrough_orig_dtypes: - obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes - return obj, stats - -def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: - out: dict[str, Tensor] = {} - qmeta = obj.get("qmeta", {}) - passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) - for name, q in obj["quantized"].items(): - dtype = getattr(torch, obj["dtypes"][name]) - s = obj["scales"][name] - if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: - s = s.to(dtype=torch.float32) - # Broadcast the saved row scale back across trailing dimensions. - out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() - else: - scale = float(s.item()) - out[name] = (q.float() * scale).to(dtype=dtype).contiguous() - for name, t in obj["passthrough"].items(): - # Restore small tensors, undoing the temporary fp16 storage cast if needed. - out_t = t.detach().to("cpu").contiguous() - orig_dtype = passthrough_orig_dtypes.get(name) - if isinstance(orig_dtype, str): - out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() - out[name] = out_t - return out - - -# ----------------------------- -# DATA LOADING -# ----------------------------- - -def load_data_shard(file: Path) -> Tensor: - header_bytes = 256 * np.dtype(" None: - self.file_idx = (self.file_idx + 1) % len(self.files) - self.tokens = load_data_shard(self.files[self.file_idx]) - self.pos = 0 - - def take(self, n: int) -> Tensor: - chunks: list[Tensor] = [] - remaining = n - while remaining > 0: - avail = self.tokens.numel() - self.pos - if avail <= 0: - self._advance_file() - continue - k = min(remaining, avail) - chunks.append(self.tokens[self.pos : self.pos + k]) - self.pos += k - remaining -= k - return chunks[0] if len(chunks) == 1 else torch.cat(chunks) - - -class DistributedTokenLoader: - # Each call consumes a contiguous chunk from the shared token stream, then slices out - # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. - def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): - self.rank = rank - self.world_size = world_size - self.device = device - self.stream = TokenStream(pattern) - - def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: - local_tokens = global_tokens // (self.world_size * grad_accum_steps) - per_rank_span = local_tokens + 1 - chunk = self.stream.take(per_rank_span * self.world_size) - start = self.rank * per_rank_span - local = chunk[start : start + per_rank_span].to(dtype=torch.int64) - x = local[:-1].reshape(-1, seq_len) - y = local[1:].reshape(-1, seq_len) - return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) - -# ----------------------------- -# TRANSFORMER MODULES -# ----------------------------- - -class RMSNorm(nn.Module): - def __init__(self, eps: float | None = None): - super().__init__() - self.eps = eps - - def forward(self, x: Tensor) -> Tensor: - return F.rms_norm(x, (x.size(-1),), eps=self.eps) - - -class CastedLinear(nn.Linear): - # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. - def forward(self, x: Tensor) -> Tensor: - bias = self.bias.to(x.dtype) if self.bias is not None else None - return F.linear(x, self.weight.to(x.dtype), bias) - - -def restore_low_dim_params_to_fp32(module: nn.Module) -> None: - # Keep small/control parameters in fp32 even when the model body runs in bf16. - with torch.no_grad(): - for name, param in module.named_parameters(): - if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: - param.data = param.data.float() - - -class Rotary(nn.Module): - # Caches cos/sin tables per sequence length on the current device. - # NTK-aware RoPE scaling for sequence length extrapolation at eval time. - def __init__(self, dim: int, base: float = 10000.0, train_seq_len: int = 1024): - super().__init__() - self.dim = dim - self.base = base - self.train_seq_len = train_seq_len - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self._seq_len_cached = 0 - self._cos_cached: Tensor | None = None - self._sin_cached: Tensor | None = None - - def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: - if ( - self._cos_cached is None - or self._sin_cached is None - or self._seq_len_cached != seq_len - or self._cos_cached.device != device - ): - if seq_len > self.train_seq_len: - scale = seq_len / self.train_seq_len - new_base = self.base * (scale ** (self.dim / (self.dim - 2))) - inv_freq = 1.0 / (new_base ** (torch.arange(0, self.dim, 2, dtype=torch.float32, device=device) / self.dim)) - else: - inv_freq = self.inv_freq.to(device) - t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) - freqs = torch.outer(t, inv_freq) - self._cos_cached = freqs.cos()[None, None, :, :] - self._sin_cached = freqs.sin()[None, None, :, :] - self._seq_len_cached = seq_len - return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) - - -def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: - half = x.size(-1) // 2 - x1, x2 = x[..., :half], x[..., half:] - return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) - - -class CausalSelfAttention(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - rope_base: float, - qk_gain_init: float, - ): - super().__init__() - if dim % num_heads != 0: - raise ValueError("model_dim must be divisible by num_heads") - if num_heads % num_kv_heads != 0: - raise ValueError("num_heads must be divisible by num_kv_heads") - self.num_heads = num_heads - self.num_kv_heads = num_kv_heads - self.head_dim = dim // num_heads - if self.head_dim % 2 != 0: - raise ValueError("head_dim must be even for RoPE") - kv_dim = self.num_kv_heads * self.head_dim - self.c_q = CastedLinear(dim, dim, bias=False) - self.c_k = CastedLinear(dim, kv_dim, bias=False) - self.c_v = CastedLinear(dim, kv_dim, bias=False) - self.proj = CastedLinear(dim, dim, bias=False) - self.proj._zero_init = True - self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) - self.rotary = Rotary(self.head_dim, base=rope_base, train_seq_len=1024) - - def forward(self, x: Tensor) -> Tensor: - bsz, seqlen, dim = x.shape - q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) - k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) - v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) - q = F.rms_norm(q, (q.size(-1),)) - k = F.rms_norm(k, (k.size(-1),)) - cos, sin = self.rotary(seqlen, x.device, q.dtype) - q = apply_rotary_emb(q, cos, sin) - k = apply_rotary_emb(k, cos, sin) - q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] - y = F.scaled_dot_product_attention( - q, - k, - v, - attn_mask=None, - is_causal=True, - enable_gqa=(self.num_kv_heads != self.num_heads), - ) - y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) - return self.proj(y) - - -class MLP(nn.Module): - def __init__(self, dim: int, mlp_mult: int, mlp_hidden: int = 0): - super().__init__() - hidden = mlp_hidden if mlp_hidden > 0 else mlp_mult * dim - self.fc = CastedLinear(dim, hidden, bias=False) - self.proj = CastedLinear(hidden, dim, bias=False) - self.proj._zero_init = True - - def forward(self, x: Tensor) -> Tensor: - x = torch.relu(self.fc(x)) - return self.proj(x.square()) - - -class Block(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: int, - rope_base: float, - qk_gain_init: float, - mlp_hidden: int = 0, - ): - super().__init__() - self.attn_norm = RMSNorm() - self.mlp_norm = RMSNorm() - self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) - self.mlp = MLP(dim, mlp_mult, mlp_hidden) - self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) - - def forward(self, x: Tensor, x0: Tensor) -> Tensor: - mix = self.resid_mix.to(dtype=x.dtype) - x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 - attn_out = self.attn(self.attn_norm(x)) - x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out - x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) - return x - - -class GPT(nn.Module): - def __init__( - self, - vocab_size: int, - num_layers: int, - model_dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: int, - mlp_hidden: int, - tie_embeddings: bool, - tied_embed_init_std: float, - logit_softcap: float, - rope_base: float, - qk_gain_init: float, - ): - super().__init__() - if logit_softcap <= 0.0: - raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") - self.tie_embeddings = tie_embeddings - self.tied_embed_init_std = tied_embed_init_std - self.logit_softcap = logit_softcap - self.tok_emb = nn.Embedding(vocab_size, model_dim) - self.num_encoder_layers = num_layers // 2 - self.num_decoder_layers = num_layers - self.num_encoder_layers - self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) - self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) - self.blocks = nn.ModuleList( - [ - Block( - model_dim, - num_heads, - num_kv_heads, - mlp_mult, - rope_base, - qk_gain_init, - mlp_hidden=mlp_hidden, - ) - for i in range(num_layers) - ] - ) - self.final_norm = RMSNorm() - self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) - if self.lm_head is not None: - self.lm_head._zero_init = True - self._init_weights() - - def _init_weights(self) -> None: - if self.tie_embeddings: - nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) - for module in self.modules(): - if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): - nn.init.zeros_(module.weight) - - def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: - x = self.tok_emb(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - x0 = x - skips: list[Tensor] = [] - - # First half stores skips; second half reuses them in reverse order. - for i in range(self.num_encoder_layers): - x = self.blocks[i](x, x0) - skips.append(x) - for i in range(self.num_decoder_layers): - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - x = self.blocks[self.num_encoder_layers + i](x, x0) - - x = self.final_norm(x).reshape(-1, x.size(-1)) - targets = target_ids.reshape(-1) - if self.tie_embeddings: - logits_proj = F.linear(x, self.tok_emb.weight) - else: - if self.lm_head is None: - raise RuntimeError("lm_head is required when tie_embeddings=False") - logits_proj = self.lm_head(x) - logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - return F.cross_entropy(logits.float(), targets, reduction="mean") - - @torch.no_grad() - def get_logits(self, input_ids: Tensor) -> Tensor: - x = self.tok_emb(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - x0 = x - skips: list[Tensor] = [] - for i in range(self.num_encoder_layers): - x = self.blocks[i](x, x0) - skips.append(x) - for i in range(self.num_decoder_layers): - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - x = self.blocks[self.num_encoder_layers + i](x, x0) - x = self.final_norm(x) - if self.tie_embeddings: - logits_proj = F.linear(x, self.tok_emb.weight) - else: - logits_proj = self.lm_head(x) - return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - - -def eval_val_sliding( - args, base_model: nn.Module, rank: int, world_size: int, device: torch.device, - val_tokens: Tensor, base_bytes_lut: Tensor, has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, eval_seq_len: int, eval_stride: int, -) -> tuple[float, float]: - total_tokens = val_tokens.numel() - 1 - all_starts = list(range(0, total_tokens - eval_seq_len + 1, eval_stride)) - my_starts = all_starts[rank::world_size] - - val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) - val_token_count = torch.zeros((), device=device, dtype=torch.float64) - val_byte_count = torch.zeros((), device=device, dtype=torch.float64) - - base_model.eval() - with torch.inference_mode(): - for start in my_starts: - end = start + eval_seq_len - x = val_tokens[start:end].to(device=device, dtype=torch.int64).unsqueeze(0) - y = val_tokens[start + 1:end + 1].to(device=device, dtype=torch.int64).unsqueeze(0) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16): - logits = base_model.get_logits(x) - score_from = eval_seq_len - eval_stride - if start == 0: - score_from = 0 - suffix_logits = logits[0, score_from:].float() - suffix_targets = y[0, score_from:] - per_pos_loss = F.cross_entropy(suffix_logits, suffix_targets, reduction="none") - val_loss_sum += per_pos_loss.to(torch.float64).sum() - val_token_count += per_pos_loss.numel() - prev_ids = x[0, score_from:] - tgt_ids = y[0, score_from:] - token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) - token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) - val_byte_count += token_bytes.to(torch.float64).sum() - - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) - - val_loss = val_loss_sum / val_token_count - bits_per_token = val_loss.item() / math.log(2.0) - tokens_per_byte = val_token_count.item() / val_byte_count.item() - base_model.train() - return float(val_loss.item()), float(bits_per_token * tokens_per_byte) - - -# ----------------------------- -# TRAINING -# ----------------------------- - -def main() -> None: - global zeropower_via_newtonschulz5 - - code = Path(__file__).read_text(encoding="utf-8") - args = Hyperparameters() - zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) - - # ----------------------------- - # DISTRIBUTED + CUDA SETUP - # ----------------------------- - - distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ - rank = int(os.environ.get("RANK", "0")) - world_size = int(os.environ.get("WORLD_SIZE", "1")) - local_rank = int(os.environ.get("LOCAL_RANK", "0")) - if world_size <= 0: - raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") - if 8 % world_size != 0: - raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") - grad_accum_steps = 8 // world_size - grad_scale = 1.0 / grad_accum_steps - if not torch.cuda.is_available(): - raise RuntimeError("CUDA is required") - device = torch.device("cuda", local_rank) - torch.cuda.set_device(device) - if distributed: - dist.init_process_group(backend="nccl", device_id=device) - dist.barrier() - master_process = rank == 0 - - # Fast math knobs - torch.backends.cuda.matmul.allow_tf32 = True - torch.backends.cudnn.allow_tf32 = True - from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp - - enable_cudnn_sdp(False) - enable_flash_sdp(True) - enable_mem_efficient_sdp(False) - enable_math_sdp(False) - - logfile = None - if master_process: - os.makedirs("logs", exist_ok=True) - logfile = f"logs/{args.run_id}.txt" - print(logfile) - - def log0(msg: str, console: bool = True) -> None: - if not master_process: - return - if console: - print(msg) - if logfile is not None: - with open(logfile, "a", encoding="utf-8") as f: - print(msg, file=f) - - log0(code, console=False) - log0("=" * 100, console=False) - log0(f"Running Python {sys.version}", console=False) - log0(f"Running PyTorch {torch.__version__}", console=False) - log0( - subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, - console=False, - ) - log0("=" * 100, console=False) - - # ----------------------------- - # TOKENIZER + VALIDATION METRIC SETUP - # ----------------------------- - - random.seed(args.seed) - np.random.seed(args.seed) - torch.manual_seed(args.seed) - torch.cuda.manual_seed_all(args.seed) - - if not args.tokenizer_path.endswith(".model"): - raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") - sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) - if int(sp.vocab_size()) != args.vocab_size: - raise ValueError( - f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" - ) - dataset_dir = Path(args.data_path).resolve() - actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) - effective_eval_seq_len = args.eval_seq_len if args.eval_seq_len > 0 else args.train_seq_len - val_seq_len = max(args.train_seq_len, effective_eval_seq_len) - val_tokens = load_validation_tokens(args.val_files, val_seq_len) - base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( - sp, args.vocab_size, device - ) - log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") - log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") - log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") - - # ----------------------------- - # MODEL + OPTIMIZER SETUP - # ----------------------------- - - base_model = GPT( - vocab_size=args.vocab_size, - num_layers=args.num_layers, - model_dim=args.model_dim, - num_heads=args.num_heads, - num_kv_heads=args.num_kv_heads, - mlp_mult=args.mlp_mult, - mlp_hidden=args.mlp_hidden, - tie_embeddings=args.tie_embeddings, - tied_embed_init_std=args.tied_embed_init_std, - logit_softcap=args.logit_softcap, - rope_base=args.rope_base, - qk_gain_init=args.qk_gain_init, - ).to(device).bfloat16() - for module in base_model.modules(): - if isinstance(module, CastedLinear): - module.float() - restore_low_dim_params_to_fp32(base_model) - compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) - model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model - - # Optimizer split: - # - token embedding (Adam) uses EMBED_LR - # - untied lm_head (Adam) uses HEAD_LR - # - matrix params in transformer blocks use MATRIX_LR via Muon - # - vectors/scalars use SCALAR_LR via Adam - block_named_params = list(base_model.blocks.named_parameters()) - matrix_params = [ - p - for name, p in block_named_params - if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - scalar_params = [ - p - for name, p in block_named_params - if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - if base_model.skip_weights.numel() > 0: - scalar_params.append(base_model.skip_weights) - token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr - optimizer_tok = torch.optim.Adam( - [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizer_muon = Muon( - matrix_params, - lr=args.matrix_lr, - momentum=args.muon_momentum, - backend_steps=args.muon_backend_steps, - ) - for group in optimizer_muon.param_groups: - group["base_lr"] = args.matrix_lr - optimizer_scalar = torch.optim.Adam( - [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] - if base_model.lm_head is not None: - optimizer_head = torch.optim.Adam( - [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizers.insert(1, optimizer_head) - - n_params = sum(p.numel() for p in base_model.parameters()) - log0(f"model_params:{n_params}") - log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") - log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") - log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") - log0( - f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " - f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " - f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" - ) - log0( - f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " - f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " - f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" - ) - log0(f"seed:{args.seed}") - - # ----------------------------- - # DATA LOADER & MODEL WARMUP - # ----------------------------- - - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - def zero_grad_all() -> None: - for opt in optimizers: - opt.zero_grad(set_to_none=True) - - max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None - - def lr_mul(step: int, elapsed_ms: float) -> float: - if args.warmdown_iters <= 0: - return 1.0 - if max_wallclock_ms is None: - warmdown_start = max(args.iterations - args.warmdown_iters, 0) - return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 - step_ms = elapsed_ms / max(step, 1) - warmdown_ms = args.warmdown_iters * step_ms - remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) - return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 - - # Warmup primes the compiled forward/backward/optimizer paths, then we restore the - # initial weights/optimizer state so measured training starts from the true init. - if args.warmup_steps > 0: - initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} - initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] - model.train() - for warmup_step in range(args.warmup_steps): - zero_grad_all() - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - warmup_loss = model(x, y) - (warmup_loss * grad_scale).backward() - for opt in optimizers: - opt.step() - zero_grad_all() - if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: - log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") - base_model.load_state_dict(initial_model_state, strict=True) - for opt, state in zip(optimizers, initial_optimizer_states, strict=True): - opt.load_state_dict(state) - zero_grad_all() - if distributed: - model.require_backward_grad_sync = True - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - # ----------------------------- - # MAIN TRAINING LOOP - # ----------------------------- - - training_time_ms = 0.0 - stop_after_step: int | None = None - torch.cuda.synchronize() - t0 = time.perf_counter() - - step = 0 - while True: - last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) - - should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) - if should_validate: - torch.cuda.synchronize() - training_time_ms += 1000.0 * (time.perf_counter() - t0) - val_loss, val_bpb = eval_val( - args, - model, - rank, - world_size, - device, - grad_accum_steps, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, - ) - log0( - f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " - f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" - ) - torch.cuda.synchronize() - t0 = time.perf_counter() - - if last_step: - if stop_after_step is not None and step < args.iterations: - log0( - f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " - f"step:{step}/{args.iterations}" - ) - break - - elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - scale = lr_mul(step, elapsed_ms) - zero_grad_all() - train_loss = torch.zeros((), device=device) - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - loss = model(x, y) - train_loss += loss.detach() - (loss * grad_scale).backward() - train_loss /= grad_accum_steps - - frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 - muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum - for group in optimizer_muon.param_groups: - group["momentum"] = muon_momentum - - for opt in optimizers: - for group in opt.param_groups: - group["lr"] = group["base_lr"] * scale - - if args.grad_clip_norm > 0: - torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) - for opt in optimizers: - opt.step() - zero_grad_all() - - step += 1 - approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - should_log_train = ( - args.train_log_every > 0 - and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) - ) - if should_log_train: - log0( - f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " - f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" - ) - - # Needed to sync whether we've reached the wallclock cap. - reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms - if distributed and max_wallclock_ms is not None: - reached_cap_tensor = torch.tensor(int(reached_cap), device=device) - dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) - reached_cap = bool(reached_cap_tensor.item()) - if stop_after_step is None and reached_cap: - stop_after_step = step - - log0( - f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " - f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" - ) - - # ----------------------------- - # SERIALIZATION + ROUNDTRIP VALIDATION - # ----------------------------- - # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce - # the compressed int8+zlib artifact and validate the round-tripped weights. - - if master_process: - torch.save(base_model.state_dict(), "final_model.pt") - model_bytes = os.path.getsize("final_model.pt") - code_bytes = len(code.encode("utf-8")) - log0(f"Serialized model: {model_bytes} bytes") - log0(f"Code size: {code_bytes} bytes") - log0(f"Total submission size: {model_bytes + code_bytes} bytes") - - quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) - quant_buf = io.BytesIO() - torch.save(quant_obj, quant_buf) - quant_raw = quant_buf.getvalue() - quant_blob = zlib.compress(quant_raw, level=9) - quant_raw_bytes = len(quant_raw) - if master_process: - with open("final_model.int8.ptz", "wb") as f: - f.write(quant_blob) - quant_file_bytes = os.path.getsize("final_model.int8.ptz") - code_bytes = len(code.encode("utf-8")) - ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) - log0( - f"Serialized model int8+zlib: {quant_file_bytes} bytes " - f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" - ) - log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") - - if distributed: - dist.barrier() - with open("final_model.int8.ptz", "rb") as f: - quant_blob_disk = f.read() - quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") - base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) - torch.cuda.synchronize() - t_qeval = time.perf_counter() - q_val_loss, q_val_bpb = eval_val( - args, - model, - rank, - world_size, - device, - grad_accum_steps, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, - eval_seq_len=effective_eval_seq_len, - ) - torch.cuda.synchronize() - log0( - f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " - f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms " - f"eval_seq_len:{effective_eval_seq_len}" - ) - log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") - - if args.eval_stride > 0: - torch.cuda.synchronize() - t_slide = time.perf_counter() - s_val_loss, s_val_bpb = eval_val_sliding( - args, base_model, rank, world_size, device, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - eval_seq_len=effective_eval_seq_len, eval_stride=args.eval_stride, - ) - torch.cuda.synchronize() - log0( - f"final_sliding_window val_loss:{s_val_loss:.4f} val_bpb:{s_val_bpb:.4f} " - f"eval_time:{1000.0 * (time.perf_counter() - t_slide):.0f}ms " - f"stride:{args.eval_stride} seq_len:{effective_eval_seq_len}" - ) - log0(f"final_sliding_window_exact val_loss:{s_val_loss:.8f} val_bpb:{s_val_bpb:.8f}") - - if distributed: - dist.destroy_process_group() - - -if __name__ == "__main__": - main() - -==================================================================================================== -Running Python 3.11.10 (main, Sep 7 2024, 18:35:41) [GCC 11.4.0] -Running PyTorch 2.6.0+cu124 -Thu Mar 19 18:07:21 2026 -+-----------------------------------------------------------------------------------------+ -| NVIDIA-SMI 570.195.03 Driver Version: 570.195.03 CUDA Version: 12.8 | -|-----------------------------------------+------------------------+----------------------+ -| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | -| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | -| | | MIG M. | -|=========================================+========================+======================| -| 0 NVIDIA H100 80GB HBM3 On | 00000000:05:00.0 Off | 0 | -| N/A 37C P0 118W / 700W | 7088MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 1 NVIDIA H100 80GB HBM3 On | 00000000:06:00.0 Off | 0 | -| N/A 32C P0 115W / 700W | 3455MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 2 NVIDIA H100 80GB HBM3 On | 00000000:07:00.0 Off | 0 | -| N/A 31C P0 117W / 700W | 3455MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 3 NVIDIA H100 80GB HBM3 On | 00000000:08:00.0 Off | 0 | -| N/A 37C P0 120W / 700W | 3455MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 4 NVIDIA H100 80GB HBM3 On | 00000000:09:00.0 Off | 0 | -| N/A 36C P0 118W / 700W | 3455MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 5 NVIDIA H100 80GB HBM3 On | 00000000:0A:00.0 Off | 0 | -| N/A 32C P0 117W / 700W | 3455MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 6 NVIDIA H100 80GB HBM3 On | 00000000:0B:00.0 Off | 0 | -| N/A 31C P0 117W / 700W | 3455MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 7 NVIDIA H100 80GB HBM3 On | 00000000:0C:00.0 Off | 0 | -| N/A 37C P0 118W / 700W | 3215MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ - -+-----------------------------------------------------------------------------------------+ -| Processes: | -| GPU GI CI PID Type Process name GPU Memory | -| ID ID Usage | -|=========================================================================================| -| 0 N/A N/A 330090 C /usr/bin/python 3398MiB | -| 0 N/A N/A 330091 C /usr/bin/python 520MiB | -| 0 N/A N/A 330092 C /usr/bin/python 520MiB | -| 0 N/A N/A 330093 C /usr/bin/python 520MiB | -| 0 N/A N/A 330094 C /usr/bin/python 520MiB | -| 0 N/A N/A 330095 C /usr/bin/python 520MiB | -| 0 N/A N/A 330096 C /usr/bin/python 520MiB | -| 0 N/A N/A 330097 C /usr/bin/python 520MiB | -| 1 N/A N/A 330091 C /usr/bin/python 3446MiB | -| 2 N/A N/A 330092 C /usr/bin/python 3446MiB | -| 3 N/A N/A 330093 C /usr/bin/python 3446MiB | -| 4 N/A N/A 330094 C /usr/bin/python 3446MiB | -| 5 N/A N/A 330095 C /usr/bin/python 3446MiB | -| 6 N/A N/A 330096 C /usr/bin/python 3446MiB | -| 7 N/A N/A 330097 C /usr/bin/python 3206MiB | -+-----------------------------------------------------------------------------------------+ - -==================================================================================================== -val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model -train_loader:dataset:fineweb10B_sp1024 train_shards:80 -val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 -model_params:21778504 -world_size:8 grad_accum_steps:1 -sdp_backends:cudnn=False flash=True mem_efficient=False math=False -attention_mode:gqa num_heads:8 num_kv_heads:4 -tie_embeddings:True embed_lr:0.03 head_lr:0.0 matrix_lr:0.02 scalar_lr:0.02 -train_batch_tokens:786432 train_seq_len:2048 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 -seed:1337 -warmup_step:1/20 -warmup_step:2/20 -warmup_step:3/20 -warmup_step:4/20 -warmup_step:5/20 -warmup_step:6/20 -warmup_step:7/20 -warmup_step:8/20 -warmup_step:9/20 -warmup_step:10/20 -warmup_step:11/20 -warmup_step:12/20 -warmup_step:13/20 -warmup_step:14/20 -warmup_step:15/20 -warmup_step:16/20 -warmup_step:17/20 -warmup_step:18/20 -warmup_step:19/20 -warmup_step:20/20 -step:0/20000 val_loss:6.9364 val_bpb:4.1081 train_time:0ms step_avg:0.04ms -step:1/20000 train_loss:6.9368 train_time:240ms step_avg:240.35ms -step:2/20000 train_loss:11.8525 train_time:294ms step_avg:147.02ms -step:3/20000 train_loss:11.3195 train_time:386ms step_avg:128.79ms -step:4/20000 train_loss:8.9748 train_time:479ms step_avg:119.85ms -step:5/20000 train_loss:7.1182 train_time:570ms step_avg:114.09ms -step:6/20000 train_loss:6.2765 train_time:661ms step_avg:110.18ms -step:7/20000 train_loss:6.1699 train_time:754ms step_avg:107.78ms -step:8/20000 train_loss:6.1337 train_time:847ms step_avg:105.83ms -step:9/20000 train_loss:5.9695 train_time:939ms step_avg:104.29ms -step:10/20000 train_loss:5.7977 train_time:1032ms step_avg:103.21ms -step:200/20000 train_loss:2.3819 train_time:16744ms step_avg:83.72ms -step:400/20000 train_loss:2.4145 train_time:33413ms step_avg:83.53ms -step:600/20000 train_loss:2.3334 train_time:50051ms step_avg:83.42ms -step:800/20000 train_loss:2.2392 train_time:66794ms step_avg:83.49ms -step:1000/20000 train_loss:2.2752 train_time:83413ms step_avg:83.41ms -step:1200/20000 train_loss:2.3535 train_time:100153ms step_avg:83.46ms -step:1400/20000 train_loss:2.1814 train_time:116858ms step_avg:83.47ms -step:1600/20000 train_loss:2.0815 train_time:133428ms step_avg:83.39ms -step:1800/20000 train_loss:2.1712 train_time:150121ms step_avg:83.40ms -step:2000/20000 train_loss:2.0736 train_time:166715ms step_avg:83.36ms -step:2200/20000 train_loss:2.1553 train_time:183457ms step_avg:83.39ms -step:2400/20000 train_loss:2.0710 train_time:200084ms step_avg:83.37ms -step:2600/20000 train_loss:2.1033 train_time:216842ms step_avg:83.40ms -step:2800/20000 train_loss:2.1519 train_time:233604ms step_avg:83.43ms -step:3000/20000 train_loss:2.1580 train_time:250291ms step_avg:83.43ms -step:3200/20000 train_loss:2.1602 train_time:267045ms step_avg:83.45ms -step:3400/20000 train_loss:2.0096 train_time:283643ms step_avg:83.42ms -step:3600/20000 train_loss:2.0850 train_time:300363ms step_avg:83.43ms -step:3800/20000 train_loss:2.0634 train_time:316945ms step_avg:83.41ms -step:4000/20000 train_loss:1.9683 train_time:333637ms step_avg:83.41ms -step:4200/20000 train_loss:2.1467 train_time:350341ms step_avg:83.41ms -step:4400/20000 train_loss:2.0295 train_time:366916ms step_avg:83.39ms -step:4600/20000 train_loss:1.8411 train_time:383617ms step_avg:83.40ms -step:4800/20000 train_loss:2.4330 train_time:400226ms step_avg:83.38ms -step:5000/20000 train_loss:2.1098 train_time:416933ms step_avg:83.39ms -step:5000/20000 val_loss:2.0282 val_bpb:1.2012 train_time:416986ms step_avg:83.40ms -step:5200/20000 train_loss:2.0506 train_time:433472ms step_avg:83.36ms -step:5400/20000 train_loss:2.0596 train_time:450114ms step_avg:83.35ms -step:5600/20000 train_loss:1.9724 train_time:466779ms step_avg:83.35ms -step:5800/20000 train_loss:2.0161 train_time:483379ms step_avg:83.34ms -step:6000/20000 train_loss:1.9675 train_time:500077ms step_avg:83.35ms -step:6200/20000 train_loss:1.9746 train_time:516678ms step_avg:83.34ms -step:6400/20000 train_loss:2.0335 train_time:533433ms step_avg:83.35ms -step:6600/20000 train_loss:1.8843 train_time:550074ms step_avg:83.34ms -step:6800/20000 train_loss:2.0721 train_time:566784ms step_avg:83.35ms -step:7000/20000 train_loss:1.8462 train_time:583490ms step_avg:83.36ms -step:7199/20000 val_loss:1.9801 val_bpb:1.1727 train_time:600031ms step_avg:83.35ms -stopping_early: wallclock_cap train_time:600031ms step:7199/20000 -peak memory allocated: 16840 MiB reserved: 17936 MiB -Serialized model: 86098946 bytes -Code size: 53569 bytes -Total submission size: 86152515 bytes -Serialized model int8+zlib: 15924148 bytes (payload:22690080 raw_torch:22733906 payload_ratio:3.79x) -Total submission size int8+zlib: 15977717 bytes -final_int8_zlib_roundtrip val_loss:1.9905 val_bpb:1.1789 eval_time:2067ms eval_seq_len:2048 -final_int8_zlib_roundtrip_exact val_loss:1.99052666 val_bpb:1.17890201 -final_sliding_window val_loss:1.9543 val_bpb:1.1574 eval_time:240533ms stride:256 seq_len:2048 -final_sliding_window_exact val_loss:1.95428963 val_bpb:1.15744040 diff --git a/records/track_10min_16mb/2026-03-19_WarmdownQuantization/train_gpt.py b/records/track_10min_16mb/2026-03-19_WarmdownQuantization/train_gpt.py deleted file mode 100644 index 53ae1843ee..0000000000 --- a/records/track_10min_16mb/2026-03-19_WarmdownQuantization/train_gpt.py +++ /dev/null @@ -1,1246 +0,0 @@ -""" -The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. - -Hard stop: `train_gpt.py` and `train_gpt_mlx.py` must never be longer than 1500 lines. -""" - -from __future__ import annotations - -import copy -import glob -import io -import math -import os -import random -import subprocess -import sys -import time -import uuid -import zlib -from pathlib import Path - -import numpy as np -import sentencepiece as spm -import torch -import torch.distributed as dist -import torch.nn.functional as F -from torch import Tensor, nn -from torch.nn.parallel import DistributedDataParallel as DDP - -# ----------------------------- -# HYPERPARAMETERS -# ----------------------------- -# Default Simple Baseline run: -# - 9 transformer blocks at width 512 -# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion -# - vocab size 1024, sequence length 1024, tied embeddings -# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap - -class Hyperparameters: - # Data paths are shard globs produced by the existing preprocessing pipeline. - data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") - train_files = os.path.join(data_path, "fineweb_train_*.bin") - val_files = os.path.join(data_path, "fineweb_val_*.bin") - tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") - run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) - seed = int(os.environ.get("SEED", 1337)) - - # Validation cadence and batch size. Validation always uses the full fineweb_val split. - val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) - val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) - train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) - - # Training length. - iterations = int(os.environ.get("ITERATIONS", 20000)) - warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) - warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) - train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) - train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) - eval_seq_len = int(os.environ.get("EVAL_SEQ_LEN", 2048)) - eval_stride = int(os.environ.get("EVAL_STRIDE", 0)) - max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) - qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) - - # Model shape. - vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) - num_layers = int(os.environ.get("NUM_LAYERS", 9)) - num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) - model_dim = int(os.environ.get("MODEL_DIM", 512)) - num_heads = int(os.environ.get("NUM_HEADS", 8)) - mlp_mult = int(os.environ.get("MLP_MULT", 2)) - mlp_hidden = int(os.environ.get("MLP_HIDDEN", 0)) - tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) - rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) - logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) - - # Optimizer hyperparameters. - embed_lr = float(os.environ.get("EMBED_LR", 0.6)) - head_lr = float(os.environ.get("HEAD_LR", 0.008)) - tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) - tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) - matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) - scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) - muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) - muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) - muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) - muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) - beta1 = float(os.environ.get("BETA1", 0.9)) - beta2 = float(os.environ.get("BETA2", 0.95)) - adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) - grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) - -# ----------------------------- -# MUON OPTIMIZER -# ----------------------------- -# -# As borrowed from modded-nanogpt -# Background on Muon: https://kellerjordan.github.io/posts/muon/ - -def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: - # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. - # Muon uses this to normalize matrix-shaped gradients before applying them. - a, b, c = (3.4445, -4.7750, 2.0315) - X = G.bfloat16() - X /= X.norm() + eps - transposed = G.size(0) > G.size(1) - if transposed: - X = X.T - for _ in range(steps): - A = X @ X.T - B = b * A + c * A @ A - X = a * X + B @ X - return X.T if transposed else X - - -class Muon(torch.optim.Optimizer): - def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): - super().__init__( - params, - dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), - ) - - @torch.no_grad() - def step(self, closure=None): - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - distributed = dist.is_available() and dist.is_initialized() - world_size = dist.get_world_size() if distributed else 1 - rank = dist.get_rank() if distributed else 0 - - for group in self.param_groups: - params = group["params"] - if not params: - continue - lr = group["lr"] - momentum = group["momentum"] - backend_steps = group["backend_steps"] - nesterov = group["nesterov"] - - total_params = sum(int(p.numel()) for p in params) - updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) - - curr = 0 - for i, p in enumerate(params): - if i % world_size == rank and p.grad is not None: - g = p.grad - state = self.state[p] - if "momentum_buffer" not in state: - state["momentum_buffer"] = torch.zeros_like(g) - buf = state["momentum_buffer"] - buf.mul_(momentum).add_(g) - if nesterov: - g = g.add(buf, alpha=momentum) - g = zeropower_via_newtonschulz5(g, steps=backend_steps) - # Scale correction from Muon reference implementations. - g *= max(1, g.size(0) / g.size(1)) ** 0.5 - updates_flat[curr : curr + p.numel()] = g.reshape(-1) - curr += p.numel() - - if distributed: - dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) - - curr = 0 - for p in params: - g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) - p.add_(g, alpha=-lr) - curr += p.numel() - - return loss - - -# ----------------------------- -# TOKENIZER-AGNOSTIC EVALUATION SETUP -# ----------------------------- -# -# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. -# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. -# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. -# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. - -def build_sentencepiece_luts( - sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device -) -> tuple[Tensor, Tensor, Tensor]: - sp_vocab_size = int(sp.vocab_size()) - table_size = max(sp_vocab_size, vocab_size) - base_bytes_np = np.zeros((table_size,), dtype=np.int16) - has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) - is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) - for token_id in range(sp_vocab_size): - if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): - continue - is_boundary_token_np[token_id] = False - if sp.is_byte(token_id): - base_bytes_np[token_id] = 1 - continue - piece = sp.id_to_piece(token_id) - if piece.startswith("▁"): - has_leading_space_np[token_id] = True - piece = piece[1:] - base_bytes_np[token_id] = len(piece.encode("utf-8")) - return ( - torch.tensor(base_bytes_np, dtype=torch.int16, device=device), - torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), - torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), - ) - - -def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: - files = [Path(p) for p in sorted(glob.glob(pattern))] - if not files: - raise FileNotFoundError(f"No files found for pattern: {pattern}") - # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. - tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() - usable = ((tokens.numel() - 1) // seq_len) * seq_len - if usable <= 0: - raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") - return tokens[: usable + 1] - - -def eval_val( - args: Hyperparameters, - model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - grad_accum_steps: int, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, - eval_seq_len: int | None = None, -) -> tuple[float, float]: - # Validation computes two metrics: - # - val_loss: token cross-entropy (natural log) - # - val_bpb: tokenizer-agnostic compression metric used by the challenge - seq_len = eval_seq_len or args.train_seq_len - local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) - if local_batch_tokens < seq_len: - raise ValueError( - "VAL_BATCH_SIZE must provide at least one sequence per rank; " - f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " - f"GRAD_ACCUM_STEPS={grad_accum_steps}, seq_len={seq_len}" - ) - local_batch_seqs = local_batch_tokens // seq_len - total_seqs = (val_tokens.numel() - 1) // seq_len - seq_start = (total_seqs * rank) // world_size - seq_end = (total_seqs * (rank + 1)) // world_size - val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) - val_token_count = torch.zeros((), device=device, dtype=torch.float64) - val_byte_count = torch.zeros((), device=device, dtype=torch.float64) - - model.eval() - with torch.inference_mode(): - for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): - batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) - raw_start = batch_seq_start * seq_len - raw_end = batch_seq_end * seq_len + 1 - local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) - x = local[:-1].reshape(-1, seq_len) - y = local[1:].reshape(-1, seq_len) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - batch_loss = model(x, y).detach() - batch_token_count = float(y.numel()) - val_loss_sum += batch_loss.to(torch.float64) * batch_token_count - val_token_count += batch_token_count - prev_ids = x.reshape(-1) - tgt_ids = y.reshape(-1) - token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) - token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) - val_byte_count += token_bytes.to(torch.float64).sum() - - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) - - val_loss = val_loss_sum / val_token_count - bits_per_token = val_loss.item() / math.log(2.0) - tokens_per_byte = val_token_count.item() / val_byte_count.item() - model.train() - return float(val_loss.item()), float(bits_per_token * tokens_per_byte) - -# ----------------------------- -# POST-TRAINING QUANTIZATION -# ----------------------------- -# -# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. -# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. -# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. - -CONTROL_TENSOR_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "CONTROL_TENSOR_NAME_PATTERNS", - "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", - ",".join(CONTROL_TENSOR_NAME_PATTERNS), - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 -INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 -INT8_PER_ROW_SCALE_DTYPE = torch.float16 -INT8_CLIP_PERCENTILE = 99.99984 -INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 - -def tensor_nbytes(t: Tensor) -> int: - return int(t.numel()) * int(t.element_size()) - -def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: - if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): - return t.float().contiguous() - if t.dtype in {torch.float32, torch.bfloat16}: - passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") - return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() - return t - -def quantize_float_tensor(t: Tensor, bits: int = 8) -> tuple[Tensor, Tensor]: - max_val = 127 if bits == 8 else (2 ** (bits - 1)) - 1 # int6: 31, int8: 127 - t32 = t.float() - if t32.ndim == 2: - clip_abs = ( - torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) - if t32.numel() - else torch.empty((t32.shape[0],), dtype=torch.float32) - ) - clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) - scale = (clip_abs / float(max_val)).clamp_min(1.0 / float(max_val)) - q = torch.clamp(torch.round(clipped / scale[:, None]), -max_val, max_val).to(torch.int8).contiguous() - return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() - - clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 - scale = torch.tensor(clip_abs / float(max_val) if clip_abs > 0 else 1.0, dtype=torch.float32) - q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -max_val, max_val).to(torch.int8).contiguous() - return q, scale - -def quantize_state_dict_int8(state_dict: dict[str, Tensor]): - # Single supported clean-script export format: - # - per-row int8 for 2D float tensors - # - per-tensor int8 for other float tensors - # - exact passthrough for non-floats - # - passthrough for small float tensors, stored as fp16 to save bytes - quantized: dict[str, Tensor] = {} - scales: dict[str, Tensor] = {} - dtypes: dict[str, str] = {} - passthrough: dict[str, Tensor] = {} - passthrough_orig_dtypes: dict[str, str] = {} - qmeta: dict[str, dict[str, object]] = {} - stats = dict.fromkeys( - ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), - 0, - ) - - for name, tensor in state_dict.items(): - t = tensor.detach().to("cpu").contiguous() - stats["param_count"] += int(t.numel()) - stats["num_tensors"] += 1 - stats["baseline_tensor_bytes"] += tensor_nbytes(t) - - if not t.is_floating_point(): - stats["num_nonfloat_tensors"] += 1 - passthrough[name] = t - stats["int8_payload_bytes"] += tensor_nbytes(t) - continue - - # FP16 passthrough for tied embedding (our trick) - if name == "tok_emb.weight": - kept = t.to(dtype=torch.float16).contiguous() - passthrough[name] = kept - passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") - stats["int8_payload_bytes"] += tensor_nbytes(kept) - continue - - # Late-K passthrough: keep last 2 layers' key weights in fp16 (PR #99's trick) - num_layers_total = max((int(k.split(".")[1]) for k in state_dict if k.startswith("blocks.")), default=0) + 1 - if name.endswith("c_k.weight") and any(f"blocks.{i}." in name for i in range(num_layers_total - 2, num_layers_total)): - kept = t.to(dtype=torch.float16).contiguous() - passthrough[name] = kept - passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") - stats["int8_payload_bytes"] += tensor_nbytes(kept) - continue - - # Small float tensors are cheap enough to keep directly. - if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: - kept = keep_float_tensor(name, t, passthrough_orig_dtypes) - passthrough[name] = kept - stats["int8_payload_bytes"] += tensor_nbytes(kept) - continue - - # Everything else: int6 quantization (saves ~25% vs int8) - stats["num_float_tensors"] += 1 - q, s = quantize_float_tensor(t, bits=6) - if s.ndim > 0: - qmeta[name] = {"scheme": "per_row", "axis": 0} - quantized[name] = q - scales[name] = s - dtypes[name] = str(t.dtype).removeprefix("torch.") - stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) - - obj: dict[str, object] = { - "__quant_format__": "int8_clean_per_row_v1", - "quantized": quantized, - "scales": scales, - "dtypes": dtypes, - "passthrough": passthrough, - } - if qmeta: - obj["qmeta"] = qmeta - if passthrough_orig_dtypes: - obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes - return obj, stats - -def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: - out: dict[str, Tensor] = {} - qmeta = obj.get("qmeta", {}) - passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) - for name, q in obj["quantized"].items(): - dtype = getattr(torch, obj["dtypes"][name]) - s = obj["scales"][name] - if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: - s = s.to(dtype=torch.float32) - # Broadcast the saved row scale back across trailing dimensions. - out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() - else: - scale = float(s.item()) - out[name] = (q.float() * scale).to(dtype=dtype).contiguous() - for name, t in obj["passthrough"].items(): - # Restore small tensors, undoing the temporary fp16 storage cast if needed. - out_t = t.detach().to("cpu").contiguous() - orig_dtype = passthrough_orig_dtypes.get(name) - if isinstance(orig_dtype, str): - out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() - out[name] = out_t - return out - - -# ----------------------------- -# DATA LOADING -# ----------------------------- - -def load_data_shard(file: Path) -> Tensor: - header_bytes = 256 * np.dtype(" None: - self.file_idx = (self.file_idx + 1) % len(self.files) - self.tokens = load_data_shard(self.files[self.file_idx]) - self.pos = 0 - - def take(self, n: int) -> Tensor: - chunks: list[Tensor] = [] - remaining = n - while remaining > 0: - avail = self.tokens.numel() - self.pos - if avail <= 0: - self._advance_file() - continue - k = min(remaining, avail) - chunks.append(self.tokens[self.pos : self.pos + k]) - self.pos += k - remaining -= k - return chunks[0] if len(chunks) == 1 else torch.cat(chunks) - - -class DistributedTokenLoader: - # Each call consumes a contiguous chunk from the shared token stream, then slices out - # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. - def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): - self.rank = rank - self.world_size = world_size - self.device = device - self.stream = TokenStream(pattern) - - def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: - local_tokens = global_tokens // (self.world_size * grad_accum_steps) - per_rank_span = local_tokens + 1 - chunk = self.stream.take(per_rank_span * self.world_size) - start = self.rank * per_rank_span - local = chunk[start : start + per_rank_span].to(dtype=torch.int64) - x = local[:-1].reshape(-1, seq_len) - y = local[1:].reshape(-1, seq_len) - return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) - -# ----------------------------- -# TRANSFORMER MODULES -# ----------------------------- - -class RMSNorm(nn.Module): - def __init__(self, eps: float | None = None): - super().__init__() - self.eps = eps - - def forward(self, x: Tensor) -> Tensor: - return F.rms_norm(x, (x.size(-1),), eps=self.eps) - - -class CastedLinear(nn.Linear): - # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. - def forward(self, x: Tensor) -> Tensor: - bias = self.bias.to(x.dtype) if self.bias is not None else None - return F.linear(x, self.weight.to(x.dtype), bias) - - -def restore_low_dim_params_to_fp32(module: nn.Module) -> None: - # Keep small/control parameters in fp32 even when the model body runs in bf16. - with torch.no_grad(): - for name, param in module.named_parameters(): - if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: - param.data = param.data.float() - - -class Rotary(nn.Module): - # Caches cos/sin tables per sequence length on the current device. - # NTK-aware RoPE scaling for sequence length extrapolation at eval time. - def __init__(self, dim: int, base: float = 10000.0, train_seq_len: int = 1024): - super().__init__() - self.dim = dim - self.base = base - self.train_seq_len = train_seq_len - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self._seq_len_cached = 0 - self._cos_cached: Tensor | None = None - self._sin_cached: Tensor | None = None - - def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: - if ( - self._cos_cached is None - or self._sin_cached is None - or self._seq_len_cached != seq_len - or self._cos_cached.device != device - ): - if seq_len > self.train_seq_len: - scale = seq_len / self.train_seq_len - new_base = self.base * (scale ** (self.dim / (self.dim - 2))) - inv_freq = 1.0 / (new_base ** (torch.arange(0, self.dim, 2, dtype=torch.float32, device=device) / self.dim)) - else: - inv_freq = self.inv_freq.to(device) - t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) - freqs = torch.outer(t, inv_freq) - self._cos_cached = freqs.cos()[None, None, :, :] - self._sin_cached = freqs.sin()[None, None, :, :] - self._seq_len_cached = seq_len - return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) - - -def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: - half = x.size(-1) // 2 - x1, x2 = x[..., :half], x[..., half:] - return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) - - -class CausalSelfAttention(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - rope_base: float, - qk_gain_init: float, - ): - super().__init__() - if dim % num_heads != 0: - raise ValueError("model_dim must be divisible by num_heads") - if num_heads % num_kv_heads != 0: - raise ValueError("num_heads must be divisible by num_kv_heads") - self.num_heads = num_heads - self.num_kv_heads = num_kv_heads - self.head_dim = dim // num_heads - if self.head_dim % 2 != 0: - raise ValueError("head_dim must be even for RoPE") - kv_dim = self.num_kv_heads * self.head_dim - self.c_q = CastedLinear(dim, dim, bias=False) - self.c_k = CastedLinear(dim, kv_dim, bias=False) - self.c_v = CastedLinear(dim, kv_dim, bias=False) - self.proj = CastedLinear(dim, dim, bias=False) - self.proj._zero_init = True - self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) - self.rotary = Rotary(self.head_dim, base=rope_base, train_seq_len=1024) - - def forward(self, x: Tensor) -> Tensor: - bsz, seqlen, dim = x.shape - q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) - k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) - v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) - q = F.rms_norm(q, (q.size(-1),)) - k = F.rms_norm(k, (k.size(-1),)) - cos, sin = self.rotary(seqlen, x.device, q.dtype) - q = apply_rotary_emb(q, cos, sin) - k = apply_rotary_emb(k, cos, sin) - q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] - y = F.scaled_dot_product_attention( - q, - k, - v, - attn_mask=None, - is_causal=True, - enable_gqa=(self.num_kv_heads != self.num_heads), - ) - y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) - return self.proj(y) - - -class MLP(nn.Module): - def __init__(self, dim: int, mlp_mult: int, mlp_hidden: int = 0): - super().__init__() - hidden = mlp_hidden if mlp_hidden > 0 else mlp_mult * dim - self.fc = CastedLinear(dim, hidden, bias=False) - self.proj = CastedLinear(hidden, dim, bias=False) - self.proj._zero_init = True - - def forward(self, x: Tensor) -> Tensor: - x = torch.relu(self.fc(x)) - return self.proj(x.square()) - - -class Block(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: int, - rope_base: float, - qk_gain_init: float, - mlp_hidden: int = 0, - ): - super().__init__() - self.attn_norm = RMSNorm() - self.mlp_norm = RMSNorm() - self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) - self.mlp = MLP(dim, mlp_mult, mlp_hidden) - self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) - - def forward(self, x: Tensor, x0: Tensor) -> Tensor: - mix = self.resid_mix.to(dtype=x.dtype) - x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 - attn_out = self.attn(self.attn_norm(x)) - x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out - x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) - return x - - -class GPT(nn.Module): - def __init__( - self, - vocab_size: int, - num_layers: int, - model_dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: int, - mlp_hidden: int, - tie_embeddings: bool, - tied_embed_init_std: float, - logit_softcap: float, - rope_base: float, - qk_gain_init: float, - ): - super().__init__() - if logit_softcap <= 0.0: - raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") - self.tie_embeddings = tie_embeddings - self.tied_embed_init_std = tied_embed_init_std - self.logit_softcap = logit_softcap - self.tok_emb = nn.Embedding(vocab_size, model_dim) - self.num_encoder_layers = num_layers // 2 - self.num_decoder_layers = num_layers - self.num_encoder_layers - self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) - self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) - self.blocks = nn.ModuleList( - [ - Block( - model_dim, - num_heads, - num_kv_heads, - mlp_mult, - rope_base, - qk_gain_init, - mlp_hidden=mlp_hidden, - ) - for i in range(num_layers) - ] - ) - self.final_norm = RMSNorm() - self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) - if self.lm_head is not None: - self.lm_head._zero_init = True - self._init_weights() - - def _init_weights(self) -> None: - if self.tie_embeddings: - nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) - for module in self.modules(): - if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): - nn.init.zeros_(module.weight) - - def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: - x = self.tok_emb(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - x0 = x - skips: list[Tensor] = [] - - # First half stores skips; second half reuses them in reverse order. - for i in range(self.num_encoder_layers): - x = self.blocks[i](x, x0) - skips.append(x) - for i in range(self.num_decoder_layers): - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - x = self.blocks[self.num_encoder_layers + i](x, x0) - - x = self.final_norm(x).reshape(-1, x.size(-1)) - targets = target_ids.reshape(-1) - if self.tie_embeddings: - logits_proj = F.linear(x, self.tok_emb.weight) - else: - if self.lm_head is None: - raise RuntimeError("lm_head is required when tie_embeddings=False") - logits_proj = self.lm_head(x) - logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - return F.cross_entropy(logits.float(), targets, reduction="mean") - - @torch.no_grad() - def get_logits(self, input_ids: Tensor) -> Tensor: - x = self.tok_emb(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - x0 = x - skips: list[Tensor] = [] - for i in range(self.num_encoder_layers): - x = self.blocks[i](x, x0) - skips.append(x) - for i in range(self.num_decoder_layers): - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - x = self.blocks[self.num_encoder_layers + i](x, x0) - x = self.final_norm(x) - if self.tie_embeddings: - logits_proj = F.linear(x, self.tok_emb.weight) - else: - logits_proj = self.lm_head(x) - return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - - -def eval_val_sliding( - args, base_model: nn.Module, rank: int, world_size: int, device: torch.device, - val_tokens: Tensor, base_bytes_lut: Tensor, has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, eval_seq_len: int, eval_stride: int, -) -> tuple[float, float]: - total_tokens = val_tokens.numel() - 1 - all_starts = list(range(0, total_tokens - eval_seq_len + 1, eval_stride)) - my_starts = all_starts[rank::world_size] - - val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) - val_token_count = torch.zeros((), device=device, dtype=torch.float64) - val_byte_count = torch.zeros((), device=device, dtype=torch.float64) - - base_model.eval() - with torch.inference_mode(): - for start in my_starts: - end = start + eval_seq_len - x = val_tokens[start:end].to(device=device, dtype=torch.int64).unsqueeze(0) - y = val_tokens[start + 1:end + 1].to(device=device, dtype=torch.int64).unsqueeze(0) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16): - logits = base_model.get_logits(x) - score_from = eval_seq_len - eval_stride - if start == 0: - score_from = 0 - suffix_logits = logits[0, score_from:].float() - suffix_targets = y[0, score_from:] - per_pos_loss = F.cross_entropy(suffix_logits, suffix_targets, reduction="none") - val_loss_sum += per_pos_loss.to(torch.float64).sum() - val_token_count += per_pos_loss.numel() - prev_ids = x[0, score_from:] - tgt_ids = y[0, score_from:] - token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) - token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) - val_byte_count += token_bytes.to(torch.float64).sum() - - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) - - val_loss = val_loss_sum / val_token_count - bits_per_token = val_loss.item() / math.log(2.0) - tokens_per_byte = val_token_count.item() / val_byte_count.item() - base_model.train() - return float(val_loss.item()), float(bits_per_token * tokens_per_byte) - - -# ----------------------------- -# TRAINING -# ----------------------------- - -def main() -> None: - global zeropower_via_newtonschulz5 - - code = Path(__file__).read_text(encoding="utf-8") - args = Hyperparameters() - zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) - - # ----------------------------- - # DISTRIBUTED + CUDA SETUP - # ----------------------------- - - distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ - rank = int(os.environ.get("RANK", "0")) - world_size = int(os.environ.get("WORLD_SIZE", "1")) - local_rank = int(os.environ.get("LOCAL_RANK", "0")) - if world_size <= 0: - raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") - if 8 % world_size != 0: - raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") - grad_accum_steps = 8 // world_size - grad_scale = 1.0 / grad_accum_steps - if not torch.cuda.is_available(): - raise RuntimeError("CUDA is required") - device = torch.device("cuda", local_rank) - torch.cuda.set_device(device) - if distributed: - dist.init_process_group(backend="nccl", device_id=device) - dist.barrier() - master_process = rank == 0 - - # Fast math knobs - torch.backends.cuda.matmul.allow_tf32 = True - torch.backends.cudnn.allow_tf32 = True - from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp - - enable_cudnn_sdp(False) - enable_flash_sdp(True) - enable_mem_efficient_sdp(False) - enable_math_sdp(False) - - logfile = None - if master_process: - os.makedirs("logs", exist_ok=True) - logfile = f"logs/{args.run_id}.txt" - print(logfile) - - def log0(msg: str, console: bool = True) -> None: - if not master_process: - return - if console: - print(msg) - if logfile is not None: - with open(logfile, "a", encoding="utf-8") as f: - print(msg, file=f) - - log0(code, console=False) - log0("=" * 100, console=False) - log0(f"Running Python {sys.version}", console=False) - log0(f"Running PyTorch {torch.__version__}", console=False) - log0( - subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, - console=False, - ) - log0("=" * 100, console=False) - - # ----------------------------- - # TOKENIZER + VALIDATION METRIC SETUP - # ----------------------------- - - random.seed(args.seed) - np.random.seed(args.seed) - torch.manual_seed(args.seed) - torch.cuda.manual_seed_all(args.seed) - - if not args.tokenizer_path.endswith(".model"): - raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") - sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) - if int(sp.vocab_size()) != args.vocab_size: - raise ValueError( - f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" - ) - dataset_dir = Path(args.data_path).resolve() - actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) - effective_eval_seq_len = args.eval_seq_len if args.eval_seq_len > 0 else args.train_seq_len - val_seq_len = max(args.train_seq_len, effective_eval_seq_len) - val_tokens = load_validation_tokens(args.val_files, val_seq_len) - base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( - sp, args.vocab_size, device - ) - log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") - log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") - log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") - - # ----------------------------- - # MODEL + OPTIMIZER SETUP - # ----------------------------- - - base_model = GPT( - vocab_size=args.vocab_size, - num_layers=args.num_layers, - model_dim=args.model_dim, - num_heads=args.num_heads, - num_kv_heads=args.num_kv_heads, - mlp_mult=args.mlp_mult, - mlp_hidden=args.mlp_hidden, - tie_embeddings=args.tie_embeddings, - tied_embed_init_std=args.tied_embed_init_std, - logit_softcap=args.logit_softcap, - rope_base=args.rope_base, - qk_gain_init=args.qk_gain_init, - ).to(device).bfloat16() - for module in base_model.modules(): - if isinstance(module, CastedLinear): - module.float() - restore_low_dim_params_to_fp32(base_model) - compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) - model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model - - # Optimizer split: - # - token embedding (Adam) uses EMBED_LR - # - untied lm_head (Adam) uses HEAD_LR - # - matrix params in transformer blocks use MATRIX_LR via Muon - # - vectors/scalars use SCALAR_LR via Adam - block_named_params = list(base_model.blocks.named_parameters()) - matrix_params = [ - p - for name, p in block_named_params - if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - scalar_params = [ - p - for name, p in block_named_params - if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - if base_model.skip_weights.numel() > 0: - scalar_params.append(base_model.skip_weights) - token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr - optimizer_tok = torch.optim.Adam( - [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizer_muon = Muon( - matrix_params, - lr=args.matrix_lr, - momentum=args.muon_momentum, - backend_steps=args.muon_backend_steps, - ) - for group in optimizer_muon.param_groups: - group["base_lr"] = args.matrix_lr - optimizer_scalar = torch.optim.Adam( - [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] - if base_model.lm_head is not None: - optimizer_head = torch.optim.Adam( - [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizers.insert(1, optimizer_head) - - n_params = sum(p.numel() for p in base_model.parameters()) - log0(f"model_params:{n_params}") - log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") - log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") - log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") - log0( - f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " - f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " - f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" - ) - log0( - f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " - f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " - f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" - ) - log0(f"seed:{args.seed}") - - # ----------------------------- - # DATA LOADER & MODEL WARMUP - # ----------------------------- - - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - def zero_grad_all() -> None: - for opt in optimizers: - opt.zero_grad(set_to_none=True) - - max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None - - def lr_mul(step: int, elapsed_ms: float) -> float: - if args.warmdown_iters <= 0: - return 1.0 - if max_wallclock_ms is None: - warmdown_start = max(args.iterations - args.warmdown_iters, 0) - return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 - step_ms = elapsed_ms / max(step, 1) - warmdown_ms = args.warmdown_iters * step_ms - remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) - return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 - - # Warmup primes the compiled forward/backward/optimizer paths, then we restore the - # initial weights/optimizer state so measured training starts from the true init. - if args.warmup_steps > 0: - initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} - initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] - model.train() - for warmup_step in range(args.warmup_steps): - zero_grad_all() - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - warmup_loss = model(x, y) - (warmup_loss * grad_scale).backward() - for opt in optimizers: - opt.step() - zero_grad_all() - if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: - log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") - base_model.load_state_dict(initial_model_state, strict=True) - for opt, state in zip(optimizers, initial_optimizer_states, strict=True): - opt.load_state_dict(state) - zero_grad_all() - if distributed: - model.require_backward_grad_sync = True - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - # ----------------------------- - # MAIN TRAINING LOOP - # ----------------------------- - - training_time_ms = 0.0 - stop_after_step: int | None = None - torch.cuda.synchronize() - t0 = time.perf_counter() - - step = 0 - while True: - last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) - - should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) - if should_validate: - torch.cuda.synchronize() - training_time_ms += 1000.0 * (time.perf_counter() - t0) - val_loss, val_bpb = eval_val( - args, - model, - rank, - world_size, - device, - grad_accum_steps, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, - ) - log0( - f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " - f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" - ) - torch.cuda.synchronize() - t0 = time.perf_counter() - - if last_step: - if stop_after_step is not None and step < args.iterations: - log0( - f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " - f"step:{step}/{args.iterations}" - ) - break - - elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - scale = lr_mul(step, elapsed_ms) - zero_grad_all() - train_loss = torch.zeros((), device=device) - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - loss = model(x, y) - train_loss += loss.detach() - (loss * grad_scale).backward() - train_loss /= grad_accum_steps - - frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 - muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum - for group in optimizer_muon.param_groups: - group["momentum"] = muon_momentum - - for opt in optimizers: - for group in opt.param_groups: - group["lr"] = group["base_lr"] * scale - - if args.grad_clip_norm > 0: - torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) - for opt in optimizers: - opt.step() - zero_grad_all() - - step += 1 - approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - should_log_train = ( - args.train_log_every > 0 - and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) - ) - if should_log_train: - log0( - f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " - f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" - ) - - # Needed to sync whether we've reached the wallclock cap. - reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms - if distributed and max_wallclock_ms is not None: - reached_cap_tensor = torch.tensor(int(reached_cap), device=device) - dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) - reached_cap = bool(reached_cap_tensor.item()) - if stop_after_step is None and reached_cap: - stop_after_step = step - - log0( - f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " - f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" - ) - - # ----------------------------- - # SERIALIZATION + ROUNDTRIP VALIDATION - # ----------------------------- - # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce - # the compressed int8+zlib artifact and validate the round-tripped weights. - - if master_process: - torch.save(base_model.state_dict(), "final_model.pt") - model_bytes = os.path.getsize("final_model.pt") - code_bytes = len(code.encode("utf-8")) - log0(f"Serialized model: {model_bytes} bytes") - log0(f"Code size: {code_bytes} bytes") - log0(f"Total submission size: {model_bytes + code_bytes} bytes") - - quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) - quant_buf = io.BytesIO() - torch.save(quant_obj, quant_buf) - quant_raw = quant_buf.getvalue() - quant_blob = zlib.compress(quant_raw, level=9) - quant_raw_bytes = len(quant_raw) - if master_process: - with open("final_model.int8.ptz", "wb") as f: - f.write(quant_blob) - quant_file_bytes = os.path.getsize("final_model.int8.ptz") - code_bytes = len(code.encode("utf-8")) - ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) - log0( - f"Serialized model int8+zlib: {quant_file_bytes} bytes " - f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" - ) - log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") - - if distributed: - dist.barrier() - with open("final_model.int8.ptz", "rb") as f: - quant_blob_disk = f.read() - quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") - base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) - torch.cuda.synchronize() - t_qeval = time.perf_counter() - q_val_loss, q_val_bpb = eval_val( - args, - model, - rank, - world_size, - device, - grad_accum_steps, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, - eval_seq_len=effective_eval_seq_len, - ) - torch.cuda.synchronize() - log0( - f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " - f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms " - f"eval_seq_len:{effective_eval_seq_len}" - ) - log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") - - if args.eval_stride > 0: - torch.cuda.synchronize() - t_slide = time.perf_counter() - s_val_loss, s_val_bpb = eval_val_sliding( - args, base_model, rank, world_size, device, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - eval_seq_len=effective_eval_seq_len, eval_stride=args.eval_stride, - ) - torch.cuda.synchronize() - log0( - f"final_sliding_window val_loss:{s_val_loss:.4f} val_bpb:{s_val_bpb:.4f} " - f"eval_time:{1000.0 * (time.perf_counter() - t_slide):.0f}ms " - f"stride:{args.eval_stride} seq_len:{effective_eval_seq_len}" - ) - log0(f"final_sliding_window_exact val_loss:{s_val_loss:.8f} val_bpb:{s_val_bpb:.8f}") - - if distributed: - dist.destroy_process_group() - - -if __name__ == "__main__": - main() diff --git a/records/track_10min_16mb/2026-03-19_int6_STE QAT_ MLP_bigram _U_Net/train.log b/records/track_10min_16mb/2026-03-19_int6_STE QAT_ MLP_bigram _U_Net/train.log deleted file mode 100644 index 505225703f..0000000000 --- a/records/track_10min_16mb/2026-03-19_int6_STE QAT_ MLP_bigram _U_Net/train.log +++ /dev/null @@ -1,1540 +0,0 @@ -""" -The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. - -Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. -""" - -from __future__ import annotations - -import copy -import glob -import io -import math -import os -import random -import subprocess -import sys -import time -import uuid -import zlib -from pathlib import Path - -try: - import zstandard as zstd - HAS_ZSTD = True -except ImportError: - HAS_ZSTD = False - -import numpy as np -import sentencepiece as spm -import torch -import torch.distributed as dist -import torch.nn.functional as F -from torch import Tensor, nn -from torch.nn.parallel import DistributedDataParallel as DDP - -# ----------------------------- -# HYPERPARAMETERS -# ----------------------------- -# Default Simple Baseline run: -# - 9 transformer blocks at width 512 -# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion -# - vocab size 1024, sequence length 1024, tied embeddings -# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap - -class Hyperparameters: - # Data paths are shard globs produced by the existing preprocessing pipeline. - data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") - train_files = os.path.join(data_path, "fineweb_train_*.bin") - val_files = os.path.join(data_path, "fineweb_val_*.bin") - tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") - run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) - seed = int(os.environ.get("SEED", 1337)) - - # Validation cadence and batch size. Validation always uses the full fineweb_val split. - val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) - val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 0)) - train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) - # Sliding window eval: stride controls how many tokens to slide between windows. - # Only the last `eval_stride` tokens per window are scored, giving each scored token - # (seq_len - eval_stride) tokens of context. Set to 0 to disable (use standard eval). - eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) - - # Training length. - iterations = int(os.environ.get("ITERATIONS", 20000)) - warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3000)) - warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) - train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) - train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) - max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) - qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) - - # Model shape. - vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) - num_layers = int(os.environ.get("NUM_LAYERS", 9)) - num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) - model_dim = int(os.environ.get("MODEL_DIM", 512)) - num_heads = int(os.environ.get("NUM_HEADS", 8)) - mlp_mult = int(os.environ.get("MLP_MULT", 3)) - tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) - rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) - logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) - - # Optimizer hyperparameters. - embed_lr = float(os.environ.get("EMBED_LR", 0.6)) - head_lr = float(os.environ.get("HEAD_LR", 0.008)) - tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.030)) - tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) - matrix_lr = float(os.environ.get("MATRIX_LR", 0.020)) - scalar_lr = float(os.environ.get("SCALAR_LR", 0.020)) - muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) - muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) - muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) - muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) - beta1 = float(os.environ.get("BETA1", 0.9)) - beta2 = float(os.environ.get("BETA2", 0.95)) - adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) - grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) - - # BigramHash: inject token-pair context via a hash-table embedding. - bigram_hash_buckets = int(os.environ.get("BIGRAM_HASH_BUCKETS", 4096)) - bigram_hash_dim = int(os.environ.get("BIGRAM_HASH_DIM", 128)) - -# ----------------------------- -# MUON OPTIMIZER -# ----------------------------- -# -# As borrowed from modded-nanogpt -# Background on Muon: https://kellerjordan.github.io/posts/muon/ - -def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: - # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. - # Muon uses this to normalize matrix-shaped gradients before applying them. - a, b, c = (3.4445, -4.7750, 2.0315) - X = G.bfloat16() - X /= X.norm() + eps - transposed = G.size(0) > G.size(1) - if transposed: - X = X.T - for _ in range(steps): - A = X @ X.T - B = b * A + c * A @ A - X = a * X + B @ X - return X.T if transposed else X - - -class Muon(torch.optim.Optimizer): - def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): - super().__init__( - params, - dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), - ) - - @torch.no_grad() - def step(self, closure=None): - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - distributed = dist.is_available() and dist.is_initialized() - world_size = dist.get_world_size() if distributed else 1 - rank = dist.get_rank() if distributed else 0 - - for group in self.param_groups: - params = group["params"] - if not params: - continue - lr = group["lr"] - momentum = group["momentum"] - backend_steps = group["backend_steps"] - nesterov = group["nesterov"] - - total_params = sum(int(p.numel()) for p in params) - updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) - - curr = 0 - for i, p in enumerate(params): - if i % world_size == rank and p.grad is not None: - g = p.grad - state = self.state[p] - if "momentum_buffer" not in state: - state["momentum_buffer"] = torch.zeros_like(g) - buf = state["momentum_buffer"] - buf.mul_(momentum).add_(g) - if nesterov: - g = g.add(buf, alpha=momentum) - g = zeropower_via_newtonschulz5(g, steps=backend_steps) - # Scale correction from Muon reference implementations. - g *= max(1, g.size(0) / g.size(1)) ** 0.5 - updates_flat[curr : curr + p.numel()] = g.reshape(-1) - curr += p.numel() - - if distributed: - dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) - - curr = 0 - for p in params: - g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) - p.add_(g, alpha=-lr) - curr += p.numel() - - return loss - - -# ----------------------------- -# TOKENIZER-AGNOSTIC EVALUATION SETUP -# ----------------------------- -# -# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. -# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. -# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. -# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. - -def build_sentencepiece_luts( - sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device -) -> tuple[Tensor, Tensor, Tensor]: - sp_vocab_size = int(sp.vocab_size()) - table_size = max(sp_vocab_size, vocab_size) - base_bytes_np = np.zeros((table_size,), dtype=np.int16) - has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) - is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) - for token_id in range(sp_vocab_size): - if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): - continue - is_boundary_token_np[token_id] = False - if sp.is_byte(token_id): - base_bytes_np[token_id] = 1 - continue - piece = sp.id_to_piece(token_id) - if piece.startswith("▁"): - has_leading_space_np[token_id] = True - piece = piece[1:] - base_bytes_np[token_id] = len(piece.encode("utf-8")) - return ( - torch.tensor(base_bytes_np, dtype=torch.int16, device=device), - torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), - torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), - ) - - -def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: - files = [Path(p) for p in sorted(glob.glob(pattern))] - if not files: - raise FileNotFoundError(f"No files found for pattern: {pattern}") - # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. - tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() - usable = ((tokens.numel() - 1) // seq_len) * seq_len - if usable <= 0: - raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") - return tokens[: usable + 1] - - -def eval_val( - args: Hyperparameters, - model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - grad_accum_steps: int, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, -) -> tuple[float, float]: - # Validation computes two metrics: - # - val_loss: token cross-entropy (natural log) - # - val_bpb: tokenizer-agnostic compression metric used by the challenge - local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) - if local_batch_tokens < args.train_seq_len: - raise ValueError( - "VAL_BATCH_SIZE must provide at least one sequence per rank; " - f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " - f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" - ) - local_batch_seqs = local_batch_tokens // args.train_seq_len - total_seqs = (val_tokens.numel() - 1) // args.train_seq_len - seq_start = (total_seqs * rank) // world_size - seq_end = (total_seqs * (rank + 1)) // world_size - val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) - val_token_count = torch.zeros((), device=device, dtype=torch.float64) - val_byte_count = torch.zeros((), device=device, dtype=torch.float64) - - model.eval() - with torch.inference_mode(): - for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): - batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) - raw_start = batch_seq_start * args.train_seq_len - raw_end = batch_seq_end * args.train_seq_len + 1 - local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) - x = local[:-1].reshape(-1, args.train_seq_len) - y = local[1:].reshape(-1, args.train_seq_len) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - batch_loss = model(x, y).detach() - batch_token_count = float(y.numel()) - val_loss_sum += batch_loss.to(torch.float64) * batch_token_count - val_token_count += batch_token_count - prev_ids = x.reshape(-1) - tgt_ids = y.reshape(-1) - token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) - token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) - val_byte_count += token_bytes.to(torch.float64).sum() - - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) - - val_loss = val_loss_sum / val_token_count - bits_per_token = val_loss.item() / math.log(2.0) - tokens_per_byte = val_token_count.item() / val_byte_count.item() - model.train() - return float(val_loss.item()), float(bits_per_token * tokens_per_byte) - - -def eval_val_sliding_window( - args: Hyperparameters, - base_model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, - stride: int = 64, -) -> tuple[float, float]: - """Sliding window evaluation: each scored token gets (seq_len - stride) context. - - Instead of chopping validation into non-overlapping 1024-token blocks (where - the first token in each block gets zero context), we slide a 1024-token window - by `stride` tokens at a time and only score the last `stride` tokens per window. - Every scored token sees 960+ tokens of context, dramatically improving BPB. - """ - seq_len = args.train_seq_len - total_tokens = val_tokens.numel() - 1 # need 1 extra for final target - - # Number of windows: first window at offset 0, then slide by stride - num_windows = max((total_tokens - seq_len) // stride + 1, 1) - - # Distribute windows across ranks - win_start = (num_windows * rank) // world_size - win_end = (num_windows * (rank + 1)) // world_size - - val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) - val_token_count = torch.zeros((), device=device, dtype=torch.float64) - val_byte_count = torch.zeros((), device=device, dtype=torch.float64) - - # Eval batch: how many windows per forward pass. Tune for memory. - eval_batch = int(os.environ.get("SW_EVAL_BATCH", 32)) - - base_model.eval() - with torch.inference_mode(): - window_list = list(range(win_start, win_end)) - num_batches = (len(window_list) + eval_batch - 1) // eval_batch - for batch_idx in range(num_batches): - batch_wins = window_list[batch_idx * eval_batch : (batch_idx + 1) * eval_batch] - bsz = len(batch_wins) - - # Build input: each window is val_tokens[w*stride : w*stride + seq_len] - inputs = torch.stack([ - val_tokens[w * stride : w * stride + seq_len] - for w in batch_wins - ]).to(device=device, dtype=torch.int64) # [B, seq_len] - - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - logits = base_model.forward_logits(inputs) # [B, seq_len, vocab] - - # Score only the last `stride` positions per window. - # logits[:, j, :] predicts the token AFTER position j in the input. - # So logits[:, -stride:, :] predicts tokens at input positions - # [seq_len - stride + 1, seq_len + 1) — which are the targets. - scored_logits = logits[:, -stride:, :].reshape(-1, logits.size(-1)) # [B*stride, vocab] - - # Build targets: for window w, targets are val_tokens[w*stride + seq_len - stride + 1 : w*stride + seq_len + 1] - targets = torch.stack([ - val_tokens[w * stride + seq_len - stride + 1 : w * stride + seq_len + 1] - for w in batch_wins - ]).to(device=device, dtype=torch.int64).reshape(-1) # [B*stride] - - loss = F.cross_entropy(scored_logits.float(), targets, reduction="sum") - val_loss_sum += loss.to(torch.float64) - val_token_count += float(targets.numel()) - - # BPB: prev_ids are the input tokens at each scored position - prev_ids = torch.stack([ - val_tokens[w * stride + seq_len - stride : w * stride + seq_len] - for w in batch_wins - ]).to(device=device, dtype=torch.int64).reshape(-1) # [B*stride] - - tgt_ids = targets - token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) - token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) - val_byte_count += token_bytes.to(torch.float64).sum() - - if rank == 0 and (batch_idx + 1) % 100 == 0: - print(f" sw_eval batch {batch_idx + 1}/{num_batches}", flush=True) - - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) - - val_loss = val_loss_sum / val_token_count - bits_per_token = val_loss.item() / math.log(2.0) - tokens_per_byte = val_token_count.item() / val_byte_count.item() - base_model.train() - return float(val_loss.item()), float(bits_per_token * tokens_per_byte) - -# ----------------------------- -# POST-TRAINING QUANTIZATION -# ----------------------------- -# -# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. -# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. -# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. - -CONTROL_TENSOR_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "CONTROL_TENSOR_NAME_PATTERNS", - "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", - ",".join(CONTROL_TENSOR_NAME_PATTERNS), - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 -INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 -INT8_PER_ROW_SCALE_DTYPE = torch.float16 - -INT6_QUANT_RANGE = 31 # int6: [-31, 31] -INT6_CLIP_Q = 0.9999984 - -# Tensors matching these patterns are stored as fp16 passthrough (not quantized). -# For weights without STE fake-quant protection (e.g. nn.Embedding, bigram hash table). -FP16_PASSTHROUGH_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "FP16_PASSTHROUGH_PATTERNS", - "tok_emb,bigram_hash", - ).split(",") - if pattern -) - -def tensor_nbytes(t: Tensor) -> int: - return int(t.numel()) * int(t.element_size()) - -def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: - if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): - return t.float().contiguous() - if t.dtype in {torch.float32, torch.bfloat16}: - passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") - return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() - return t - -def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: - t32 = t.float() - if t32.ndim == 2: - # Int6 per-row quantization: [-31, 31] range stored in int8 containers. - # The unused high bits compress extremely well with zstd. - clip_abs = ( - torch.quantile(t32.abs(), INT6_CLIP_Q, dim=1) - if t32.numel() - else torch.empty((t32.shape[0],), dtype=torch.float32) - ) - clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) - scale = (clip_abs / float(INT6_QUANT_RANGE)).clamp_min(1.0 / float(INT6_QUANT_RANGE)) - q = torch.clamp(torch.round(clipped / scale[:, None]), -INT6_QUANT_RANGE, INT6_QUANT_RANGE).to(torch.int8).contiguous() - return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() - - # Vectors / scalars still use int8 per-tensor scale (they're tiny). - clip_abs = float(torch.quantile(t32.abs().flatten(), INT6_CLIP_Q).item()) if t32.numel() else 0.0 - scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) - q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() - return q, scale - -def quantize_state_dict_int8(state_dict: dict[str, Tensor]): - # Single supported clean-script export format: - # - per-row int8 for 2D float tensors - # - per-tensor int8 for other float tensors - # - exact passthrough for non-floats - # - passthrough for small float tensors, stored as fp16 to save bytes - quantized: dict[str, Tensor] = {} - scales: dict[str, Tensor] = {} - dtypes: dict[str, str] = {} - passthrough: dict[str, Tensor] = {} - passthrough_orig_dtypes: dict[str, str] = {} - qmeta: dict[str, dict[str, object]] = {} - stats = dict.fromkeys( - ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), - 0, - ) - - for name, tensor in state_dict.items(): - t = tensor.detach().to("cpu").contiguous() - stats["param_count"] += int(t.numel()) - stats["num_tensors"] += 1 - stats["baseline_tensor_bytes"] += tensor_nbytes(t) - - if not t.is_floating_point(): - stats["num_nonfloat_tensors"] += 1 - passthrough[name] = t - stats["int8_payload_bytes"] += tensor_nbytes(t) - continue - - # Small float tensors are cheap enough to keep directly. We still downcast - # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. - if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: - kept = keep_float_tensor(name, t, passthrough_orig_dtypes) - passthrough[name] = kept - stats["int8_payload_bytes"] += tensor_nbytes(kept) - continue - - stats["num_float_tensors"] += 1 - # fp16 passthrough for tensors without STE protection (tok_emb, bigram_hash). - if any(pattern in name for pattern in FP16_PASSTHROUGH_PATTERNS): - passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") - kept = t.to(dtype=torch.float16).contiguous() - passthrough[name] = kept - stats["int8_payload_bytes"] += tensor_nbytes(kept) - continue - q, s = quantize_float_tensor(t) - if s.ndim > 0: - qmeta[name] = {"scheme": "per_row", "axis": 0} - quantized[name] = q - scales[name] = s - dtypes[name] = str(t.dtype).removeprefix("torch.") - stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) - - obj: dict[str, object] = { - "__quant_format__": "int8_clean_per_row_v1", - "quantized": quantized, - "scales": scales, - "dtypes": dtypes, - "passthrough": passthrough, - } - if qmeta: - obj["qmeta"] = qmeta - if passthrough_orig_dtypes: - obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes - return obj, stats - -def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: - out: dict[str, Tensor] = {} - qmeta = obj.get("qmeta", {}) - passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) - for name, q in obj["quantized"].items(): - dtype = getattr(torch, obj["dtypes"][name]) - s = obj["scales"][name] - if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: - s = s.to(dtype=torch.float32) - # Broadcast the saved row scale back across trailing dimensions. - out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() - else: - scale = float(s.item()) - out[name] = (q.float() * scale).to(dtype=dtype).contiguous() - for name, t in obj["passthrough"].items(): - # Restore small tensors, undoing the temporary fp16 storage cast if needed. - out_t = t.detach().to("cpu").contiguous() - orig_dtype = passthrough_orig_dtypes.get(name) - if isinstance(orig_dtype, str): - out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() - out[name] = out_t - return out - - -# ----------------------------- -# DATA LOADING -# ----------------------------- - -def load_data_shard(file: Path) -> Tensor: - header_bytes = 256 * np.dtype(" None: - self.file_idx = (self.file_idx + 1) % len(self.files) - self.tokens = load_data_shard(self.files[self.file_idx]) - self.pos = 0 - - def take(self, n: int) -> Tensor: - chunks: list[Tensor] = [] - remaining = n - while remaining > 0: - avail = self.tokens.numel() - self.pos - if avail <= 0: - self._advance_file() - continue - k = min(remaining, avail) - chunks.append(self.tokens[self.pos : self.pos + k]) - self.pos += k - remaining -= k - return chunks[0] if len(chunks) == 1 else torch.cat(chunks) - - -class DistributedTokenLoader: - # Each call consumes a contiguous chunk from the shared token stream, then slices out - # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. - def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): - self.rank = rank - self.world_size = world_size - self.device = device - self.stream = TokenStream(pattern) - - def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: - local_tokens = global_tokens // (self.world_size * grad_accum_steps) - per_rank_span = local_tokens + 1 - chunk = self.stream.take(per_rank_span * self.world_size) - start = self.rank * per_rank_span - local = chunk[start : start + per_rank_span].to(dtype=torch.int64) - x = local[:-1].reshape(-1, seq_len) - y = local[1:].reshape(-1, seq_len) - return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) - -# ----------------------------- -# TRANSFORMER MODULES -# ----------------------------- - -class RMSNorm(nn.Module): - def __init__(self, eps: float | None = None): - super().__init__() - self.eps = eps - - def forward(self, x: Tensor) -> Tensor: - return F.rms_norm(x, (x.size(-1),), eps=self.eps) - - -class CastedLinear(nn.Linear): - # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. - # During training, apply fake int6 quantization (STE) so the model learns to be robust - # to the post-training quantization that will be applied for the 16MB artifact. - def forward(self, x: Tensor) -> Tensor: - w = self.weight.to(x.dtype) - if self.training and w.ndim == 2: - # Fake int6 per-row quantization with Straight-Through Estimator: - # Forward uses quantized weights, backward passes gradients through as-is. - with torch.no_grad(): - w32 = w.float() - clip_abs = torch.quantile(w32.abs(), INT6_CLIP_Q, dim=1).clamp_min(1e-8) - scale = clip_abs / INT6_QUANT_RANGE - w_clipped = torch.clamp(w32, -clip_abs[:, None], clip_abs[:, None]) - w_q = (torch.round(w_clipped / scale[:, None]) * scale[:, None]).to(x.dtype) - w = w + (w_q - w).detach() # STE: value of w_q, gradient of w - bias = self.bias.to(x.dtype) if self.bias is not None else None - return F.linear(x, w, bias) - - -def restore_low_dim_params_to_fp32(module: nn.Module) -> None: - # Keep small/control parameters in fp32 even when the model body runs in bf16. - with torch.no_grad(): - for name, param in module.named_parameters(): - if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: - param.data = param.data.float() - - -class Rotary(nn.Module): - # Caches cos/sin tables per sequence length on the current device. - def __init__(self, dim: int, base: float = 10000.0): - super().__init__() - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self._seq_len_cached = 0 - self._cos_cached: Tensor | None = None - self._sin_cached: Tensor | None = None - - def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: - if ( - self._cos_cached is None - or self._sin_cached is None - or self._seq_len_cached != seq_len - or self._cos_cached.device != device - ): - t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) - freqs = torch.outer(t, self.inv_freq.to(device)) - self._cos_cached = freqs.cos()[None, None, :, :] - self._sin_cached = freqs.sin()[None, None, :, :] - self._seq_len_cached = seq_len - return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) - - -def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: - half = x.size(-1) // 2 - x1, x2 = x[..., :half], x[..., half:] - return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) - - -class CausalSelfAttention(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - rope_base: float, - qk_gain_init: float, - ): - super().__init__() - if dim % num_heads != 0: - raise ValueError("model_dim must be divisible by num_heads") - if num_heads % num_kv_heads != 0: - raise ValueError("num_heads must be divisible by num_kv_heads") - self.num_heads = num_heads - self.num_kv_heads = num_kv_heads - self.head_dim = dim // num_heads - if self.head_dim % 2 != 0: - raise ValueError("head_dim must be even for RoPE") - kv_dim = self.num_kv_heads * self.head_dim - self.c_q = CastedLinear(dim, dim, bias=False) - self.c_k = CastedLinear(dim, kv_dim, bias=False) - self.c_v = CastedLinear(dim, kv_dim, bias=False) - self.proj = CastedLinear(dim, dim, bias=False) - self.proj._zero_init = True - self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) - self.rotary = Rotary(self.head_dim, base=rope_base) - - def forward(self, x: Tensor) -> Tensor: - bsz, seqlen, dim = x.shape - q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) - k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) - v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) - q = F.rms_norm(q, (q.size(-1),)) - k = F.rms_norm(k, (k.size(-1),)) - cos, sin = self.rotary(seqlen, x.device, q.dtype) - q = apply_rotary_emb(q, cos, sin) - k = apply_rotary_emb(k, cos, sin) - q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] - y = F.scaled_dot_product_attention( - q, - k, - v, - attn_mask=None, - is_causal=True, - enable_gqa=(self.num_kv_heads != self.num_heads), - ) - y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) - return self.proj(y) - - -class MLP(nn.Module): - # relu^2 MLP from the original modded-nanogpt setup - def __init__(self, dim: int, mlp_mult: int): - super().__init__() - hidden = mlp_mult * dim - self.fc = CastedLinear(dim, hidden, bias=False) - self.proj = CastedLinear(hidden, dim, bias=False) - self.proj._zero_init = True - - def forward(self, x: Tensor) -> Tensor: - x = torch.relu(self.fc(x)) - return self.proj(x.square()) - - -class Block(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: int, - rope_base: float, - qk_gain_init: float, - ): - super().__init__() - self.attn_norm = RMSNorm() - self.mlp_norm = RMSNorm() - self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) - self.mlp = MLP(dim, mlp_mult) - self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) - - def forward(self, x: Tensor, x0: Tensor) -> Tensor: - mix = self.resid_mix.to(dtype=x.dtype) - x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 - attn_out = self.attn(self.attn_norm(x)) - x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out - x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) - return x - - -class BigramHash(nn.Module): - """Hash-table embedding for token bigrams, projected to model dim. - Maps (prev_token, cur_token) pairs via a simple hash to a learned embedding. - Gives the model cheap character-pair / bigram info before attention. - """ - def __init__(self, num_buckets: int, hash_dim: int, model_dim: int): - super().__init__() - self.num_buckets = num_buckets - self.table = nn.Embedding(num_buckets, hash_dim) - self.proj = CastedLinear(hash_dim, model_dim, bias=False) - self.proj._zero_init = True - nn.init.normal_(self.table.weight, std=0.01) - - def forward(self, input_ids: Tensor) -> Tensor: - bsz, seqlen = input_ids.shape - prev_ids = torch.cat([torch.zeros(bsz, 1, dtype=input_ids.dtype, device=input_ids.device), input_ids[:, :-1]], dim=1) - h = ((prev_ids.long() * 92821 + input_ids.long()) % self.num_buckets).long() - return self.proj(self.table(h)) - - -class GPT(nn.Module): - def __init__( - self, - vocab_size: int, - num_layers: int, - model_dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: int, - tie_embeddings: bool, - tied_embed_init_std: float, - logit_softcap: float, - rope_base: float, - qk_gain_init: float, - bigram_hash_buckets: int = 0, - bigram_hash_dim: int = 128, - ): - super().__init__() - if logit_softcap <= 0.0: - raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") - self.tie_embeddings = tie_embeddings - self.tied_embed_init_std = tied_embed_init_std - self.logit_softcap = logit_softcap - self.tok_emb = nn.Embedding(vocab_size, model_dim) - self.bigram_hash = BigramHash(bigram_hash_buckets, bigram_hash_dim, model_dim) if bigram_hash_buckets > 0 else None - self.num_encoder_layers = num_layers // 2 - self.num_decoder_layers = num_layers - self.num_encoder_layers - self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) - self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) - self.blocks = nn.ModuleList( - [ - Block( - model_dim, - num_heads, - num_kv_heads, - mlp_mult, - rope_base, - qk_gain_init, - ) - for i in range(num_layers) - ] - ) - self.final_norm = RMSNorm() - self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) - if self.lm_head is not None: - self.lm_head._zero_init = True - self._init_weights() - - def _init_weights(self) -> None: - if self.tie_embeddings: - nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) - for module in self.modules(): - if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): - nn.init.zeros_(module.weight) - - def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: - x = self.tok_emb(input_ids) - if self.bigram_hash is not None: - x = x + self.bigram_hash(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - x0 = x - skips: list[Tensor] = [] - - # First half stores skips; second half reuses them in reverse order. - for i in range(self.num_encoder_layers): - x = self.blocks[i](x, x0) - skips.append(x) - for i in range(self.num_decoder_layers): - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - x = self.blocks[self.num_encoder_layers + i](x, x0) - - x = self.final_norm(x).reshape(-1, x.size(-1)) - targets = target_ids.reshape(-1) - if self.tie_embeddings: - logits_proj = F.linear(x, self.tok_emb.weight) - else: - if self.lm_head is None: - raise RuntimeError("lm_head is required when tie_embeddings=False") - logits_proj = self.lm_head(x) - logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - return F.cross_entropy(logits.float(), targets, reduction="mean") - - def forward_logits(self, input_ids: Tensor) -> Tensor: - """Return logits without computing loss. Used for sliding window eval.""" - x = self.tok_emb(input_ids) - if self.bigram_hash is not None: - x = x + self.bigram_hash(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - x0 = x - skips: list[Tensor] = [] - for i in range(self.num_encoder_layers): - x = self.blocks[i](x, x0) - skips.append(x) - for i in range(self.num_decoder_layers): - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - x = self.blocks[self.num_encoder_layers + i](x, x0) - x = self.final_norm(x) - if self.tie_embeddings: - logits_proj = F.linear(x, self.tok_emb.weight.to(x.dtype)) - else: - if self.lm_head is None: - raise RuntimeError("lm_head is required when tie_embeddings=False") - logits_proj = self.lm_head(x) - return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - - -# ----------------------------- -# TRAINING -# ----------------------------- - -def main() -> None: - global zeropower_via_newtonschulz5 - - code = Path(__file__).read_text(encoding="utf-8") - args = Hyperparameters() - zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) - - # ----------------------------- - # DISTRIBUTED + CUDA SETUP - # ----------------------------- - - distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ - rank = int(os.environ.get("RANK", "0")) - world_size = int(os.environ.get("WORLD_SIZE", "1")) - local_rank = int(os.environ.get("LOCAL_RANK", "0")) - if world_size <= 0: - raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") - if 8 % world_size != 0: - raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") - grad_accum_steps = 8 // world_size - grad_scale = 1.0 / grad_accum_steps - if not torch.cuda.is_available(): - raise RuntimeError("CUDA is required") - device = torch.device("cuda", local_rank) - torch.cuda.set_device(device) - if distributed: - dist.init_process_group(backend="nccl", device_id=device) - dist.barrier() - master_process = rank == 0 - - # Fast math knobs - torch.backends.cuda.matmul.allow_tf32 = True - torch.backends.cudnn.allow_tf32 = True - from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp - - enable_cudnn_sdp(False) - enable_flash_sdp(True) - enable_mem_efficient_sdp(False) - enable_math_sdp(False) - - logfile = None - if master_process: - os.makedirs("logs", exist_ok=True) - logfile = f"logs/{args.run_id}.txt" - print(logfile) - - def log0(msg: str, console: bool = True) -> None: - if not master_process: - return - if console: - print(msg) - if logfile is not None: - with open(logfile, "a", encoding="utf-8") as f: - print(msg, file=f) - - log0(code, console=False) - log0("=" * 100, console=False) - log0(f"Running Python {sys.version}", console=False) - log0(f"Running PyTorch {torch.__version__}", console=False) - log0( - subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, - console=False, - ) - log0("=" * 100, console=False) - - # ----------------------------- - # TOKENIZER + VALIDATION METRIC SETUP - # ----------------------------- - - random.seed(args.seed) - np.random.seed(args.seed) - torch.manual_seed(args.seed) - torch.cuda.manual_seed_all(args.seed) - - if not args.tokenizer_path.endswith(".model"): - raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") - sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) - if int(sp.vocab_size()) != args.vocab_size: - raise ValueError( - f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" - ) - dataset_dir = Path(args.data_path).resolve() - actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) - val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) - base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( - sp, args.vocab_size, device - ) - log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") - log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") - log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") - - # ----------------------------- - # MODEL + OPTIMIZER SETUP - # ----------------------------- - - base_model = GPT( - vocab_size=args.vocab_size, - num_layers=args.num_layers, - model_dim=args.model_dim, - num_heads=args.num_heads, - num_kv_heads=args.num_kv_heads, - mlp_mult=args.mlp_mult, - tie_embeddings=args.tie_embeddings, - tied_embed_init_std=args.tied_embed_init_std, - logit_softcap=args.logit_softcap, - rope_base=args.rope_base, - qk_gain_init=args.qk_gain_init, - bigram_hash_buckets=args.bigram_hash_buckets, - bigram_hash_dim=args.bigram_hash_dim, - ).to(device).bfloat16() - for module in base_model.modules(): - if isinstance(module, CastedLinear): - module.float() - restore_low_dim_params_to_fp32(base_model) - compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) - model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model - - # Optimizer split: - # - token embedding (Adam) uses EMBED_LR - # - untied lm_head (Adam) uses HEAD_LR - # - matrix params in transformer blocks use MATRIX_LR via Muon - # - vectors/scalars use SCALAR_LR via Adam - block_named_params = list(base_model.blocks.named_parameters()) - matrix_params = [ - p - for name, p in block_named_params - if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - scalar_params = [ - p - for name, p in block_named_params - if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - if base_model.skip_weights.numel() > 0: - scalar_params.append(base_model.skip_weights) - token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr - # Collect embedding-like params - embed_params = [base_model.tok_emb.weight] - if base_model.bigram_hash is not None: - embed_params.append(base_model.bigram_hash.table.weight) - # bigram_hash.proj is a CastedLinear — its weight goes to Muon - matrix_params.append(base_model.bigram_hash.proj.weight) - optimizer_tok = torch.optim.Adam( - [{"params": embed_params, "lr": token_lr, "base_lr": token_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizer_muon = Muon( - matrix_params, - lr=args.matrix_lr, - momentum=args.muon_momentum, - backend_steps=args.muon_backend_steps, - ) - for group in optimizer_muon.param_groups: - group["base_lr"] = args.matrix_lr - optimizer_scalar = torch.optim.Adam( - [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] - if base_model.lm_head is not None: - optimizer_head = torch.optim.Adam( - [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizers.insert(1, optimizer_head) - - n_params = sum(p.numel() for p in base_model.parameters()) - log0(f"model_params:{n_params}") - log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") - log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") - log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") - log0( - f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " - f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " - f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" - ) - log0( - f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " - f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " - f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" - ) - log0(f"seed:{args.seed}") - - # ----------------------------- - # DATA LOADER & MODEL WARMUP - # ----------------------------- - - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - def zero_grad_all() -> None: - for opt in optimizers: - opt.zero_grad(set_to_none=True) - - max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None - - def lr_mul(step: int, elapsed_ms: float) -> float: - if args.warmdown_iters <= 0: - return 1.0 - if max_wallclock_ms is None: - warmdown_start = max(args.iterations - args.warmdown_iters, 0) - return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 - step_ms = elapsed_ms / max(step, 1) - warmdown_ms = args.warmdown_iters * step_ms - remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) - return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 - - # Warmup primes the compiled forward/backward/optimizer paths, then we restore the - # initial weights/optimizer state so measured training starts from the true init. - if args.warmup_steps > 0: - initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} - initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] - model.train() - for warmup_step in range(args.warmup_steps): - zero_grad_all() - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - warmup_loss = model(x, y) - (warmup_loss * grad_scale).backward() - for opt in optimizers: - opt.step() - zero_grad_all() - if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: - log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") - base_model.load_state_dict(initial_model_state, strict=True) - for opt, state in zip(optimizers, initial_optimizer_states, strict=True): - opt.load_state_dict(state) - zero_grad_all() - if distributed: - model.require_backward_grad_sync = True - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - # ----------------------------- - # MAIN TRAINING LOOP - # ----------------------------- - - training_time_ms = 0.0 - stop_after_step: int | None = None - torch.cuda.synchronize() - t0 = time.perf_counter() - - step = 0 - while True: - last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) - - should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) - if should_validate: - torch.cuda.synchronize() - training_time_ms += 1000.0 * (time.perf_counter() - t0) - val_loss, val_bpb = eval_val( - args, - model, - rank, - world_size, - device, - grad_accum_steps, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, - ) - log0( - f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " - f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" - ) - torch.cuda.synchronize() - t0 = time.perf_counter() - - if last_step: - if stop_after_step is not None and step < args.iterations: - log0( - f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " - f"step:{step}/{args.iterations}" - ) - break - - elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - scale = lr_mul(step, elapsed_ms) - zero_grad_all() - train_loss = torch.zeros((), device=device) - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - loss = model(x, y) - train_loss += loss.detach() - (loss * grad_scale).backward() - train_loss /= grad_accum_steps - - frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 - muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum - for group in optimizer_muon.param_groups: - group["momentum"] = muon_momentum - - for opt in optimizers: - for group in opt.param_groups: - group["lr"] = group["base_lr"] * scale - - if args.grad_clip_norm > 0: - torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) - for opt in optimizers: - opt.step() - zero_grad_all() - - step += 1 - approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - should_log_train = ( - args.train_log_every > 0 - and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) - ) - if should_log_train: - log0( - f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " - f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" - ) - - # Needed to sync whether we've reached the wallclock cap. - reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms - if distributed and max_wallclock_ms is not None: - reached_cap_tensor = torch.tensor(int(reached_cap), device=device) - dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) - reached_cap = bool(reached_cap_tensor.item()) - if stop_after_step is None and reached_cap: - stop_after_step = step - - log0( - f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " - f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" - ) - - # ----------------------------- - # SERIALIZATION + ROUNDTRIP VALIDATION - # ----------------------------- - # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce - # the compressed int8+zlib artifact and validate the round-tripped weights. - - if master_process: - torch.save(base_model.state_dict(), "final_model.pt") - model_bytes = os.path.getsize("final_model.pt") - code_bytes = len(code.encode("utf-8")) - log0(f"Serialized model: {model_bytes} bytes") - log0(f"Code size: {code_bytes} bytes") - log0(f"Total submission size: {model_bytes + code_bytes} bytes") - - quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) - quant_buf = io.BytesIO() - torch.save(quant_obj, quant_buf) - quant_raw = quant_buf.getvalue() - if HAS_ZSTD: - cctx = zstd.ZstdCompressor(level=22) - quant_blob = cctx.compress(quant_raw) - compress_name = "zstd-22" - else: - quant_blob = zlib.compress(quant_raw, level=9) - compress_name = "zlib-9" - quant_raw_bytes = len(quant_raw) - if master_process: - quant_ext = "int6.ptz" if HAS_ZSTD else "int6.ptz" - with open(f"final_model.{quant_ext}", "wb") as f: - f.write(quant_blob) - quant_file_bytes = os.path.getsize(f"final_model.{quant_ext}") - code_bytes = len(code.encode("utf-8")) - ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) - log0( - f"Serialized model int6+{compress_name}: {quant_file_bytes} bytes " - f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" - ) - log0(f"Total submission size int6+{compress_name}: {quant_file_bytes + code_bytes} bytes") - - if distributed: - dist.barrier() - quant_ext = "int6.ptz" - with open(f"final_model.{quant_ext}", "rb") as f: - quant_blob_disk = f.read() - if HAS_ZSTD: - dctx = zstd.ZstdDecompressor() - quant_decompressed = dctx.decompress(quant_blob_disk) - else: - quant_decompressed = zlib.decompress(quant_blob_disk) - quant_state = torch.load(io.BytesIO(quant_decompressed), map_location="cpu") - base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) - torch.cuda.synchronize() - t_qeval = time.perf_counter() - q_val_loss, q_val_bpb = eval_val( - args, - model, - rank, - world_size, - device, - grad_accum_steps, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, - ) - torch.cuda.synchronize() - log0( - f"final_int6_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " - f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" - ) - log0(f"final_int6_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") - - # Sliding window eval: gives each scored token (seq_len - stride) context. - if args.eval_stride > 0: - torch.cuda.synchronize() - t_sw = time.perf_counter() - sw_val_loss, sw_val_bpb = eval_val_sliding_window( - args, - base_model, - rank, - world_size, - device, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, - stride=args.eval_stride, - ) - torch.cuda.synchronize() - log0( - f"final_sliding_window_eval stride:{args.eval_stride} " - f"val_loss:{sw_val_loss:.4f} val_bpb:{sw_val_bpb:.4f} " - f"eval_time:{1000.0 * (time.perf_counter() - t_sw):.0f}ms" - ) - log0( - f"final_sliding_window_eval_exact stride:{args.eval_stride} " - f"val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}" - ) - - if distributed: - dist.destroy_process_group() - - -if __name__ == "__main__": - main() -==================================================================================================== -Running Python 3.12.3 (main, Nov 6 2025, 13:44:16) [GCC 13.3.0] -Running PyTorch 2.9.1+cu128 -Fri Mar 20 02:18:24 2026 -+-----------------------------------------------------------------------------------------+ -| NVIDIA-SMI 570.195.03 Driver Version: 570.195.03 CUDA Version: 12.8 | -|-----------------------------------------+------------------------+----------------------+ -| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | -| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | -| | | MIG M. | -|=========================================+========================+======================| -| 0 NVIDIA H100 80GB HBM3 On | 00000000:05:00.0 Off | 0 | -| N/A 34C P0 115W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 1 NVIDIA H100 80GB HBM3 On | 00000000:06:00.0 Off | 0 | -| N/A 31C P0 115W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 2 NVIDIA H100 80GB HBM3 On | 00000000:07:00.0 Off | 0 | -| N/A 30C P0 117W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 3 NVIDIA H100 80GB HBM3 On | 00000000:08:00.0 Off | 0 | -| N/A 34C P0 121W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 4 NVIDIA H100 80GB HBM3 On | 00000000:09:00.0 Off | 0 | -| N/A 34C P0 118W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 5 NVIDIA H100 80GB HBM3 On | 00000000:0A:00.0 Off | 0 | -| N/A 31C P0 115W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 6 NVIDIA H100 80GB HBM3 On | 00000000:0B:00.0 Off | 0 | -| N/A 31C P0 122W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 7 NVIDIA H100 80GB HBM3 On | 00000000:0C:00.0 Off | 0 | -| N/A 35C P0 121W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ - -+-----------------------------------------------------------------------------------------+ -| Processes: | -| GPU GI CI PID Type Process name GPU Memory | -| ID ID Usage | -|=========================================================================================| -| 0 N/A N/A 49921 C /usr/local/bin/python 1510MiB | -| 1 N/A N/A 49922 C /usr/local/bin/python 1510MiB | -| 2 N/A N/A 49923 C /usr/local/bin/python 1510MiB | -| 3 N/A N/A 49924 C /usr/local/bin/python 1510MiB | -| 4 N/A N/A 49925 C /usr/local/bin/python 1510MiB | -| 5 N/A N/A 49926 C /usr/local/bin/python 1510MiB | -| 6 N/A N/A 49927 C /usr/local/bin/python 1510MiB | -| 7 N/A N/A 49928 C /usr/local/bin/python 1510MiB | -+-----------------------------------------------------------------------------------------+ - -==================================================================================================== -val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model -train_loader:dataset:fineweb10B_sp1024 train_shards:80 -val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 -model_params:22368328 -world_size:8 grad_accum_steps:1 -sdp_backends:cudnn=False flash=True mem_efficient=False math=False -attention_mode:gqa num_heads:8 num_kv_heads:4 -tie_embeddings:True embed_lr:0.03 head_lr:0.0 matrix_lr:0.02 scalar_lr:0.02 -train_batch_tokens:524288 train_seq_len:1024 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 -seed:1337 -warmup_step:1/20 -warmup_step:2/20 -warmup_step:3/20 -warmup_step:4/20 -warmup_step:5/20 -warmup_step:6/20 -warmup_step:7/20 -warmup_step:8/20 -warmup_step:9/20 -warmup_step:10/20 -warmup_step:11/20 -warmup_step:12/20 -warmup_step:13/20 -warmup_step:14/20 -warmup_step:15/20 -warmup_step:16/20 -warmup_step:17/20 -warmup_step:18/20 -warmup_step:19/20 -warmup_step:20/20 -step:0/20000 val_loss:6.9379 val_bpb:4.1090 train_time:0ms step_avg:0.02ms -step:1/20000 train_loss:6.9364 train_time:31ms step_avg:30.55ms -step:2/20000 train_loss:12.1409 train_time:82ms step_avg:40.86ms -step:3/20000 train_loss:7.1968 train_time:135ms step_avg:44.91ms -step:4/20000 train_loss:6.4291 train_time:190ms step_avg:47.51ms -step:5/20000 train_loss:6.9210 train_time:245ms step_avg:48.95ms -step:6/20000 train_loss:7.5366 train_time:301ms step_avg:50.13ms -step:7/20000 train_loss:6.7198 train_time:355ms step_avg:50.65ms -step:8/20000 train_loss:6.3634 train_time:409ms step_avg:51.09ms -step:9/20000 train_loss:6.2161 train_time:464ms step_avg:51.52ms -step:10/20000 train_loss:6.0286 train_time:517ms step_avg:51.68ms -step:200/20000 train_loss:2.7939 train_time:9860ms step_avg:49.30ms -step:400/20000 train_loss:2.2882 train_time:19689ms step_avg:49.22ms -step:600/20000 train_loss:2.4945 train_time:29565ms step_avg:49.28ms -step:800/20000 train_loss:2.2566 train_time:39444ms step_avg:49.30ms -step:1000/20000 train_loss:2.3514 train_time:49340ms step_avg:49.34ms -step:1200/20000 train_loss:2.3707 train_time:59244ms step_avg:49.37ms -step:1400/20000 train_loss:2.4180 train_time:69137ms step_avg:49.38ms -step:1600/20000 train_loss:2.0918 train_time:79026ms step_avg:49.39ms -step:1800/20000 train_loss:2.1879 train_time:88911ms step_avg:49.40ms -step:2000/20000 train_loss:2.2381 train_time:98804ms step_avg:49.40ms -step:2200/20000 train_loss:2.0439 train_time:108685ms step_avg:49.40ms -step:2400/20000 train_loss:2.1727 train_time:118547ms step_avg:49.39ms -step:2600/20000 train_loss:2.3791 train_time:128432ms step_avg:49.40ms -step:2800/20000 train_loss:2.1977 train_time:138299ms step_avg:49.39ms -step:3000/20000 train_loss:2.1923 train_time:148160ms step_avg:49.39ms -step:3200/20000 train_loss:2.1539 train_time:158049ms step_avg:49.39ms -step:3400/20000 train_loss:2.1235 train_time:167930ms step_avg:49.39ms -step:3600/20000 train_loss:2.0735 train_time:177813ms step_avg:49.39ms -step:3800/20000 train_loss:2.1821 train_time:187714ms step_avg:49.40ms -step:4000/20000 train_loss:2.1263 train_time:197620ms step_avg:49.41ms -step:4000/20000 val_loss:2.1281 val_bpb:1.2604 train_time:197653ms step_avg:49.41ms -step:4200/20000 train_loss:2.1315 train_time:207621ms step_avg:49.43ms -step:4400/20000 train_loss:2.0687 train_time:217521ms step_avg:49.44ms -step:4600/20000 train_loss:1.9267 train_time:227415ms step_avg:49.44ms -step:4800/20000 train_loss:2.2176 train_time:237314ms step_avg:49.44ms -step:5000/20000 train_loss:1.9865 train_time:247207ms step_avg:49.44ms -step:5200/20000 train_loss:2.1270 train_time:257113ms step_avg:49.44ms -step:5400/20000 train_loss:2.1403 train_time:267008ms step_avg:49.45ms -step:5600/20000 train_loss:2.1378 train_time:276906ms step_avg:49.45ms -step:5800/20000 train_loss:2.0980 train_time:286918ms step_avg:49.47ms -step:6000/20000 train_loss:2.1798 train_time:297160ms step_avg:49.53ms -step:6200/20000 train_loss:2.0437 train_time:307259ms step_avg:49.56ms -step:6400/20000 train_loss:2.1160 train_time:317124ms step_avg:49.55ms -step:6600/20000 train_loss:2.0748 train_time:326994ms step_avg:49.54ms -step:6800/20000 train_loss:2.1477 train_time:336875ms step_avg:49.54ms -step:7000/20000 train_loss:2.1806 train_time:346747ms step_avg:49.54ms -step:7200/20000 train_loss:2.1479 train_time:356620ms step_avg:49.53ms -step:7400/20000 train_loss:2.0732 train_time:366491ms step_avg:49.53ms -step:7600/20000 train_loss:1.9514 train_time:376370ms step_avg:49.52ms -step:7800/20000 train_loss:2.1021 train_time:386252ms step_avg:49.52ms -step:8000/20000 train_loss:2.0685 train_time:396124ms step_avg:49.52ms -step:8000/20000 val_loss:2.0762 val_bpb:1.2297 train_time:396152ms step_avg:49.52ms -step:8200/20000 train_loss:2.1457 train_time:405994ms step_avg:49.51ms -step:8400/20000 train_loss:2.0894 train_time:415948ms step_avg:49.52ms -step:8600/20000 train_loss:2.0913 train_time:425810ms step_avg:49.51ms -step:8800/20000 train_loss:2.0542 train_time:435719ms step_avg:49.51ms -step:9000/20000 train_loss:1.9794 train_time:445595ms step_avg:49.51ms -step:9200/20000 train_loss:2.0395 train_time:455473ms step_avg:49.51ms -step:9400/20000 train_loss:2.0786 train_time:465361ms step_avg:49.51ms -step:9600/20000 train_loss:2.0926 train_time:475238ms step_avg:49.50ms -step:9800/20000 train_loss:2.0138 train_time:485125ms step_avg:49.50ms -step:10000/20000 train_loss:2.0578 train_time:495051ms step_avg:49.51ms -step:10200/20000 train_loss:2.0050 train_time:504951ms step_avg:49.50ms -step:10400/20000 train_loss:2.0338 train_time:514831ms step_avg:49.50ms -step:10600/20000 train_loss:1.9155 train_time:524714ms step_avg:49.50ms -step:10800/20000 train_loss:2.1173 train_time:534599ms step_avg:49.50ms -step:11000/20000 train_loss:2.0472 train_time:544503ms step_avg:49.50ms -step:11200/20000 train_loss:1.9902 train_time:554367ms step_avg:49.50ms -step:11400/20000 train_loss:1.9795 train_time:564228ms step_avg:49.49ms -step:11600/20000 train_loss:1.9797 train_time:574096ms step_avg:49.49ms -step:11800/20000 train_loss:2.0115 train_time:583977ms step_avg:49.49ms -step:12000/20000 train_loss:1.9875 train_time:593849ms step_avg:49.49ms -step:12000/20000 val_loss:2.0165 val_bpb:1.1943 train_time:593877ms step_avg:49.49ms -step:12123/20000 val_loss:2.0158 val_bpb:1.1939 train_time:599995ms step_avg:49.49ms -stopping_early: wallclock_cap train_time:599995ms step:12123/20000 -peak memory allocated: 11273 MiB reserved: 11438 MiB -Serialized model: 87410841 bytes -Code size: 58666 bytes -Total submission size: 87469507 bytes -Serialized model int6+zstd-22: 16127834 bytes (payload:23608608 raw_torch:23654103 payload_ratio:3.70x) -Total submission size int6+zstd-22: 16186500 bytes -final_int6_roundtrip val_loss:2.0145 val_bpb:1.1931 eval_time:1530ms -final_int6_roundtrip_exact val_loss:2.01453600 val_bpb:1.19312169 -final_sliding_window_eval stride:64 val_loss:1.9582 val_bpb:1.1598 eval_time:74131ms -final_sliding_window_eval_exact stride:64 val_loss:1.95823653 val_bpb:1.15977898 diff --git a/records/track_10min_16mb/2026-03-19_smeargate_orthoinit_muonwd/README.md b/records/track_10min_16mb/2026-03-19_smeargate_orthoinit_muonwd/README.md deleted file mode 100644 index 843604db97..0000000000 --- a/records/track_10min_16mb/2026-03-19_smeargate_orthoinit_muonwd/README.md +++ /dev/null @@ -1,137 +0,0 @@ -# SmearGate + OrthoInit + Muon WD + Int6 STE QAT + MLP 3x + Sliding Window - -**val\_bpb: 1.1556** (post-quant int6+zstd-22, sliding window eval stride=64) - -## Summary - -A 22.4M parameter transformer language model trained in under 10 minutes on 8×H100 GPUs, compressed to a 15.1MB artifact via int6 quantization-aware training and zstd-22. The architecture combines a SmearGate bigram embedding layer, orthogonal weight initialization, 3× MLP expansion, U-Net skip connections, and decoupled Muon weight decay, evaluated with sliding window context at stride 64. - -## Architecture - -### Transformer Core - -A 9-layer, 512-dim transformer with 8 attention heads (4 KV heads via grouped-query attention) and tied input/output embeddings over a 1024-token BPE vocabulary. Sequence length during training is 1024 tokens. - -### SmearGate - -A learned per-dimension gate (\~512 params) that blends each token's embedding with the previous token's embedding before the transformer processes anything: - -```python -gate = sigmoid(self.gate) # shape \[dim], init ≈ 0.95 -output = gate \* current\_emb + (1 - gate) \* prev\_token\_emb -``` - -This injects bigram (two-token) context directly into the embedding layer. Normally a transformer must discover token-pair relationships through self-attention; SmearGate provides this signal for free. The gate is initialized via `sigmoid(3.0) ≈ 0.95` so it starts near-identity (mostly current token), and the model learns per-dimension how much previous-token blending is useful. - -Applied after embedding lookup and bigram hash addition, before RMS normalization. - -### Bigram Hash Embedding - -A 4096-bucket hash table (dim=128, projected to 512) maps consecutive token pairs to learned embeddings via `(prev \* 92821 + cur) % 4096`. This gives the model direct access to token-pair features at minimal parameter cost. - -### MLP 3× Expansion - -MLP hidden dimension is 3× the model dimension (1536 for a 512-dim model). The space savings from int6 quantization fund this extra capacity — wider MLPs allow more expressive nonlinear feature transformation between attention operations. - -### U-Net Skip Connections - -The 9-layer transformer is split into an encoder half (4 layers) and a decoder half (5 layers) with learned skip weights connecting corresponding encoder/decoder layers. This gives the decoder direct access to earlier representations without relying solely on the residual stream. - -## Training - -### Muon Optimizer with Weight Decay - -The Muon optimizer (MomentUm Orthogonalized by Newton-Schulz) runs SGD with Nesterov momentum, then post-processes each 2D parameter's gradient update by replacing it with the nearest orthogonal matrix via 5-step Newton-Schulz iteration. This is equivalent to steepest descent under the spectral norm, improving the conditioning of the optimization landscape. - -Decoupled weight decay (`p.mul\_(1 - wd \* lr)`, wd=0.01) is applied before each gradient update. This keeps weights smaller and better-distributed, which directly benefits both generalization and downstream quantization — tighter weight distributions quantize into fewer int6 buckets with less error and compress better with zstd. - -Momentum is warmed from 0.92 → 0.99 over the first 1500 steps. - -### Orthogonal Weight Initialization - -All non-zero-init CastedLinear weight matrices are initialized with `nn.init.orthogonal\_()`. Orthogonal matrices have all singular values equal to 1, meaning gradients flow uniformly through the network at initialization with no vanishing or exploding signals. Additionally, since Muon's Newton-Schulz step orthogonalizes updates, starting from an already-orthogonal matrix means early updates are immediately useful rather than spent correcting a random initialization. With only \~12k steps in the 10-minute budget, faster convergence matters. - -### Int6 Quantization-Aware Training (STE) - -All 2D weight matrices are fake-quantized to int6 (\[-31, 31]) during every forward pass via Straight-Through Estimator — the forward pass sees quantized weights while gradients flow through the rounding operation as if it were identity. The model learns weight configurations that are inherently robust to post-training quantization. The tied embedding matrix is stored as fp16 passthrough (not quantized), since it serves double duty for both input embeddings and output predictions where errors compound in both directions. - -### Learning Rate Schedule - -Warmup over 20 steps, followed by linear warmdown over the final 3000 steps. Separate learning rates for tied embeddings (0.030), matrix parameters (0.020), and scalar parameters (0.020). - -## Evaluation - -### Sliding Window (stride=64) - -Instead of chopping validation text into non-overlapping chunks (where tokens near the start of each chunk lack context), sliding window uses overlapping windows with stride 64 and the full 1024-token context window. Each scored token gets 960+ tokens of prior context. This is purely an evaluation-time technique — it does not change the model. - -## Export - -### Int6 + zstd-22 Compression - -All quantized weights are packed into int8 containers and compressed with zstandard at level 22. The int6 representation plus aggressive compression brings the full submission (model + code) to 15.1MB, under the 16MB cap. - -## Metrics - -|Metric|Value| -|-|-| -|**Post-quant sliding window val\_bpb**|**1.1556**| -|Post-quant sliding window val\_loss|1.9511| -|Post-quant standard val\_bpb|1.1891| -|Post-quant standard val\_loss|2.0077| -|Quantization gap (standard eval)|\~0.0001 BPB| -|Model parameters|22,368,840| -|Artifact size (int6+zstd-22)|15,878,809 bytes (15.1 MB)| -|Train steps completed|12,047| -|Train time|600s (10.0 min)| -|Sliding window eval time|75s| -|Peak GPU memory|11,340 MiB| - -## Configuration - -``` -VOCAB\_SIZE=1024 -NUM\_LAYERS=9 -MODEL\_DIM=512 -NUM\_HEADS=8 -NUM\_KV\_HEADS=4 -MLP\_MULT=3 -TIE\_EMBEDDINGS=1 -USE\_SMEARGATE=1 -TRAIN\_SEQ\_LEN=1024 -TRAIN\_BATCH\_TOKENS=524288 -LOGIT\_SOFTCAP=30.0 -ROPE\_BASE=10000.0 -QK\_GAIN\_INIT=1.5 -BIGRAM\_HASH\_BUCKETS=4096 -BIGRAM\_HASH\_DIM=128 -TIED\_EMBED\_LR=0.030 -MATRIX\_LR=0.020 -SCALAR\_LR=0.020 -MUON\_MOMENTUM=0.99 -MUON\_MOMENTUM\_WARMUP\_START=0.92 -MUON\_MOMENTUM\_WARMUP\_STEPS=1500 -MUON\_WEIGHT\_DECAY=0.01 -MUON\_BACKEND\_STEPS=5 -WARMDOWN\_ITERS=3000 -WARMUP\_STEPS=20 -EVAL\_STRIDE=64 -MAX\_WALLCLOCK\_SECONDS=600 -SEED=1337 -``` - -## Command - -```bash -RUN\_ID=smeargate\_orthoinit\_muonwd \\ -DATA\_PATH=./data/datasets/fineweb10B\_sp1024 \\ -TOKENIZER\_PATH=./data/tokenizers/fineweb\_1024\_bpe.model \\ -torchrun --standalone --nproc\_per\_node=8 train\_gpt.py -``` - -## Hardware - -8× NVIDIA H100 80GB HBM3 SXM (RunPod). - -## - diff --git a/records/track_10min_16mb/2026-03-19_smeargate_orthoinit_muonwd/submission.json b/records/track_10min_16mb/2026-03-19_smeargate_orthoinit_muonwd/submission.json deleted file mode 100644 index 99426f02a1..0000000000 --- a/records/track_10min_16mb/2026-03-19_smeargate_orthoinit_muonwd/submission.json +++ /dev/null @@ -1,29 +0,0 @@ -{ - "val_bpb": 1.1556, - "val_loss": 1.9511, - "eval_method": "sliding_window", - "eval_stride": 64, - "quantization": "int6", - "compression": "zstd-22", - "artifact_size_bytes": 15878809, - "model_params": 22368840, - "train_steps": 12047, - "train_time_seconds": 600, - "hardware": "8xH100-80GB-SXM", - "seed": 1337, - "techniques": [ - "SmearGate", - "Orthogonal Init", - "Muon Weight Decay", - "Int6 STE QAT", - "MLP 3x", - "Sliding Window Eval", - "Bigram Hash Embedding", - "U-Net Skip Connections", - "Muon Momentum Warmup", - "zstd-22 Compression" - ], - "post_quant_standard_val_bpb": 1.1891, - "post_quant_standard_val_loss": 2.0077, - "quantization_gap_bpb": 0.0001 -} diff --git a/records/track_10min_16mb/2026-03-19_smeargate_orthoinit_muonwd/train.log b/records/track_10min_16mb/2026-03-19_smeargate_orthoinit_muonwd/train.log deleted file mode 100644 index 71fd53fb13..0000000000 --- a/records/track_10min_16mb/2026-03-19_smeargate_orthoinit_muonwd/train.log +++ /dev/null @@ -1,1594 +0,0 @@ -""" -The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. - -Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. -""" - -from __future__ import annotations - -import copy -import glob -import io -import math -import os -import random -import subprocess -import sys -import time -import uuid -import zlib -from pathlib import Path - -try: - import zstandard as zstd - HAS_ZSTD = True -except ImportError: - HAS_ZSTD = False - -import numpy as np -import sentencepiece as spm -import torch -import torch.distributed as dist -import torch.nn.functional as F -from torch import Tensor, nn -from torch.nn.parallel import DistributedDataParallel as DDP - -# ----------------------------- -# HYPERPARAMETERS -# ----------------------------- -# Default Simple Baseline run: -# - 9 transformer blocks at width 512 -# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion -# - vocab size 1024, sequence length 1024, tied embeddings -# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap - -class Hyperparameters: - # Data paths are shard globs produced by the existing preprocessing pipeline. - data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") - train_files = os.path.join(data_path, "fineweb_train_*.bin") - val_files = os.path.join(data_path, "fineweb_val_*.bin") - tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") - run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) - seed = int(os.environ.get("SEED", 1337)) - - # Validation cadence and batch size. Validation always uses the full fineweb_val split. - val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) - val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 0)) - train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) - # Sliding window eval: stride controls how many tokens to slide between windows. - # Only the last `eval_stride` tokens per window are scored, giving each scored token - # (seq_len - eval_stride) tokens of context. Set to 0 to disable (use standard eval). - eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) - - # Training length. - iterations = int(os.environ.get("ITERATIONS", 20000)) - warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3000)) - warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) - train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) - train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) - max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) - qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) - - # Model shape. - vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) - num_layers = int(os.environ.get("NUM_LAYERS", 9)) - num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) - model_dim = int(os.environ.get("MODEL_DIM", 512)) - num_heads = int(os.environ.get("NUM_HEADS", 8)) - mlp_mult = int(os.environ.get("MLP_MULT", 3)) - tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) - rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) - logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) - - # Optimizer hyperparameters. - embed_lr = float(os.environ.get("EMBED_LR", 0.6)) - head_lr = float(os.environ.get("HEAD_LR", 0.008)) - tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.030)) - tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) - matrix_lr = float(os.environ.get("MATRIX_LR", 0.020)) - scalar_lr = float(os.environ.get("SCALAR_LR", 0.020)) - muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) - muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) - muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) - muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) - muon_weight_decay = float(os.environ.get("MUON_WEIGHT_DECAY", 0.01)) - beta1 = float(os.environ.get("BETA1", 0.9)) - beta2 = float(os.environ.get("BETA2", 0.95)) - adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) - grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) - - # BigramHash: inject token-pair context via a hash-table embedding. - bigram_hash_buckets = int(os.environ.get("BIGRAM_HASH_BUCKETS", 4096)) - bigram_hash_dim = int(os.environ.get("BIGRAM_HASH_DIM", 128)) - - # SmearGate: blend each token's embedding with the previous token's - # via a tiny learned gate. Injects bigram context before the transformer. - use_smeargate = bool(int(os.environ.get("USE_SMEARGATE", "1"))) - -# ----------------------------- -# MUON OPTIMIZER -# ----------------------------- -# -# As borrowed from modded-nanogpt -# Background on Muon: https://kellerjordan.github.io/posts/muon/ - -def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: - # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. - # Muon uses this to normalize matrix-shaped gradients before applying them. - a, b, c = (3.4445, -4.7750, 2.0315) - X = G.bfloat16() - X /= X.norm() + eps - transposed = G.size(0) > G.size(1) - if transposed: - X = X.T - for _ in range(steps): - A = X @ X.T - B = b * A + c * A @ A - X = a * X + B @ X - return X.T if transposed else X - - -class Muon(torch.optim.Optimizer): - def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True, weight_decay: float = 0.0): - super().__init__( - params, - dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov, weight_decay=weight_decay), - ) - - @torch.no_grad() - def step(self, closure=None): - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - distributed = dist.is_available() and dist.is_initialized() - world_size = dist.get_world_size() if distributed else 1 - rank = dist.get_rank() if distributed else 0 - - for group in self.param_groups: - params = group["params"] - if not params: - continue - lr = group["lr"] - momentum = group["momentum"] - backend_steps = group["backend_steps"] - nesterov = group["nesterov"] - wd = group["weight_decay"] - - # Decoupled weight decay: shrink weights toward zero before the - # gradient update. This keeps the weight distribution tight and - # well-centred, which (a) improves generalisation and (b) makes - # post-training int6 quantisation compress better — the tighter - # distribution fits into fewer quantisation buckets with less error. - if wd > 0: - for p in params: - p.mul_(1.0 - wd * lr) - - total_params = sum(int(p.numel()) for p in params) - updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) - - curr = 0 - for i, p in enumerate(params): - if i % world_size == rank and p.grad is not None: - g = p.grad - state = self.state[p] - if "momentum_buffer" not in state: - state["momentum_buffer"] = torch.zeros_like(g) - buf = state["momentum_buffer"] - buf.mul_(momentum).add_(g) - if nesterov: - g = g.add(buf, alpha=momentum) - g = zeropower_via_newtonschulz5(g, steps=backend_steps) - # Scale correction from Muon reference implementations. - g *= max(1, g.size(0) / g.size(1)) ** 0.5 - updates_flat[curr : curr + p.numel()] = g.reshape(-1) - curr += p.numel() - - if distributed: - dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) - - curr = 0 - for p in params: - g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) - p.add_(g, alpha=-lr) - curr += p.numel() - - return loss - - -# ----------------------------- -# TOKENIZER-AGNOSTIC EVALUATION SETUP -# ----------------------------- -# -# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. -# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. -# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. -# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. - -def build_sentencepiece_luts( - sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device -) -> tuple[Tensor, Tensor, Tensor]: - sp_vocab_size = int(sp.vocab_size()) - table_size = max(sp_vocab_size, vocab_size) - base_bytes_np = np.zeros((table_size,), dtype=np.int16) - has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) - is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) - for token_id in range(sp_vocab_size): - if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): - continue - is_boundary_token_np[token_id] = False - if sp.is_byte(token_id): - base_bytes_np[token_id] = 1 - continue - piece = sp.id_to_piece(token_id) - if piece.startswith("▁"): - has_leading_space_np[token_id] = True - piece = piece[1:] - base_bytes_np[token_id] = len(piece.encode("utf-8")) - return ( - torch.tensor(base_bytes_np, dtype=torch.int16, device=device), - torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), - torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), - ) - - -def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: - files = [Path(p) for p in sorted(glob.glob(pattern))] - if not files: - raise FileNotFoundError(f"No files found for pattern: {pattern}") - # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. - tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() - usable = ((tokens.numel() - 1) // seq_len) * seq_len - if usable <= 0: - raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") - return tokens[: usable + 1] - - -def eval_val( - args: Hyperparameters, - model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - grad_accum_steps: int, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, -) -> tuple[float, float]: - # Validation computes two metrics: - # - val_loss: token cross-entropy (natural log) - # - val_bpb: tokenizer-agnostic compression metric used by the challenge - local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) - if local_batch_tokens < args.train_seq_len: - raise ValueError( - "VAL_BATCH_SIZE must provide at least one sequence per rank; " - f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " - f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" - ) - local_batch_seqs = local_batch_tokens // args.train_seq_len - total_seqs = (val_tokens.numel() - 1) // args.train_seq_len - seq_start = (total_seqs * rank) // world_size - seq_end = (total_seqs * (rank + 1)) // world_size - val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) - val_token_count = torch.zeros((), device=device, dtype=torch.float64) - val_byte_count = torch.zeros((), device=device, dtype=torch.float64) - - model.eval() - with torch.inference_mode(): - for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): - batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) - raw_start = batch_seq_start * args.train_seq_len - raw_end = batch_seq_end * args.train_seq_len + 1 - local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) - x = local[:-1].reshape(-1, args.train_seq_len) - y = local[1:].reshape(-1, args.train_seq_len) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - batch_loss = model(x, y).detach() - batch_token_count = float(y.numel()) - val_loss_sum += batch_loss.to(torch.float64) * batch_token_count - val_token_count += batch_token_count - prev_ids = x.reshape(-1) - tgt_ids = y.reshape(-1) - token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) - token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) - val_byte_count += token_bytes.to(torch.float64).sum() - - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) - - val_loss = val_loss_sum / val_token_count - bits_per_token = val_loss.item() / math.log(2.0) - tokens_per_byte = val_token_count.item() / val_byte_count.item() - model.train() - return float(val_loss.item()), float(bits_per_token * tokens_per_byte) - - -def eval_val_sliding_window( - args: Hyperparameters, - base_model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, - stride: int = 64, -) -> tuple[float, float]: - """Sliding window evaluation: each scored token gets (seq_len - stride) context. - - Instead of chopping validation into non-overlapping 1024-token blocks (where - the first token in each block gets zero context), we slide a 1024-token window - by `stride` tokens at a time and only score the last `stride` tokens per window. - Every scored token sees 960+ tokens of context, dramatically improving BPB. - """ - seq_len = args.train_seq_len - total_tokens = val_tokens.numel() - 1 # need 1 extra for final target - - # Number of windows: first window at offset 0, then slide by stride - num_windows = max((total_tokens - seq_len) // stride + 1, 1) - - # Distribute windows across ranks - win_start = (num_windows * rank) // world_size - win_end = (num_windows * (rank + 1)) // world_size - - val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) - val_token_count = torch.zeros((), device=device, dtype=torch.float64) - val_byte_count = torch.zeros((), device=device, dtype=torch.float64) - - # Eval batch: how many windows per forward pass. Tune for memory. - eval_batch = int(os.environ.get("SW_EVAL_BATCH", 32)) - - base_model.eval() - with torch.inference_mode(): - window_list = list(range(win_start, win_end)) - num_batches = (len(window_list) + eval_batch - 1) // eval_batch - for batch_idx in range(num_batches): - batch_wins = window_list[batch_idx * eval_batch : (batch_idx + 1) * eval_batch] - bsz = len(batch_wins) - - # Build input: each window is val_tokens[w*stride : w*stride + seq_len] - inputs = torch.stack([ - val_tokens[w * stride : w * stride + seq_len] - for w in batch_wins - ]).to(device=device, dtype=torch.int64) # [B, seq_len] - - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - logits = base_model.forward_logits(inputs) # [B, seq_len, vocab] - - # Score only the last `stride` positions per window. - # logits[:, j, :] predicts the token AFTER position j in the input. - # So logits[:, -stride:, :] predicts tokens at input positions - # [seq_len - stride + 1, seq_len + 1) — which are the targets. - scored_logits = logits[:, -stride:, :].reshape(-1, logits.size(-1)) # [B*stride, vocab] - - # Build targets: for window w, targets are val_tokens[w*stride + seq_len - stride + 1 : w*stride + seq_len + 1] - targets = torch.stack([ - val_tokens[w * stride + seq_len - stride + 1 : w * stride + seq_len + 1] - for w in batch_wins - ]).to(device=device, dtype=torch.int64).reshape(-1) # [B*stride] - - loss = F.cross_entropy(scored_logits.float(), targets, reduction="sum") - val_loss_sum += loss.to(torch.float64) - val_token_count += float(targets.numel()) - - # BPB: prev_ids are the input tokens at each scored position - prev_ids = torch.stack([ - val_tokens[w * stride + seq_len - stride : w * stride + seq_len] - for w in batch_wins - ]).to(device=device, dtype=torch.int64).reshape(-1) # [B*stride] - - tgt_ids = targets - token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) - token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) - val_byte_count += token_bytes.to(torch.float64).sum() - - if rank == 0 and (batch_idx + 1) % 100 == 0: - print(f" sw_eval batch {batch_idx + 1}/{num_batches}", flush=True) - - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) - - val_loss = val_loss_sum / val_token_count - bits_per_token = val_loss.item() / math.log(2.0) - tokens_per_byte = val_token_count.item() / val_byte_count.item() - base_model.train() - return float(val_loss.item()), float(bits_per_token * tokens_per_byte) - -# ----------------------------- -# POST-TRAINING QUANTIZATION -# ----------------------------- -# -# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. -# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. -# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. - -CONTROL_TENSOR_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "CONTROL_TENSOR_NAME_PATTERNS", - "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", - ",".join(CONTROL_TENSOR_NAME_PATTERNS), - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 -INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 -INT8_PER_ROW_SCALE_DTYPE = torch.float16 - -INT6_QUANT_RANGE = 31 # int6: [-31, 31] -INT6_CLIP_Q = 0.9999984 - -# Tensors matching these patterns are stored as fp16 passthrough (not quantized). -# For weights without STE fake-quant protection (e.g. nn.Embedding, bigram hash table). -FP16_PASSTHROUGH_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "FP16_PASSTHROUGH_PATTERNS", - "tok_emb,bigram_hash", - ).split(",") - if pattern -) - -def tensor_nbytes(t: Tensor) -> int: - return int(t.numel()) * int(t.element_size()) - -def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: - if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): - return t.float().contiguous() - if t.dtype in {torch.float32, torch.bfloat16}: - passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") - return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() - return t - -def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: - t32 = t.float() - if t32.ndim == 2: - # Int6 per-row quantization: [-31, 31] range stored in int8 containers. - # The unused high bits compress extremely well with zstd. - clip_abs = ( - torch.quantile(t32.abs(), INT6_CLIP_Q, dim=1) - if t32.numel() - else torch.empty((t32.shape[0],), dtype=torch.float32) - ) - clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) - scale = (clip_abs / float(INT6_QUANT_RANGE)).clamp_min(1.0 / float(INT6_QUANT_RANGE)) - q = torch.clamp(torch.round(clipped / scale[:, None]), -INT6_QUANT_RANGE, INT6_QUANT_RANGE).to(torch.int8).contiguous() - return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() - - # Vectors / scalars still use int8 per-tensor scale (they're tiny). - clip_abs = float(torch.quantile(t32.abs().flatten(), INT6_CLIP_Q).item()) if t32.numel() else 0.0 - scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) - q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() - return q, scale - -def quantize_state_dict_int8(state_dict: dict[str, Tensor]): - # Single supported clean-script export format: - # - per-row int8 for 2D float tensors - # - per-tensor int8 for other float tensors - # - exact passthrough for non-floats - # - passthrough for small float tensors, stored as fp16 to save bytes - quantized: dict[str, Tensor] = {} - scales: dict[str, Tensor] = {} - dtypes: dict[str, str] = {} - passthrough: dict[str, Tensor] = {} - passthrough_orig_dtypes: dict[str, str] = {} - qmeta: dict[str, dict[str, object]] = {} - stats = dict.fromkeys( - ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), - 0, - ) - - for name, tensor in state_dict.items(): - t = tensor.detach().to("cpu").contiguous() - stats["param_count"] += int(t.numel()) - stats["num_tensors"] += 1 - stats["baseline_tensor_bytes"] += tensor_nbytes(t) - - if not t.is_floating_point(): - stats["num_nonfloat_tensors"] += 1 - passthrough[name] = t - stats["int8_payload_bytes"] += tensor_nbytes(t) - continue - - # Small float tensors are cheap enough to keep directly. We still downcast - # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. - if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: - kept = keep_float_tensor(name, t, passthrough_orig_dtypes) - passthrough[name] = kept - stats["int8_payload_bytes"] += tensor_nbytes(kept) - continue - - stats["num_float_tensors"] += 1 - # fp16 passthrough for tensors without STE protection (tok_emb, bigram_hash). - if any(pattern in name for pattern in FP16_PASSTHROUGH_PATTERNS): - passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") - kept = t.to(dtype=torch.float16).contiguous() - passthrough[name] = kept - stats["int8_payload_bytes"] += tensor_nbytes(kept) - continue - q, s = quantize_float_tensor(t) - if s.ndim > 0: - qmeta[name] = {"scheme": "per_row", "axis": 0} - quantized[name] = q - scales[name] = s - dtypes[name] = str(t.dtype).removeprefix("torch.") - stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) - - obj: dict[str, object] = { - "__quant_format__": "int8_clean_per_row_v1", - "quantized": quantized, - "scales": scales, - "dtypes": dtypes, - "passthrough": passthrough, - } - if qmeta: - obj["qmeta"] = qmeta - if passthrough_orig_dtypes: - obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes - return obj, stats - -def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: - out: dict[str, Tensor] = {} - qmeta = obj.get("qmeta", {}) - passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) - for name, q in obj["quantized"].items(): - dtype = getattr(torch, obj["dtypes"][name]) - s = obj["scales"][name] - if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: - s = s.to(dtype=torch.float32) - # Broadcast the saved row scale back across trailing dimensions. - out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() - else: - scale = float(s.item()) - out[name] = (q.float() * scale).to(dtype=dtype).contiguous() - for name, t in obj["passthrough"].items(): - # Restore small tensors, undoing the temporary fp16 storage cast if needed. - out_t = t.detach().to("cpu").contiguous() - orig_dtype = passthrough_orig_dtypes.get(name) - if isinstance(orig_dtype, str): - out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() - out[name] = out_t - return out - - -# ----------------------------- -# DATA LOADING -# ----------------------------- - -def load_data_shard(file: Path) -> Tensor: - header_bytes = 256 * np.dtype(" None: - self.file_idx = (self.file_idx + 1) % len(self.files) - self.tokens = load_data_shard(self.files[self.file_idx]) - self.pos = 0 - - def take(self, n: int) -> Tensor: - chunks: list[Tensor] = [] - remaining = n - while remaining > 0: - avail = self.tokens.numel() - self.pos - if avail <= 0: - self._advance_file() - continue - k = min(remaining, avail) - chunks.append(self.tokens[self.pos : self.pos + k]) - self.pos += k - remaining -= k - return chunks[0] if len(chunks) == 1 else torch.cat(chunks) - - -class DistributedTokenLoader: - # Each call consumes a contiguous chunk from the shared token stream, then slices out - # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. - def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): - self.rank = rank - self.world_size = world_size - self.device = device - self.stream = TokenStream(pattern) - - def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: - local_tokens = global_tokens // (self.world_size * grad_accum_steps) - per_rank_span = local_tokens + 1 - chunk = self.stream.take(per_rank_span * self.world_size) - start = self.rank * per_rank_span - local = chunk[start : start + per_rank_span].to(dtype=torch.int64) - x = local[:-1].reshape(-1, seq_len) - y = local[1:].reshape(-1, seq_len) - return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) - -# ----------------------------- -# TRANSFORMER MODULES -# ----------------------------- - -class RMSNorm(nn.Module): - def __init__(self, eps: float | None = None): - super().__init__() - self.eps = eps - - def forward(self, x: Tensor) -> Tensor: - return F.rms_norm(x, (x.size(-1),), eps=self.eps) - - -class CastedLinear(nn.Linear): - # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. - # During training, apply fake int6 quantization (STE) so the model learns to be robust - # to the post-training quantization that will be applied for the 16MB artifact. - def forward(self, x: Tensor) -> Tensor: - w = self.weight.to(x.dtype) - if self.training and w.ndim == 2: - # Fake int6 per-row quantization with Straight-Through Estimator: - # Forward uses quantized weights, backward passes gradients through as-is. - with torch.no_grad(): - w32 = w.float() - clip_abs = torch.quantile(w32.abs(), INT6_CLIP_Q, dim=1).clamp_min(1e-8) - scale = clip_abs / INT6_QUANT_RANGE - w_clipped = torch.clamp(w32, -clip_abs[:, None], clip_abs[:, None]) - w_q = (torch.round(w_clipped / scale[:, None]) * scale[:, None]).to(x.dtype) - w = w + (w_q - w).detach() # STE: value of w_q, gradient of w - bias = self.bias.to(x.dtype) if self.bias is not None else None - return F.linear(x, w, bias) - - -def restore_low_dim_params_to_fp32(module: nn.Module) -> None: - # Keep small/control parameters in fp32 even when the model body runs in bf16. - with torch.no_grad(): - for name, param in module.named_parameters(): - if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: - param.data = param.data.float() - - -class Rotary(nn.Module): - # Caches cos/sin tables per sequence length on the current device. - def __init__(self, dim: int, base: float = 10000.0): - super().__init__() - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self._seq_len_cached = 0 - self._cos_cached: Tensor | None = None - self._sin_cached: Tensor | None = None - - def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: - if ( - self._cos_cached is None - or self._sin_cached is None - or self._seq_len_cached != seq_len - or self._cos_cached.device != device - ): - t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) - freqs = torch.outer(t, self.inv_freq.to(device)) - self._cos_cached = freqs.cos()[None, None, :, :] - self._sin_cached = freqs.sin()[None, None, :, :] - self._seq_len_cached = seq_len - return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) - - -def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: - half = x.size(-1) // 2 - x1, x2 = x[..., :half], x[..., half:] - return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) - - -class CausalSelfAttention(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - rope_base: float, - qk_gain_init: float, - ): - super().__init__() - if dim % num_heads != 0: - raise ValueError("model_dim must be divisible by num_heads") - if num_heads % num_kv_heads != 0: - raise ValueError("num_heads must be divisible by num_kv_heads") - self.num_heads = num_heads - self.num_kv_heads = num_kv_heads - self.head_dim = dim // num_heads - if self.head_dim % 2 != 0: - raise ValueError("head_dim must be even for RoPE") - kv_dim = self.num_kv_heads * self.head_dim - self.c_q = CastedLinear(dim, dim, bias=False) - self.c_k = CastedLinear(dim, kv_dim, bias=False) - self.c_v = CastedLinear(dim, kv_dim, bias=False) - self.proj = CastedLinear(dim, dim, bias=False) - self.proj._zero_init = True - self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) - self.rotary = Rotary(self.head_dim, base=rope_base) - - def forward(self, x: Tensor) -> Tensor: - bsz, seqlen, dim = x.shape - q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) - k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) - v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) - q = F.rms_norm(q, (q.size(-1),)) - k = F.rms_norm(k, (k.size(-1),)) - cos, sin = self.rotary(seqlen, x.device, q.dtype) - q = apply_rotary_emb(q, cos, sin) - k = apply_rotary_emb(k, cos, sin) - q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] - y = F.scaled_dot_product_attention( - q, - k, - v, - attn_mask=None, - is_causal=True, - enable_gqa=(self.num_kv_heads != self.num_heads), - ) - y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) - return self.proj(y) - - -class MLP(nn.Module): - # relu^2 MLP from the original modded-nanogpt setup - def __init__(self, dim: int, mlp_mult: int): - super().__init__() - hidden = mlp_mult * dim - self.fc = CastedLinear(dim, hidden, bias=False) - self.proj = CastedLinear(hidden, dim, bias=False) - self.proj._zero_init = True - - def forward(self, x: Tensor) -> Tensor: - x = torch.relu(self.fc(x)) - return self.proj(x.square()) - - -class Block(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: int, - rope_base: float, - qk_gain_init: float, - ): - super().__init__() - self.attn_norm = RMSNorm() - self.mlp_norm = RMSNorm() - self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) - self.mlp = MLP(dim, mlp_mult) - self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) - - def forward(self, x: Tensor, x0: Tensor) -> Tensor: - mix = self.resid_mix.to(dtype=x.dtype) - x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 - attn_out = self.attn(self.attn_norm(x)) - x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out - x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) - return x - - -class BigramHash(nn.Module): - """Hash-table embedding for token bigrams, projected to model dim. - Maps (prev_token, cur_token) pairs via a simple hash to a learned embedding. - Gives the model cheap character-pair / bigram info before attention. - """ - def __init__(self, num_buckets: int, hash_dim: int, model_dim: int): - super().__init__() - self.num_buckets = num_buckets - self.table = nn.Embedding(num_buckets, hash_dim) - self.proj = CastedLinear(hash_dim, model_dim, bias=False) - self.proj._zero_init = True - nn.init.normal_(self.table.weight, std=0.01) - - def forward(self, input_ids: Tensor) -> Tensor: - bsz, seqlen = input_ids.shape - prev_ids = torch.cat([torch.zeros(bsz, 1, dtype=input_ids.dtype, device=input_ids.device), input_ids[:, :-1]], dim=1) - h = ((prev_ids.long() * 92821 + input_ids.long()) % self.num_buckets).long() - return self.proj(self.table(h)) - - -class SmearGate(nn.Module): - """Learned per-dimension gate blending each token's embedding with the - previous token's. Injects bigram (two-token) context directly into the - embedding layer *before* the transformer starts processing. - - Normally a transformer must discover token-pair relationships through - self-attention; SmearGate provides this signal for free at ~dim params. - - Technique originated by @unnir in parameter-golf PR #102/#135. - """ - def __init__(self, dim: int): - super().__init__() - # Initialise so sigmoid(gate) ≈ 0.95 → mostly pass-through at init, - # with a small amount of previous-token blending that the model can - # learn to increase or decrease per dimension. - self.gate = nn.Parameter(torch.full((dim,), 3.0, dtype=torch.float32)) - - def forward(self, x: Tensor) -> Tensor: - g = torch.sigmoid(self.gate).to(dtype=x.dtype) - x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) - return g * x + (1.0 - g) * x_prev - - -class GPT(nn.Module): - def __init__( - self, - vocab_size: int, - num_layers: int, - model_dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: int, - tie_embeddings: bool, - tied_embed_init_std: float, - logit_softcap: float, - rope_base: float, - qk_gain_init: float, - bigram_hash_buckets: int = 0, - bigram_hash_dim: int = 128, - use_smeargate: bool = True, - ): - super().__init__() - if logit_softcap <= 0.0: - raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") - self.tie_embeddings = tie_embeddings - self.tied_embed_init_std = tied_embed_init_std - self.logit_softcap = logit_softcap - self.tok_emb = nn.Embedding(vocab_size, model_dim) - self.bigram_hash = BigramHash(bigram_hash_buckets, bigram_hash_dim, model_dim) if bigram_hash_buckets > 0 else None - self.smeargate = SmearGate(model_dim) if use_smeargate else None - self.num_encoder_layers = num_layers // 2 - self.num_decoder_layers = num_layers - self.num_encoder_layers - self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) - self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) - self.blocks = nn.ModuleList( - [ - Block( - model_dim, - num_heads, - num_kv_heads, - mlp_mult, - rope_base, - qk_gain_init, - ) - for i in range(num_layers) - ] - ) - self.final_norm = RMSNorm() - self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) - if self.lm_head is not None: - self.lm_head._zero_init = True - self._init_weights() - - def _init_weights(self) -> None: - if self.tie_embeddings: - nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) - for module in self.modules(): - if isinstance(module, nn.Linear): - if getattr(module, "_zero_init", False): - nn.init.zeros_(module.weight) - else: - # Orthogonal init: produces well-conditioned weight matrices - # whose singular values are all 1. This gives the model a - # better starting point — gradients flow more uniformly - # through orthogonal matrices, so early training steps are - # more informative, which matters when you only get ~12k - # steps in the 10-minute budget. - nn.init.orthogonal_(module.weight) - - def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: - x = self.tok_emb(input_ids) - if self.bigram_hash is not None: - x = x + self.bigram_hash(input_ids) - if self.smeargate is not None: - x = self.smeargate(x) - x = F.rms_norm(x, (x.size(-1),)) - x0 = x - skips: list[Tensor] = [] - - # First half stores skips; second half reuses them in reverse order. - for i in range(self.num_encoder_layers): - x = self.blocks[i](x, x0) - skips.append(x) - for i in range(self.num_decoder_layers): - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - x = self.blocks[self.num_encoder_layers + i](x, x0) - - x = self.final_norm(x).reshape(-1, x.size(-1)) - targets = target_ids.reshape(-1) - if self.tie_embeddings: - logits_proj = F.linear(x, self.tok_emb.weight) - else: - if self.lm_head is None: - raise RuntimeError("lm_head is required when tie_embeddings=False") - logits_proj = self.lm_head(x) - logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - return F.cross_entropy(logits.float(), targets, reduction="mean") - - def forward_logits(self, input_ids: Tensor) -> Tensor: - """Return logits without computing loss. Used for sliding window eval.""" - x = self.tok_emb(input_ids) - if self.bigram_hash is not None: - x = x + self.bigram_hash(input_ids) - if self.smeargate is not None: - x = self.smeargate(x) - x = F.rms_norm(x, (x.size(-1),)) - x0 = x - skips: list[Tensor] = [] - for i in range(self.num_encoder_layers): - x = self.blocks[i](x, x0) - skips.append(x) - for i in range(self.num_decoder_layers): - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - x = self.blocks[self.num_encoder_layers + i](x, x0) - x = self.final_norm(x) - if self.tie_embeddings: - logits_proj = F.linear(x, self.tok_emb.weight.to(x.dtype)) - else: - if self.lm_head is None: - raise RuntimeError("lm_head is required when tie_embeddings=False") - logits_proj = self.lm_head(x) - return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - - -# ----------------------------- -# TRAINING -# ----------------------------- - -def main() -> None: - global zeropower_via_newtonschulz5 - - code = Path(__file__).read_text(encoding="utf-8") - args = Hyperparameters() - zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) - - # ----------------------------- - # DISTRIBUTED + CUDA SETUP - # ----------------------------- - - distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ - rank = int(os.environ.get("RANK", "0")) - world_size = int(os.environ.get("WORLD_SIZE", "1")) - local_rank = int(os.environ.get("LOCAL_RANK", "0")) - if world_size <= 0: - raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") - if 8 % world_size != 0: - raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") - grad_accum_steps = 8 // world_size - grad_scale = 1.0 / grad_accum_steps - if not torch.cuda.is_available(): - raise RuntimeError("CUDA is required") - device = torch.device("cuda", local_rank) - torch.cuda.set_device(device) - if distributed: - dist.init_process_group(backend="nccl", device_id=device) - dist.barrier() - master_process = rank == 0 - - # Fast math knobs - torch.backends.cuda.matmul.allow_tf32 = True - torch.backends.cudnn.allow_tf32 = True - from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp - - enable_cudnn_sdp(False) - enable_flash_sdp(True) - enable_mem_efficient_sdp(False) - enable_math_sdp(False) - - logfile = None - if master_process: - os.makedirs("logs", exist_ok=True) - logfile = f"logs/{args.run_id}.txt" - print(logfile) - - def log0(msg: str, console: bool = True) -> None: - if not master_process: - return - if console: - print(msg) - if logfile is not None: - with open(logfile, "a", encoding="utf-8") as f: - print(msg, file=f) - - log0(code, console=False) - log0("=" * 100, console=False) - log0(f"Running Python {sys.version}", console=False) - log0(f"Running PyTorch {torch.__version__}", console=False) - log0( - subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, - console=False, - ) - log0("=" * 100, console=False) - - # ----------------------------- - # TOKENIZER + VALIDATION METRIC SETUP - # ----------------------------- - - random.seed(args.seed) - np.random.seed(args.seed) - torch.manual_seed(args.seed) - torch.cuda.manual_seed_all(args.seed) - - if not args.tokenizer_path.endswith(".model"): - raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") - sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) - if int(sp.vocab_size()) != args.vocab_size: - raise ValueError( - f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" - ) - dataset_dir = Path(args.data_path).resolve() - actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) - val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) - base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( - sp, args.vocab_size, device - ) - log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") - log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") - log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") - - # ----------------------------- - # MODEL + OPTIMIZER SETUP - # ----------------------------- - - base_model = GPT( - vocab_size=args.vocab_size, - num_layers=args.num_layers, - model_dim=args.model_dim, - num_heads=args.num_heads, - num_kv_heads=args.num_kv_heads, - mlp_mult=args.mlp_mult, - tie_embeddings=args.tie_embeddings, - tied_embed_init_std=args.tied_embed_init_std, - logit_softcap=args.logit_softcap, - rope_base=args.rope_base, - qk_gain_init=args.qk_gain_init, - bigram_hash_buckets=args.bigram_hash_buckets, - bigram_hash_dim=args.bigram_hash_dim, - use_smeargate=args.use_smeargate, - ).to(device).bfloat16() - for module in base_model.modules(): - if isinstance(module, CastedLinear): - module.float() - restore_low_dim_params_to_fp32(base_model) - compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) - model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model - - # Optimizer split: - # - token embedding (Adam) uses EMBED_LR - # - untied lm_head (Adam) uses HEAD_LR - # - matrix params in transformer blocks use MATRIX_LR via Muon - # - vectors/scalars use SCALAR_LR via Adam - block_named_params = list(base_model.blocks.named_parameters()) - matrix_params = [ - p - for name, p in block_named_params - if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - scalar_params = [ - p - for name, p in block_named_params - if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - if base_model.skip_weights.numel() > 0: - scalar_params.append(base_model.skip_weights) - token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr - # Collect embedding-like params - embed_params = [base_model.tok_emb.weight] - if base_model.bigram_hash is not None: - embed_params.append(base_model.bigram_hash.table.weight) - # bigram_hash.proj is a CastedLinear — its weight goes to Muon - matrix_params.append(base_model.bigram_hash.proj.weight) - optimizer_tok = torch.optim.Adam( - [{"params": embed_params, "lr": token_lr, "base_lr": token_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizer_muon = Muon( - matrix_params, - lr=args.matrix_lr, - momentum=args.muon_momentum, - backend_steps=args.muon_backend_steps, - weight_decay=args.muon_weight_decay, - ) - for group in optimizer_muon.param_groups: - group["base_lr"] = args.matrix_lr - # SmearGate gate is a 1D parameter → goes to scalar optimizer - if base_model.smeargate is not None: - scalar_params.append(base_model.smeargate.gate) - optimizer_scalar = torch.optim.Adam( - [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] - if base_model.lm_head is not None: - optimizer_head = torch.optim.Adam( - [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizers.insert(1, optimizer_head) - - n_params = sum(p.numel() for p in base_model.parameters()) - log0(f"model_params:{n_params}") - log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") - log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") - log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") - log0( - f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " - f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " - f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" - ) - log0( - f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " - f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " - f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" - ) - log0(f"seed:{args.seed}") - - # ----------------------------- - # DATA LOADER & MODEL WARMUP - # ----------------------------- - - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - def zero_grad_all() -> None: - for opt in optimizers: - opt.zero_grad(set_to_none=True) - - max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None - - def lr_mul(step: int, elapsed_ms: float) -> float: - if args.warmdown_iters <= 0: - return 1.0 - if max_wallclock_ms is None: - warmdown_start = max(args.iterations - args.warmdown_iters, 0) - return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 - step_ms = elapsed_ms / max(step, 1) - warmdown_ms = args.warmdown_iters * step_ms - remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) - return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 - - # Warmup primes the compiled forward/backward/optimizer paths, then we restore the - # initial weights/optimizer state so measured training starts from the true init. - if args.warmup_steps > 0: - initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} - initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] - model.train() - for warmup_step in range(args.warmup_steps): - zero_grad_all() - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - warmup_loss = model(x, y) - (warmup_loss * grad_scale).backward() - for opt in optimizers: - opt.step() - zero_grad_all() - if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: - log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") - base_model.load_state_dict(initial_model_state, strict=True) - for opt, state in zip(optimizers, initial_optimizer_states, strict=True): - opt.load_state_dict(state) - zero_grad_all() - if distributed: - model.require_backward_grad_sync = True - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - # ----------------------------- - # MAIN TRAINING LOOP - # ----------------------------- - - training_time_ms = 0.0 - stop_after_step: int | None = None - torch.cuda.synchronize() - t0 = time.perf_counter() - - step = 0 - while True: - last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) - - should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) - if should_validate: - torch.cuda.synchronize() - training_time_ms += 1000.0 * (time.perf_counter() - t0) - val_loss, val_bpb = eval_val( - args, - model, - rank, - world_size, - device, - grad_accum_steps, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, - ) - log0( - f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " - f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" - ) - torch.cuda.synchronize() - t0 = time.perf_counter() - - if last_step: - if stop_after_step is not None and step < args.iterations: - log0( - f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " - f"step:{step}/{args.iterations}" - ) - break - - elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - scale = lr_mul(step, elapsed_ms) - zero_grad_all() - train_loss = torch.zeros((), device=device) - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - loss = model(x, y) - train_loss += loss.detach() - (loss * grad_scale).backward() - train_loss /= grad_accum_steps - - frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 - muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum - for group in optimizer_muon.param_groups: - group["momentum"] = muon_momentum - - for opt in optimizers: - for group in opt.param_groups: - group["lr"] = group["base_lr"] * scale - - if args.grad_clip_norm > 0: - torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) - for opt in optimizers: - opt.step() - zero_grad_all() - - step += 1 - approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - should_log_train = ( - args.train_log_every > 0 - and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) - ) - if should_log_train: - log0( - f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " - f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" - ) - - # Needed to sync whether we've reached the wallclock cap. - reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms - if distributed and max_wallclock_ms is not None: - reached_cap_tensor = torch.tensor(int(reached_cap), device=device) - dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) - reached_cap = bool(reached_cap_tensor.item()) - if stop_after_step is None and reached_cap: - stop_after_step = step - - log0( - f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " - f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" - ) - - # ----------------------------- - # SERIALIZATION + ROUNDTRIP VALIDATION - # ----------------------------- - # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce - # the compressed int8+zlib artifact and validate the round-tripped weights. - - if master_process: - torch.save(base_model.state_dict(), "final_model.pt") - model_bytes = os.path.getsize("final_model.pt") - code_bytes = len(code.encode("utf-8")) - log0(f"Serialized model: {model_bytes} bytes") - log0(f"Code size: {code_bytes} bytes") - log0(f"Total submission size: {model_bytes + code_bytes} bytes") - - quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) - quant_buf = io.BytesIO() - torch.save(quant_obj, quant_buf) - quant_raw = quant_buf.getvalue() - if HAS_ZSTD: - cctx = zstd.ZstdCompressor(level=22) - quant_blob = cctx.compress(quant_raw) - compress_name = "zstd-22" - else: - quant_blob = zlib.compress(quant_raw, level=9) - compress_name = "zlib-9" - quant_raw_bytes = len(quant_raw) - if master_process: - quant_ext = "int6.ptz" if HAS_ZSTD else "int6.ptz" - with open(f"final_model.{quant_ext}", "wb") as f: - f.write(quant_blob) - quant_file_bytes = os.path.getsize(f"final_model.{quant_ext}") - code_bytes = len(code.encode("utf-8")) - ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) - log0( - f"Serialized model int6+{compress_name}: {quant_file_bytes} bytes " - f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" - ) - log0(f"Total submission size int6+{compress_name}: {quant_file_bytes + code_bytes} bytes") - - if distributed: - dist.barrier() - quant_ext = "int6.ptz" - with open(f"final_model.{quant_ext}", "rb") as f: - quant_blob_disk = f.read() - if HAS_ZSTD: - dctx = zstd.ZstdDecompressor() - quant_decompressed = dctx.decompress(quant_blob_disk) - else: - quant_decompressed = zlib.decompress(quant_blob_disk) - quant_state = torch.load(io.BytesIO(quant_decompressed), map_location="cpu") - base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) - torch.cuda.synchronize() - t_qeval = time.perf_counter() - q_val_loss, q_val_bpb = eval_val( - args, - model, - rank, - world_size, - device, - grad_accum_steps, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, - ) - torch.cuda.synchronize() - log0( - f"final_int6_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " - f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" - ) - log0(f"final_int6_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") - - # Sliding window eval: gives each scored token (seq_len - stride) context. - if args.eval_stride > 0: - torch.cuda.synchronize() - t_sw = time.perf_counter() - sw_val_loss, sw_val_bpb = eval_val_sliding_window( - args, - base_model, - rank, - world_size, - device, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, - stride=args.eval_stride, - ) - torch.cuda.synchronize() - log0( - f"final_sliding_window_eval stride:{args.eval_stride} " - f"val_loss:{sw_val_loss:.4f} val_bpb:{sw_val_bpb:.4f} " - f"eval_time:{1000.0 * (time.perf_counter() - t_sw):.0f}ms" - ) - log0( - f"final_sliding_window_eval_exact stride:{args.eval_stride} " - f"val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}" - ) - - if distributed: - dist.destroy_process_group() - - -if __name__ == "__main__": - main() -==================================================================================================== -Running Python 3.12.3 (main, Nov 6 2025, 13:44:16) [GCC 13.3.0] -Running PyTorch 2.9.1+cu128 -Fri Mar 20 02:59:26 2026 -+-----------------------------------------------------------------------------------------+ -| NVIDIA-SMI 570.195.03 Driver Version: 570.195.03 CUDA Version: 12.8 | -|-----------------------------------------+------------------------+----------------------+ -| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | -| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | -| | | MIG M. | -|=========================================+========================+======================| -| 0 NVIDIA H100 80GB HBM3 On | 00000000:05:00.0 Off | 0 | -| N/A 34C P0 117W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 1 NVIDIA H100 80GB HBM3 On | 00000000:06:00.0 Off | 0 | -| N/A 30C P0 114W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 2 NVIDIA H100 80GB HBM3 On | 00000000:07:00.0 Off | 0 | -| N/A 29C P0 117W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 3 NVIDIA H100 80GB HBM3 On | 00000000:08:00.0 Off | 0 | -| N/A 34C P0 120W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 4 NVIDIA H100 80GB HBM3 On | 00000000:09:00.0 Off | 0 | -| N/A 34C P0 118W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 5 NVIDIA H100 80GB HBM3 On | 00000000:0A:00.0 Off | 0 | -| N/A 30C P0 115W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 6 NVIDIA H100 80GB HBM3 On | 00000000:0B:00.0 Off | 0 | -| N/A 30C P0 120W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 7 NVIDIA H100 80GB HBM3 On | 00000000:0C:00.0 Off | 0 | -| N/A 35C P0 123W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ - -+-----------------------------------------------------------------------------------------+ -| Processes: | -| GPU GI CI PID Type Process name GPU Memory | -| ID ID Usage | -|=========================================================================================| -| 0 N/A N/A 63083 C /usr/local/bin/python 1510MiB | -| 1 N/A N/A 63084 C /usr/local/bin/python 1510MiB | -| 2 N/A N/A 63085 C /usr/local/bin/python 1510MiB | -| 3 N/A N/A 63086 C /usr/local/bin/python 1510MiB | -| 4 N/A N/A 63087 C /usr/local/bin/python 1510MiB | -| 5 N/A N/A 63088 C /usr/local/bin/python 1510MiB | -| 6 N/A N/A 63089 C /usr/local/bin/python 1510MiB | -| 7 N/A N/A 63090 C /usr/local/bin/python 1510MiB | -+-----------------------------------------------------------------------------------------+ - -==================================================================================================== -val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model -train_loader:dataset:fineweb10B_sp1024 train_shards:80 -val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 -model_params:22368840 -world_size:8 grad_accum_steps:1 -sdp_backends:cudnn=False flash=True mem_efficient=False math=False -attention_mode:gqa num_heads:8 num_kv_heads:4 -tie_embeddings:True embed_lr:0.03 head_lr:0.0 matrix_lr:0.02 scalar_lr:0.02 -train_batch_tokens:524288 train_seq_len:1024 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 -seed:1337 -warmup_step:1/20 -warmup_step:2/20 -warmup_step:3/20 -warmup_step:4/20 -warmup_step:5/20 -warmup_step:6/20 -warmup_step:7/20 -warmup_step:8/20 -warmup_step:9/20 -warmup_step:10/20 -warmup_step:11/20 -warmup_step:12/20 -warmup_step:13/20 -warmup_step:14/20 -warmup_step:15/20 -warmup_step:16/20 -warmup_step:17/20 -warmup_step:18/20 -warmup_step:19/20 -warmup_step:20/20 -step:1/20000 train_loss:6.9360 train_time:54ms step_avg:54.31ms -step:2/20000 train_loss:11.9110 train_time:95ms step_avg:47.40ms -step:3/20000 train_loss:6.9870 train_time:150ms step_avg:49.93ms -step:4/20000 train_loss:6.5188 train_time:207ms step_avg:51.71ms -step:5/20000 train_loss:7.0146 train_time:256ms step_avg:51.19ms -step:6/20000 train_loss:7.6320 train_time:312ms step_avg:52.04ms -step:7/20000 train_loss:6.7726 train_time:366ms step_avg:52.29ms -step:8/20000 train_loss:6.4472 train_time:424ms step_avg:52.99ms -step:9/20000 train_loss:6.2283 train_time:476ms step_avg:52.87ms -step:10/20000 train_loss:6.0399 train_time:533ms step_avg:53.32ms -step:200/20000 train_loss:2.8733 train_time:9951ms step_avg:49.75ms -step:400/20000 train_loss:2.3103 train_time:19855ms step_avg:49.64ms -step:600/20000 train_loss:2.5058 train_time:29786ms step_avg:49.64ms -step:800/20000 train_loss:2.2640 train_time:39699ms step_avg:49.62ms -step:1000/20000 train_loss:2.3549 train_time:49633ms step_avg:49.63ms -step:1200/20000 train_loss:2.3742 train_time:59568ms step_avg:49.64ms -step:1400/20000 train_loss:2.4250 train_time:69496ms step_avg:49.64ms -step:1600/20000 train_loss:2.0981 train_time:79428ms step_avg:49.64ms -step:1800/20000 train_loss:2.1951 train_time:89342ms step_avg:49.63ms -step:2000/20000 train_loss:2.2342 train_time:99282ms step_avg:49.64ms -step:2200/20000 train_loss:2.0520 train_time:109218ms step_avg:49.64ms -step:2400/20000 train_loss:2.1752 train_time:119147ms step_avg:49.64ms -step:2600/20000 train_loss:2.3816 train_time:129101ms step_avg:49.65ms -step:2800/20000 train_loss:2.2040 train_time:139034ms step_avg:49.65ms -step:3000/20000 train_loss:2.1976 train_time:148962ms step_avg:49.65ms -step:3200/20000 train_loss:2.1568 train_time:158936ms step_avg:49.67ms -step:3400/20000 train_loss:2.1293 train_time:168883ms step_avg:49.67ms -step:3600/20000 train_loss:2.0805 train_time:178828ms step_avg:49.67ms -step:3800/20000 train_loss:2.1861 train_time:188749ms step_avg:49.67ms -step:4000/20000 train_loss:2.1296 train_time:198681ms step_avg:49.67ms -step:4200/20000 train_loss:2.1472 train_time:208700ms step_avg:49.69ms -step:4400/20000 train_loss:2.0760 train_time:218620ms step_avg:49.69ms -step:4600/20000 train_loss:1.9327 train_time:228570ms step_avg:49.69ms -step:4800/20000 train_loss:2.2263 train_time:238524ms step_avg:49.69ms -step:5000/20000 train_loss:1.9942 train_time:248446ms step_avg:49.69ms -step:5200/20000 train_loss:2.1393 train_time:258372ms step_avg:49.69ms -step:5400/20000 train_loss:2.1549 train_time:268304ms step_avg:49.69ms -step:5600/20000 train_loss:2.1469 train_time:278242ms step_avg:49.69ms -step:5800/20000 train_loss:2.1108 train_time:288216ms step_avg:49.69ms -step:6000/20000 train_loss:2.1936 train_time:298205ms step_avg:49.70ms -step:6200/20000 train_loss:2.0551 train_time:308176ms step_avg:49.71ms -step:6400/20000 train_loss:2.1305 train_time:318155ms step_avg:49.71ms -step:6600/20000 train_loss:2.0871 train_time:328123ms step_avg:49.72ms -step:6800/20000 train_loss:2.1626 train_time:338112ms step_avg:49.72ms -step:7000/20000 train_loss:2.1948 train_time:348088ms step_avg:49.73ms -step:7200/20000 train_loss:2.1656 train_time:358079ms step_avg:49.73ms -step:7400/20000 train_loss:2.0888 train_time:368049ms step_avg:49.74ms -step:7600/20000 train_loss:1.9626 train_time:378011ms step_avg:49.74ms -step:7800/20000 train_loss:2.1138 train_time:387989ms step_avg:49.74ms -step:8000/20000 train_loss:2.0847 train_time:397958ms step_avg:49.74ms -step:8200/20000 train_loss:2.1603 train_time:407935ms step_avg:49.75ms -step:8400/20000 train_loss:2.1027 train_time:418010ms step_avg:49.76ms -step:8600/20000 train_loss:2.1065 train_time:427980ms step_avg:49.77ms -step:8800/20000 train_loss:2.0714 train_time:437943ms step_avg:49.77ms -step:9000/20000 train_loss:1.9954 train_time:447930ms step_avg:49.77ms -step:9200/20000 train_loss:2.0574 train_time:458197ms step_avg:49.80ms -step:9400/20000 train_loss:2.0975 train_time:468123ms step_avg:49.80ms -step:9600/20000 train_loss:2.1099 train_time:478053ms step_avg:49.80ms -step:9800/20000 train_loss:2.0284 train_time:487968ms step_avg:49.79ms -step:10000/20000 train_loss:2.0685 train_time:497890ms step_avg:49.79ms -step:10200/20000 train_loss:2.0115 train_time:507802ms step_avg:49.78ms -step:10400/20000 train_loss:2.0397 train_time:517716ms step_avg:49.78ms -step:10600/20000 train_loss:1.9194 train_time:527650ms step_avg:49.78ms -step:10800/20000 train_loss:2.1216 train_time:537617ms step_avg:49.78ms -step:11000/20000 train_loss:2.0445 train_time:547588ms step_avg:49.78ms -step:11200/20000 train_loss:1.9996 train_time:557571ms step_avg:49.78ms -step:11400/20000 train_loss:1.9753 train_time:567547ms step_avg:49.78ms -step:11600/20000 train_loss:1.9778 train_time:577534ms step_avg:49.79ms -step:11800/20000 train_loss:2.0026 train_time:587535ms step_avg:49.79ms -step:12000/20000 train_loss:1.9772 train_time:597520ms step_avg:49.79ms -step:12047/20000 val_loss:2.0079 val_bpb:1.1892 train_time:599974ms step_avg:49.80ms -stopping_early: wallclock_cap train_time:599974ms step:12047/20000 -peak memory allocated: 11340 MiB reserved: 11504 MiB -Serialized model: 87413210 bytes -Code size: 61629 bytes -Total submission size: 87474839 bytes -Serialized model int6+zstd-22: 15817180 bytes (payload:23609632 raw_torch:23655445 payload_ratio:3.70x) -Total submission size int6+zstd-22: 15878809 bytes -final_int6_roundtrip val_loss:2.0077 val_bpb:1.1891 eval_time:1553ms -final_int6_roundtrip_exact val_loss:2.00771951 val_bpb:1.18908459 -final_sliding_window_eval stride:64 val_loss:1.9511 val_bpb:1.1556 eval_time:75096ms -final_sliding_window_eval_exact stride:64 val_loss:1.95110428 val_bpb:1.15555486 diff --git a/records/track_10min_16mb/2026-03-19_smeargate_orthoinit_muonwd/train_gpt_v5.py b/records/track_10min_16mb/2026-03-19_smeargate_orthoinit_muonwd/train_gpt_v5.py deleted file mode 100644 index ed29085c15..0000000000 --- a/records/track_10min_16mb/2026-03-19_smeargate_orthoinit_muonwd/train_gpt_v5.py +++ /dev/null @@ -1,1422 +0,0 @@ -""" -The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. - -Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. -""" - -from __future__ import annotations - -import copy -import glob -import io -import math -import os -import random -import subprocess -import sys -import time -import uuid -import zlib -from pathlib import Path - -try: - import zstandard as zstd - HAS_ZSTD = True -except ImportError: - HAS_ZSTD = False - -import numpy as np -import sentencepiece as spm -import torch -import torch.distributed as dist -import torch.nn.functional as F -from torch import Tensor, nn -from torch.nn.parallel import DistributedDataParallel as DDP - -# ----------------------------- -# HYPERPARAMETERS -# ----------------------------- -# Default Simple Baseline run: -# - 9 transformer blocks at width 512 -# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion -# - vocab size 1024, sequence length 1024, tied embeddings -# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap - -class Hyperparameters: - # Data paths are shard globs produced by the existing preprocessing pipeline. - data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") - train_files = os.path.join(data_path, "fineweb_train_*.bin") - val_files = os.path.join(data_path, "fineweb_val_*.bin") - tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") - run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) - seed = int(os.environ.get("SEED", 1337)) - - # Validation cadence and batch size. Validation always uses the full fineweb_val split. - val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) - val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 0)) - train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) - # Sliding window eval: stride controls how many tokens to slide between windows. - # Only the last `eval_stride` tokens per window are scored, giving each scored token - # (seq_len - eval_stride) tokens of context. Set to 0 to disable (use standard eval). - eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) - - # Training length. - iterations = int(os.environ.get("ITERATIONS", 20000)) - warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3000)) - warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) - train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) - train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) - max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) - qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) - - # Model shape. - vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) - num_layers = int(os.environ.get("NUM_LAYERS", 9)) - num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) - model_dim = int(os.environ.get("MODEL_DIM", 512)) - num_heads = int(os.environ.get("NUM_HEADS", 8)) - mlp_mult = int(os.environ.get("MLP_MULT", 3)) - tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) - rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) - logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) - - # Optimizer hyperparameters. - embed_lr = float(os.environ.get("EMBED_LR", 0.6)) - head_lr = float(os.environ.get("HEAD_LR", 0.008)) - tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.030)) - tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) - matrix_lr = float(os.environ.get("MATRIX_LR", 0.020)) - scalar_lr = float(os.environ.get("SCALAR_LR", 0.020)) - muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) - muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) - muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) - muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) - muon_weight_decay = float(os.environ.get("MUON_WEIGHT_DECAY", 0.01)) - beta1 = float(os.environ.get("BETA1", 0.9)) - beta2 = float(os.environ.get("BETA2", 0.95)) - adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) - grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) - - # BigramHash: inject token-pair context via a hash-table embedding. - bigram_hash_buckets = int(os.environ.get("BIGRAM_HASH_BUCKETS", 4096)) - bigram_hash_dim = int(os.environ.get("BIGRAM_HASH_DIM", 128)) - - # SmearGate: blend each token's embedding with the previous token's - # via a tiny learned gate. Injects bigram context before the transformer. - use_smeargate = bool(int(os.environ.get("USE_SMEARGATE", "1"))) - -# ----------------------------- -# MUON OPTIMIZER -# ----------------------------- -# -# As borrowed from modded-nanogpt -# Background on Muon: https://kellerjordan.github.io/posts/muon/ - -def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: - # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. - # Muon uses this to normalize matrix-shaped gradients before applying them. - a, b, c = (3.4445, -4.7750, 2.0315) - X = G.bfloat16() - X /= X.norm() + eps - transposed = G.size(0) > G.size(1) - if transposed: - X = X.T - for _ in range(steps): - A = X @ X.T - B = b * A + c * A @ A - X = a * X + B @ X - return X.T if transposed else X - - -class Muon(torch.optim.Optimizer): - def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True, weight_decay: float = 0.0): - super().__init__( - params, - dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov, weight_decay=weight_decay), - ) - - @torch.no_grad() - def step(self, closure=None): - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - distributed = dist.is_available() and dist.is_initialized() - world_size = dist.get_world_size() if distributed else 1 - rank = dist.get_rank() if distributed else 0 - - for group in self.param_groups: - params = group["params"] - if not params: - continue - lr = group["lr"] - momentum = group["momentum"] - backend_steps = group["backend_steps"] - nesterov = group["nesterov"] - wd = group["weight_decay"] - - # Decoupled weight decay: shrink weights toward zero before the - # gradient update. This keeps the weight distribution tight and - # well-centred, which (a) improves generalisation and (b) makes - # post-training int6 quantisation compress better — the tighter - # distribution fits into fewer quantisation buckets with less error. - if wd > 0: - for p in params: - p.mul_(1.0 - wd * lr) - - total_params = sum(int(p.numel()) for p in params) - updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) - - curr = 0 - for i, p in enumerate(params): - if i % world_size == rank and p.grad is not None: - g = p.grad - state = self.state[p] - if "momentum_buffer" not in state: - state["momentum_buffer"] = torch.zeros_like(g) - buf = state["momentum_buffer"] - buf.mul_(momentum).add_(g) - if nesterov: - g = g.add(buf, alpha=momentum) - g = zeropower_via_newtonschulz5(g, steps=backend_steps) - # Scale correction from Muon reference implementations. - g *= max(1, g.size(0) / g.size(1)) ** 0.5 - updates_flat[curr : curr + p.numel()] = g.reshape(-1) - curr += p.numel() - - if distributed: - dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) - - curr = 0 - for p in params: - g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) - p.add_(g, alpha=-lr) - curr += p.numel() - - return loss - - -# ----------------------------- -# TOKENIZER-AGNOSTIC EVALUATION SETUP -# ----------------------------- -# -# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. -# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. -# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. -# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. - -def build_sentencepiece_luts( - sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device -) -> tuple[Tensor, Tensor, Tensor]: - sp_vocab_size = int(sp.vocab_size()) - table_size = max(sp_vocab_size, vocab_size) - base_bytes_np = np.zeros((table_size,), dtype=np.int16) - has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) - is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) - for token_id in range(sp_vocab_size): - if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): - continue - is_boundary_token_np[token_id] = False - if sp.is_byte(token_id): - base_bytes_np[token_id] = 1 - continue - piece = sp.id_to_piece(token_id) - if piece.startswith("▁"): - has_leading_space_np[token_id] = True - piece = piece[1:] - base_bytes_np[token_id] = len(piece.encode("utf-8")) - return ( - torch.tensor(base_bytes_np, dtype=torch.int16, device=device), - torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), - torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), - ) - - -def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: - files = [Path(p) for p in sorted(glob.glob(pattern))] - if not files: - raise FileNotFoundError(f"No files found for pattern: {pattern}") - # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. - tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() - usable = ((tokens.numel() - 1) // seq_len) * seq_len - if usable <= 0: - raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") - return tokens[: usable + 1] - - -def eval_val( - args: Hyperparameters, - model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - grad_accum_steps: int, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, -) -> tuple[float, float]: - # Validation computes two metrics: - # - val_loss: token cross-entropy (natural log) - # - val_bpb: tokenizer-agnostic compression metric used by the challenge - local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) - if local_batch_tokens < args.train_seq_len: - raise ValueError( - "VAL_BATCH_SIZE must provide at least one sequence per rank; " - f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " - f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" - ) - local_batch_seqs = local_batch_tokens // args.train_seq_len - total_seqs = (val_tokens.numel() - 1) // args.train_seq_len - seq_start = (total_seqs * rank) // world_size - seq_end = (total_seqs * (rank + 1)) // world_size - val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) - val_token_count = torch.zeros((), device=device, dtype=torch.float64) - val_byte_count = torch.zeros((), device=device, dtype=torch.float64) - - model.eval() - with torch.inference_mode(): - for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): - batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) - raw_start = batch_seq_start * args.train_seq_len - raw_end = batch_seq_end * args.train_seq_len + 1 - local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) - x = local[:-1].reshape(-1, args.train_seq_len) - y = local[1:].reshape(-1, args.train_seq_len) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - batch_loss = model(x, y).detach() - batch_token_count = float(y.numel()) - val_loss_sum += batch_loss.to(torch.float64) * batch_token_count - val_token_count += batch_token_count - prev_ids = x.reshape(-1) - tgt_ids = y.reshape(-1) - token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) - token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) - val_byte_count += token_bytes.to(torch.float64).sum() - - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) - - val_loss = val_loss_sum / val_token_count - bits_per_token = val_loss.item() / math.log(2.0) - tokens_per_byte = val_token_count.item() / val_byte_count.item() - model.train() - return float(val_loss.item()), float(bits_per_token * tokens_per_byte) - - -def eval_val_sliding_window( - args: Hyperparameters, - base_model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, - stride: int = 64, -) -> tuple[float, float]: - """Sliding window evaluation: each scored token gets (seq_len - stride) context. - - Instead of chopping validation into non-overlapping 1024-token blocks (where - the first token in each block gets zero context), we slide a 1024-token window - by `stride` tokens at a time and only score the last `stride` tokens per window. - Every scored token sees 960+ tokens of context, dramatically improving BPB. - """ - seq_len = args.train_seq_len - total_tokens = val_tokens.numel() - 1 # need 1 extra for final target - - # Number of windows: first window at offset 0, then slide by stride - num_windows = max((total_tokens - seq_len) // stride + 1, 1) - - # Distribute windows across ranks - win_start = (num_windows * rank) // world_size - win_end = (num_windows * (rank + 1)) // world_size - - val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) - val_token_count = torch.zeros((), device=device, dtype=torch.float64) - val_byte_count = torch.zeros((), device=device, dtype=torch.float64) - - # Eval batch: how many windows per forward pass. Tune for memory. - eval_batch = int(os.environ.get("SW_EVAL_BATCH", 32)) - - base_model.eval() - with torch.inference_mode(): - window_list = list(range(win_start, win_end)) - num_batches = (len(window_list) + eval_batch - 1) // eval_batch - for batch_idx in range(num_batches): - batch_wins = window_list[batch_idx * eval_batch : (batch_idx + 1) * eval_batch] - bsz = len(batch_wins) - - # Build input: each window is val_tokens[w*stride : w*stride + seq_len] - inputs = torch.stack([ - val_tokens[w * stride : w * stride + seq_len] - for w in batch_wins - ]).to(device=device, dtype=torch.int64) # [B, seq_len] - - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - logits = base_model.forward_logits(inputs) # [B, seq_len, vocab] - - # Score only the last `stride` positions per window. - # logits[:, j, :] predicts the token AFTER position j in the input. - # So logits[:, -stride:, :] predicts tokens at input positions - # [seq_len - stride + 1, seq_len + 1) — which are the targets. - scored_logits = logits[:, -stride:, :].reshape(-1, logits.size(-1)) # [B*stride, vocab] - - # Build targets: for window w, targets are val_tokens[w*stride + seq_len - stride + 1 : w*stride + seq_len + 1] - targets = torch.stack([ - val_tokens[w * stride + seq_len - stride + 1 : w * stride + seq_len + 1] - for w in batch_wins - ]).to(device=device, dtype=torch.int64).reshape(-1) # [B*stride] - - loss = F.cross_entropy(scored_logits.float(), targets, reduction="sum") - val_loss_sum += loss.to(torch.float64) - val_token_count += float(targets.numel()) - - # BPB: prev_ids are the input tokens at each scored position - prev_ids = torch.stack([ - val_tokens[w * stride + seq_len - stride : w * stride + seq_len] - for w in batch_wins - ]).to(device=device, dtype=torch.int64).reshape(-1) # [B*stride] - - tgt_ids = targets - token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) - token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) - val_byte_count += token_bytes.to(torch.float64).sum() - - if rank == 0 and (batch_idx + 1) % 100 == 0: - print(f" sw_eval batch {batch_idx + 1}/{num_batches}", flush=True) - - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) - - val_loss = val_loss_sum / val_token_count - bits_per_token = val_loss.item() / math.log(2.0) - tokens_per_byte = val_token_count.item() / val_byte_count.item() - base_model.train() - return float(val_loss.item()), float(bits_per_token * tokens_per_byte) - -# ----------------------------- -# POST-TRAINING QUANTIZATION -# ----------------------------- -# -# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. -# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. -# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. - -CONTROL_TENSOR_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "CONTROL_TENSOR_NAME_PATTERNS", - "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", - ",".join(CONTROL_TENSOR_NAME_PATTERNS), - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 -INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 -INT8_PER_ROW_SCALE_DTYPE = torch.float16 - -INT6_QUANT_RANGE = 31 # int6: [-31, 31] -INT6_CLIP_Q = 0.9999984 - -# Tensors matching these patterns are stored as fp16 passthrough (not quantized). -# For weights without STE fake-quant protection (e.g. nn.Embedding, bigram hash table). -FP16_PASSTHROUGH_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "FP16_PASSTHROUGH_PATTERNS", - "tok_emb,bigram_hash", - ).split(",") - if pattern -) - -def tensor_nbytes(t: Tensor) -> int: - return int(t.numel()) * int(t.element_size()) - -def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: - if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): - return t.float().contiguous() - if t.dtype in {torch.float32, torch.bfloat16}: - passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") - return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() - return t - -def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: - t32 = t.float() - if t32.ndim == 2: - # Int6 per-row quantization: [-31, 31] range stored in int8 containers. - # The unused high bits compress extremely well with zstd. - clip_abs = ( - torch.quantile(t32.abs(), INT6_CLIP_Q, dim=1) - if t32.numel() - else torch.empty((t32.shape[0],), dtype=torch.float32) - ) - clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) - scale = (clip_abs / float(INT6_QUANT_RANGE)).clamp_min(1.0 / float(INT6_QUANT_RANGE)) - q = torch.clamp(torch.round(clipped / scale[:, None]), -INT6_QUANT_RANGE, INT6_QUANT_RANGE).to(torch.int8).contiguous() - return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() - - # Vectors / scalars still use int8 per-tensor scale (they're tiny). - clip_abs = float(torch.quantile(t32.abs().flatten(), INT6_CLIP_Q).item()) if t32.numel() else 0.0 - scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) - q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() - return q, scale - -def quantize_state_dict_int8(state_dict: dict[str, Tensor]): - # Single supported clean-script export format: - # - per-row int8 for 2D float tensors - # - per-tensor int8 for other float tensors - # - exact passthrough for non-floats - # - passthrough for small float tensors, stored as fp16 to save bytes - quantized: dict[str, Tensor] = {} - scales: dict[str, Tensor] = {} - dtypes: dict[str, str] = {} - passthrough: dict[str, Tensor] = {} - passthrough_orig_dtypes: dict[str, str] = {} - qmeta: dict[str, dict[str, object]] = {} - stats = dict.fromkeys( - ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), - 0, - ) - - for name, tensor in state_dict.items(): - t = tensor.detach().to("cpu").contiguous() - stats["param_count"] += int(t.numel()) - stats["num_tensors"] += 1 - stats["baseline_tensor_bytes"] += tensor_nbytes(t) - - if not t.is_floating_point(): - stats["num_nonfloat_tensors"] += 1 - passthrough[name] = t - stats["int8_payload_bytes"] += tensor_nbytes(t) - continue - - # Small float tensors are cheap enough to keep directly. We still downcast - # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. - if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: - kept = keep_float_tensor(name, t, passthrough_orig_dtypes) - passthrough[name] = kept - stats["int8_payload_bytes"] += tensor_nbytes(kept) - continue - - stats["num_float_tensors"] += 1 - # fp16 passthrough for tensors without STE protection (tok_emb, bigram_hash). - if any(pattern in name for pattern in FP16_PASSTHROUGH_PATTERNS): - passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") - kept = t.to(dtype=torch.float16).contiguous() - passthrough[name] = kept - stats["int8_payload_bytes"] += tensor_nbytes(kept) - continue - q, s = quantize_float_tensor(t) - if s.ndim > 0: - qmeta[name] = {"scheme": "per_row", "axis": 0} - quantized[name] = q - scales[name] = s - dtypes[name] = str(t.dtype).removeprefix("torch.") - stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) - - obj: dict[str, object] = { - "__quant_format__": "int8_clean_per_row_v1", - "quantized": quantized, - "scales": scales, - "dtypes": dtypes, - "passthrough": passthrough, - } - if qmeta: - obj["qmeta"] = qmeta - if passthrough_orig_dtypes: - obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes - return obj, stats - -def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: - out: dict[str, Tensor] = {} - qmeta = obj.get("qmeta", {}) - passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) - for name, q in obj["quantized"].items(): - dtype = getattr(torch, obj["dtypes"][name]) - s = obj["scales"][name] - if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: - s = s.to(dtype=torch.float32) - # Broadcast the saved row scale back across trailing dimensions. - out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() - else: - scale = float(s.item()) - out[name] = (q.float() * scale).to(dtype=dtype).contiguous() - for name, t in obj["passthrough"].items(): - # Restore small tensors, undoing the temporary fp16 storage cast if needed. - out_t = t.detach().to("cpu").contiguous() - orig_dtype = passthrough_orig_dtypes.get(name) - if isinstance(orig_dtype, str): - out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() - out[name] = out_t - return out - - -# ----------------------------- -# DATA LOADING -# ----------------------------- - -def load_data_shard(file: Path) -> Tensor: - header_bytes = 256 * np.dtype(" None: - self.file_idx = (self.file_idx + 1) % len(self.files) - self.tokens = load_data_shard(self.files[self.file_idx]) - self.pos = 0 - - def take(self, n: int) -> Tensor: - chunks: list[Tensor] = [] - remaining = n - while remaining > 0: - avail = self.tokens.numel() - self.pos - if avail <= 0: - self._advance_file() - continue - k = min(remaining, avail) - chunks.append(self.tokens[self.pos : self.pos + k]) - self.pos += k - remaining -= k - return chunks[0] if len(chunks) == 1 else torch.cat(chunks) - - -class DistributedTokenLoader: - # Each call consumes a contiguous chunk from the shared token stream, then slices out - # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. - def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): - self.rank = rank - self.world_size = world_size - self.device = device - self.stream = TokenStream(pattern) - - def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: - local_tokens = global_tokens // (self.world_size * grad_accum_steps) - per_rank_span = local_tokens + 1 - chunk = self.stream.take(per_rank_span * self.world_size) - start = self.rank * per_rank_span - local = chunk[start : start + per_rank_span].to(dtype=torch.int64) - x = local[:-1].reshape(-1, seq_len) - y = local[1:].reshape(-1, seq_len) - return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) - -# ----------------------------- -# TRANSFORMER MODULES -# ----------------------------- - -class RMSNorm(nn.Module): - def __init__(self, eps: float | None = None): - super().__init__() - self.eps = eps - - def forward(self, x: Tensor) -> Tensor: - return F.rms_norm(x, (x.size(-1),), eps=self.eps) - - -class CastedLinear(nn.Linear): - # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. - # During training, apply fake int6 quantization (STE) so the model learns to be robust - # to the post-training quantization that will be applied for the 16MB artifact. - def forward(self, x: Tensor) -> Tensor: - w = self.weight.to(x.dtype) - if self.training and w.ndim == 2: - # Fake int6 per-row quantization with Straight-Through Estimator: - # Forward uses quantized weights, backward passes gradients through as-is. - with torch.no_grad(): - w32 = w.float() - clip_abs = torch.quantile(w32.abs(), INT6_CLIP_Q, dim=1).clamp_min(1e-8) - scale = clip_abs / INT6_QUANT_RANGE - w_clipped = torch.clamp(w32, -clip_abs[:, None], clip_abs[:, None]) - w_q = (torch.round(w_clipped / scale[:, None]) * scale[:, None]).to(x.dtype) - w = w + (w_q - w).detach() # STE: value of w_q, gradient of w - bias = self.bias.to(x.dtype) if self.bias is not None else None - return F.linear(x, w, bias) - - -def restore_low_dim_params_to_fp32(module: nn.Module) -> None: - # Keep small/control parameters in fp32 even when the model body runs in bf16. - with torch.no_grad(): - for name, param in module.named_parameters(): - if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: - param.data = param.data.float() - - -class Rotary(nn.Module): - # Caches cos/sin tables per sequence length on the current device. - def __init__(self, dim: int, base: float = 10000.0): - super().__init__() - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self._seq_len_cached = 0 - self._cos_cached: Tensor | None = None - self._sin_cached: Tensor | None = None - - def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: - if ( - self._cos_cached is None - or self._sin_cached is None - or self._seq_len_cached != seq_len - or self._cos_cached.device != device - ): - t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) - freqs = torch.outer(t, self.inv_freq.to(device)) - self._cos_cached = freqs.cos()[None, None, :, :] - self._sin_cached = freqs.sin()[None, None, :, :] - self._seq_len_cached = seq_len - return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) - - -def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: - half = x.size(-1) // 2 - x1, x2 = x[..., :half], x[..., half:] - return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) - - -class CausalSelfAttention(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - rope_base: float, - qk_gain_init: float, - ): - super().__init__() - if dim % num_heads != 0: - raise ValueError("model_dim must be divisible by num_heads") - if num_heads % num_kv_heads != 0: - raise ValueError("num_heads must be divisible by num_kv_heads") - self.num_heads = num_heads - self.num_kv_heads = num_kv_heads - self.head_dim = dim // num_heads - if self.head_dim % 2 != 0: - raise ValueError("head_dim must be even for RoPE") - kv_dim = self.num_kv_heads * self.head_dim - self.c_q = CastedLinear(dim, dim, bias=False) - self.c_k = CastedLinear(dim, kv_dim, bias=False) - self.c_v = CastedLinear(dim, kv_dim, bias=False) - self.proj = CastedLinear(dim, dim, bias=False) - self.proj._zero_init = True - self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) - self.rotary = Rotary(self.head_dim, base=rope_base) - - def forward(self, x: Tensor) -> Tensor: - bsz, seqlen, dim = x.shape - q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) - k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) - v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) - q = F.rms_norm(q, (q.size(-1),)) - k = F.rms_norm(k, (k.size(-1),)) - cos, sin = self.rotary(seqlen, x.device, q.dtype) - q = apply_rotary_emb(q, cos, sin) - k = apply_rotary_emb(k, cos, sin) - q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] - y = F.scaled_dot_product_attention( - q, - k, - v, - attn_mask=None, - is_causal=True, - enable_gqa=(self.num_kv_heads != self.num_heads), - ) - y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) - return self.proj(y) - - -class MLP(nn.Module): - # relu^2 MLP from the original modded-nanogpt setup - def __init__(self, dim: int, mlp_mult: int): - super().__init__() - hidden = mlp_mult * dim - self.fc = CastedLinear(dim, hidden, bias=False) - self.proj = CastedLinear(hidden, dim, bias=False) - self.proj._zero_init = True - - def forward(self, x: Tensor) -> Tensor: - x = torch.relu(self.fc(x)) - return self.proj(x.square()) - - -class Block(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: int, - rope_base: float, - qk_gain_init: float, - ): - super().__init__() - self.attn_norm = RMSNorm() - self.mlp_norm = RMSNorm() - self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) - self.mlp = MLP(dim, mlp_mult) - self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) - - def forward(self, x: Tensor, x0: Tensor) -> Tensor: - mix = self.resid_mix.to(dtype=x.dtype) - x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 - attn_out = self.attn(self.attn_norm(x)) - x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out - x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) - return x - - -class BigramHash(nn.Module): - """Hash-table embedding for token bigrams, projected to model dim. - Maps (prev_token, cur_token) pairs via a simple hash to a learned embedding. - Gives the model cheap character-pair / bigram info before attention. - """ - def __init__(self, num_buckets: int, hash_dim: int, model_dim: int): - super().__init__() - self.num_buckets = num_buckets - self.table = nn.Embedding(num_buckets, hash_dim) - self.proj = CastedLinear(hash_dim, model_dim, bias=False) - self.proj._zero_init = True - nn.init.normal_(self.table.weight, std=0.01) - - def forward(self, input_ids: Tensor) -> Tensor: - bsz, seqlen = input_ids.shape - prev_ids = torch.cat([torch.zeros(bsz, 1, dtype=input_ids.dtype, device=input_ids.device), input_ids[:, :-1]], dim=1) - h = ((prev_ids.long() * 92821 + input_ids.long()) % self.num_buckets).long() - return self.proj(self.table(h)) - - -class SmearGate(nn.Module): - """Learned per-dimension gate blending each token's embedding with the - previous token's. Injects bigram (two-token) context directly into the - embedding layer *before* the transformer starts processing. - - Normally a transformer must discover token-pair relationships through - self-attention; SmearGate provides this signal for free at ~dim params. - - Technique originated by @unnir in parameter-golf PR #102/#135. - """ - def __init__(self, dim: int): - super().__init__() - # Initialise so sigmoid(gate) ≈ 0.95 → mostly pass-through at init, - # with a small amount of previous-token blending that the model can - # learn to increase or decrease per dimension. - self.gate = nn.Parameter(torch.full((dim,), 3.0, dtype=torch.float32)) - - def forward(self, x: Tensor) -> Tensor: - g = torch.sigmoid(self.gate).to(dtype=x.dtype) - x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) - return g * x + (1.0 - g) * x_prev - - -class GPT(nn.Module): - def __init__( - self, - vocab_size: int, - num_layers: int, - model_dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: int, - tie_embeddings: bool, - tied_embed_init_std: float, - logit_softcap: float, - rope_base: float, - qk_gain_init: float, - bigram_hash_buckets: int = 0, - bigram_hash_dim: int = 128, - use_smeargate: bool = True, - ): - super().__init__() - if logit_softcap <= 0.0: - raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") - self.tie_embeddings = tie_embeddings - self.tied_embed_init_std = tied_embed_init_std - self.logit_softcap = logit_softcap - self.tok_emb = nn.Embedding(vocab_size, model_dim) - self.bigram_hash = BigramHash(bigram_hash_buckets, bigram_hash_dim, model_dim) if bigram_hash_buckets > 0 else None - self.smeargate = SmearGate(model_dim) if use_smeargate else None - self.num_encoder_layers = num_layers // 2 - self.num_decoder_layers = num_layers - self.num_encoder_layers - self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) - self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) - self.blocks = nn.ModuleList( - [ - Block( - model_dim, - num_heads, - num_kv_heads, - mlp_mult, - rope_base, - qk_gain_init, - ) - for i in range(num_layers) - ] - ) - self.final_norm = RMSNorm() - self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) - if self.lm_head is not None: - self.lm_head._zero_init = True - self._init_weights() - - def _init_weights(self) -> None: - if self.tie_embeddings: - nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) - for module in self.modules(): - if isinstance(module, nn.Linear): - if getattr(module, "_zero_init", False): - nn.init.zeros_(module.weight) - else: - # Orthogonal init: produces well-conditioned weight matrices - # whose singular values are all 1. This gives the model a - # better starting point — gradients flow more uniformly - # through orthogonal matrices, so early training steps are - # more informative, which matters when you only get ~12k - # steps in the 10-minute budget. - nn.init.orthogonal_(module.weight) - - def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: - x = self.tok_emb(input_ids) - if self.bigram_hash is not None: - x = x + self.bigram_hash(input_ids) - if self.smeargate is not None: - x = self.smeargate(x) - x = F.rms_norm(x, (x.size(-1),)) - x0 = x - skips: list[Tensor] = [] - - # First half stores skips; second half reuses them in reverse order. - for i in range(self.num_encoder_layers): - x = self.blocks[i](x, x0) - skips.append(x) - for i in range(self.num_decoder_layers): - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - x = self.blocks[self.num_encoder_layers + i](x, x0) - - x = self.final_norm(x).reshape(-1, x.size(-1)) - targets = target_ids.reshape(-1) - if self.tie_embeddings: - logits_proj = F.linear(x, self.tok_emb.weight) - else: - if self.lm_head is None: - raise RuntimeError("lm_head is required when tie_embeddings=False") - logits_proj = self.lm_head(x) - logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - return F.cross_entropy(logits.float(), targets, reduction="mean") - - def forward_logits(self, input_ids: Tensor) -> Tensor: - """Return logits without computing loss. Used for sliding window eval.""" - x = self.tok_emb(input_ids) - if self.bigram_hash is not None: - x = x + self.bigram_hash(input_ids) - if self.smeargate is not None: - x = self.smeargate(x) - x = F.rms_norm(x, (x.size(-1),)) - x0 = x - skips: list[Tensor] = [] - for i in range(self.num_encoder_layers): - x = self.blocks[i](x, x0) - skips.append(x) - for i in range(self.num_decoder_layers): - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - x = self.blocks[self.num_encoder_layers + i](x, x0) - x = self.final_norm(x) - if self.tie_embeddings: - logits_proj = F.linear(x, self.tok_emb.weight.to(x.dtype)) - else: - if self.lm_head is None: - raise RuntimeError("lm_head is required when tie_embeddings=False") - logits_proj = self.lm_head(x) - return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - - -# ----------------------------- -# TRAINING -# ----------------------------- - -def main() -> None: - global zeropower_via_newtonschulz5 - - code = Path(__file__).read_text(encoding="utf-8") - args = Hyperparameters() - zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) - - # ----------------------------- - # DISTRIBUTED + CUDA SETUP - # ----------------------------- - - distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ - rank = int(os.environ.get("RANK", "0")) - world_size = int(os.environ.get("WORLD_SIZE", "1")) - local_rank = int(os.environ.get("LOCAL_RANK", "0")) - if world_size <= 0: - raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") - if 8 % world_size != 0: - raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") - grad_accum_steps = 8 // world_size - grad_scale = 1.0 / grad_accum_steps - if not torch.cuda.is_available(): - raise RuntimeError("CUDA is required") - device = torch.device("cuda", local_rank) - torch.cuda.set_device(device) - if distributed: - dist.init_process_group(backend="nccl", device_id=device) - dist.barrier() - master_process = rank == 0 - - # Fast math knobs - torch.backends.cuda.matmul.allow_tf32 = True - torch.backends.cudnn.allow_tf32 = True - from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp - - enable_cudnn_sdp(False) - enable_flash_sdp(True) - enable_mem_efficient_sdp(False) - enable_math_sdp(False) - - logfile = None - if master_process: - os.makedirs("logs", exist_ok=True) - logfile = f"logs/{args.run_id}.txt" - print(logfile) - - def log0(msg: str, console: bool = True) -> None: - if not master_process: - return - if console: - print(msg) - if logfile is not None: - with open(logfile, "a", encoding="utf-8") as f: - print(msg, file=f) - - log0(code, console=False) - log0("=" * 100, console=False) - log0(f"Running Python {sys.version}", console=False) - log0(f"Running PyTorch {torch.__version__}", console=False) - log0( - subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, - console=False, - ) - log0("=" * 100, console=False) - - # ----------------------------- - # TOKENIZER + VALIDATION METRIC SETUP - # ----------------------------- - - random.seed(args.seed) - np.random.seed(args.seed) - torch.manual_seed(args.seed) - torch.cuda.manual_seed_all(args.seed) - - if not args.tokenizer_path.endswith(".model"): - raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") - sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) - if int(sp.vocab_size()) != args.vocab_size: - raise ValueError( - f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" - ) - dataset_dir = Path(args.data_path).resolve() - actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) - val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) - base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( - sp, args.vocab_size, device - ) - log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") - log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") - log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") - - # ----------------------------- - # MODEL + OPTIMIZER SETUP - # ----------------------------- - - base_model = GPT( - vocab_size=args.vocab_size, - num_layers=args.num_layers, - model_dim=args.model_dim, - num_heads=args.num_heads, - num_kv_heads=args.num_kv_heads, - mlp_mult=args.mlp_mult, - tie_embeddings=args.tie_embeddings, - tied_embed_init_std=args.tied_embed_init_std, - logit_softcap=args.logit_softcap, - rope_base=args.rope_base, - qk_gain_init=args.qk_gain_init, - bigram_hash_buckets=args.bigram_hash_buckets, - bigram_hash_dim=args.bigram_hash_dim, - use_smeargate=args.use_smeargate, - ).to(device).bfloat16() - for module in base_model.modules(): - if isinstance(module, CastedLinear): - module.float() - restore_low_dim_params_to_fp32(base_model) - compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) - model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model - - # Optimizer split: - # - token embedding (Adam) uses EMBED_LR - # - untied lm_head (Adam) uses HEAD_LR - # - matrix params in transformer blocks use MATRIX_LR via Muon - # - vectors/scalars use SCALAR_LR via Adam - block_named_params = list(base_model.blocks.named_parameters()) - matrix_params = [ - p - for name, p in block_named_params - if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - scalar_params = [ - p - for name, p in block_named_params - if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - if base_model.skip_weights.numel() > 0: - scalar_params.append(base_model.skip_weights) - token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr - # Collect embedding-like params - embed_params = [base_model.tok_emb.weight] - if base_model.bigram_hash is not None: - embed_params.append(base_model.bigram_hash.table.weight) - # bigram_hash.proj is a CastedLinear — its weight goes to Muon - matrix_params.append(base_model.bigram_hash.proj.weight) - optimizer_tok = torch.optim.Adam( - [{"params": embed_params, "lr": token_lr, "base_lr": token_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizer_muon = Muon( - matrix_params, - lr=args.matrix_lr, - momentum=args.muon_momentum, - backend_steps=args.muon_backend_steps, - weight_decay=args.muon_weight_decay, - ) - for group in optimizer_muon.param_groups: - group["base_lr"] = args.matrix_lr - # SmearGate gate is a 1D parameter → goes to scalar optimizer - if base_model.smeargate is not None: - scalar_params.append(base_model.smeargate.gate) - optimizer_scalar = torch.optim.Adam( - [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] - if base_model.lm_head is not None: - optimizer_head = torch.optim.Adam( - [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizers.insert(1, optimizer_head) - - n_params = sum(p.numel() for p in base_model.parameters()) - log0(f"model_params:{n_params}") - log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") - log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") - log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") - log0( - f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " - f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " - f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" - ) - log0( - f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " - f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " - f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" - ) - log0(f"seed:{args.seed}") - - # ----------------------------- - # DATA LOADER & MODEL WARMUP - # ----------------------------- - - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - def zero_grad_all() -> None: - for opt in optimizers: - opt.zero_grad(set_to_none=True) - - max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None - - def lr_mul(step: int, elapsed_ms: float) -> float: - if args.warmdown_iters <= 0: - return 1.0 - if max_wallclock_ms is None: - warmdown_start = max(args.iterations - args.warmdown_iters, 0) - return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 - step_ms = elapsed_ms / max(step, 1) - warmdown_ms = args.warmdown_iters * step_ms - remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) - return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 - - # Warmup primes the compiled forward/backward/optimizer paths, then we restore the - # initial weights/optimizer state so measured training starts from the true init. - if args.warmup_steps > 0: - initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} - initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] - model.train() - for warmup_step in range(args.warmup_steps): - zero_grad_all() - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - warmup_loss = model(x, y) - (warmup_loss * grad_scale).backward() - for opt in optimizers: - opt.step() - zero_grad_all() - if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: - log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") - base_model.load_state_dict(initial_model_state, strict=True) - for opt, state in zip(optimizers, initial_optimizer_states, strict=True): - opt.load_state_dict(state) - zero_grad_all() - if distributed: - model.require_backward_grad_sync = True - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - # ----------------------------- - # MAIN TRAINING LOOP - # ----------------------------- - - training_time_ms = 0.0 - stop_after_step: int | None = None - torch.cuda.synchronize() - t0 = time.perf_counter() - - step = 0 - while True: - last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) - - should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) - if should_validate: - torch.cuda.synchronize() - training_time_ms += 1000.0 * (time.perf_counter() - t0) - val_loss, val_bpb = eval_val( - args, - model, - rank, - world_size, - device, - grad_accum_steps, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, - ) - log0( - f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " - f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" - ) - torch.cuda.synchronize() - t0 = time.perf_counter() - - if last_step: - if stop_after_step is not None and step < args.iterations: - log0( - f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " - f"step:{step}/{args.iterations}" - ) - break - - elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - scale = lr_mul(step, elapsed_ms) - zero_grad_all() - train_loss = torch.zeros((), device=device) - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - loss = model(x, y) - train_loss += loss.detach() - (loss * grad_scale).backward() - train_loss /= grad_accum_steps - - frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 - muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum - for group in optimizer_muon.param_groups: - group["momentum"] = muon_momentum - - for opt in optimizers: - for group in opt.param_groups: - group["lr"] = group["base_lr"] * scale - - if args.grad_clip_norm > 0: - torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) - for opt in optimizers: - opt.step() - zero_grad_all() - - step += 1 - approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - should_log_train = ( - args.train_log_every > 0 - and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) - ) - if should_log_train: - log0( - f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " - f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" - ) - - # Needed to sync whether we've reached the wallclock cap. - reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms - if distributed and max_wallclock_ms is not None: - reached_cap_tensor = torch.tensor(int(reached_cap), device=device) - dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) - reached_cap = bool(reached_cap_tensor.item()) - if stop_after_step is None and reached_cap: - stop_after_step = step - - log0( - f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " - f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" - ) - - # ----------------------------- - # SERIALIZATION + ROUNDTRIP VALIDATION - # ----------------------------- - # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce - # the compressed int8+zlib artifact and validate the round-tripped weights. - - if master_process: - torch.save(base_model.state_dict(), "final_model.pt") - model_bytes = os.path.getsize("final_model.pt") - code_bytes = len(code.encode("utf-8")) - log0(f"Serialized model: {model_bytes} bytes") - log0(f"Code size: {code_bytes} bytes") - log0(f"Total submission size: {model_bytes + code_bytes} bytes") - - quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) - quant_buf = io.BytesIO() - torch.save(quant_obj, quant_buf) - quant_raw = quant_buf.getvalue() - if HAS_ZSTD: - cctx = zstd.ZstdCompressor(level=22) - quant_blob = cctx.compress(quant_raw) - compress_name = "zstd-22" - else: - quant_blob = zlib.compress(quant_raw, level=9) - compress_name = "zlib-9" - quant_raw_bytes = len(quant_raw) - if master_process: - quant_ext = "int6.ptz" if HAS_ZSTD else "int6.ptz" - with open(f"final_model.{quant_ext}", "wb") as f: - f.write(quant_blob) - quant_file_bytes = os.path.getsize(f"final_model.{quant_ext}") - code_bytes = len(code.encode("utf-8")) - ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) - log0( - f"Serialized model int6+{compress_name}: {quant_file_bytes} bytes " - f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" - ) - log0(f"Total submission size int6+{compress_name}: {quant_file_bytes + code_bytes} bytes") - - if distributed: - dist.barrier() - quant_ext = "int6.ptz" - with open(f"final_model.{quant_ext}", "rb") as f: - quant_blob_disk = f.read() - if HAS_ZSTD: - dctx = zstd.ZstdDecompressor() - quant_decompressed = dctx.decompress(quant_blob_disk) - else: - quant_decompressed = zlib.decompress(quant_blob_disk) - quant_state = torch.load(io.BytesIO(quant_decompressed), map_location="cpu") - base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) - torch.cuda.synchronize() - t_qeval = time.perf_counter() - q_val_loss, q_val_bpb = eval_val( - args, - model, - rank, - world_size, - device, - grad_accum_steps, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, - ) - torch.cuda.synchronize() - log0( - f"final_int6_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " - f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" - ) - log0(f"final_int6_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") - - # Sliding window eval: gives each scored token (seq_len - stride) context. - if args.eval_stride > 0: - torch.cuda.synchronize() - t_sw = time.perf_counter() - sw_val_loss, sw_val_bpb = eval_val_sliding_window( - args, - base_model, - rank, - world_size, - device, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, - stride=args.eval_stride, - ) - torch.cuda.synchronize() - log0( - f"final_sliding_window_eval stride:{args.eval_stride} " - f"val_loss:{sw_val_loss:.4f} val_bpb:{sw_val_bpb:.4f} " - f"eval_time:{1000.0 * (time.perf_counter() - t_sw):.0f}ms" - ) - log0( - f"final_sliding_window_eval_exact stride:{args.eval_stride} " - f"val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}" - ) - - if distributed: - dist.destroy_process_group() - - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/records/track_10min_16mb/2026-03-20_10L_Int5MLP_MuonWD04_SWA50/README.md b/records/track_10min_16mb/2026-03-20_10L_Int5MLP_MuonWD04_SWA50/README.md deleted file mode 100644 index d06e11fc62..0000000000 --- a/records/track_10min_16mb/2026-03-20_10L_Int5MLP_MuonWD04_SWA50/README.md +++ /dev/null @@ -1,74 +0,0 @@ -# 10L Int5-MLP + BigramHash(10240) + SWA(frac=0.4) + WD=0.04 - -**val_bpb: 1.14276** (mean of 3 seeds, sliding window stride=64, post int5/int6+zstd quantization roundtrip) - -## Run Command - -```bash -# Setup (once) -bash prepare.sh - -# Train + evaluate (default seed=42) -bash eval/eval.sh - -# With specific seed -SEED=42 bash eval/eval.sh -``` - -All parameters are set as defaults in `train_gpt.py`. No env vars needed. - -## 3-Seed Results - -| Seed | val_bpb | artifact_bytes | valid | -|------|---------|---------------|-------| -| 42 | 1.14271 | 15,965,978 | yes | -| 1337 | 1.14298 | 15,830,186 | yes | -| 2024 | 1.14260 | ~15.8M | yes | -| **Mean** | **1.14276** | | | -| **Std** | **0.00016** | | | - -## Key Techniques - -### Mixed Int5/Int6 Quantization -- **Int5 [-16,15]** for MLP weights (most compressible, 1.88x zstd ratio) -- **Int6 [-32,31]** for attention weights (precision-sensitive, 1.51x zstd ratio) -- **FP16** for tied embeddings and last-layer key projections -- Int5 MLP saves ~1.86MB vs uniform int6, funding a 10th layer - -### BigramHash(10240) -- Hash consecutive token pairs into 10240-bucket embedding table (dim=128) -- Projected to model_dim=512 via learned linear -- Reduces token-pair hash collisions vs 4096 buckets (+0.001 bpb) - -### SWA with start_frac=0.4 -- Collect checkpoints only from last 40% of warmdown (most converged) -- 24 checkpoints averaged every 50 steps -- Quality over quantity: fewer but better-converged checkpoints - -## Architecture -- 10 layers, 512 dim, 8 heads, 4 KV heads (GQA) -- MLP 3x expansion (hidden=1536), relu^2 activation -- SmearGate + BigramHash(10240, dim=128) -- Orthogonal init with muP-scaled output projections -- U-Net skip connections, tied embeddings - -## Training Hyperparameters -- Muon optimizer: matrix_lr=0.02, WD=0.04, momentum=0.99 -- AdamW for embeddings/scalars: WD=0.04 -- warmdown=3000 iters, warmup=20 steps -- seq_len=2048, batch=786K tokens -- grad_clip=0.3, 3% magnitude pruning -- SWA: start_frac=0.4, every=50 steps -- Sliding window eval: stride=64 - -## Ablation Summary -| Change | val_bpb | Delta | -|--------|---------|-------| -| 9L int6 (PR162 base) | 1.1485 | baseline | -| + int5 MLP + 10th layer | 1.1453 | -0.003 | -| + WD=0.04 + warmdown=3000 | 1.1452 | -0.0001 | -| + SWA_start_frac=0.4 | 1.1446 | -0.0006 | -| + bigram=8192 | 1.1434 | -0.0012 | -| + bigram=10240 | **1.1426** | **-0.0008** | - -Built on PR #162 by @unnir (SmearGate, BigramHash, OrthoInit). diff --git a/records/track_10min_16mb/2026-03-20_10L_Int5MLP_MuonWD04_SWA50/submission.json b/records/track_10min_16mb/2026-03-20_10L_Int5MLP_MuonWD04_SWA50/submission.json deleted file mode 100644 index ab55fded8b..0000000000 --- a/records/track_10min_16mb/2026-03-20_10L_Int5MLP_MuonWD04_SWA50/submission.json +++ /dev/null @@ -1,9 +0,0 @@ -{ - "name": "10L Int5-MLP + BigramHash(10240) + SWA(frac=0.4) + WD=0.04", - "val_loss": 1.14276, - "bytes_total": 15900000, - "blurb": "10 layers with mixed int5/int6 quantization. BigramHash 10240 buckets (up from 4096). SWA start_frac=0.4 (24 converged checkpoints). WD=0.04 global, warmdown=3000. Mean of 3 seeds: 1.14276 (std 0.00016). SmearGate + OrthoInit + zstd-22.", - "author": "thwu1", - "github_id": "thwu1", - "date": "2026-03-20" -} diff --git a/records/track_10min_16mb/2026-03-20_10L_Int5MLP_MuonWD04_SWA50/train_gpt.py b/records/track_10min_16mb/2026-03-20_10L_Int5MLP_MuonWD04_SWA50/train_gpt.py deleted file mode 100644 index bbe5ab2943..0000000000 --- a/records/track_10min_16mb/2026-03-20_10L_Int5MLP_MuonWD04_SWA50/train_gpt.py +++ /dev/null @@ -1,1231 +0,0 @@ -""" -The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. - -Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. -""" - -from __future__ import annotations - -import copy -import glob -import io -import math -import os -import random -import subprocess -import sys -import time -import uuid -import zlib -from pathlib import Path - -try: - import zstandard - _COMPRESSOR = "zstd" -except ImportError: - _COMPRESSOR = "zlib" - -import numpy as np -import sentencepiece as spm -import torch -import torch.distributed as dist -import torch.nn.functional as F -from torch import Tensor, nn -from torch.nn.parallel import DistributedDataParallel as DDP - -# ----------------------------- -# HYPERPARAMETERS -# ----------------------------- - -class Hyperparameters: - data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") - train_files = os.path.join(data_path, "fineweb_train_*.bin") - val_files = os.path.join(data_path, "fineweb_val_*.bin") - tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") - run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) - seed = int(os.environ.get("SEED", 42)) - - val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) - val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 500)) - train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 100)) - - iterations = int(os.environ.get("ITERATIONS", 20000)) - warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3000)) - warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) - train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 786_432)) - train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) - max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) - qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) - - vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) - num_layers = int(os.environ.get("NUM_LAYERS", 10)) - num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) - model_dim = int(os.environ.get("MODEL_DIM", 512)) - num_heads = int(os.environ.get("NUM_HEADS", 8)) - mlp_mult = float(os.environ.get("MLP_MULT", 3.0)) - tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) - rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) - logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) - - embed_lr = float(os.environ.get("EMBED_LR", 0.6)) - head_lr = float(os.environ.get("HEAD_LR", 0.008)) - tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.03)) - tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) - matrix_lr = float(os.environ.get("MATRIX_LR", 0.02)) - scalar_lr = float(os.environ.get("SCALAR_LR", 0.02)) - muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) - muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) - muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) - muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) - beta1 = float(os.environ.get("BETA1", 0.9)) - beta2 = float(os.environ.get("BETA2", 0.95)) - adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) - grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) - weight_decay = float(os.environ.get("WEIGHT_DECAY", 0.04)) - - eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) - eval_batch_seqs = int(os.environ.get("EVAL_BATCH_SEQS", 32)) - - bigram_vocab_size = int(os.environ.get("BIGRAM_VOCAB_SIZE", 10240)) - bigram_dim = int(os.environ.get("BIGRAM_DIM", 128)) - - swa_enabled = bool(int(os.environ.get("SWA_ENABLED", "1"))) - swa_start_frac = float(os.environ.get("SWA_START_FRAC", 0.4)) - swa_every = int(os.environ.get("SWA_EVERY", 50)) - -# ----------------------------- -# MUON OPTIMIZER -# ----------------------------- - -def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: - a, b, c = (3.4445, -4.7750, 2.0315) - X = G.bfloat16() - X /= X.norm() + eps - transposed = G.size(0) > G.size(1) - if transposed: - X = X.T - for _ in range(steps): - A = X @ X.T - B = b * A + c * A @ A - X = a * X + B @ X - return X.T if transposed else X - - -class Muon(torch.optim.Optimizer): - def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True, weight_decay: float = 0.0): - super().__init__( - params, - dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov, weight_decay=weight_decay), - ) - - @torch.no_grad() - def step(self, closure=None): - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - distributed = dist.is_available() and dist.is_initialized() - world_size = dist.get_world_size() if distributed else 1 - rank = dist.get_rank() if distributed else 0 - - for group in self.param_groups: - params = group["params"] - if not params: - continue - lr = group["lr"] - momentum = group["momentum"] - backend_steps = group["backend_steps"] - nesterov = group["nesterov"] - - total_params = sum(int(p.numel()) for p in params) - updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) - - curr = 0 - for i, p in enumerate(params): - if i % world_size == rank and p.grad is not None: - g = p.grad - state = self.state[p] - if "momentum_buffer" not in state: - state["momentum_buffer"] = torch.zeros_like(g) - buf = state["momentum_buffer"] - buf.mul_(momentum).add_(g) - if nesterov: - g = g.add(buf, alpha=momentum) - g = zeropower_via_newtonschulz5(g, steps=backend_steps) - g *= max(1, g.size(0) / g.size(1)) ** 0.5 - updates_flat[curr : curr + p.numel()] = g.reshape(-1) - curr += p.numel() - - if distributed: - dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) - - wd = group.get("weight_decay", 0.0) - curr = 0 - for p in params: - g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) - if wd > 0: - p.data.mul_(1.0 - lr * wd) - p.add_(g, alpha=-lr) - curr += p.numel() - return loss - - -# ----------------------------- -# TOKENIZER-AGNOSTIC EVALUATION -# ----------------------------- - -def build_sentencepiece_luts( - sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device -) -> tuple[Tensor, Tensor, Tensor]: - sp_vocab_size = int(sp.vocab_size()) - table_size = max(sp_vocab_size, vocab_size) - base_bytes_np = np.zeros((table_size,), dtype=np.int16) - has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) - is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) - for token_id in range(sp_vocab_size): - if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): - continue - is_boundary_token_np[token_id] = False - if sp.is_byte(token_id): - base_bytes_np[token_id] = 1 - continue - piece = sp.id_to_piece(token_id) - if piece.startswith("\u2581"): - has_leading_space_np[token_id] = True - piece = piece[1:] - base_bytes_np[token_id] = len(piece.encode("utf-8")) - return ( - torch.tensor(base_bytes_np, dtype=torch.int16, device=device), - torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), - torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), - ) - - -def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: - files = [Path(p) for p in sorted(glob.glob(pattern))] - if not files: - raise FileNotFoundError(f"No files found for pattern: {pattern}") - tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() - usable = ((tokens.numel() - 1) // seq_len) * seq_len - if usable <= 0: - raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") - return tokens[: usable + 1] - - -def eval_val( - args: Hyperparameters, - model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - grad_accum_steps: int, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, -) -> tuple[float, float]: - local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) - if local_batch_tokens < args.train_seq_len: - raise ValueError( - "VAL_BATCH_SIZE must provide at least one sequence per rank; " - f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " - f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" - ) - local_batch_seqs = local_batch_tokens // args.train_seq_len - total_seqs = (val_tokens.numel() - 1) // args.train_seq_len - seq_start = (total_seqs * rank) // world_size - seq_end = (total_seqs * (rank + 1)) // world_size - val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) - val_token_count = torch.zeros((), device=device, dtype=torch.float64) - val_byte_count = torch.zeros((), device=device, dtype=torch.float64) - model.eval() - with torch.inference_mode(): - for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): - batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) - raw_start = batch_seq_start * args.train_seq_len - raw_end = batch_seq_end * args.train_seq_len + 1 - local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) - x = local[:-1].reshape(-1, args.train_seq_len) - y = local[1:].reshape(-1, args.train_seq_len) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - batch_loss = model(x, y).detach() - batch_token_count = float(y.numel()) - val_loss_sum += batch_loss.to(torch.float64) * batch_token_count - val_token_count += batch_token_count - prev_ids = x.reshape(-1) - tgt_ids = y.reshape(-1) - token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) - token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) - val_byte_count += token_bytes.to(torch.float64).sum() - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) - val_loss = val_loss_sum / val_token_count - bits_per_token = val_loss.item() / math.log(2.0) - tokens_per_byte = val_token_count.item() / val_byte_count.item() - model.train() - return float(val_loss.item()), float(bits_per_token * tokens_per_byte) - - -# ----------------------------- -# POST-TRAINING QUANTIZATION (INT8 legacy + INT6 mixed) -# ----------------------------- - -CONTROL_TENSOR_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "CONTROL_TENSOR_NAME_PATTERNS", - "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,smear,bigram.scale", - ).split(",") - if pattern -) -FP16_KEEP_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get("FP16_KEEP_NAME_PATTERNS", "tok_emb,blocks.8.attn.c_k").split(",") - if pattern -) -INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", - ",".join(CONTROL_TENSOR_NAME_PATTERNS), - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 -INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 -INT8_PER_ROW_SCALE_DTYPE = torch.float16 -INT8_CLIP_PERCENTILE = 99.99984 -INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 - -def tensor_nbytes(t: Tensor) -> int: - return int(t.numel()) * int(t.element_size()) - -def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: - t32 = t.float() - if t32.ndim == 2: - clip_abs = ( - torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) - if t32.numel() - else torch.empty((t32.shape[0],), dtype=torch.float32) - ) - clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) - scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) - q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() - return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() - clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 - scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) - q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() - return q, scale - - -def _classify_param(name: str) -> str: - if "tok_emb" in name or "lm_head" in name: - return "embed" - if ".mlp." in name: - return "mlp" - if "bigram" in name: - return "bigram" - if ".attn." in name or (".proj." in name and ".mlp." not in name): - return "attn" - return "other" - -def quantize_intN_per_row(t: Tensor, clip_range: int = 31) -> tuple[Tensor, Tensor]: - t32 = t.float() - if t32.ndim == 2: - row_max = t32.abs().amax(dim=1) - scale = (row_max / clip_range).clamp_min(1e-12).to(torch.float16) - scale = scale.clamp_min(torch.finfo(torch.float16).tiny) - q = torch.clamp(torch.round(t32 / scale.float()[:, None]), -(clip_range+1), clip_range).to(torch.int8) - return q, scale - amax = t32.abs().max().item() - scale = torch.tensor(max(amax / clip_range, 1e-12), dtype=torch.float16) - q = torch.clamp(torch.round(t32 / scale.float()), -(clip_range+1), clip_range).to(torch.int8) - return q, scale - -def mixed_quantize_int6(state_dict: dict[str, Tensor], int6_cats: set[str]): - result: dict[str, Tensor] = {} - meta: dict[str, object] = {} - for name, tensor in state_dict.items(): - t = tensor.detach().cpu().contiguous() - cat = _classify_param(name) - if not t.is_floating_point() or t.numel() <= 8192: - result[name] = t.to(torch.float16) if t.is_floating_point() else t - meta[name] = "passthrough" - continue - if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): - result[name] = t.float() - meta[name] = "passthrough_ctrl" - continue - if any(pattern in name for pattern in FP16_KEEP_NAME_PATTERNS): - result[name] = t.to(dtype=torch.float16).contiguous() - meta[name] = "passthrough_fp16" - continue - if cat in int6_cats and t.ndim >= 1: - clip = 15 if cat == "mlp" else 31 # int5 for MLP, int6 for attention - q, s = quantize_intN_per_row(t, clip_range=clip) - result[name + ".q"] = q - result[name + ".scale"] = s - meta[name] = {"type": f"int{5 if cat == 'mlp' else 6}"} - else: - q, s = quantize_float_tensor(t) - result[name + ".q"] = q - result[name + ".scale"] = s - meta[name] = {"type": "int8"} - return result, meta - -def dequantize_mixed_int6(result: dict[str, Tensor], meta: dict[str, object], - template_sd: dict[str, Tensor]) -> dict[str, Tensor]: - out: dict[str, Tensor] = {} - for name, orig in template_sd.items(): - info = meta[name] - orig_dtype = orig.dtype - if info in ("passthrough", "passthrough_ctrl", "passthrough_fp16"): - t = result[name] - if t.dtype == torch.float16 and orig_dtype in (torch.float32, torch.bfloat16): - t = t.to(orig_dtype) - out[name] = t - continue - q, s = result[name + ".q"], result[name + ".scale"] - if s.ndim > 0: - out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig_dtype) - else: - out[name] = (q.float() * float(s.item())).to(orig_dtype) - return out - - -# ----------------------------- -# DATA LOADING -# ----------------------------- - -def load_data_shard(file: Path) -> Tensor: - header_bytes = 256 * np.dtype(" None: - self.file_idx = (self.file_idx + 1) % len(self.files) - self.tokens = load_data_shard(self.files[self.file_idx]) - self.pos = 0 - - def take(self, n: int) -> Tensor: - chunks: list[Tensor] = [] - remaining = n - while remaining > 0: - avail = self.tokens.numel() - self.pos - if avail <= 0: - self._advance_file() - continue - k = min(remaining, avail) - chunks.append(self.tokens[self.pos : self.pos + k]) - self.pos += k - remaining -= k - return chunks[0] if len(chunks) == 1 else torch.cat(chunks) - - -class DistributedTokenLoader: - def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): - self.rank = rank - self.world_size = world_size - self.device = device - self.stream = TokenStream(pattern) - - def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: - local_tokens = global_tokens // (self.world_size * grad_accum_steps) - per_rank_span = local_tokens + 1 - chunk = self.stream.take(per_rank_span * self.world_size) - start = self.rank * per_rank_span - local = chunk[start : start + per_rank_span].to(dtype=torch.int64) - x = local[:-1].reshape(-1, seq_len) - y = local[1:].reshape(-1, seq_len) - return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) - - -# ----------------------------- -# TRANSFORMER MODULES -# ----------------------------- - -class RMSNorm(nn.Module): - def __init__(self, eps: float | None = None): - super().__init__() - self.eps = eps - - def forward(self, x: Tensor) -> Tensor: - return F.rms_norm(x, (x.size(-1),), eps=self.eps) - - -class CastedLinear(nn.Linear): - def forward(self, x: Tensor) -> Tensor: - w = self.weight.to(x.dtype) - bias = self.bias.to(x.dtype) if self.bias is not None else None - return F.linear(x, w, bias) - - -def restore_low_dim_params_to_fp32(module: nn.Module) -> None: - with torch.no_grad(): - for name, param in module.named_parameters(): - if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: - param.data = param.data.float() - - -class Rotary(nn.Module): - def __init__(self, dim: int, base: float = 10000.0): - super().__init__() - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self._seq_len_cached = 0 - self._cos_cached: Tensor | None = None - self._sin_cached: Tensor | None = None - - def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: - if ( - self._cos_cached is None - or self._sin_cached is None - or self._seq_len_cached != seq_len - or self._cos_cached.device != device - ): - t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) - freqs = torch.outer(t, self.inv_freq.to(device)) - self._cos_cached = freqs.cos()[None, None, :, :] - self._sin_cached = freqs.sin()[None, None, :, :] - self._seq_len_cached = seq_len - return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) - - -def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: - half = x.size(-1) // 2 - x1, x2 = x[..., :half], x[..., half:] - return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) - - -class CausalSelfAttention(nn.Module): - def __init__(self, dim: int, num_heads: int, num_kv_heads: int, rope_base: float, qk_gain_init: float): - super().__init__() - if dim % num_heads != 0: - raise ValueError("model_dim must be divisible by num_heads") - if num_heads % num_kv_heads != 0: - raise ValueError("num_heads must be divisible by num_kv_heads") - self.num_heads = num_heads - self.num_kv_heads = num_kv_heads - self.head_dim = dim // num_heads - if self.head_dim % 2 != 0: - raise ValueError("head_dim must be even for RoPE") - kv_dim = self.num_kv_heads * self.head_dim - self.c_q = CastedLinear(dim, dim, bias=False) - self.c_k = CastedLinear(dim, kv_dim, bias=False) - self.c_v = CastedLinear(dim, kv_dim, bias=False) - self.proj = CastedLinear(dim, dim, bias=False) - self.proj._zero_init = True - self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) - self.rotary = Rotary(self.head_dim, base=rope_base) - - def forward(self, x: Tensor) -> Tensor: - bsz, seqlen, dim = x.shape - q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) - k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) - v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) - q = F.rms_norm(q, (q.size(-1),)) - k = F.rms_norm(k, (k.size(-1),)) - cos, sin = self.rotary(seqlen, x.device, q.dtype) - q = apply_rotary_emb(q, cos, sin) - k = apply_rotary_emb(k, cos, sin) - q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] - y = F.scaled_dot_product_attention( - q, k, v, attn_mask=None, is_causal=True, - enable_gqa=(self.num_kv_heads != self.num_heads), - ) - y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) - return self.proj(y) - - -class MLP(nn.Module): - def __init__(self, dim: int, mlp_mult: float): - super().__init__() - hidden = int(mlp_mult * dim) - self.fc = CastedLinear(dim, hidden, bias=False) - self.proj = CastedLinear(hidden, dim, bias=False) - self.proj._zero_init = True - - def forward(self, x: Tensor) -> Tensor: - x = torch.relu(self.fc(x)) - return self.proj(x.square()) - - -class SmearGate(nn.Module): - """Blend each token's embedding with the previous token's embedding.""" - def __init__(self, dim: int): - super().__init__() - self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) - - def forward(self, x: Tensor) -> Tensor: - g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] - x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) - return (1 - g) * x + g * x_prev - - -class BigramHashEmbedding(nn.Module): - """Hash consecutive token pairs into a learned embedding table.""" - def __init__(self, bigram_vocab_size: int, bigram_dim: int, model_dim: int): - super().__init__() - self.bigram_vocab_size = bigram_vocab_size - self.embed = nn.Embedding(bigram_vocab_size, bigram_dim) - nn.init.zeros_(self.embed.weight) - self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None - if self.proj is not None: - nn.init.zeros_(self.proj.weight) - self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) - - def bigram_hash(self, tokens: Tensor) -> Tensor: - t = tokens.to(torch.int32) - mod = self.bigram_vocab_size - 1 - out = torch.empty_like(t) - out[..., 0] = mod - out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod - return out.long() - - def forward(self, token_ids: Tensor) -> Tensor: - h = self.embed(self.bigram_hash(token_ids)) - if self.proj is not None: - h = self.proj(h) - return h * self.scale.to(dtype=h.dtype) - - -class Block(nn.Module): - def __init__(self, dim: int, num_heads: int, num_kv_heads: int, mlp_mult: float, rope_base: float, qk_gain_init: float): - super().__init__() - self.attn_norm = RMSNorm() - self.mlp_norm = RMSNorm() - self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) - self.mlp = MLP(dim, mlp_mult) - self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) - - def forward(self, x: Tensor, x0: Tensor) -> Tensor: - mix = self.resid_mix.to(dtype=x.dtype) - x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 - attn_out = self.attn(self.attn_norm(x)) - x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out - x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) - return x - - -class GPT(nn.Module): - def __init__( - self, - vocab_size: int, - num_layers: int, - model_dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: float, - tie_embeddings: bool, - tied_embed_init_std: float, - logit_softcap: float, - rope_base: float, - qk_gain_init: float, - bigram_vocab_size: int = 0, - bigram_dim: int = 128, - ): - super().__init__() - if logit_softcap <= 0.0: - raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") - self.tie_embeddings = tie_embeddings - self.tied_embed_init_std = tied_embed_init_std - self.logit_softcap = logit_softcap - self.tok_emb = nn.Embedding(vocab_size, model_dim) - self.bigram = BigramHashEmbedding(bigram_vocab_size, bigram_dim, model_dim) if bigram_vocab_size > 0 else None - self.num_encoder_layers = num_layers // 2 - self.num_decoder_layers = num_layers - self.num_encoder_layers - self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) - self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) - self.smear = SmearGate(model_dim) - self.blocks = nn.ModuleList( - [ - Block(model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init) - for _ in range(num_layers) - ] - ) - self.final_norm = RMSNorm() - self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) - if self.lm_head is not None: - self.lm_head._zero_init = True - self._init_weights() - - def _init_weights(self) -> None: - if self.tie_embeddings: - nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) - num_layers = len(self.blocks) - for name, module in self.named_modules(): - if isinstance(module, nn.Linear): - if getattr(module, "_zero_init", False): - nn.init.zeros_(module.weight) - elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: - nn.init.orthogonal_(module.weight, gain=1.0) - if ".proj." in name or name.endswith(".proj"): - with torch.no_grad(): - module.weight.mul_(1.0 / math.sqrt(2 * num_layers)) - - def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: - x = self.tok_emb(input_ids) - if self.bigram is not None: - x = x + self.bigram(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - x = self.smear(x) - x0 = x - skips: list[Tensor] = [] - for i in range(self.num_encoder_layers): - x = self.blocks[i](x, x0) - skips.append(x) - for i in range(self.num_decoder_layers): - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - x = self.blocks[self.num_encoder_layers + i](x, x0) - x = self.final_norm(x).reshape(-1, x.size(-1)) - targets = target_ids.reshape(-1) - if self.tie_embeddings: - logits_proj = F.linear(x, self.tok_emb.weight) - else: - if self.lm_head is None: - raise RuntimeError("lm_head is required when tie_embeddings=False") - logits_proj = self.lm_head(x) - logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - return F.cross_entropy(logits.float(), targets, reduction="mean") - - def forward_logits(self, input_ids: Tensor) -> Tensor: - x = self.tok_emb(input_ids) - if self.bigram is not None: - x = x + self.bigram(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - x = self.smear(x) - x0 = x - skips: list[Tensor] = [] - for i in range(self.num_encoder_layers): - x = self.blocks[i](x, x0) - skips.append(x) - for i in range(self.num_decoder_layers): - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - x = self.blocks[self.num_encoder_layers + i](x, x0) - x = self.final_norm(x) - if self.tie_embeddings: - logits_proj = F.linear(x, self.tok_emb.weight) - else: - logits_proj = self.lm_head(x) - return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - - -def eval_val_sliding( - args: Hyperparameters, - base_model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, - stride: int, - batch_seqs: int = 32, -) -> tuple[float, float]: - seq_len = args.train_seq_len - total_tokens = val_tokens.numel() - 1 - window_starts = [ws for ws in range(0, total_tokens, stride) - if min(ws + seq_len, total_tokens) - ws >= stride or ws == 0] - total_windows = len(window_starts) - my_s = (total_windows * rank) // world_size - my_e = (total_windows * (rank + 1)) // world_size - my_windows = window_starts[my_s:my_e] - - loss_sum = torch.zeros((), device=device, dtype=torch.float64) - token_count = torch.zeros((), device=device, dtype=torch.float64) - byte_count = torch.zeros((), device=device, dtype=torch.float64) - - base_model.eval() - with torch.inference_mode(): - for bi in range(0, len(my_windows), batch_seqs): - batch_ws = my_windows[bi:bi + batch_seqs] - bsz = len(batch_ws) - x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) - y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) - wlens: list[int] = [] - for i, ws in enumerate(batch_ws): - end = min(ws + seq_len, total_tokens) - wlen = end - ws - wlens.append(wlen) - chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) - x_batch[i, :wlen] = chunk[:-1] - y_batch[i, :wlen] = chunk[1:] - with torch.autocast(device_type="cuda", dtype=torch.bfloat16): - logits = base_model.forward_logits(x_batch) - nll = F.cross_entropy( - logits.reshape(-1, logits.size(-1)).float(), - y_batch.reshape(-1), - reduction="none", - ).reshape(bsz, seq_len) - for i, ws in enumerate(batch_ws): - wlen = wlens[i] - s = 0 if ws == 0 else max(wlen - stride, 0) - scored_nll = nll[i, s:wlen].to(torch.float64) - loss_sum += scored_nll.sum() - token_count += float(wlen - s) - tgt = y_batch[i, s:wlen] - prev = x_batch[i, s:wlen] - tb = base_bytes_lut[tgt].to(torch.float64) - tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) - byte_count += tb.sum() - if rank == 0 and (bi // batch_seqs) % 50 == 0: - done = min(bi + batch_seqs, len(my_windows)) - pct = done / len(my_windows) * 100 - running_bpb = 0.0 - if token_count.item() > 0: - rl = (loss_sum / token_count).item() - running_bpb = rl / math.log(2.0) * (token_count.item() / byte_count.item()) - print(f" sliding_eval [{pct:5.1f}%] {done}/{len(my_windows)} windows running_bpb={running_bpb:.6f}", flush=True) - - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) - - val_loss = (loss_sum / token_count).item() - bits_per_token = val_loss / math.log(2.0) - tokens_per_byte = token_count.item() / byte_count.item() - base_model.train() - return val_loss, bits_per_token * tokens_per_byte - - -# ----------------------------- -# TRAINING -# ----------------------------- - -def main() -> None: - global zeropower_via_newtonschulz5 - - code = Path(__file__).read_text(encoding="utf-8") - args = Hyperparameters() - zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) - - distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ - rank = int(os.environ.get("RANK", "0")) - world_size = int(os.environ.get("WORLD_SIZE", "1")) - local_rank = int(os.environ.get("LOCAL_RANK", "0")) - if world_size <= 0: - raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") - if 8 % world_size != 0: - raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") - grad_accum_steps = 8 // world_size - grad_scale = 1.0 / grad_accum_steps - if not torch.cuda.is_available(): - raise RuntimeError("CUDA is required") - device = torch.device("cuda", local_rank) - torch.cuda.set_device(device) - if distributed: - dist.init_process_group(backend="nccl", device_id=device) - dist.barrier() - master_process = rank == 0 - - torch.backends.cuda.matmul.allow_tf32 = True - torch.backends.cudnn.allow_tf32 = True - from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp - enable_cudnn_sdp(False) - enable_flash_sdp(True) - enable_mem_efficient_sdp(False) - enable_math_sdp(False) - - logfile = None - if master_process: - os.makedirs("logs", exist_ok=True) - logfile = f"logs/{args.run_id}.txt" - print(logfile) - - def log0(msg: str, console: bool = True) -> None: - if not master_process: - return - if console: - print(msg) - if logfile is not None: - with open(logfile, "a", encoding="utf-8") as f: - print(msg, file=f) - - log0(code, console=False) - log0("=" * 100, console=False) - log0(f"Running Python {sys.version}", console=False) - log0(f"Running PyTorch {torch.__version__}", console=False) - log0( - subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, - console=False, - ) - log0("=" * 100, console=False) - - random.seed(args.seed) - np.random.seed(args.seed) - torch.manual_seed(args.seed) - torch.cuda.manual_seed_all(args.seed) - - if not args.tokenizer_path.endswith(".model"): - raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") - sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) - if int(sp.vocab_size()) != args.vocab_size: - raise ValueError( - f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" - ) - dataset_dir = Path(args.data_path).resolve() - actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) - val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) - base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( - sp, args.vocab_size, device - ) - log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") - log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") - log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") - - # MODEL + OPTIMIZER SETUP - base_model = GPT( - vocab_size=args.vocab_size, - num_layers=args.num_layers, - model_dim=args.model_dim, - num_heads=args.num_heads, - num_kv_heads=args.num_kv_heads, - mlp_mult=args.mlp_mult, - tie_embeddings=args.tie_embeddings, - tied_embed_init_std=args.tied_embed_init_std, - logit_softcap=args.logit_softcap, - rope_base=args.rope_base, - qk_gain_init=args.qk_gain_init, - bigram_vocab_size=args.bigram_vocab_size, - bigram_dim=args.bigram_dim, - ).to(device).bfloat16() - for module in base_model.modules(): - if isinstance(module, CastedLinear): - module.float() - restore_low_dim_params_to_fp32(base_model) - compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) - model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model - - block_named_params = list(base_model.blocks.named_parameters()) - matrix_params = [ - p for name, p in block_named_params - if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - scalar_params = [ - p for name, p in block_named_params - if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - if base_model.skip_weights.numel() > 0: - scalar_params.append(base_model.skip_weights) - scalar_params.append(base_model.smear.gate) - if base_model.bigram is not None: - scalar_params.append(base_model.bigram.scale) - - token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr - tok_params = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] - if base_model.bigram is not None: - tok_params.append({"params": [base_model.bigram.embed.weight], "lr": token_lr, "base_lr": token_lr}) - if base_model.bigram.proj is not None: - matrix_params.append(base_model.bigram.proj.weight) - - optimizer_tok = torch.optim.AdamW( - tok_params, - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - weight_decay=args.weight_decay, - fused=True, - ) - optimizer_muon = Muon( - matrix_params, - lr=args.matrix_lr, - momentum=args.muon_momentum, - backend_steps=args.muon_backend_steps, - weight_decay=0.04, - ) - for group in optimizer_muon.param_groups: - group["base_lr"] = args.matrix_lr - optimizer_scalar = torch.optim.AdamW( - [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - weight_decay=args.weight_decay, - fused=True, - ) - optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] - if base_model.lm_head is not None: - optimizer_head = torch.optim.Adam( - [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizers.insert(1, optimizer_head) - - n_params = sum(p.numel() for p in base_model.parameters()) - log0(f"model_params:{n_params}") - log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") - log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") - log0( - f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " - f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" - ) - log0( - f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " - f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " - f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" - ) - log0(f"seed:{args.seed}") - - # DATA LOADER & MODEL WARMUP - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - def zero_grad_all() -> None: - for opt in optimizers: - opt.zero_grad(set_to_none=True) - - max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None - - def lr_mul(step: int, elapsed_ms: float) -> float: - if args.warmdown_iters <= 0: - return 1.0 - if max_wallclock_ms is None: - warmdown_start = max(args.iterations - args.warmdown_iters, 0) - return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 - step_ms = elapsed_ms / max(step, 1) - warmdown_ms = args.warmdown_iters * step_ms - remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) - return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 - - if args.warmup_steps > 0: - initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} - initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] - model.train() - for warmup_step in range(args.warmup_steps): - zero_grad_all() - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - warmup_loss = model(x, y) - (warmup_loss * grad_scale).backward() - for opt in optimizers: - opt.step() - zero_grad_all() - if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: - log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") - base_model.load_state_dict(initial_model_state, strict=True) - for opt, state in zip(optimizers, initial_optimizer_states, strict=True): - opt.load_state_dict(state) - zero_grad_all() - if distributed: - model.require_backward_grad_sync = True - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - # MAIN TRAINING LOOP - training_time_ms = 0.0 - stop_after_step: int | None = None - swa_state: dict[str, Tensor] | None = None - swa_count = 0 - torch.cuda.synchronize() - t0 = time.perf_counter() - - step = 0 - while True: - last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) - - should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) - if should_validate: - torch.cuda.synchronize() - training_time_ms += 1000.0 * (time.perf_counter() - t0) - val_loss, val_bpb = eval_val( - args, model, rank, world_size, device, grad_accum_steps, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - ) - log0( - f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " - f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" - ) - torch.cuda.synchronize() - t0 = time.perf_counter() - - if last_step: - if stop_after_step is not None and step < args.iterations: - log0( - f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " - f"step:{step}/{args.iterations}" - ) - break - - elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - scale = lr_mul(step, elapsed_ms) - zero_grad_all() - train_loss = torch.zeros((), device=device) - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - loss = model(x, y) - train_loss += loss.detach() - (loss * grad_scale).backward() - train_loss /= grad_accum_steps - - frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 - muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum - for group in optimizer_muon.param_groups: - group["momentum"] = muon_momentum - - for opt in optimizers: - for group in opt.param_groups: - group["lr"] = group["base_lr"] * scale - - if args.grad_clip_norm > 0: - torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) - for opt in optimizers: - opt.step() - zero_grad_all() - - step += 1 - approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - - # SWA: collect checkpoints during warmdown - if args.swa_enabled and scale < args.swa_start_frac and step % args.swa_every == 0: - if swa_state is None: - swa_state = {name: t.detach().cpu().clone() for name, t in base_model.state_dict().items()} - swa_count = 1 - log0(f"swa:start step:{step}") - else: - for name, t in base_model.state_dict().items(): - swa_state[name] += t.detach().cpu() - swa_count += 1 - - should_log_train = ( - args.train_log_every > 0 - and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) - ) - if should_log_train: - log0( - f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " - f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" - ) - - reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms - if distributed and max_wallclock_ms is not None: - reached_cap_tensor = torch.tensor(int(reached_cap), device=device) - dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) - reached_cap = bool(reached_cap_tensor.item()) - if stop_after_step is None and reached_cap: - stop_after_step = step - - log0( - f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " - f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" - ) - - # Apply SWA if collected - if args.swa_enabled and swa_state is not None and swa_count > 1: - log0(f"swa:applying averaged {swa_count} checkpoints") - current_state = base_model.state_dict() - avg_state = { - name: (tensor / swa_count).to(dtype=current_state[name].dtype) - for name, tensor in swa_state.items() - } - base_model.load_state_dict(avg_state, strict=True) - - # SERIALIZATION + ROUNDTRIP VALIDATION - if master_process: - torch.save(base_model.state_dict(), "final_model.pt") - model_bytes = os.path.getsize("final_model.pt") - code_bytes = len(code.encode("utf-8")) - log0(f"Serialized model: {model_bytes} bytes") - log0(f"Code size: {code_bytes} bytes") - log0(f"Total submission size: {model_bytes + code_bytes} bytes") - - # Magnitude pruning: zero out smallest weights to improve compression - with torch.no_grad(): - for name, param in base_model.named_parameters(): - if param.ndim == 2 and param.numel() > 65536: - threshold = torch.quantile(param.abs().float().flatten(), 0.03) - mask = param.abs() < threshold - param.masked_fill_(mask, 0.0) - - # INT6 mixed quantization + zstd/zlib export - sd_cpu = {k: v.detach().cpu() for k, v in base_model.state_dict().items()} - quant_result, quant_meta = mixed_quantize_int6(sd_cpu, {"mlp", "attn", "bigram"}) - quant_buf = io.BytesIO() - torch.save({"w": quant_result, "m": quant_meta}, quant_buf) - quant_raw = quant_buf.getvalue() - if _COMPRESSOR == "zstd": - quant_blob = zstandard.ZstdCompressor(level=22).compress(quant_raw) - else: - quant_blob = zlib.compress(quant_raw, 9) - if master_process: - with open("final_model.int8.ptz", "wb") as f: - f.write(quant_blob) - quant_file_bytes = os.path.getsize("final_model.int8.ptz") - code_bytes = len(code.encode("utf-8")) - log0(f"Serialized model int6+{_COMPRESSOR}: {quant_file_bytes} bytes") - log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") - - if distributed: - dist.barrier() - with open("final_model.int8.ptz", "rb") as f: - quant_blob_disk = f.read() - if _COMPRESSOR == "zstd": - decompressed = zstandard.ZstdDecompressor().decompress(quant_blob_disk) - else: - decompressed = zlib.decompress(quant_blob_disk) - quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") - deq_state = dequantize_mixed_int6(quant_state["w"], quant_state["m"], sd_cpu) - base_model.load_state_dict(deq_state, strict=True) - - # Sliding window eval on int6-roundtripped weights - torch.cuda.synchronize() - t_qeval = time.perf_counter() - if args.eval_stride > 0 and args.eval_stride < args.train_seq_len: - log0(f"final_eval_mode:sliding_window stride:{args.eval_stride} batch_seqs:{args.eval_batch_seqs}") - q_val_loss, q_val_bpb = eval_val_sliding( - args, base_model, rank, world_size, device, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - stride=args.eval_stride, batch_seqs=args.eval_batch_seqs, - ) - else: - log0("final_eval_mode:standard") - q_val_loss, q_val_bpb = eval_val( - args, model, rank, world_size, device, grad_accum_steps, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - ) - torch.cuda.synchronize() - log0( - f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " - f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" - ) - log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") - - if distributed: - dist.destroy_process_group() - - -if __name__ == "__main__": - main() -# fixes applied -# tuned diff --git a/records/track_10min_16mb/2026-03-20_10L_Int5MLP_MuonWD04_SWA50/train_seed1337.log b/records/track_10min_16mb/2026-03-20_10L_Int5MLP_MuonWD04_SWA50/train_seed1337.log deleted file mode 100644 index 37e2794d8f..0000000000 --- a/records/track_10min_16mb/2026-03-20_10L_Int5MLP_MuonWD04_SWA50/train_seed1337.log +++ /dev/null @@ -1,1425 +0,0 @@ -""" -The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. - -Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. -""" - -from __future__ import annotations - -import copy -import glob -import io -import math -import os -import random -import subprocess -import sys -import time -import uuid -import zlib -from pathlib import Path - -try: - import zstandard - _COMPRESSOR = "zstd" -except ImportError: - _COMPRESSOR = "zlib" - -import numpy as np -import sentencepiece as spm -import torch -import torch.distributed as dist -import torch.nn.functional as F -from torch import Tensor, nn -from torch.nn.parallel import DistributedDataParallel as DDP - -# ----------------------------- -# HYPERPARAMETERS -# ----------------------------- - -class Hyperparameters: - data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") - train_files = os.path.join(data_path, "fineweb_train_*.bin") - val_files = os.path.join(data_path, "fineweb_val_*.bin") - tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") - run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) - seed = int(os.environ.get("SEED", 42)) - - val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) - val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 500)) - train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 100)) - - iterations = int(os.environ.get("ITERATIONS", 20000)) - warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3000)) - warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) - train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 786_432)) - train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) - max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) - qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) - - vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) - num_layers = int(os.environ.get("NUM_LAYERS", 10)) - num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) - model_dim = int(os.environ.get("MODEL_DIM", 512)) - num_heads = int(os.environ.get("NUM_HEADS", 8)) - mlp_mult = float(os.environ.get("MLP_MULT", 3.0)) - tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) - rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) - logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) - - embed_lr = float(os.environ.get("EMBED_LR", 0.6)) - head_lr = float(os.environ.get("HEAD_LR", 0.008)) - tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.03)) - tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) - matrix_lr = float(os.environ.get("MATRIX_LR", 0.02)) - scalar_lr = float(os.environ.get("SCALAR_LR", 0.02)) - muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) - muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) - muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) - muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) - beta1 = float(os.environ.get("BETA1", 0.9)) - beta2 = float(os.environ.get("BETA2", 0.95)) - adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) - grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) - weight_decay = float(os.environ.get("WEIGHT_DECAY", 0.04)) - - eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) - eval_batch_seqs = int(os.environ.get("EVAL_BATCH_SEQS", 32)) - - bigram_vocab_size = int(os.environ.get("BIGRAM_VOCAB_SIZE", 10240)) - bigram_dim = int(os.environ.get("BIGRAM_DIM", 128)) - - swa_enabled = bool(int(os.environ.get("SWA_ENABLED", "1"))) - swa_start_frac = float(os.environ.get("SWA_START_FRAC", 0.4)) - swa_every = int(os.environ.get("SWA_EVERY", 50)) - -# ----------------------------- -# MUON OPTIMIZER -# ----------------------------- - -def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: - a, b, c = (3.4445, -4.7750, 2.0315) - X = G.bfloat16() - X /= X.norm() + eps - transposed = G.size(0) > G.size(1) - if transposed: - X = X.T - for _ in range(steps): - A = X @ X.T - B = b * A + c * A @ A - X = a * X + B @ X - return X.T if transposed else X - - -class Muon(torch.optim.Optimizer): - def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True, weight_decay: float = 0.0): - super().__init__( - params, - dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov, weight_decay=weight_decay), - ) - - @torch.no_grad() - def step(self, closure=None): - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - distributed = dist.is_available() and dist.is_initialized() - world_size = dist.get_world_size() if distributed else 1 - rank = dist.get_rank() if distributed else 0 - - for group in self.param_groups: - params = group["params"] - if not params: - continue - lr = group["lr"] - momentum = group["momentum"] - backend_steps = group["backend_steps"] - nesterov = group["nesterov"] - - total_params = sum(int(p.numel()) for p in params) - updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) - - curr = 0 - for i, p in enumerate(params): - if i % world_size == rank and p.grad is not None: - g = p.grad - state = self.state[p] - if "momentum_buffer" not in state: - state["momentum_buffer"] = torch.zeros_like(g) - buf = state["momentum_buffer"] - buf.mul_(momentum).add_(g) - if nesterov: - g = g.add(buf, alpha=momentum) - g = zeropower_via_newtonschulz5(g, steps=backend_steps) - g *= max(1, g.size(0) / g.size(1)) ** 0.5 - updates_flat[curr : curr + p.numel()] = g.reshape(-1) - curr += p.numel() - - if distributed: - dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) - - wd = group.get("weight_decay", 0.0) - curr = 0 - for p in params: - g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) - if wd > 0: - p.data.mul_(1.0 - lr * wd) - p.add_(g, alpha=-lr) - curr += p.numel() - return loss - - -# ----------------------------- -# TOKENIZER-AGNOSTIC EVALUATION -# ----------------------------- - -def build_sentencepiece_luts( - sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device -) -> tuple[Tensor, Tensor, Tensor]: - sp_vocab_size = int(sp.vocab_size()) - table_size = max(sp_vocab_size, vocab_size) - base_bytes_np = np.zeros((table_size,), dtype=np.int16) - has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) - is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) - for token_id in range(sp_vocab_size): - if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): - continue - is_boundary_token_np[token_id] = False - if sp.is_byte(token_id): - base_bytes_np[token_id] = 1 - continue - piece = sp.id_to_piece(token_id) - if piece.startswith("\u2581"): - has_leading_space_np[token_id] = True - piece = piece[1:] - base_bytes_np[token_id] = len(piece.encode("utf-8")) - return ( - torch.tensor(base_bytes_np, dtype=torch.int16, device=device), - torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), - torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), - ) - - -def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: - files = [Path(p) for p in sorted(glob.glob(pattern))] - if not files: - raise FileNotFoundError(f"No files found for pattern: {pattern}") - tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() - usable = ((tokens.numel() - 1) // seq_len) * seq_len - if usable <= 0: - raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") - return tokens[: usable + 1] - - -def eval_val( - args: Hyperparameters, - model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - grad_accum_steps: int, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, -) -> tuple[float, float]: - local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) - if local_batch_tokens < args.train_seq_len: - raise ValueError( - "VAL_BATCH_SIZE must provide at least one sequence per rank; " - f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " - f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" - ) - local_batch_seqs = local_batch_tokens // args.train_seq_len - total_seqs = (val_tokens.numel() - 1) // args.train_seq_len - seq_start = (total_seqs * rank) // world_size - seq_end = (total_seqs * (rank + 1)) // world_size - val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) - val_token_count = torch.zeros((), device=device, dtype=torch.float64) - val_byte_count = torch.zeros((), device=device, dtype=torch.float64) - model.eval() - with torch.inference_mode(): - for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): - batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) - raw_start = batch_seq_start * args.train_seq_len - raw_end = batch_seq_end * args.train_seq_len + 1 - local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) - x = local[:-1].reshape(-1, args.train_seq_len) - y = local[1:].reshape(-1, args.train_seq_len) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - batch_loss = model(x, y).detach() - batch_token_count = float(y.numel()) - val_loss_sum += batch_loss.to(torch.float64) * batch_token_count - val_token_count += batch_token_count - prev_ids = x.reshape(-1) - tgt_ids = y.reshape(-1) - token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) - token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) - val_byte_count += token_bytes.to(torch.float64).sum() - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) - val_loss = val_loss_sum / val_token_count - bits_per_token = val_loss.item() / math.log(2.0) - tokens_per_byte = val_token_count.item() / val_byte_count.item() - model.train() - return float(val_loss.item()), float(bits_per_token * tokens_per_byte) - - -# ----------------------------- -# POST-TRAINING QUANTIZATION (INT8 legacy + INT6 mixed) -# ----------------------------- - -CONTROL_TENSOR_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "CONTROL_TENSOR_NAME_PATTERNS", - "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,smear,bigram.scale", - ).split(",") - if pattern -) -FP16_KEEP_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get("FP16_KEEP_NAME_PATTERNS", "tok_emb,blocks.8.attn.c_k").split(",") - if pattern -) -INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", - ",".join(CONTROL_TENSOR_NAME_PATTERNS), - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 -INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 -INT8_PER_ROW_SCALE_DTYPE = torch.float16 -INT8_CLIP_PERCENTILE = 99.99984 -INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 - -def tensor_nbytes(t: Tensor) -> int: - return int(t.numel()) * int(t.element_size()) - -def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: - t32 = t.float() - if t32.ndim == 2: - clip_abs = ( - torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) - if t32.numel() - else torch.empty((t32.shape[0],), dtype=torch.float32) - ) - clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) - scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) - q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() - return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() - clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 - scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) - q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() - return q, scale - - -def _classify_param(name: str) -> str: - if "tok_emb" in name or "lm_head" in name: - return "embed" - if ".mlp." in name: - return "mlp" - if "bigram" in name: - return "bigram" - if ".attn." in name or (".proj." in name and ".mlp." not in name): - return "attn" - return "other" - -def quantize_intN_per_row(t: Tensor, clip_range: int = 31) -> tuple[Tensor, Tensor]: - t32 = t.float() - if t32.ndim == 2: - row_max = t32.abs().amax(dim=1) - scale = (row_max / clip_range).clamp_min(1e-12).to(torch.float16) - scale = scale.clamp_min(torch.finfo(torch.float16).tiny) - q = torch.clamp(torch.round(t32 / scale.float()[:, None]), -(clip_range+1), clip_range).to(torch.int8) - return q, scale - amax = t32.abs().max().item() - scale = torch.tensor(max(amax / clip_range, 1e-12), dtype=torch.float16) - q = torch.clamp(torch.round(t32 / scale.float()), -(clip_range+1), clip_range).to(torch.int8) - return q, scale - -def mixed_quantize_int6(state_dict: dict[str, Tensor], int6_cats: set[str]): - result: dict[str, Tensor] = {} - meta: dict[str, object] = {} - for name, tensor in state_dict.items(): - t = tensor.detach().cpu().contiguous() - cat = _classify_param(name) - if not t.is_floating_point() or t.numel() <= 8192: - result[name] = t.to(torch.float16) if t.is_floating_point() else t - meta[name] = "passthrough" - continue - if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): - result[name] = t.float() - meta[name] = "passthrough_ctrl" - continue - if any(pattern in name for pattern in FP16_KEEP_NAME_PATTERNS): - result[name] = t.to(dtype=torch.float16).contiguous() - meta[name] = "passthrough_fp16" - continue - if cat in int6_cats and t.ndim >= 1: - clip = 15 if cat == "mlp" else 31 # int5 for MLP, int6 for attention - q, s = quantize_intN_per_row(t, clip_range=clip) - result[name + ".q"] = q - result[name + ".scale"] = s - meta[name] = {"type": f"int{5 if cat == 'mlp' else 6}"} - else: - q, s = quantize_float_tensor(t) - result[name + ".q"] = q - result[name + ".scale"] = s - meta[name] = {"type": "int8"} - return result, meta - -def dequantize_mixed_int6(result: dict[str, Tensor], meta: dict[str, object], - template_sd: dict[str, Tensor]) -> dict[str, Tensor]: - out: dict[str, Tensor] = {} - for name, orig in template_sd.items(): - info = meta[name] - orig_dtype = orig.dtype - if info in ("passthrough", "passthrough_ctrl", "passthrough_fp16"): - t = result[name] - if t.dtype == torch.float16 and orig_dtype in (torch.float32, torch.bfloat16): - t = t.to(orig_dtype) - out[name] = t - continue - q, s = result[name + ".q"], result[name + ".scale"] - if s.ndim > 0: - out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig_dtype) - else: - out[name] = (q.float() * float(s.item())).to(orig_dtype) - return out - - -# ----------------------------- -# DATA LOADING -# ----------------------------- - -def load_data_shard(file: Path) -> Tensor: - header_bytes = 256 * np.dtype(" None: - self.file_idx = (self.file_idx + 1) % len(self.files) - self.tokens = load_data_shard(self.files[self.file_idx]) - self.pos = 0 - - def take(self, n: int) -> Tensor: - chunks: list[Tensor] = [] - remaining = n - while remaining > 0: - avail = self.tokens.numel() - self.pos - if avail <= 0: - self._advance_file() - continue - k = min(remaining, avail) - chunks.append(self.tokens[self.pos : self.pos + k]) - self.pos += k - remaining -= k - return chunks[0] if len(chunks) == 1 else torch.cat(chunks) - - -class DistributedTokenLoader: - def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): - self.rank = rank - self.world_size = world_size - self.device = device - self.stream = TokenStream(pattern) - - def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: - local_tokens = global_tokens // (self.world_size * grad_accum_steps) - per_rank_span = local_tokens + 1 - chunk = self.stream.take(per_rank_span * self.world_size) - start = self.rank * per_rank_span - local = chunk[start : start + per_rank_span].to(dtype=torch.int64) - x = local[:-1].reshape(-1, seq_len) - y = local[1:].reshape(-1, seq_len) - return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) - - -# ----------------------------- -# TRANSFORMER MODULES -# ----------------------------- - -class RMSNorm(nn.Module): - def __init__(self, eps: float | None = None): - super().__init__() - self.eps = eps - - def forward(self, x: Tensor) -> Tensor: - return F.rms_norm(x, (x.size(-1),), eps=self.eps) - - -class CastedLinear(nn.Linear): - def forward(self, x: Tensor) -> Tensor: - w = self.weight.to(x.dtype) - bias = self.bias.to(x.dtype) if self.bias is not None else None - return F.linear(x, w, bias) - - -def restore_low_dim_params_to_fp32(module: nn.Module) -> None: - with torch.no_grad(): - for name, param in module.named_parameters(): - if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: - param.data = param.data.float() - - -class Rotary(nn.Module): - def __init__(self, dim: int, base: float = 10000.0): - super().__init__() - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self._seq_len_cached = 0 - self._cos_cached: Tensor | None = None - self._sin_cached: Tensor | None = None - - def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: - if ( - self._cos_cached is None - or self._sin_cached is None - or self._seq_len_cached != seq_len - or self._cos_cached.device != device - ): - t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) - freqs = torch.outer(t, self.inv_freq.to(device)) - self._cos_cached = freqs.cos()[None, None, :, :] - self._sin_cached = freqs.sin()[None, None, :, :] - self._seq_len_cached = seq_len - return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) - - -def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: - half = x.size(-1) // 2 - x1, x2 = x[..., :half], x[..., half:] - return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) - - -class CausalSelfAttention(nn.Module): - def __init__(self, dim: int, num_heads: int, num_kv_heads: int, rope_base: float, qk_gain_init: float): - super().__init__() - if dim % num_heads != 0: - raise ValueError("model_dim must be divisible by num_heads") - if num_heads % num_kv_heads != 0: - raise ValueError("num_heads must be divisible by num_kv_heads") - self.num_heads = num_heads - self.num_kv_heads = num_kv_heads - self.head_dim = dim // num_heads - if self.head_dim % 2 != 0: - raise ValueError("head_dim must be even for RoPE") - kv_dim = self.num_kv_heads * self.head_dim - self.c_q = CastedLinear(dim, dim, bias=False) - self.c_k = CastedLinear(dim, kv_dim, bias=False) - self.c_v = CastedLinear(dim, kv_dim, bias=False) - self.proj = CastedLinear(dim, dim, bias=False) - self.proj._zero_init = True - self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) - self.rotary = Rotary(self.head_dim, base=rope_base) - - def forward(self, x: Tensor) -> Tensor: - bsz, seqlen, dim = x.shape - q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) - k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) - v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) - q = F.rms_norm(q, (q.size(-1),)) - k = F.rms_norm(k, (k.size(-1),)) - cos, sin = self.rotary(seqlen, x.device, q.dtype) - q = apply_rotary_emb(q, cos, sin) - k = apply_rotary_emb(k, cos, sin) - q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] - y = F.scaled_dot_product_attention( - q, k, v, attn_mask=None, is_causal=True, - enable_gqa=(self.num_kv_heads != self.num_heads), - ) - y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) - return self.proj(y) - - -class MLP(nn.Module): - def __init__(self, dim: int, mlp_mult: float): - super().__init__() - hidden = int(mlp_mult * dim) - self.fc = CastedLinear(dim, hidden, bias=False) - self.proj = CastedLinear(hidden, dim, bias=False) - self.proj._zero_init = True - - def forward(self, x: Tensor) -> Tensor: - x = torch.relu(self.fc(x)) - return self.proj(x.square()) - - -class SmearGate(nn.Module): - """Blend each token's embedding with the previous token's embedding.""" - def __init__(self, dim: int): - super().__init__() - self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) - - def forward(self, x: Tensor) -> Tensor: - g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] - x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) - return (1 - g) * x + g * x_prev - - -class BigramHashEmbedding(nn.Module): - """Hash consecutive token pairs into a learned embedding table.""" - def __init__(self, bigram_vocab_size: int, bigram_dim: int, model_dim: int): - super().__init__() - self.bigram_vocab_size = bigram_vocab_size - self.embed = nn.Embedding(bigram_vocab_size, bigram_dim) - nn.init.zeros_(self.embed.weight) - self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None - if self.proj is not None: - nn.init.zeros_(self.proj.weight) - self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) - - def bigram_hash(self, tokens: Tensor) -> Tensor: - t = tokens.to(torch.int32) - mod = self.bigram_vocab_size - 1 - out = torch.empty_like(t) - out[..., 0] = mod - out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod - return out.long() - - def forward(self, token_ids: Tensor) -> Tensor: - h = self.embed(self.bigram_hash(token_ids)) - if self.proj is not None: - h = self.proj(h) - return h * self.scale.to(dtype=h.dtype) - - -class Block(nn.Module): - def __init__(self, dim: int, num_heads: int, num_kv_heads: int, mlp_mult: float, rope_base: float, qk_gain_init: float): - super().__init__() - self.attn_norm = RMSNorm() - self.mlp_norm = RMSNorm() - self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) - self.mlp = MLP(dim, mlp_mult) - self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) - - def forward(self, x: Tensor, x0: Tensor) -> Tensor: - mix = self.resid_mix.to(dtype=x.dtype) - x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 - attn_out = self.attn(self.attn_norm(x)) - x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out - x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) - return x - - -class GPT(nn.Module): - def __init__( - self, - vocab_size: int, - num_layers: int, - model_dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: float, - tie_embeddings: bool, - tied_embed_init_std: float, - logit_softcap: float, - rope_base: float, - qk_gain_init: float, - bigram_vocab_size: int = 0, - bigram_dim: int = 128, - ): - super().__init__() - if logit_softcap <= 0.0: - raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") - self.tie_embeddings = tie_embeddings - self.tied_embed_init_std = tied_embed_init_std - self.logit_softcap = logit_softcap - self.tok_emb = nn.Embedding(vocab_size, model_dim) - self.bigram = BigramHashEmbedding(bigram_vocab_size, bigram_dim, model_dim) if bigram_vocab_size > 0 else None - self.num_encoder_layers = num_layers // 2 - self.num_decoder_layers = num_layers - self.num_encoder_layers - self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) - self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) - self.smear = SmearGate(model_dim) - self.blocks = nn.ModuleList( - [ - Block(model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init) - for _ in range(num_layers) - ] - ) - self.final_norm = RMSNorm() - self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) - if self.lm_head is not None: - self.lm_head._zero_init = True - self._init_weights() - - def _init_weights(self) -> None: - if self.tie_embeddings: - nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) - num_layers = len(self.blocks) - for name, module in self.named_modules(): - if isinstance(module, nn.Linear): - if getattr(module, "_zero_init", False): - nn.init.zeros_(module.weight) - elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: - nn.init.orthogonal_(module.weight, gain=1.0) - if ".proj." in name or name.endswith(".proj"): - with torch.no_grad(): - module.weight.mul_(1.0 / math.sqrt(2 * num_layers)) - - def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: - x = self.tok_emb(input_ids) - if self.bigram is not None: - x = x + self.bigram(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - x = self.smear(x) - x0 = x - skips: list[Tensor] = [] - for i in range(self.num_encoder_layers): - x = self.blocks[i](x, x0) - skips.append(x) - for i in range(self.num_decoder_layers): - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - x = self.blocks[self.num_encoder_layers + i](x, x0) - x = self.final_norm(x).reshape(-1, x.size(-1)) - targets = target_ids.reshape(-1) - if self.tie_embeddings: - logits_proj = F.linear(x, self.tok_emb.weight) - else: - if self.lm_head is None: - raise RuntimeError("lm_head is required when tie_embeddings=False") - logits_proj = self.lm_head(x) - logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - return F.cross_entropy(logits.float(), targets, reduction="mean") - - def forward_logits(self, input_ids: Tensor) -> Tensor: - x = self.tok_emb(input_ids) - if self.bigram is not None: - x = x + self.bigram(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - x = self.smear(x) - x0 = x - skips: list[Tensor] = [] - for i in range(self.num_encoder_layers): - x = self.blocks[i](x, x0) - skips.append(x) - for i in range(self.num_decoder_layers): - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - x = self.blocks[self.num_encoder_layers + i](x, x0) - x = self.final_norm(x) - if self.tie_embeddings: - logits_proj = F.linear(x, self.tok_emb.weight) - else: - logits_proj = self.lm_head(x) - return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - - -def eval_val_sliding( - args: Hyperparameters, - base_model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, - stride: int, - batch_seqs: int = 32, -) -> tuple[float, float]: - seq_len = args.train_seq_len - total_tokens = val_tokens.numel() - 1 - window_starts = [ws for ws in range(0, total_tokens, stride) - if min(ws + seq_len, total_tokens) - ws >= stride or ws == 0] - total_windows = len(window_starts) - my_s = (total_windows * rank) // world_size - my_e = (total_windows * (rank + 1)) // world_size - my_windows = window_starts[my_s:my_e] - - loss_sum = torch.zeros((), device=device, dtype=torch.float64) - token_count = torch.zeros((), device=device, dtype=torch.float64) - byte_count = torch.zeros((), device=device, dtype=torch.float64) - - base_model.eval() - with torch.inference_mode(): - for bi in range(0, len(my_windows), batch_seqs): - batch_ws = my_windows[bi:bi + batch_seqs] - bsz = len(batch_ws) - x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) - y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) - wlens: list[int] = [] - for i, ws in enumerate(batch_ws): - end = min(ws + seq_len, total_tokens) - wlen = end - ws - wlens.append(wlen) - chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) - x_batch[i, :wlen] = chunk[:-1] - y_batch[i, :wlen] = chunk[1:] - with torch.autocast(device_type="cuda", dtype=torch.bfloat16): - logits = base_model.forward_logits(x_batch) - nll = F.cross_entropy( - logits.reshape(-1, logits.size(-1)).float(), - y_batch.reshape(-1), - reduction="none", - ).reshape(bsz, seq_len) - for i, ws in enumerate(batch_ws): - wlen = wlens[i] - s = 0 if ws == 0 else max(wlen - stride, 0) - scored_nll = nll[i, s:wlen].to(torch.float64) - loss_sum += scored_nll.sum() - token_count += float(wlen - s) - tgt = y_batch[i, s:wlen] - prev = x_batch[i, s:wlen] - tb = base_bytes_lut[tgt].to(torch.float64) - tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) - byte_count += tb.sum() - if rank == 0 and (bi // batch_seqs) % 50 == 0: - done = min(bi + batch_seqs, len(my_windows)) - pct = done / len(my_windows) * 100 - running_bpb = 0.0 - if token_count.item() > 0: - rl = (loss_sum / token_count).item() - running_bpb = rl / math.log(2.0) * (token_count.item() / byte_count.item()) - print(f" sliding_eval [{pct:5.1f}%] {done}/{len(my_windows)} windows running_bpb={running_bpb:.6f}", flush=True) - - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) - - val_loss = (loss_sum / token_count).item() - bits_per_token = val_loss / math.log(2.0) - tokens_per_byte = token_count.item() / byte_count.item() - base_model.train() - return val_loss, bits_per_token * tokens_per_byte - - -# ----------------------------- -# TRAINING -# ----------------------------- - -def main() -> None: - global zeropower_via_newtonschulz5 - - code = Path(__file__).read_text(encoding="utf-8") - args = Hyperparameters() - zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) - - distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ - rank = int(os.environ.get("RANK", "0")) - world_size = int(os.environ.get("WORLD_SIZE", "1")) - local_rank = int(os.environ.get("LOCAL_RANK", "0")) - if world_size <= 0: - raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") - if 8 % world_size != 0: - raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") - grad_accum_steps = 8 // world_size - grad_scale = 1.0 / grad_accum_steps - if not torch.cuda.is_available(): - raise RuntimeError("CUDA is required") - device = torch.device("cuda", local_rank) - torch.cuda.set_device(device) - if distributed: - dist.init_process_group(backend="nccl", device_id=device) - dist.barrier() - master_process = rank == 0 - - torch.backends.cuda.matmul.allow_tf32 = True - torch.backends.cudnn.allow_tf32 = True - from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp - enable_cudnn_sdp(False) - enable_flash_sdp(True) - enable_mem_efficient_sdp(False) - enable_math_sdp(False) - - logfile = None - if master_process: - os.makedirs("logs", exist_ok=True) - logfile = f"logs/{args.run_id}.txt" - print(logfile) - - def log0(msg: str, console: bool = True) -> None: - if not master_process: - return - if console: - print(msg) - if logfile is not None: - with open(logfile, "a", encoding="utf-8") as f: - print(msg, file=f) - - log0(code, console=False) - log0("=" * 100, console=False) - log0(f"Running Python {sys.version}", console=False) - log0(f"Running PyTorch {torch.__version__}", console=False) - log0( - subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, - console=False, - ) - log0("=" * 100, console=False) - - random.seed(args.seed) - np.random.seed(args.seed) - torch.manual_seed(args.seed) - torch.cuda.manual_seed_all(args.seed) - - if not args.tokenizer_path.endswith(".model"): - raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") - sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) - if int(sp.vocab_size()) != args.vocab_size: - raise ValueError( - f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" - ) - dataset_dir = Path(args.data_path).resolve() - actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) - val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) - base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( - sp, args.vocab_size, device - ) - log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") - log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") - log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") - - # MODEL + OPTIMIZER SETUP - base_model = GPT( - vocab_size=args.vocab_size, - num_layers=args.num_layers, - model_dim=args.model_dim, - num_heads=args.num_heads, - num_kv_heads=args.num_kv_heads, - mlp_mult=args.mlp_mult, - tie_embeddings=args.tie_embeddings, - tied_embed_init_std=args.tied_embed_init_std, - logit_softcap=args.logit_softcap, - rope_base=args.rope_base, - qk_gain_init=args.qk_gain_init, - bigram_vocab_size=args.bigram_vocab_size, - bigram_dim=args.bigram_dim, - ).to(device).bfloat16() - for module in base_model.modules(): - if isinstance(module, CastedLinear): - module.float() - restore_low_dim_params_to_fp32(base_model) - compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) - model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model - - block_named_params = list(base_model.blocks.named_parameters()) - matrix_params = [ - p for name, p in block_named_params - if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - scalar_params = [ - p for name, p in block_named_params - if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - if base_model.skip_weights.numel() > 0: - scalar_params.append(base_model.skip_weights) - scalar_params.append(base_model.smear.gate) - if base_model.bigram is not None: - scalar_params.append(base_model.bigram.scale) - - token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr - tok_params = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] - if base_model.bigram is not None: - tok_params.append({"params": [base_model.bigram.embed.weight], "lr": token_lr, "base_lr": token_lr}) - if base_model.bigram.proj is not None: - matrix_params.append(base_model.bigram.proj.weight) - - optimizer_tok = torch.optim.AdamW( - tok_params, - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - weight_decay=args.weight_decay, - fused=True, - ) - optimizer_muon = Muon( - matrix_params, - lr=args.matrix_lr, - momentum=args.muon_momentum, - backend_steps=args.muon_backend_steps, - weight_decay=0.04, - ) - for group in optimizer_muon.param_groups: - group["base_lr"] = args.matrix_lr - optimizer_scalar = torch.optim.AdamW( - [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - weight_decay=args.weight_decay, - fused=True, - ) - optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] - if base_model.lm_head is not None: - optimizer_head = torch.optim.Adam( - [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizers.insert(1, optimizer_head) - - n_params = sum(p.numel() for p in base_model.parameters()) - log0(f"model_params:{n_params}") - log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") - log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") - log0( - f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " - f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" - ) - log0( - f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " - f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " - f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" - ) - log0(f"seed:{args.seed}") - - # DATA LOADER & MODEL WARMUP - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - def zero_grad_all() -> None: - for opt in optimizers: - opt.zero_grad(set_to_none=True) - - max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None - - def lr_mul(step: int, elapsed_ms: float) -> float: - if args.warmdown_iters <= 0: - return 1.0 - if max_wallclock_ms is None: - warmdown_start = max(args.iterations - args.warmdown_iters, 0) - return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 - step_ms = elapsed_ms / max(step, 1) - warmdown_ms = args.warmdown_iters * step_ms - remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) - return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 - - if args.warmup_steps > 0: - initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} - initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] - model.train() - for warmup_step in range(args.warmup_steps): - zero_grad_all() - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - warmup_loss = model(x, y) - (warmup_loss * grad_scale).backward() - for opt in optimizers: - opt.step() - zero_grad_all() - if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: - log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") - base_model.load_state_dict(initial_model_state, strict=True) - for opt, state in zip(optimizers, initial_optimizer_states, strict=True): - opt.load_state_dict(state) - zero_grad_all() - if distributed: - model.require_backward_grad_sync = True - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - # MAIN TRAINING LOOP - training_time_ms = 0.0 - stop_after_step: int | None = None - swa_state: dict[str, Tensor] | None = None - swa_count = 0 - torch.cuda.synchronize() - t0 = time.perf_counter() - - step = 0 - while True: - last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) - - should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) - if should_validate: - torch.cuda.synchronize() - training_time_ms += 1000.0 * (time.perf_counter() - t0) - val_loss, val_bpb = eval_val( - args, model, rank, world_size, device, grad_accum_steps, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - ) - log0( - f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " - f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" - ) - torch.cuda.synchronize() - t0 = time.perf_counter() - - if last_step: - if stop_after_step is not None and step < args.iterations: - log0( - f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " - f"step:{step}/{args.iterations}" - ) - break - - elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - scale = lr_mul(step, elapsed_ms) - zero_grad_all() - train_loss = torch.zeros((), device=device) - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - loss = model(x, y) - train_loss += loss.detach() - (loss * grad_scale).backward() - train_loss /= grad_accum_steps - - frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 - muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum - for group in optimizer_muon.param_groups: - group["momentum"] = muon_momentum - - for opt in optimizers: - for group in opt.param_groups: - group["lr"] = group["base_lr"] * scale - - if args.grad_clip_norm > 0: - torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) - for opt in optimizers: - opt.step() - zero_grad_all() - - step += 1 - approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - - # SWA: collect checkpoints during warmdown - if args.swa_enabled and scale < args.swa_start_frac and step % args.swa_every == 0: - if swa_state is None: - swa_state = {name: t.detach().cpu().clone() for name, t in base_model.state_dict().items()} - swa_count = 1 - log0(f"swa:start step:{step}") - else: - for name, t in base_model.state_dict().items(): - swa_state[name] += t.detach().cpu() - swa_count += 1 - - should_log_train = ( - args.train_log_every > 0 - and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) - ) - if should_log_train: - log0( - f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " - f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" - ) - - reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms - if distributed and max_wallclock_ms is not None: - reached_cap_tensor = torch.tensor(int(reached_cap), device=device) - dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) - reached_cap = bool(reached_cap_tensor.item()) - if stop_after_step is None and reached_cap: - stop_after_step = step - - log0( - f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " - f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" - ) - - # Apply SWA if collected - if args.swa_enabled and swa_state is not None and swa_count > 1: - log0(f"swa:applying averaged {swa_count} checkpoints") - current_state = base_model.state_dict() - avg_state = { - name: (tensor / swa_count).to(dtype=current_state[name].dtype) - for name, tensor in swa_state.items() - } - base_model.load_state_dict(avg_state, strict=True) - - # SERIALIZATION + ROUNDTRIP VALIDATION - if master_process: - torch.save(base_model.state_dict(), "final_model.pt") - model_bytes = os.path.getsize("final_model.pt") - code_bytes = len(code.encode("utf-8")) - log0(f"Serialized model: {model_bytes} bytes") - log0(f"Code size: {code_bytes} bytes") - log0(f"Total submission size: {model_bytes + code_bytes} bytes") - - # Magnitude pruning: zero out smallest weights to improve compression - with torch.no_grad(): - for name, param in base_model.named_parameters(): - if param.ndim == 2 and param.numel() > 65536: - threshold = torch.quantile(param.abs().float().flatten(), 0.03) - mask = param.abs() < threshold - param.masked_fill_(mask, 0.0) - - # INT6 mixed quantization + zstd/zlib export - sd_cpu = {k: v.detach().cpu() for k, v in base_model.state_dict().items()} - quant_result, quant_meta = mixed_quantize_int6(sd_cpu, {"mlp", "attn", "bigram"}) - quant_buf = io.BytesIO() - torch.save({"w": quant_result, "m": quant_meta}, quant_buf) - quant_raw = quant_buf.getvalue() - if _COMPRESSOR == "zstd": - quant_blob = zstandard.ZstdCompressor(level=22).compress(quant_raw) - else: - quant_blob = zlib.compress(quant_raw, 9) - if master_process: - with open("final_model.int8.ptz", "wb") as f: - f.write(quant_blob) - quant_file_bytes = os.path.getsize("final_model.int8.ptz") - code_bytes = len(code.encode("utf-8")) - log0(f"Serialized model int6+{_COMPRESSOR}: {quant_file_bytes} bytes") - log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") - - if distributed: - dist.barrier() - with open("final_model.int8.ptz", "rb") as f: - quant_blob_disk = f.read() - if _COMPRESSOR == "zstd": - decompressed = zstandard.ZstdDecompressor().decompress(quant_blob_disk) - else: - decompressed = zlib.decompress(quant_blob_disk) - quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") - deq_state = dequantize_mixed_int6(quant_state["w"], quant_state["m"], sd_cpu) - base_model.load_state_dict(deq_state, strict=True) - - # Sliding window eval on int6-roundtripped weights - torch.cuda.synchronize() - t_qeval = time.perf_counter() - if args.eval_stride > 0 and args.eval_stride < args.train_seq_len: - log0(f"final_eval_mode:sliding_window stride:{args.eval_stride} batch_seqs:{args.eval_batch_seqs}") - q_val_loss, q_val_bpb = eval_val_sliding( - args, base_model, rank, world_size, device, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - stride=args.eval_stride, batch_seqs=args.eval_batch_seqs, - ) - else: - log0("final_eval_mode:standard") - q_val_loss, q_val_bpb = eval_val( - args, model, rank, world_size, device, grad_accum_steps, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - ) - torch.cuda.synchronize() - log0( - f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " - f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" - ) - log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") - - if distributed: - dist.destroy_process_group() - - -if __name__ == "__main__": - main() -# fixes applied -# tuned - -==================================================================================================== -Running Python 3.13.12 | packaged by Anaconda, Inc. | (main, Feb 24 2026, 16:13:31) [GCC 14.3.0] -Running PyTorch 2.10.0+cu128 -Fri Mar 20 09:00:51 2026 -+-----------------------------------------------------------------------------------------+ -| NVIDIA-SMI 590.48.01 Driver Version: 590.48.01 CUDA Version: 13.1 | -+-----------------------------------------+------------------------+----------------------+ -| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | -| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | -| | | MIG M. | -|=========================================+========================+======================| -| 0 NVIDIA H100 80GB HBM3 On | 00000000:19:00.0 Off | 0 | -| N/A 33C P0 119W / 700W | 1525MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 1 NVIDIA H100 80GB HBM3 On | 00000000:2D:00.0 Off | 0 | -| N/A 35C P0 122W / 700W | 1525MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 2 NVIDIA H100 80GB HBM3 On | 00000000:3F:00.0 Off | 0 | -| N/A 37C P0 123W / 700W | 1525MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 3 NVIDIA H100 80GB HBM3 On | 00000000:66:00.0 Off | 0 | -| N/A 33C P0 120W / 700W | 1525MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 4 NVIDIA H100 80GB HBM3 On | 00000000:9B:00.0 Off | 0 | -| N/A 33C P0 117W / 700W | 1525MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 5 NVIDIA H100 80GB HBM3 On | 00000000:AE:00.0 Off | 0 | -| N/A 37C P0 118W / 700W | 1525MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 6 NVIDIA H100 80GB HBM3 On | 00000000:BF:00.0 Off | 0 | -| N/A 36C P0 123W / 700W | 1525MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 7 NVIDIA H100 80GB HBM3 On | 00000000:E4:00.0 Off | 0 | -| N/A 33C P0 117W / 700W | 1525MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ - -+-----------------------------------------------------------------------------------------+ -| Processes: | -| GPU GI CI PID Type Process name GPU Memory | -| ID ID Usage | -|=========================================================================================| -| 0 N/A N/A 1396149 C ...hao/miniconda3/bin/python3.13 1516MiB | -| 1 N/A N/A 1396150 C ...hao/miniconda3/bin/python3.13 1516MiB | -| 2 N/A N/A 1396151 C ...hao/miniconda3/bin/python3.13 1516MiB | -| 3 N/A N/A 1396152 C ...hao/miniconda3/bin/python3.13 1516MiB | -| 4 N/A N/A 1396153 C ...hao/miniconda3/bin/python3.13 1516MiB | -| 5 N/A N/A 1396154 C ...hao/miniconda3/bin/python3.13 1516MiB | -| 6 N/A N/A 1396155 C ...hao/miniconda3/bin/python3.13 1516MiB | -| 7 N/A N/A 1396156 C ...hao/miniconda3/bin/python3.13 1516MiB | -+-----------------------------------------------------------------------------------------+ - -==================================================================================================== -val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model -train_loader:dataset:fineweb10B_sp1024 train_shards:80 -val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 -model_params:25517137 -world_size:8 grad_accum_steps:1 -attention_mode:gqa num_heads:8 num_kv_heads:4 -tie_embeddings:True embed_lr:0.03 matrix_lr:0.02 scalar_lr:0.02 -train_batch_tokens:786432 train_seq_len:2048 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 -seed:1337 -warmup_step:1/20 -warmup_step:2/20 -warmup_step:3/20 -warmup_step:4/20 -warmup_step:5/20 -warmup_step:6/20 -warmup_step:7/20 -warmup_step:8/20 -warmup_step:9/20 -warmup_step:10/20 -warmup_step:11/20 -warmup_step:12/20 -warmup_step:13/20 -warmup_step:14/20 -warmup_step:15/20 -warmup_step:16/20 -warmup_step:17/20 -warmup_step:18/20 -warmup_step:19/20 -warmup_step:20/20 -step:0/20000 val_loss:6.9294 val_bpb:4.1040 train_time:0ms step_avg:0.01ms -step:1/20000 train_loss:6.9301 train_time:130ms step_avg:130.32ms -step:2/20000 train_loss:7.9520 train_time:195ms step_avg:97.71ms -step:3/20000 train_loss:7.5214 train_time:282ms step_avg:94.02ms -step:4/20000 train_loss:6.9070 train_time:369ms step_avg:92.14ms -step:5/20000 train_loss:6.7675 train_time:455ms step_avg:91.07ms -step:6/20000 train_loss:6.7193 train_time:543ms step_avg:90.42ms -step:7/20000 train_loss:6.5964 train_time:630ms step_avg:89.93ms -step:8/20000 train_loss:6.4789 train_time:716ms step_avg:89.49ms -step:9/20000 train_loss:6.2072 train_time:803ms step_avg:89.27ms -step:10/20000 train_loss:5.9974 train_time:891ms step_avg:89.09ms -step:100/20000 train_loss:3.1676 train_time:8824ms step_avg:88.24ms -step:200/20000 train_loss:2.3773 train_time:17751ms step_avg:88.76ms -step:300/20000 train_loss:2.5384 train_time:26695ms step_avg:88.98ms -step:400/20000 train_loss:2.4039 train_time:35636ms step_avg:89.09ms -step:500/20000 train_loss:2.3881 train_time:44479ms step_avg:88.96ms -step:500/20000 val_loss:2.3494 val_bpb:1.3914 train_time:44503ms step_avg:89.01ms -step:600/20000 train_loss:2.3305 train_time:53414ms step_avg:89.02ms -step:700/20000 train_loss:2.3452 train_time:62364ms step_avg:89.09ms -step:800/20000 train_loss:2.2337 train_time:71316ms step_avg:89.15ms -step:900/20000 train_loss:2.1325 train_time:80251ms step_avg:89.17ms -step:1000/20000 train_loss:2.2755 train_time:89100ms step_avg:89.10ms -step:1000/20000 val_loss:2.2289 val_bpb:1.3201 train_time:89123ms step_avg:89.12ms -step:1100/20000 train_loss:2.3292 train_time:98045ms step_avg:89.13ms -step:1200/20000 train_loss:2.3543 train_time:107003ms step_avg:89.17ms -step:1300/20000 train_loss:2.1038 train_time:115960ms step_avg:89.20ms -step:1400/20000 train_loss:2.1846 train_time:124918ms step_avg:89.23ms -step:1500/20000 train_loss:2.2223 train_time:133766ms step_avg:89.18ms -step:1500/20000 val_loss:2.1862 val_bpb:1.2948 train_time:133791ms step_avg:89.19ms -step:1600/20000 train_loss:2.0782 train_time:142730ms step_avg:89.21ms -step:1700/20000 train_loss:2.1447 train_time:151687ms step_avg:89.23ms -step:1800/20000 train_loss:2.1678 train_time:160640ms step_avg:89.24ms -step:1900/20000 train_loss:2.1295 train_time:169506ms step_avg:89.21ms -step:2000/20000 train_loss:2.0681 train_time:178463ms step_avg:89.23ms -step:2000/20000 val_loss:2.1350 val_bpb:1.2645 train_time:178486ms step_avg:89.24ms -step:2100/20000 train_loss:2.0490 train_time:187415ms step_avg:89.25ms -step:2200/20000 train_loss:2.1476 train_time:196373ms step_avg:89.26ms -step:2300/20000 train_loss:2.1116 train_time:205332ms step_avg:89.27ms -step:2400/20000 train_loss:2.0671 train_time:214189ms step_avg:89.25ms -step:2500/20000 train_loss:2.1701 train_time:223144ms step_avg:89.26ms -step:2500/20000 val_loss:2.1082 val_bpb:1.2486 train_time:223169ms step_avg:89.27ms -step:2600/20000 train_loss:2.1107 train_time:232104ms step_avg:89.27ms -step:2700/20000 train_loss:2.1003 train_time:241053ms step_avg:89.28ms -step:2800/20000 train_loss:2.1585 train_time:250014ms step_avg:89.29ms -step:2900/20000 train_loss:2.0284 train_time:258873ms step_avg:89.27ms -step:3000/20000 train_loss:2.1618 train_time:267822ms step_avg:89.27ms -step:3000/20000 val_loss:2.0930 val_bpb:1.2396 train_time:267847ms step_avg:89.28ms -step:3100/20000 train_loss:2.0408 train_time:276775ms step_avg:89.28ms -step:3200/20000 train_loss:2.1751 train_time:285728ms step_avg:89.29ms -step:3300/20000 train_loss:2.0734 train_time:294590ms step_avg:89.27ms -step:3400/20000 train_loss:2.0257 train_time:303549ms step_avg:89.28ms -step:3500/20000 train_loss:2.1861 train_time:312503ms step_avg:89.29ms -step:3500/20000 val_loss:2.0858 val_bpb:1.2353 train_time:312527ms step_avg:89.29ms -step:3600/20000 train_loss:2.0995 train_time:321470ms step_avg:89.30ms -step:3700/20000 train_loss:2.1005 train_time:330435ms step_avg:89.31ms -step:3800/20000 train_loss:2.0772 train_time:339288ms step_avg:89.29ms -step:3900/20000 train_loss:2.0822 train_time:348240ms step_avg:89.29ms -step:4000/20000 train_loss:1.9797 train_time:357194ms step_avg:89.30ms -step:4000/20000 val_loss:2.0728 val_bpb:1.2276 train_time:357219ms step_avg:89.30ms -step:4100/20000 train_loss:2.0222 train_time:366145ms step_avg:89.30ms -step:4200/20000 train_loss:2.1538 train_time:375112ms step_avg:89.31ms -step:4300/20000 train_loss:2.0591 train_time:383969ms step_avg:89.30ms -step:4400/20000 train_loss:2.0388 train_time:392915ms step_avg:89.30ms -step:4500/20000 train_loss:2.1274 train_time:401871ms step_avg:89.30ms -step:4500/20000 val_loss:2.0489 val_bpb:1.2135 train_time:401896ms step_avg:89.31ms -step:4600/20000 train_loss:1.8466 train_time:410829ms step_avg:89.31ms -step:4700/20000 train_loss:2.2395 train_time:419693ms step_avg:89.30ms -step:4800/20000 train_loss:2.4328 train_time:428647ms step_avg:89.30ms -step:4900/20000 train_loss:2.0547 train_time:437593ms step_avg:89.30ms -step:5000/20000 train_loss:2.1083 train_time:446543ms step_avg:89.31ms -step:5000/20000 val_loss:2.0278 val_bpb:1.2010 train_time:446567ms step_avg:89.31ms -step:5100/20000 train_loss:2.1281 train_time:455499ms step_avg:89.31ms -step:5200/20000 train_loss:2.0457 train_time:464350ms step_avg:89.30ms -step:5300/20000 train_loss:2.0089 train_time:473303ms step_avg:89.30ms -step:5400/20000 train_loss:2.0496 train_time:482254ms step_avg:89.31ms -step:5500/20000 train_loss:2.0190 train_time:491211ms step_avg:89.31ms -step:5500/20000 val_loss:2.0048 val_bpb:1.1873 train_time:491235ms step_avg:89.32ms -swa:start step:5550 -step:5600/20000 train_loss:1.9565 train_time:500228ms step_avg:89.33ms -step:5700/20000 train_loss:2.0161 train_time:509157ms step_avg:89.33ms -step:5800/20000 train_loss:1.9981 train_time:518160ms step_avg:89.34ms -step:5900/20000 train_loss:1.9062 train_time:527154ms step_avg:89.35ms -step:6000/20000 train_loss:1.9428 train_time:536166ms step_avg:89.36ms -step:6000/20000 val_loss:1.9814 val_bpb:1.1735 train_time:536215ms step_avg:89.37ms -step:6100/20000 train_loss:1.9190 train_time:545074ms step_avg:89.36ms -step:6200/20000 train_loss:1.9497 train_time:554086ms step_avg:89.37ms -step:6300/20000 train_loss:1.9479 train_time:563098ms step_avg:89.38ms -step:6400/20000 train_loss:1.9982 train_time:572118ms step_avg:89.39ms -step:6500/20000 train_loss:2.0839 train_time:581126ms step_avg:89.40ms -step:6500/20000 val_loss:1.9542 val_bpb:1.1574 train_time:581188ms step_avg:89.41ms -step:6600/20000 train_loss:1.8412 train_time:590051ms step_avg:89.40ms -step:6700/20000 train_loss:1.9494 train_time:599052ms step_avg:89.41ms -step:6711/20000 val_loss:1.9473 val_bpb:1.1533 train_time:600083ms step_avg:89.42ms -stopping_early: wallclock_cap train_time:600083ms step:6711/20000 -peak memory allocated: 18871 MiB reserved: 18976 MiB -swa:applying averaged 24 checkpoints -Serialized model: 98437419 bytes -Code size: 52930 bytes -Total submission size: 98490349 bytes -Serialized model int6+zstd: 15777256 bytes -Total submission size int8+zlib: 15830186 bytes -final_eval_mode:sliding_window stride:64 batch_seqs:32 -final_int8_zlib_roundtrip val_loss:1.9299 val_bpb:1.1430 eval_time:168222ms -final_int8_zlib_roundtrip_exact val_loss:1.92987577 val_bpb:1.14298416 diff --git a/records/track_10min_16mb/2026-03-20_10L_Int5MLP_MuonWD04_SWA50/train_seed2024.log b/records/track_10min_16mb/2026-03-20_10L_Int5MLP_MuonWD04_SWA50/train_seed2024.log deleted file mode 100644 index 71d68549d3..0000000000 --- a/records/track_10min_16mb/2026-03-20_10L_Int5MLP_MuonWD04_SWA50/train_seed2024.log +++ /dev/null @@ -1,1425 +0,0 @@ -""" -The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. - -Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. -""" - -from __future__ import annotations - -import copy -import glob -import io -import math -import os -import random -import subprocess -import sys -import time -import uuid -import zlib -from pathlib import Path - -try: - import zstandard - _COMPRESSOR = "zstd" -except ImportError: - _COMPRESSOR = "zlib" - -import numpy as np -import sentencepiece as spm -import torch -import torch.distributed as dist -import torch.nn.functional as F -from torch import Tensor, nn -from torch.nn.parallel import DistributedDataParallel as DDP - -# ----------------------------- -# HYPERPARAMETERS -# ----------------------------- - -class Hyperparameters: - data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") - train_files = os.path.join(data_path, "fineweb_train_*.bin") - val_files = os.path.join(data_path, "fineweb_val_*.bin") - tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") - run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) - seed = int(os.environ.get("SEED", 42)) - - val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) - val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 500)) - train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 100)) - - iterations = int(os.environ.get("ITERATIONS", 20000)) - warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3000)) - warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) - train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 786_432)) - train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) - max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) - qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) - - vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) - num_layers = int(os.environ.get("NUM_LAYERS", 10)) - num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) - model_dim = int(os.environ.get("MODEL_DIM", 512)) - num_heads = int(os.environ.get("NUM_HEADS", 8)) - mlp_mult = float(os.environ.get("MLP_MULT", 3.0)) - tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) - rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) - logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) - - embed_lr = float(os.environ.get("EMBED_LR", 0.6)) - head_lr = float(os.environ.get("HEAD_LR", 0.008)) - tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.03)) - tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) - matrix_lr = float(os.environ.get("MATRIX_LR", 0.02)) - scalar_lr = float(os.environ.get("SCALAR_LR", 0.02)) - muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) - muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) - muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) - muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) - beta1 = float(os.environ.get("BETA1", 0.9)) - beta2 = float(os.environ.get("BETA2", 0.95)) - adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) - grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) - weight_decay = float(os.environ.get("WEIGHT_DECAY", 0.04)) - - eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) - eval_batch_seqs = int(os.environ.get("EVAL_BATCH_SEQS", 32)) - - bigram_vocab_size = int(os.environ.get("BIGRAM_VOCAB_SIZE", 10240)) - bigram_dim = int(os.environ.get("BIGRAM_DIM", 128)) - - swa_enabled = bool(int(os.environ.get("SWA_ENABLED", "1"))) - swa_start_frac = float(os.environ.get("SWA_START_FRAC", 0.4)) - swa_every = int(os.environ.get("SWA_EVERY", 50)) - -# ----------------------------- -# MUON OPTIMIZER -# ----------------------------- - -def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: - a, b, c = (3.4445, -4.7750, 2.0315) - X = G.bfloat16() - X /= X.norm() + eps - transposed = G.size(0) > G.size(1) - if transposed: - X = X.T - for _ in range(steps): - A = X @ X.T - B = b * A + c * A @ A - X = a * X + B @ X - return X.T if transposed else X - - -class Muon(torch.optim.Optimizer): - def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True, weight_decay: float = 0.0): - super().__init__( - params, - dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov, weight_decay=weight_decay), - ) - - @torch.no_grad() - def step(self, closure=None): - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - distributed = dist.is_available() and dist.is_initialized() - world_size = dist.get_world_size() if distributed else 1 - rank = dist.get_rank() if distributed else 0 - - for group in self.param_groups: - params = group["params"] - if not params: - continue - lr = group["lr"] - momentum = group["momentum"] - backend_steps = group["backend_steps"] - nesterov = group["nesterov"] - - total_params = sum(int(p.numel()) for p in params) - updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) - - curr = 0 - for i, p in enumerate(params): - if i % world_size == rank and p.grad is not None: - g = p.grad - state = self.state[p] - if "momentum_buffer" not in state: - state["momentum_buffer"] = torch.zeros_like(g) - buf = state["momentum_buffer"] - buf.mul_(momentum).add_(g) - if nesterov: - g = g.add(buf, alpha=momentum) - g = zeropower_via_newtonschulz5(g, steps=backend_steps) - g *= max(1, g.size(0) / g.size(1)) ** 0.5 - updates_flat[curr : curr + p.numel()] = g.reshape(-1) - curr += p.numel() - - if distributed: - dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) - - wd = group.get("weight_decay", 0.0) - curr = 0 - for p in params: - g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) - if wd > 0: - p.data.mul_(1.0 - lr * wd) - p.add_(g, alpha=-lr) - curr += p.numel() - return loss - - -# ----------------------------- -# TOKENIZER-AGNOSTIC EVALUATION -# ----------------------------- - -def build_sentencepiece_luts( - sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device -) -> tuple[Tensor, Tensor, Tensor]: - sp_vocab_size = int(sp.vocab_size()) - table_size = max(sp_vocab_size, vocab_size) - base_bytes_np = np.zeros((table_size,), dtype=np.int16) - has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) - is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) - for token_id in range(sp_vocab_size): - if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): - continue - is_boundary_token_np[token_id] = False - if sp.is_byte(token_id): - base_bytes_np[token_id] = 1 - continue - piece = sp.id_to_piece(token_id) - if piece.startswith("\u2581"): - has_leading_space_np[token_id] = True - piece = piece[1:] - base_bytes_np[token_id] = len(piece.encode("utf-8")) - return ( - torch.tensor(base_bytes_np, dtype=torch.int16, device=device), - torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), - torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), - ) - - -def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: - files = [Path(p) for p in sorted(glob.glob(pattern))] - if not files: - raise FileNotFoundError(f"No files found for pattern: {pattern}") - tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() - usable = ((tokens.numel() - 1) // seq_len) * seq_len - if usable <= 0: - raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") - return tokens[: usable + 1] - - -def eval_val( - args: Hyperparameters, - model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - grad_accum_steps: int, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, -) -> tuple[float, float]: - local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) - if local_batch_tokens < args.train_seq_len: - raise ValueError( - "VAL_BATCH_SIZE must provide at least one sequence per rank; " - f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " - f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" - ) - local_batch_seqs = local_batch_tokens // args.train_seq_len - total_seqs = (val_tokens.numel() - 1) // args.train_seq_len - seq_start = (total_seqs * rank) // world_size - seq_end = (total_seqs * (rank + 1)) // world_size - val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) - val_token_count = torch.zeros((), device=device, dtype=torch.float64) - val_byte_count = torch.zeros((), device=device, dtype=torch.float64) - model.eval() - with torch.inference_mode(): - for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): - batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) - raw_start = batch_seq_start * args.train_seq_len - raw_end = batch_seq_end * args.train_seq_len + 1 - local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) - x = local[:-1].reshape(-1, args.train_seq_len) - y = local[1:].reshape(-1, args.train_seq_len) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - batch_loss = model(x, y).detach() - batch_token_count = float(y.numel()) - val_loss_sum += batch_loss.to(torch.float64) * batch_token_count - val_token_count += batch_token_count - prev_ids = x.reshape(-1) - tgt_ids = y.reshape(-1) - token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) - token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) - val_byte_count += token_bytes.to(torch.float64).sum() - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) - val_loss = val_loss_sum / val_token_count - bits_per_token = val_loss.item() / math.log(2.0) - tokens_per_byte = val_token_count.item() / val_byte_count.item() - model.train() - return float(val_loss.item()), float(bits_per_token * tokens_per_byte) - - -# ----------------------------- -# POST-TRAINING QUANTIZATION (INT8 legacy + INT6 mixed) -# ----------------------------- - -CONTROL_TENSOR_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "CONTROL_TENSOR_NAME_PATTERNS", - "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,smear,bigram.scale", - ).split(",") - if pattern -) -FP16_KEEP_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get("FP16_KEEP_NAME_PATTERNS", "tok_emb,blocks.8.attn.c_k").split(",") - if pattern -) -INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", - ",".join(CONTROL_TENSOR_NAME_PATTERNS), - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 -INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 -INT8_PER_ROW_SCALE_DTYPE = torch.float16 -INT8_CLIP_PERCENTILE = 99.99984 -INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 - -def tensor_nbytes(t: Tensor) -> int: - return int(t.numel()) * int(t.element_size()) - -def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: - t32 = t.float() - if t32.ndim == 2: - clip_abs = ( - torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) - if t32.numel() - else torch.empty((t32.shape[0],), dtype=torch.float32) - ) - clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) - scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) - q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() - return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() - clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 - scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) - q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() - return q, scale - - -def _classify_param(name: str) -> str: - if "tok_emb" in name or "lm_head" in name: - return "embed" - if ".mlp." in name: - return "mlp" - if "bigram" in name: - return "bigram" - if ".attn." in name or (".proj." in name and ".mlp." not in name): - return "attn" - return "other" - -def quantize_intN_per_row(t: Tensor, clip_range: int = 31) -> tuple[Tensor, Tensor]: - t32 = t.float() - if t32.ndim == 2: - row_max = t32.abs().amax(dim=1) - scale = (row_max / clip_range).clamp_min(1e-12).to(torch.float16) - scale = scale.clamp_min(torch.finfo(torch.float16).tiny) - q = torch.clamp(torch.round(t32 / scale.float()[:, None]), -(clip_range+1), clip_range).to(torch.int8) - return q, scale - amax = t32.abs().max().item() - scale = torch.tensor(max(amax / clip_range, 1e-12), dtype=torch.float16) - q = torch.clamp(torch.round(t32 / scale.float()), -(clip_range+1), clip_range).to(torch.int8) - return q, scale - -def mixed_quantize_int6(state_dict: dict[str, Tensor], int6_cats: set[str]): - result: dict[str, Tensor] = {} - meta: dict[str, object] = {} - for name, tensor in state_dict.items(): - t = tensor.detach().cpu().contiguous() - cat = _classify_param(name) - if not t.is_floating_point() or t.numel() <= 8192: - result[name] = t.to(torch.float16) if t.is_floating_point() else t - meta[name] = "passthrough" - continue - if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): - result[name] = t.float() - meta[name] = "passthrough_ctrl" - continue - if any(pattern in name for pattern in FP16_KEEP_NAME_PATTERNS): - result[name] = t.to(dtype=torch.float16).contiguous() - meta[name] = "passthrough_fp16" - continue - if cat in int6_cats and t.ndim >= 1: - clip = 15 if cat == "mlp" else 31 # int5 for MLP, int6 for attention - q, s = quantize_intN_per_row(t, clip_range=clip) - result[name + ".q"] = q - result[name + ".scale"] = s - meta[name] = {"type": f"int{5 if cat == 'mlp' else 6}"} - else: - q, s = quantize_float_tensor(t) - result[name + ".q"] = q - result[name + ".scale"] = s - meta[name] = {"type": "int8"} - return result, meta - -def dequantize_mixed_int6(result: dict[str, Tensor], meta: dict[str, object], - template_sd: dict[str, Tensor]) -> dict[str, Tensor]: - out: dict[str, Tensor] = {} - for name, orig in template_sd.items(): - info = meta[name] - orig_dtype = orig.dtype - if info in ("passthrough", "passthrough_ctrl", "passthrough_fp16"): - t = result[name] - if t.dtype == torch.float16 and orig_dtype in (torch.float32, torch.bfloat16): - t = t.to(orig_dtype) - out[name] = t - continue - q, s = result[name + ".q"], result[name + ".scale"] - if s.ndim > 0: - out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig_dtype) - else: - out[name] = (q.float() * float(s.item())).to(orig_dtype) - return out - - -# ----------------------------- -# DATA LOADING -# ----------------------------- - -def load_data_shard(file: Path) -> Tensor: - header_bytes = 256 * np.dtype(" None: - self.file_idx = (self.file_idx + 1) % len(self.files) - self.tokens = load_data_shard(self.files[self.file_idx]) - self.pos = 0 - - def take(self, n: int) -> Tensor: - chunks: list[Tensor] = [] - remaining = n - while remaining > 0: - avail = self.tokens.numel() - self.pos - if avail <= 0: - self._advance_file() - continue - k = min(remaining, avail) - chunks.append(self.tokens[self.pos : self.pos + k]) - self.pos += k - remaining -= k - return chunks[0] if len(chunks) == 1 else torch.cat(chunks) - - -class DistributedTokenLoader: - def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): - self.rank = rank - self.world_size = world_size - self.device = device - self.stream = TokenStream(pattern) - - def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: - local_tokens = global_tokens // (self.world_size * grad_accum_steps) - per_rank_span = local_tokens + 1 - chunk = self.stream.take(per_rank_span * self.world_size) - start = self.rank * per_rank_span - local = chunk[start : start + per_rank_span].to(dtype=torch.int64) - x = local[:-1].reshape(-1, seq_len) - y = local[1:].reshape(-1, seq_len) - return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) - - -# ----------------------------- -# TRANSFORMER MODULES -# ----------------------------- - -class RMSNorm(nn.Module): - def __init__(self, eps: float | None = None): - super().__init__() - self.eps = eps - - def forward(self, x: Tensor) -> Tensor: - return F.rms_norm(x, (x.size(-1),), eps=self.eps) - - -class CastedLinear(nn.Linear): - def forward(self, x: Tensor) -> Tensor: - w = self.weight.to(x.dtype) - bias = self.bias.to(x.dtype) if self.bias is not None else None - return F.linear(x, w, bias) - - -def restore_low_dim_params_to_fp32(module: nn.Module) -> None: - with torch.no_grad(): - for name, param in module.named_parameters(): - if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: - param.data = param.data.float() - - -class Rotary(nn.Module): - def __init__(self, dim: int, base: float = 10000.0): - super().__init__() - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self._seq_len_cached = 0 - self._cos_cached: Tensor | None = None - self._sin_cached: Tensor | None = None - - def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: - if ( - self._cos_cached is None - or self._sin_cached is None - or self._seq_len_cached != seq_len - or self._cos_cached.device != device - ): - t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) - freqs = torch.outer(t, self.inv_freq.to(device)) - self._cos_cached = freqs.cos()[None, None, :, :] - self._sin_cached = freqs.sin()[None, None, :, :] - self._seq_len_cached = seq_len - return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) - - -def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: - half = x.size(-1) // 2 - x1, x2 = x[..., :half], x[..., half:] - return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) - - -class CausalSelfAttention(nn.Module): - def __init__(self, dim: int, num_heads: int, num_kv_heads: int, rope_base: float, qk_gain_init: float): - super().__init__() - if dim % num_heads != 0: - raise ValueError("model_dim must be divisible by num_heads") - if num_heads % num_kv_heads != 0: - raise ValueError("num_heads must be divisible by num_kv_heads") - self.num_heads = num_heads - self.num_kv_heads = num_kv_heads - self.head_dim = dim // num_heads - if self.head_dim % 2 != 0: - raise ValueError("head_dim must be even for RoPE") - kv_dim = self.num_kv_heads * self.head_dim - self.c_q = CastedLinear(dim, dim, bias=False) - self.c_k = CastedLinear(dim, kv_dim, bias=False) - self.c_v = CastedLinear(dim, kv_dim, bias=False) - self.proj = CastedLinear(dim, dim, bias=False) - self.proj._zero_init = True - self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) - self.rotary = Rotary(self.head_dim, base=rope_base) - - def forward(self, x: Tensor) -> Tensor: - bsz, seqlen, dim = x.shape - q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) - k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) - v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) - q = F.rms_norm(q, (q.size(-1),)) - k = F.rms_norm(k, (k.size(-1),)) - cos, sin = self.rotary(seqlen, x.device, q.dtype) - q = apply_rotary_emb(q, cos, sin) - k = apply_rotary_emb(k, cos, sin) - q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] - y = F.scaled_dot_product_attention( - q, k, v, attn_mask=None, is_causal=True, - enable_gqa=(self.num_kv_heads != self.num_heads), - ) - y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) - return self.proj(y) - - -class MLP(nn.Module): - def __init__(self, dim: int, mlp_mult: float): - super().__init__() - hidden = int(mlp_mult * dim) - self.fc = CastedLinear(dim, hidden, bias=False) - self.proj = CastedLinear(hidden, dim, bias=False) - self.proj._zero_init = True - - def forward(self, x: Tensor) -> Tensor: - x = torch.relu(self.fc(x)) - return self.proj(x.square()) - - -class SmearGate(nn.Module): - """Blend each token's embedding with the previous token's embedding.""" - def __init__(self, dim: int): - super().__init__() - self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) - - def forward(self, x: Tensor) -> Tensor: - g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] - x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) - return (1 - g) * x + g * x_prev - - -class BigramHashEmbedding(nn.Module): - """Hash consecutive token pairs into a learned embedding table.""" - def __init__(self, bigram_vocab_size: int, bigram_dim: int, model_dim: int): - super().__init__() - self.bigram_vocab_size = bigram_vocab_size - self.embed = nn.Embedding(bigram_vocab_size, bigram_dim) - nn.init.zeros_(self.embed.weight) - self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None - if self.proj is not None: - nn.init.zeros_(self.proj.weight) - self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) - - def bigram_hash(self, tokens: Tensor) -> Tensor: - t = tokens.to(torch.int32) - mod = self.bigram_vocab_size - 1 - out = torch.empty_like(t) - out[..., 0] = mod - out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod - return out.long() - - def forward(self, token_ids: Tensor) -> Tensor: - h = self.embed(self.bigram_hash(token_ids)) - if self.proj is not None: - h = self.proj(h) - return h * self.scale.to(dtype=h.dtype) - - -class Block(nn.Module): - def __init__(self, dim: int, num_heads: int, num_kv_heads: int, mlp_mult: float, rope_base: float, qk_gain_init: float): - super().__init__() - self.attn_norm = RMSNorm() - self.mlp_norm = RMSNorm() - self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) - self.mlp = MLP(dim, mlp_mult) - self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) - - def forward(self, x: Tensor, x0: Tensor) -> Tensor: - mix = self.resid_mix.to(dtype=x.dtype) - x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 - attn_out = self.attn(self.attn_norm(x)) - x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out - x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) - return x - - -class GPT(nn.Module): - def __init__( - self, - vocab_size: int, - num_layers: int, - model_dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: float, - tie_embeddings: bool, - tied_embed_init_std: float, - logit_softcap: float, - rope_base: float, - qk_gain_init: float, - bigram_vocab_size: int = 0, - bigram_dim: int = 128, - ): - super().__init__() - if logit_softcap <= 0.0: - raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") - self.tie_embeddings = tie_embeddings - self.tied_embed_init_std = tied_embed_init_std - self.logit_softcap = logit_softcap - self.tok_emb = nn.Embedding(vocab_size, model_dim) - self.bigram = BigramHashEmbedding(bigram_vocab_size, bigram_dim, model_dim) if bigram_vocab_size > 0 else None - self.num_encoder_layers = num_layers // 2 - self.num_decoder_layers = num_layers - self.num_encoder_layers - self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) - self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) - self.smear = SmearGate(model_dim) - self.blocks = nn.ModuleList( - [ - Block(model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init) - for _ in range(num_layers) - ] - ) - self.final_norm = RMSNorm() - self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) - if self.lm_head is not None: - self.lm_head._zero_init = True - self._init_weights() - - def _init_weights(self) -> None: - if self.tie_embeddings: - nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) - num_layers = len(self.blocks) - for name, module in self.named_modules(): - if isinstance(module, nn.Linear): - if getattr(module, "_zero_init", False): - nn.init.zeros_(module.weight) - elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: - nn.init.orthogonal_(module.weight, gain=1.0) - if ".proj." in name or name.endswith(".proj"): - with torch.no_grad(): - module.weight.mul_(1.0 / math.sqrt(2 * num_layers)) - - def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: - x = self.tok_emb(input_ids) - if self.bigram is not None: - x = x + self.bigram(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - x = self.smear(x) - x0 = x - skips: list[Tensor] = [] - for i in range(self.num_encoder_layers): - x = self.blocks[i](x, x0) - skips.append(x) - for i in range(self.num_decoder_layers): - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - x = self.blocks[self.num_encoder_layers + i](x, x0) - x = self.final_norm(x).reshape(-1, x.size(-1)) - targets = target_ids.reshape(-1) - if self.tie_embeddings: - logits_proj = F.linear(x, self.tok_emb.weight) - else: - if self.lm_head is None: - raise RuntimeError("lm_head is required when tie_embeddings=False") - logits_proj = self.lm_head(x) - logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - return F.cross_entropy(logits.float(), targets, reduction="mean") - - def forward_logits(self, input_ids: Tensor) -> Tensor: - x = self.tok_emb(input_ids) - if self.bigram is not None: - x = x + self.bigram(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - x = self.smear(x) - x0 = x - skips: list[Tensor] = [] - for i in range(self.num_encoder_layers): - x = self.blocks[i](x, x0) - skips.append(x) - for i in range(self.num_decoder_layers): - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - x = self.blocks[self.num_encoder_layers + i](x, x0) - x = self.final_norm(x) - if self.tie_embeddings: - logits_proj = F.linear(x, self.tok_emb.weight) - else: - logits_proj = self.lm_head(x) - return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - - -def eval_val_sliding( - args: Hyperparameters, - base_model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, - stride: int, - batch_seqs: int = 32, -) -> tuple[float, float]: - seq_len = args.train_seq_len - total_tokens = val_tokens.numel() - 1 - window_starts = [ws for ws in range(0, total_tokens, stride) - if min(ws + seq_len, total_tokens) - ws >= stride or ws == 0] - total_windows = len(window_starts) - my_s = (total_windows * rank) // world_size - my_e = (total_windows * (rank + 1)) // world_size - my_windows = window_starts[my_s:my_e] - - loss_sum = torch.zeros((), device=device, dtype=torch.float64) - token_count = torch.zeros((), device=device, dtype=torch.float64) - byte_count = torch.zeros((), device=device, dtype=torch.float64) - - base_model.eval() - with torch.inference_mode(): - for bi in range(0, len(my_windows), batch_seqs): - batch_ws = my_windows[bi:bi + batch_seqs] - bsz = len(batch_ws) - x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) - y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) - wlens: list[int] = [] - for i, ws in enumerate(batch_ws): - end = min(ws + seq_len, total_tokens) - wlen = end - ws - wlens.append(wlen) - chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) - x_batch[i, :wlen] = chunk[:-1] - y_batch[i, :wlen] = chunk[1:] - with torch.autocast(device_type="cuda", dtype=torch.bfloat16): - logits = base_model.forward_logits(x_batch) - nll = F.cross_entropy( - logits.reshape(-1, logits.size(-1)).float(), - y_batch.reshape(-1), - reduction="none", - ).reshape(bsz, seq_len) - for i, ws in enumerate(batch_ws): - wlen = wlens[i] - s = 0 if ws == 0 else max(wlen - stride, 0) - scored_nll = nll[i, s:wlen].to(torch.float64) - loss_sum += scored_nll.sum() - token_count += float(wlen - s) - tgt = y_batch[i, s:wlen] - prev = x_batch[i, s:wlen] - tb = base_bytes_lut[tgt].to(torch.float64) - tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) - byte_count += tb.sum() - if rank == 0 and (bi // batch_seqs) % 50 == 0: - done = min(bi + batch_seqs, len(my_windows)) - pct = done / len(my_windows) * 100 - running_bpb = 0.0 - if token_count.item() > 0: - rl = (loss_sum / token_count).item() - running_bpb = rl / math.log(2.0) * (token_count.item() / byte_count.item()) - print(f" sliding_eval [{pct:5.1f}%] {done}/{len(my_windows)} windows running_bpb={running_bpb:.6f}", flush=True) - - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) - - val_loss = (loss_sum / token_count).item() - bits_per_token = val_loss / math.log(2.0) - tokens_per_byte = token_count.item() / byte_count.item() - base_model.train() - return val_loss, bits_per_token * tokens_per_byte - - -# ----------------------------- -# TRAINING -# ----------------------------- - -def main() -> None: - global zeropower_via_newtonschulz5 - - code = Path(__file__).read_text(encoding="utf-8") - args = Hyperparameters() - zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) - - distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ - rank = int(os.environ.get("RANK", "0")) - world_size = int(os.environ.get("WORLD_SIZE", "1")) - local_rank = int(os.environ.get("LOCAL_RANK", "0")) - if world_size <= 0: - raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") - if 8 % world_size != 0: - raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") - grad_accum_steps = 8 // world_size - grad_scale = 1.0 / grad_accum_steps - if not torch.cuda.is_available(): - raise RuntimeError("CUDA is required") - device = torch.device("cuda", local_rank) - torch.cuda.set_device(device) - if distributed: - dist.init_process_group(backend="nccl", device_id=device) - dist.barrier() - master_process = rank == 0 - - torch.backends.cuda.matmul.allow_tf32 = True - torch.backends.cudnn.allow_tf32 = True - from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp - enable_cudnn_sdp(False) - enable_flash_sdp(True) - enable_mem_efficient_sdp(False) - enable_math_sdp(False) - - logfile = None - if master_process: - os.makedirs("logs", exist_ok=True) - logfile = f"logs/{args.run_id}.txt" - print(logfile) - - def log0(msg: str, console: bool = True) -> None: - if not master_process: - return - if console: - print(msg) - if logfile is not None: - with open(logfile, "a", encoding="utf-8") as f: - print(msg, file=f) - - log0(code, console=False) - log0("=" * 100, console=False) - log0(f"Running Python {sys.version}", console=False) - log0(f"Running PyTorch {torch.__version__}", console=False) - log0( - subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, - console=False, - ) - log0("=" * 100, console=False) - - random.seed(args.seed) - np.random.seed(args.seed) - torch.manual_seed(args.seed) - torch.cuda.manual_seed_all(args.seed) - - if not args.tokenizer_path.endswith(".model"): - raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") - sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) - if int(sp.vocab_size()) != args.vocab_size: - raise ValueError( - f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" - ) - dataset_dir = Path(args.data_path).resolve() - actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) - val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) - base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( - sp, args.vocab_size, device - ) - log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") - log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") - log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") - - # MODEL + OPTIMIZER SETUP - base_model = GPT( - vocab_size=args.vocab_size, - num_layers=args.num_layers, - model_dim=args.model_dim, - num_heads=args.num_heads, - num_kv_heads=args.num_kv_heads, - mlp_mult=args.mlp_mult, - tie_embeddings=args.tie_embeddings, - tied_embed_init_std=args.tied_embed_init_std, - logit_softcap=args.logit_softcap, - rope_base=args.rope_base, - qk_gain_init=args.qk_gain_init, - bigram_vocab_size=args.bigram_vocab_size, - bigram_dim=args.bigram_dim, - ).to(device).bfloat16() - for module in base_model.modules(): - if isinstance(module, CastedLinear): - module.float() - restore_low_dim_params_to_fp32(base_model) - compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) - model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model - - block_named_params = list(base_model.blocks.named_parameters()) - matrix_params = [ - p for name, p in block_named_params - if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - scalar_params = [ - p for name, p in block_named_params - if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - if base_model.skip_weights.numel() > 0: - scalar_params.append(base_model.skip_weights) - scalar_params.append(base_model.smear.gate) - if base_model.bigram is not None: - scalar_params.append(base_model.bigram.scale) - - token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr - tok_params = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] - if base_model.bigram is not None: - tok_params.append({"params": [base_model.bigram.embed.weight], "lr": token_lr, "base_lr": token_lr}) - if base_model.bigram.proj is not None: - matrix_params.append(base_model.bigram.proj.weight) - - optimizer_tok = torch.optim.AdamW( - tok_params, - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - weight_decay=args.weight_decay, - fused=True, - ) - optimizer_muon = Muon( - matrix_params, - lr=args.matrix_lr, - momentum=args.muon_momentum, - backend_steps=args.muon_backend_steps, - weight_decay=0.04, - ) - for group in optimizer_muon.param_groups: - group["base_lr"] = args.matrix_lr - optimizer_scalar = torch.optim.AdamW( - [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - weight_decay=args.weight_decay, - fused=True, - ) - optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] - if base_model.lm_head is not None: - optimizer_head = torch.optim.Adam( - [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizers.insert(1, optimizer_head) - - n_params = sum(p.numel() for p in base_model.parameters()) - log0(f"model_params:{n_params}") - log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") - log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") - log0( - f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " - f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" - ) - log0( - f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " - f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " - f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" - ) - log0(f"seed:{args.seed}") - - # DATA LOADER & MODEL WARMUP - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - def zero_grad_all() -> None: - for opt in optimizers: - opt.zero_grad(set_to_none=True) - - max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None - - def lr_mul(step: int, elapsed_ms: float) -> float: - if args.warmdown_iters <= 0: - return 1.0 - if max_wallclock_ms is None: - warmdown_start = max(args.iterations - args.warmdown_iters, 0) - return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 - step_ms = elapsed_ms / max(step, 1) - warmdown_ms = args.warmdown_iters * step_ms - remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) - return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 - - if args.warmup_steps > 0: - initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} - initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] - model.train() - for warmup_step in range(args.warmup_steps): - zero_grad_all() - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - warmup_loss = model(x, y) - (warmup_loss * grad_scale).backward() - for opt in optimizers: - opt.step() - zero_grad_all() - if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: - log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") - base_model.load_state_dict(initial_model_state, strict=True) - for opt, state in zip(optimizers, initial_optimizer_states, strict=True): - opt.load_state_dict(state) - zero_grad_all() - if distributed: - model.require_backward_grad_sync = True - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - # MAIN TRAINING LOOP - training_time_ms = 0.0 - stop_after_step: int | None = None - swa_state: dict[str, Tensor] | None = None - swa_count = 0 - torch.cuda.synchronize() - t0 = time.perf_counter() - - step = 0 - while True: - last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) - - should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) - if should_validate: - torch.cuda.synchronize() - training_time_ms += 1000.0 * (time.perf_counter() - t0) - val_loss, val_bpb = eval_val( - args, model, rank, world_size, device, grad_accum_steps, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - ) - log0( - f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " - f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" - ) - torch.cuda.synchronize() - t0 = time.perf_counter() - - if last_step: - if stop_after_step is not None and step < args.iterations: - log0( - f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " - f"step:{step}/{args.iterations}" - ) - break - - elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - scale = lr_mul(step, elapsed_ms) - zero_grad_all() - train_loss = torch.zeros((), device=device) - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - loss = model(x, y) - train_loss += loss.detach() - (loss * grad_scale).backward() - train_loss /= grad_accum_steps - - frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 - muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum - for group in optimizer_muon.param_groups: - group["momentum"] = muon_momentum - - for opt in optimizers: - for group in opt.param_groups: - group["lr"] = group["base_lr"] * scale - - if args.grad_clip_norm > 0: - torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) - for opt in optimizers: - opt.step() - zero_grad_all() - - step += 1 - approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - - # SWA: collect checkpoints during warmdown - if args.swa_enabled and scale < args.swa_start_frac and step % args.swa_every == 0: - if swa_state is None: - swa_state = {name: t.detach().cpu().clone() for name, t in base_model.state_dict().items()} - swa_count = 1 - log0(f"swa:start step:{step}") - else: - for name, t in base_model.state_dict().items(): - swa_state[name] += t.detach().cpu() - swa_count += 1 - - should_log_train = ( - args.train_log_every > 0 - and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) - ) - if should_log_train: - log0( - f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " - f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" - ) - - reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms - if distributed and max_wallclock_ms is not None: - reached_cap_tensor = torch.tensor(int(reached_cap), device=device) - dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) - reached_cap = bool(reached_cap_tensor.item()) - if stop_after_step is None and reached_cap: - stop_after_step = step - - log0( - f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " - f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" - ) - - # Apply SWA if collected - if args.swa_enabled and swa_state is not None and swa_count > 1: - log0(f"swa:applying averaged {swa_count} checkpoints") - current_state = base_model.state_dict() - avg_state = { - name: (tensor / swa_count).to(dtype=current_state[name].dtype) - for name, tensor in swa_state.items() - } - base_model.load_state_dict(avg_state, strict=True) - - # SERIALIZATION + ROUNDTRIP VALIDATION - if master_process: - torch.save(base_model.state_dict(), "final_model.pt") - model_bytes = os.path.getsize("final_model.pt") - code_bytes = len(code.encode("utf-8")) - log0(f"Serialized model: {model_bytes} bytes") - log0(f"Code size: {code_bytes} bytes") - log0(f"Total submission size: {model_bytes + code_bytes} bytes") - - # Magnitude pruning: zero out smallest weights to improve compression - with torch.no_grad(): - for name, param in base_model.named_parameters(): - if param.ndim == 2 and param.numel() > 65536: - threshold = torch.quantile(param.abs().float().flatten(), 0.03) - mask = param.abs() < threshold - param.masked_fill_(mask, 0.0) - - # INT6 mixed quantization + zstd/zlib export - sd_cpu = {k: v.detach().cpu() for k, v in base_model.state_dict().items()} - quant_result, quant_meta = mixed_quantize_int6(sd_cpu, {"mlp", "attn", "bigram"}) - quant_buf = io.BytesIO() - torch.save({"w": quant_result, "m": quant_meta}, quant_buf) - quant_raw = quant_buf.getvalue() - if _COMPRESSOR == "zstd": - quant_blob = zstandard.ZstdCompressor(level=22).compress(quant_raw) - else: - quant_blob = zlib.compress(quant_raw, 9) - if master_process: - with open("final_model.int8.ptz", "wb") as f: - f.write(quant_blob) - quant_file_bytes = os.path.getsize("final_model.int8.ptz") - code_bytes = len(code.encode("utf-8")) - log0(f"Serialized model int6+{_COMPRESSOR}: {quant_file_bytes} bytes") - log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") - - if distributed: - dist.barrier() - with open("final_model.int8.ptz", "rb") as f: - quant_blob_disk = f.read() - if _COMPRESSOR == "zstd": - decompressed = zstandard.ZstdDecompressor().decompress(quant_blob_disk) - else: - decompressed = zlib.decompress(quant_blob_disk) - quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") - deq_state = dequantize_mixed_int6(quant_state["w"], quant_state["m"], sd_cpu) - base_model.load_state_dict(deq_state, strict=True) - - # Sliding window eval on int6-roundtripped weights - torch.cuda.synchronize() - t_qeval = time.perf_counter() - if args.eval_stride > 0 and args.eval_stride < args.train_seq_len: - log0(f"final_eval_mode:sliding_window stride:{args.eval_stride} batch_seqs:{args.eval_batch_seqs}") - q_val_loss, q_val_bpb = eval_val_sliding( - args, base_model, rank, world_size, device, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - stride=args.eval_stride, batch_seqs=args.eval_batch_seqs, - ) - else: - log0("final_eval_mode:standard") - q_val_loss, q_val_bpb = eval_val( - args, model, rank, world_size, device, grad_accum_steps, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - ) - torch.cuda.synchronize() - log0( - f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " - f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" - ) - log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") - - if distributed: - dist.destroy_process_group() - - -if __name__ == "__main__": - main() -# fixes applied -# tuned - -==================================================================================================== -Running Python 3.13.12 | packaged by Anaconda, Inc. | (main, Feb 24 2026, 16:13:31) [GCC 14.3.0] -Running PyTorch 2.10.0+cu128 -Fri Mar 20 09:15:32 2026 -+-----------------------------------------------------------------------------------------+ -| NVIDIA-SMI 590.48.01 Driver Version: 590.48.01 CUDA Version: 13.1 | -+-----------------------------------------+------------------------+----------------------+ -| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | -| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | -| | | MIG M. | -|=========================================+========================+======================| -| 0 NVIDIA H100 80GB HBM3 On | 00000000:19:00.0 Off | 0 | -| N/A 34C P0 118W / 700W | 1525MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 1 NVIDIA H100 80GB HBM3 On | 00000000:2D:00.0 Off | 0 | -| N/A 37C P0 123W / 700W | 1525MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 2 NVIDIA H100 80GB HBM3 On | 00000000:3F:00.0 Off | 0 | -| N/A 39C P0 124W / 700W | 1525MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 3 NVIDIA H100 80GB HBM3 On | 00000000:66:00.0 Off | 0 | -| N/A 33C P0 120W / 700W | 1525MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 4 NVIDIA H100 80GB HBM3 On | 00000000:9B:00.0 Off | 0 | -| N/A 34C P0 116W / 700W | 1525MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 5 NVIDIA H100 80GB HBM3 On | 00000000:AE:00.0 Off | 0 | -| N/A 39C P0 119W / 700W | 1525MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 6 NVIDIA H100 80GB HBM3 On | 00000000:BF:00.0 Off | 0 | -| N/A 37C P0 124W / 700W | 1525MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 7 NVIDIA H100 80GB HBM3 On | 00000000:E4:00.0 Off | 0 | -| N/A 33C P0 118W / 700W | 1525MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ - -+-----------------------------------------------------------------------------------------+ -| Processes: | -| GPU GI CI PID Type Process name GPU Memory | -| ID ID Usage | -|=========================================================================================| -| 0 N/A N/A 1401433 C ...hao/miniconda3/bin/python3.13 1516MiB | -| 1 N/A N/A 1401434 C ...hao/miniconda3/bin/python3.13 1516MiB | -| 2 N/A N/A 1401435 C ...hao/miniconda3/bin/python3.13 1516MiB | -| 3 N/A N/A 1401436 C ...hao/miniconda3/bin/python3.13 1516MiB | -| 4 N/A N/A 1401437 C ...hao/miniconda3/bin/python3.13 1516MiB | -| 5 N/A N/A 1401438 C ...hao/miniconda3/bin/python3.13 1516MiB | -| 6 N/A N/A 1401439 C ...hao/miniconda3/bin/python3.13 1516MiB | -| 7 N/A N/A 1401440 C ...hao/miniconda3/bin/python3.13 1516MiB | -+-----------------------------------------------------------------------------------------+ - -==================================================================================================== -val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model -train_loader:dataset:fineweb10B_sp1024 train_shards:80 -val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 -model_params:25517137 -world_size:8 grad_accum_steps:1 -attention_mode:gqa num_heads:8 num_kv_heads:4 -tie_embeddings:True embed_lr:0.03 matrix_lr:0.02 scalar_lr:0.02 -train_batch_tokens:786432 train_seq_len:2048 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 -seed:2024 -warmup_step:1/20 -warmup_step:2/20 -warmup_step:3/20 -warmup_step:4/20 -warmup_step:5/20 -warmup_step:6/20 -warmup_step:7/20 -warmup_step:8/20 -warmup_step:9/20 -warmup_step:10/20 -warmup_step:11/20 -warmup_step:12/20 -warmup_step:13/20 -warmup_step:14/20 -warmup_step:15/20 -warmup_step:16/20 -warmup_step:17/20 -warmup_step:18/20 -warmup_step:19/20 -warmup_step:20/20 -step:0/20000 val_loss:6.9327 val_bpb:4.1059 train_time:0ms step_avg:0.01ms -step:1/20000 train_loss:6.9341 train_time:131ms step_avg:130.80ms -step:2/20000 train_loss:8.1356 train_time:196ms step_avg:97.75ms -step:3/20000 train_loss:7.6976 train_time:282ms step_avg:93.95ms -step:4/20000 train_loss:6.9821 train_time:369ms step_avg:92.17ms -step:5/20000 train_loss:6.7578 train_time:457ms step_avg:91.30ms -step:6/20000 train_loss:6.6203 train_time:544ms step_avg:90.69ms -step:7/20000 train_loss:6.5157 train_time:632ms step_avg:90.25ms -step:8/20000 train_loss:6.5424 train_time:719ms step_avg:89.93ms -step:9/20000 train_loss:6.3044 train_time:807ms step_avg:89.67ms -step:10/20000 train_loss:6.0760 train_time:894ms step_avg:89.40ms -step:100/20000 train_loss:3.1623 train_time:8828ms step_avg:88.28ms -step:200/20000 train_loss:2.3732 train_time:17774ms step_avg:88.87ms -step:300/20000 train_loss:2.5410 train_time:26709ms step_avg:89.03ms -step:400/20000 train_loss:2.4028 train_time:35647ms step_avg:89.12ms -step:500/20000 train_loss:2.3913 train_time:44504ms step_avg:89.01ms -step:500/20000 val_loss:2.3539 val_bpb:1.3941 train_time:44527ms step_avg:89.05ms -step:600/20000 train_loss:2.3287 train_time:53442ms step_avg:89.07ms -step:700/20000 train_loss:2.3465 train_time:62396ms step_avg:89.14ms -step:800/20000 train_loss:2.2397 train_time:71351ms step_avg:89.19ms -step:900/20000 train_loss:2.1275 train_time:80303ms step_avg:89.23ms -step:1000/20000 train_loss:2.2699 train_time:89166ms step_avg:89.17ms -step:1000/20000 val_loss:2.2227 val_bpb:1.3164 train_time:89190ms step_avg:89.19ms -step:1100/20000 train_loss:2.3252 train_time:98110ms step_avg:89.19ms -step:1200/20000 train_loss:2.3519 train_time:107059ms step_avg:89.22ms -step:1300/20000 train_loss:2.0998 train_time:116014ms step_avg:89.24ms -step:1400/20000 train_loss:2.1813 train_time:124974ms step_avg:89.27ms -step:1500/20000 train_loss:2.2204 train_time:133836ms step_avg:89.22ms -step:1500/20000 val_loss:2.1828 val_bpb:1.2928 train_time:133860ms step_avg:89.24ms -step:1600/20000 train_loss:2.0735 train_time:142781ms step_avg:89.24ms -step:1700/20000 train_loss:2.1410 train_time:151731ms step_avg:89.25ms -step:1800/20000 train_loss:2.1571 train_time:160673ms step_avg:89.26ms -step:1900/20000 train_loss:2.1285 train_time:169532ms step_avg:89.23ms -step:2000/20000 train_loss:2.0697 train_time:178491ms step_avg:89.25ms -step:2000/20000 val_loss:2.1314 val_bpb:1.2623 train_time:178516ms step_avg:89.26ms -step:2100/20000 train_loss:2.0446 train_time:187440ms step_avg:89.26ms -step:2200/20000 train_loss:2.1392 train_time:196387ms step_avg:89.27ms -step:2300/20000 train_loss:2.1059 train_time:205340ms step_avg:89.28ms -step:2400/20000 train_loss:2.0699 train_time:214193ms step_avg:89.25ms -step:2500/20000 train_loss:2.1708 train_time:223135ms step_avg:89.25ms -step:2500/20000 val_loss:2.1068 val_bpb:1.2477 train_time:223159ms step_avg:89.26ms -step:2600/20000 train_loss:2.1094 train_time:232089ms step_avg:89.26ms -step:2700/20000 train_loss:2.1006 train_time:241051ms step_avg:89.28ms -step:2800/20000 train_loss:2.1570 train_time:249997ms step_avg:89.28ms -step:2900/20000 train_loss:2.0252 train_time:258847ms step_avg:89.26ms -step:3000/20000 train_loss:2.1603 train_time:267804ms step_avg:89.27ms -step:3000/20000 val_loss:2.0913 val_bpb:1.2386 train_time:267827ms step_avg:89.28ms -step:3100/20000 train_loss:2.0372 train_time:276764ms step_avg:89.28ms -step:3200/20000 train_loss:2.1747 train_time:285731ms step_avg:89.29ms -step:3300/20000 train_loss:2.0711 train_time:294595ms step_avg:89.27ms -step:3400/20000 train_loss:2.0218 train_time:303546ms step_avg:89.28ms -step:3500/20000 train_loss:2.1841 train_time:312483ms step_avg:89.28ms -step:3500/20000 val_loss:2.0835 val_bpb:1.2339 train_time:312508ms step_avg:89.29ms -step:3600/20000 train_loss:2.1019 train_time:321425ms step_avg:89.28ms -step:3700/20000 train_loss:2.0994 train_time:330379ms step_avg:89.29ms -step:3800/20000 train_loss:2.0746 train_time:339237ms step_avg:89.27ms -step:3900/20000 train_loss:2.0802 train_time:348187ms step_avg:89.28ms -step:4000/20000 train_loss:1.9782 train_time:357125ms step_avg:89.28ms -step:4000/20000 val_loss:2.0700 val_bpb:1.2260 train_time:357150ms step_avg:89.29ms -step:4100/20000 train_loss:2.0179 train_time:366073ms step_avg:89.29ms -step:4200/20000 train_loss:2.1571 train_time:375031ms step_avg:89.29ms -step:4300/20000 train_loss:2.0620 train_time:383875ms step_avg:89.27ms -step:4400/20000 train_loss:2.0357 train_time:392817ms step_avg:89.28ms -step:4500/20000 train_loss:2.1262 train_time:401756ms step_avg:89.28ms -step:4500/20000 val_loss:2.0470 val_bpb:1.2124 train_time:401780ms step_avg:89.28ms -step:4600/20000 train_loss:1.8455 train_time:410714ms step_avg:89.29ms -step:4700/20000 train_loss:2.2360 train_time:419566ms step_avg:89.27ms -step:4800/20000 train_loss:2.4315 train_time:428509ms step_avg:89.27ms -step:4900/20000 train_loss:2.0532 train_time:437462ms step_avg:89.28ms -step:5000/20000 train_loss:2.1049 train_time:446418ms step_avg:89.28ms -step:5000/20000 val_loss:2.0260 val_bpb:1.1999 train_time:446442ms step_avg:89.29ms -step:5100/20000 train_loss:2.1250 train_time:455362ms step_avg:89.29ms -step:5200/20000 train_loss:2.0462 train_time:464208ms step_avg:89.27ms -step:5300/20000 train_loss:2.0113 train_time:473152ms step_avg:89.27ms -step:5400/20000 train_loss:2.0500 train_time:482096ms step_avg:89.28ms -step:5500/20000 train_loss:2.0204 train_time:491045ms step_avg:89.28ms -step:5500/20000 val_loss:2.0035 val_bpb:1.1866 train_time:491069ms step_avg:89.29ms -swa:start step:5550 -step:5600/20000 train_loss:1.9568 train_time:500055ms step_avg:89.30ms -step:5700/20000 train_loss:2.0114 train_time:508967ms step_avg:89.29ms -step:5800/20000 train_loss:1.9999 train_time:517965ms step_avg:89.30ms -step:5900/20000 train_loss:1.9044 train_time:526979ms step_avg:89.32ms -step:6000/20000 train_loss:1.9416 train_time:535998ms step_avg:89.33ms -step:6000/20000 val_loss:1.9801 val_bpb:1.1727 train_time:536047ms step_avg:89.34ms -step:6100/20000 train_loss:1.9201 train_time:544896ms step_avg:89.33ms -step:6200/20000 train_loss:1.9478 train_time:553899ms step_avg:89.34ms -step:6300/20000 train_loss:1.9485 train_time:562912ms step_avg:89.35ms -step:6400/20000 train_loss:1.9963 train_time:571921ms step_avg:89.36ms -step:6500/20000 train_loss:2.0821 train_time:580925ms step_avg:89.37ms -step:6500/20000 val_loss:1.9528 val_bpb:1.1565 train_time:580987ms step_avg:89.38ms -step:6600/20000 train_loss:1.8417 train_time:589834ms step_avg:89.37ms -step:6700/20000 train_loss:1.9450 train_time:598834ms step_avg:89.38ms -step:6713/20000 val_loss:1.9458 val_bpb:1.1524 train_time:600034ms step_avg:89.38ms -stopping_early: wallclock_cap train_time:600034ms step:6713/20000 -peak memory allocated: 18871 MiB reserved: 18976 MiB -swa:applying averaged 24 checkpoints -Serialized model: 98437419 bytes -Code size: 52930 bytes -Total submission size: 98490349 bytes -Serialized model int6+zstd: 15631946 bytes -Total submission size int8+zlib: 15684876 bytes -final_eval_mode:sliding_window stride:64 batch_seqs:32 -final_int8_zlib_roundtrip val_loss:1.9292 val_bpb:1.1426 eval_time:168368ms -final_int8_zlib_roundtrip_exact val_loss:1.92923337 val_bpb:1.14260369 diff --git a/records/track_10min_16mb/2026-03-20_10L_Int5MLP_MuonWD04_SWA50/train_seed42.log b/records/track_10min_16mb/2026-03-20_10L_Int5MLP_MuonWD04_SWA50/train_seed42.log deleted file mode 100644 index 4af55b7710..0000000000 --- a/records/track_10min_16mb/2026-03-20_10L_Int5MLP_MuonWD04_SWA50/train_seed42.log +++ /dev/null @@ -1,1425 +0,0 @@ -""" -The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. - -Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. -""" - -from __future__ import annotations - -import copy -import glob -import io -import math -import os -import random -import subprocess -import sys -import time -import uuid -import zlib -from pathlib import Path - -try: - import zstandard - _COMPRESSOR = "zstd" -except ImportError: - _COMPRESSOR = "zlib" - -import numpy as np -import sentencepiece as spm -import torch -import torch.distributed as dist -import torch.nn.functional as F -from torch import Tensor, nn -from torch.nn.parallel import DistributedDataParallel as DDP - -# ----------------------------- -# HYPERPARAMETERS -# ----------------------------- - -class Hyperparameters: - data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") - train_files = os.path.join(data_path, "fineweb_train_*.bin") - val_files = os.path.join(data_path, "fineweb_val_*.bin") - tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") - run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) - seed = int(os.environ.get("SEED", 42)) - - val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) - val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 500)) - train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 100)) - - iterations = int(os.environ.get("ITERATIONS", 20000)) - warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3000)) - warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) - train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 786_432)) - train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) - max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) - qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) - - vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) - num_layers = int(os.environ.get("NUM_LAYERS", 10)) - num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) - model_dim = int(os.environ.get("MODEL_DIM", 512)) - num_heads = int(os.environ.get("NUM_HEADS", 8)) - mlp_mult = float(os.environ.get("MLP_MULT", 3.0)) - tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) - rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) - logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) - - embed_lr = float(os.environ.get("EMBED_LR", 0.6)) - head_lr = float(os.environ.get("HEAD_LR", 0.008)) - tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.03)) - tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) - matrix_lr = float(os.environ.get("MATRIX_LR", 0.02)) - scalar_lr = float(os.environ.get("SCALAR_LR", 0.02)) - muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) - muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) - muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) - muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) - beta1 = float(os.environ.get("BETA1", 0.9)) - beta2 = float(os.environ.get("BETA2", 0.95)) - adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) - grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) - weight_decay = float(os.environ.get("WEIGHT_DECAY", 0.04)) - - eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) - eval_batch_seqs = int(os.environ.get("EVAL_BATCH_SEQS", 32)) - - bigram_vocab_size = int(os.environ.get("BIGRAM_VOCAB_SIZE", 10240)) - bigram_dim = int(os.environ.get("BIGRAM_DIM", 128)) - - swa_enabled = bool(int(os.environ.get("SWA_ENABLED", "1"))) - swa_start_frac = float(os.environ.get("SWA_START_FRAC", 0.4)) - swa_every = int(os.environ.get("SWA_EVERY", 50)) - -# ----------------------------- -# MUON OPTIMIZER -# ----------------------------- - -def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: - a, b, c = (3.4445, -4.7750, 2.0315) - X = G.bfloat16() - X /= X.norm() + eps - transposed = G.size(0) > G.size(1) - if transposed: - X = X.T - for _ in range(steps): - A = X @ X.T - B = b * A + c * A @ A - X = a * X + B @ X - return X.T if transposed else X - - -class Muon(torch.optim.Optimizer): - def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True, weight_decay: float = 0.0): - super().__init__( - params, - dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov, weight_decay=weight_decay), - ) - - @torch.no_grad() - def step(self, closure=None): - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - distributed = dist.is_available() and dist.is_initialized() - world_size = dist.get_world_size() if distributed else 1 - rank = dist.get_rank() if distributed else 0 - - for group in self.param_groups: - params = group["params"] - if not params: - continue - lr = group["lr"] - momentum = group["momentum"] - backend_steps = group["backend_steps"] - nesterov = group["nesterov"] - - total_params = sum(int(p.numel()) for p in params) - updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) - - curr = 0 - for i, p in enumerate(params): - if i % world_size == rank and p.grad is not None: - g = p.grad - state = self.state[p] - if "momentum_buffer" not in state: - state["momentum_buffer"] = torch.zeros_like(g) - buf = state["momentum_buffer"] - buf.mul_(momentum).add_(g) - if nesterov: - g = g.add(buf, alpha=momentum) - g = zeropower_via_newtonschulz5(g, steps=backend_steps) - g *= max(1, g.size(0) / g.size(1)) ** 0.5 - updates_flat[curr : curr + p.numel()] = g.reshape(-1) - curr += p.numel() - - if distributed: - dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) - - wd = group.get("weight_decay", 0.0) - curr = 0 - for p in params: - g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) - if wd > 0: - p.data.mul_(1.0 - lr * wd) - p.add_(g, alpha=-lr) - curr += p.numel() - return loss - - -# ----------------------------- -# TOKENIZER-AGNOSTIC EVALUATION -# ----------------------------- - -def build_sentencepiece_luts( - sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device -) -> tuple[Tensor, Tensor, Tensor]: - sp_vocab_size = int(sp.vocab_size()) - table_size = max(sp_vocab_size, vocab_size) - base_bytes_np = np.zeros((table_size,), dtype=np.int16) - has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) - is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) - for token_id in range(sp_vocab_size): - if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): - continue - is_boundary_token_np[token_id] = False - if sp.is_byte(token_id): - base_bytes_np[token_id] = 1 - continue - piece = sp.id_to_piece(token_id) - if piece.startswith("\u2581"): - has_leading_space_np[token_id] = True - piece = piece[1:] - base_bytes_np[token_id] = len(piece.encode("utf-8")) - return ( - torch.tensor(base_bytes_np, dtype=torch.int16, device=device), - torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), - torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), - ) - - -def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: - files = [Path(p) for p in sorted(glob.glob(pattern))] - if not files: - raise FileNotFoundError(f"No files found for pattern: {pattern}") - tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() - usable = ((tokens.numel() - 1) // seq_len) * seq_len - if usable <= 0: - raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") - return tokens[: usable + 1] - - -def eval_val( - args: Hyperparameters, - model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - grad_accum_steps: int, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, -) -> tuple[float, float]: - local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) - if local_batch_tokens < args.train_seq_len: - raise ValueError( - "VAL_BATCH_SIZE must provide at least one sequence per rank; " - f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " - f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" - ) - local_batch_seqs = local_batch_tokens // args.train_seq_len - total_seqs = (val_tokens.numel() - 1) // args.train_seq_len - seq_start = (total_seqs * rank) // world_size - seq_end = (total_seqs * (rank + 1)) // world_size - val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) - val_token_count = torch.zeros((), device=device, dtype=torch.float64) - val_byte_count = torch.zeros((), device=device, dtype=torch.float64) - model.eval() - with torch.inference_mode(): - for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): - batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) - raw_start = batch_seq_start * args.train_seq_len - raw_end = batch_seq_end * args.train_seq_len + 1 - local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) - x = local[:-1].reshape(-1, args.train_seq_len) - y = local[1:].reshape(-1, args.train_seq_len) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - batch_loss = model(x, y).detach() - batch_token_count = float(y.numel()) - val_loss_sum += batch_loss.to(torch.float64) * batch_token_count - val_token_count += batch_token_count - prev_ids = x.reshape(-1) - tgt_ids = y.reshape(-1) - token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) - token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) - val_byte_count += token_bytes.to(torch.float64).sum() - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) - val_loss = val_loss_sum / val_token_count - bits_per_token = val_loss.item() / math.log(2.0) - tokens_per_byte = val_token_count.item() / val_byte_count.item() - model.train() - return float(val_loss.item()), float(bits_per_token * tokens_per_byte) - - -# ----------------------------- -# POST-TRAINING QUANTIZATION (INT8 legacy + INT6 mixed) -# ----------------------------- - -CONTROL_TENSOR_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "CONTROL_TENSOR_NAME_PATTERNS", - "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,smear,bigram.scale", - ).split(",") - if pattern -) -FP16_KEEP_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get("FP16_KEEP_NAME_PATTERNS", "tok_emb,blocks.8.attn.c_k").split(",") - if pattern -) -INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", - ",".join(CONTROL_TENSOR_NAME_PATTERNS), - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 -INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 -INT8_PER_ROW_SCALE_DTYPE = torch.float16 -INT8_CLIP_PERCENTILE = 99.99984 -INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 - -def tensor_nbytes(t: Tensor) -> int: - return int(t.numel()) * int(t.element_size()) - -def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: - t32 = t.float() - if t32.ndim == 2: - clip_abs = ( - torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) - if t32.numel() - else torch.empty((t32.shape[0],), dtype=torch.float32) - ) - clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) - scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) - q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() - return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() - clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 - scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) - q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() - return q, scale - - -def _classify_param(name: str) -> str: - if "tok_emb" in name or "lm_head" in name: - return "embed" - if ".mlp." in name: - return "mlp" - if "bigram" in name: - return "bigram" - if ".attn." in name or (".proj." in name and ".mlp." not in name): - return "attn" - return "other" - -def quantize_intN_per_row(t: Tensor, clip_range: int = 31) -> tuple[Tensor, Tensor]: - t32 = t.float() - if t32.ndim == 2: - row_max = t32.abs().amax(dim=1) - scale = (row_max / clip_range).clamp_min(1e-12).to(torch.float16) - scale = scale.clamp_min(torch.finfo(torch.float16).tiny) - q = torch.clamp(torch.round(t32 / scale.float()[:, None]), -(clip_range+1), clip_range).to(torch.int8) - return q, scale - amax = t32.abs().max().item() - scale = torch.tensor(max(amax / clip_range, 1e-12), dtype=torch.float16) - q = torch.clamp(torch.round(t32 / scale.float()), -(clip_range+1), clip_range).to(torch.int8) - return q, scale - -def mixed_quantize_int6(state_dict: dict[str, Tensor], int6_cats: set[str]): - result: dict[str, Tensor] = {} - meta: dict[str, object] = {} - for name, tensor in state_dict.items(): - t = tensor.detach().cpu().contiguous() - cat = _classify_param(name) - if not t.is_floating_point() or t.numel() <= 8192: - result[name] = t.to(torch.float16) if t.is_floating_point() else t - meta[name] = "passthrough" - continue - if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): - result[name] = t.float() - meta[name] = "passthrough_ctrl" - continue - if any(pattern in name for pattern in FP16_KEEP_NAME_PATTERNS): - result[name] = t.to(dtype=torch.float16).contiguous() - meta[name] = "passthrough_fp16" - continue - if cat in int6_cats and t.ndim >= 1: - clip = 15 if cat == "mlp" else 31 # int5 for MLP, int6 for attention - q, s = quantize_intN_per_row(t, clip_range=clip) - result[name + ".q"] = q - result[name + ".scale"] = s - meta[name] = {"type": f"int{5 if cat == 'mlp' else 6}"} - else: - q, s = quantize_float_tensor(t) - result[name + ".q"] = q - result[name + ".scale"] = s - meta[name] = {"type": "int8"} - return result, meta - -def dequantize_mixed_int6(result: dict[str, Tensor], meta: dict[str, object], - template_sd: dict[str, Tensor]) -> dict[str, Tensor]: - out: dict[str, Tensor] = {} - for name, orig in template_sd.items(): - info = meta[name] - orig_dtype = orig.dtype - if info in ("passthrough", "passthrough_ctrl", "passthrough_fp16"): - t = result[name] - if t.dtype == torch.float16 and orig_dtype in (torch.float32, torch.bfloat16): - t = t.to(orig_dtype) - out[name] = t - continue - q, s = result[name + ".q"], result[name + ".scale"] - if s.ndim > 0: - out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig_dtype) - else: - out[name] = (q.float() * float(s.item())).to(orig_dtype) - return out - - -# ----------------------------- -# DATA LOADING -# ----------------------------- - -def load_data_shard(file: Path) -> Tensor: - header_bytes = 256 * np.dtype(" None: - self.file_idx = (self.file_idx + 1) % len(self.files) - self.tokens = load_data_shard(self.files[self.file_idx]) - self.pos = 0 - - def take(self, n: int) -> Tensor: - chunks: list[Tensor] = [] - remaining = n - while remaining > 0: - avail = self.tokens.numel() - self.pos - if avail <= 0: - self._advance_file() - continue - k = min(remaining, avail) - chunks.append(self.tokens[self.pos : self.pos + k]) - self.pos += k - remaining -= k - return chunks[0] if len(chunks) == 1 else torch.cat(chunks) - - -class DistributedTokenLoader: - def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): - self.rank = rank - self.world_size = world_size - self.device = device - self.stream = TokenStream(pattern) - - def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: - local_tokens = global_tokens // (self.world_size * grad_accum_steps) - per_rank_span = local_tokens + 1 - chunk = self.stream.take(per_rank_span * self.world_size) - start = self.rank * per_rank_span - local = chunk[start : start + per_rank_span].to(dtype=torch.int64) - x = local[:-1].reshape(-1, seq_len) - y = local[1:].reshape(-1, seq_len) - return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) - - -# ----------------------------- -# TRANSFORMER MODULES -# ----------------------------- - -class RMSNorm(nn.Module): - def __init__(self, eps: float | None = None): - super().__init__() - self.eps = eps - - def forward(self, x: Tensor) -> Tensor: - return F.rms_norm(x, (x.size(-1),), eps=self.eps) - - -class CastedLinear(nn.Linear): - def forward(self, x: Tensor) -> Tensor: - w = self.weight.to(x.dtype) - bias = self.bias.to(x.dtype) if self.bias is not None else None - return F.linear(x, w, bias) - - -def restore_low_dim_params_to_fp32(module: nn.Module) -> None: - with torch.no_grad(): - for name, param in module.named_parameters(): - if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: - param.data = param.data.float() - - -class Rotary(nn.Module): - def __init__(self, dim: int, base: float = 10000.0): - super().__init__() - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self._seq_len_cached = 0 - self._cos_cached: Tensor | None = None - self._sin_cached: Tensor | None = None - - def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: - if ( - self._cos_cached is None - or self._sin_cached is None - or self._seq_len_cached != seq_len - or self._cos_cached.device != device - ): - t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) - freqs = torch.outer(t, self.inv_freq.to(device)) - self._cos_cached = freqs.cos()[None, None, :, :] - self._sin_cached = freqs.sin()[None, None, :, :] - self._seq_len_cached = seq_len - return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) - - -def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: - half = x.size(-1) // 2 - x1, x2 = x[..., :half], x[..., half:] - return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) - - -class CausalSelfAttention(nn.Module): - def __init__(self, dim: int, num_heads: int, num_kv_heads: int, rope_base: float, qk_gain_init: float): - super().__init__() - if dim % num_heads != 0: - raise ValueError("model_dim must be divisible by num_heads") - if num_heads % num_kv_heads != 0: - raise ValueError("num_heads must be divisible by num_kv_heads") - self.num_heads = num_heads - self.num_kv_heads = num_kv_heads - self.head_dim = dim // num_heads - if self.head_dim % 2 != 0: - raise ValueError("head_dim must be even for RoPE") - kv_dim = self.num_kv_heads * self.head_dim - self.c_q = CastedLinear(dim, dim, bias=False) - self.c_k = CastedLinear(dim, kv_dim, bias=False) - self.c_v = CastedLinear(dim, kv_dim, bias=False) - self.proj = CastedLinear(dim, dim, bias=False) - self.proj._zero_init = True - self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) - self.rotary = Rotary(self.head_dim, base=rope_base) - - def forward(self, x: Tensor) -> Tensor: - bsz, seqlen, dim = x.shape - q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) - k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) - v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) - q = F.rms_norm(q, (q.size(-1),)) - k = F.rms_norm(k, (k.size(-1),)) - cos, sin = self.rotary(seqlen, x.device, q.dtype) - q = apply_rotary_emb(q, cos, sin) - k = apply_rotary_emb(k, cos, sin) - q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] - y = F.scaled_dot_product_attention( - q, k, v, attn_mask=None, is_causal=True, - enable_gqa=(self.num_kv_heads != self.num_heads), - ) - y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) - return self.proj(y) - - -class MLP(nn.Module): - def __init__(self, dim: int, mlp_mult: float): - super().__init__() - hidden = int(mlp_mult * dim) - self.fc = CastedLinear(dim, hidden, bias=False) - self.proj = CastedLinear(hidden, dim, bias=False) - self.proj._zero_init = True - - def forward(self, x: Tensor) -> Tensor: - x = torch.relu(self.fc(x)) - return self.proj(x.square()) - - -class SmearGate(nn.Module): - """Blend each token's embedding with the previous token's embedding.""" - def __init__(self, dim: int): - super().__init__() - self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) - - def forward(self, x: Tensor) -> Tensor: - g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] - x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) - return (1 - g) * x + g * x_prev - - -class BigramHashEmbedding(nn.Module): - """Hash consecutive token pairs into a learned embedding table.""" - def __init__(self, bigram_vocab_size: int, bigram_dim: int, model_dim: int): - super().__init__() - self.bigram_vocab_size = bigram_vocab_size - self.embed = nn.Embedding(bigram_vocab_size, bigram_dim) - nn.init.zeros_(self.embed.weight) - self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None - if self.proj is not None: - nn.init.zeros_(self.proj.weight) - self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) - - def bigram_hash(self, tokens: Tensor) -> Tensor: - t = tokens.to(torch.int32) - mod = self.bigram_vocab_size - 1 - out = torch.empty_like(t) - out[..., 0] = mod - out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod - return out.long() - - def forward(self, token_ids: Tensor) -> Tensor: - h = self.embed(self.bigram_hash(token_ids)) - if self.proj is not None: - h = self.proj(h) - return h * self.scale.to(dtype=h.dtype) - - -class Block(nn.Module): - def __init__(self, dim: int, num_heads: int, num_kv_heads: int, mlp_mult: float, rope_base: float, qk_gain_init: float): - super().__init__() - self.attn_norm = RMSNorm() - self.mlp_norm = RMSNorm() - self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) - self.mlp = MLP(dim, mlp_mult) - self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) - - def forward(self, x: Tensor, x0: Tensor) -> Tensor: - mix = self.resid_mix.to(dtype=x.dtype) - x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 - attn_out = self.attn(self.attn_norm(x)) - x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out - x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) - return x - - -class GPT(nn.Module): - def __init__( - self, - vocab_size: int, - num_layers: int, - model_dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: float, - tie_embeddings: bool, - tied_embed_init_std: float, - logit_softcap: float, - rope_base: float, - qk_gain_init: float, - bigram_vocab_size: int = 0, - bigram_dim: int = 128, - ): - super().__init__() - if logit_softcap <= 0.0: - raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") - self.tie_embeddings = tie_embeddings - self.tied_embed_init_std = tied_embed_init_std - self.logit_softcap = logit_softcap - self.tok_emb = nn.Embedding(vocab_size, model_dim) - self.bigram = BigramHashEmbedding(bigram_vocab_size, bigram_dim, model_dim) if bigram_vocab_size > 0 else None - self.num_encoder_layers = num_layers // 2 - self.num_decoder_layers = num_layers - self.num_encoder_layers - self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) - self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) - self.smear = SmearGate(model_dim) - self.blocks = nn.ModuleList( - [ - Block(model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init) - for _ in range(num_layers) - ] - ) - self.final_norm = RMSNorm() - self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) - if self.lm_head is not None: - self.lm_head._zero_init = True - self._init_weights() - - def _init_weights(self) -> None: - if self.tie_embeddings: - nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) - num_layers = len(self.blocks) - for name, module in self.named_modules(): - if isinstance(module, nn.Linear): - if getattr(module, "_zero_init", False): - nn.init.zeros_(module.weight) - elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: - nn.init.orthogonal_(module.weight, gain=1.0) - if ".proj." in name or name.endswith(".proj"): - with torch.no_grad(): - module.weight.mul_(1.0 / math.sqrt(2 * num_layers)) - - def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: - x = self.tok_emb(input_ids) - if self.bigram is not None: - x = x + self.bigram(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - x = self.smear(x) - x0 = x - skips: list[Tensor] = [] - for i in range(self.num_encoder_layers): - x = self.blocks[i](x, x0) - skips.append(x) - for i in range(self.num_decoder_layers): - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - x = self.blocks[self.num_encoder_layers + i](x, x0) - x = self.final_norm(x).reshape(-1, x.size(-1)) - targets = target_ids.reshape(-1) - if self.tie_embeddings: - logits_proj = F.linear(x, self.tok_emb.weight) - else: - if self.lm_head is None: - raise RuntimeError("lm_head is required when tie_embeddings=False") - logits_proj = self.lm_head(x) - logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - return F.cross_entropy(logits.float(), targets, reduction="mean") - - def forward_logits(self, input_ids: Tensor) -> Tensor: - x = self.tok_emb(input_ids) - if self.bigram is not None: - x = x + self.bigram(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - x = self.smear(x) - x0 = x - skips: list[Tensor] = [] - for i in range(self.num_encoder_layers): - x = self.blocks[i](x, x0) - skips.append(x) - for i in range(self.num_decoder_layers): - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - x = self.blocks[self.num_encoder_layers + i](x, x0) - x = self.final_norm(x) - if self.tie_embeddings: - logits_proj = F.linear(x, self.tok_emb.weight) - else: - logits_proj = self.lm_head(x) - return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - - -def eval_val_sliding( - args: Hyperparameters, - base_model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, - stride: int, - batch_seqs: int = 32, -) -> tuple[float, float]: - seq_len = args.train_seq_len - total_tokens = val_tokens.numel() - 1 - window_starts = [ws for ws in range(0, total_tokens, stride) - if min(ws + seq_len, total_tokens) - ws >= stride or ws == 0] - total_windows = len(window_starts) - my_s = (total_windows * rank) // world_size - my_e = (total_windows * (rank + 1)) // world_size - my_windows = window_starts[my_s:my_e] - - loss_sum = torch.zeros((), device=device, dtype=torch.float64) - token_count = torch.zeros((), device=device, dtype=torch.float64) - byte_count = torch.zeros((), device=device, dtype=torch.float64) - - base_model.eval() - with torch.inference_mode(): - for bi in range(0, len(my_windows), batch_seqs): - batch_ws = my_windows[bi:bi + batch_seqs] - bsz = len(batch_ws) - x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) - y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) - wlens: list[int] = [] - for i, ws in enumerate(batch_ws): - end = min(ws + seq_len, total_tokens) - wlen = end - ws - wlens.append(wlen) - chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) - x_batch[i, :wlen] = chunk[:-1] - y_batch[i, :wlen] = chunk[1:] - with torch.autocast(device_type="cuda", dtype=torch.bfloat16): - logits = base_model.forward_logits(x_batch) - nll = F.cross_entropy( - logits.reshape(-1, logits.size(-1)).float(), - y_batch.reshape(-1), - reduction="none", - ).reshape(bsz, seq_len) - for i, ws in enumerate(batch_ws): - wlen = wlens[i] - s = 0 if ws == 0 else max(wlen - stride, 0) - scored_nll = nll[i, s:wlen].to(torch.float64) - loss_sum += scored_nll.sum() - token_count += float(wlen - s) - tgt = y_batch[i, s:wlen] - prev = x_batch[i, s:wlen] - tb = base_bytes_lut[tgt].to(torch.float64) - tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) - byte_count += tb.sum() - if rank == 0 and (bi // batch_seqs) % 50 == 0: - done = min(bi + batch_seqs, len(my_windows)) - pct = done / len(my_windows) * 100 - running_bpb = 0.0 - if token_count.item() > 0: - rl = (loss_sum / token_count).item() - running_bpb = rl / math.log(2.0) * (token_count.item() / byte_count.item()) - print(f" sliding_eval [{pct:5.1f}%] {done}/{len(my_windows)} windows running_bpb={running_bpb:.6f}", flush=True) - - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) - - val_loss = (loss_sum / token_count).item() - bits_per_token = val_loss / math.log(2.0) - tokens_per_byte = token_count.item() / byte_count.item() - base_model.train() - return val_loss, bits_per_token * tokens_per_byte - - -# ----------------------------- -# TRAINING -# ----------------------------- - -def main() -> None: - global zeropower_via_newtonschulz5 - - code = Path(__file__).read_text(encoding="utf-8") - args = Hyperparameters() - zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) - - distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ - rank = int(os.environ.get("RANK", "0")) - world_size = int(os.environ.get("WORLD_SIZE", "1")) - local_rank = int(os.environ.get("LOCAL_RANK", "0")) - if world_size <= 0: - raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") - if 8 % world_size != 0: - raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") - grad_accum_steps = 8 // world_size - grad_scale = 1.0 / grad_accum_steps - if not torch.cuda.is_available(): - raise RuntimeError("CUDA is required") - device = torch.device("cuda", local_rank) - torch.cuda.set_device(device) - if distributed: - dist.init_process_group(backend="nccl", device_id=device) - dist.barrier() - master_process = rank == 0 - - torch.backends.cuda.matmul.allow_tf32 = True - torch.backends.cudnn.allow_tf32 = True - from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp - enable_cudnn_sdp(False) - enable_flash_sdp(True) - enable_mem_efficient_sdp(False) - enable_math_sdp(False) - - logfile = None - if master_process: - os.makedirs("logs", exist_ok=True) - logfile = f"logs/{args.run_id}.txt" - print(logfile) - - def log0(msg: str, console: bool = True) -> None: - if not master_process: - return - if console: - print(msg) - if logfile is not None: - with open(logfile, "a", encoding="utf-8") as f: - print(msg, file=f) - - log0(code, console=False) - log0("=" * 100, console=False) - log0(f"Running Python {sys.version}", console=False) - log0(f"Running PyTorch {torch.__version__}", console=False) - log0( - subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, - console=False, - ) - log0("=" * 100, console=False) - - random.seed(args.seed) - np.random.seed(args.seed) - torch.manual_seed(args.seed) - torch.cuda.manual_seed_all(args.seed) - - if not args.tokenizer_path.endswith(".model"): - raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") - sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) - if int(sp.vocab_size()) != args.vocab_size: - raise ValueError( - f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" - ) - dataset_dir = Path(args.data_path).resolve() - actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) - val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) - base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( - sp, args.vocab_size, device - ) - log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") - log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") - log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") - - # MODEL + OPTIMIZER SETUP - base_model = GPT( - vocab_size=args.vocab_size, - num_layers=args.num_layers, - model_dim=args.model_dim, - num_heads=args.num_heads, - num_kv_heads=args.num_kv_heads, - mlp_mult=args.mlp_mult, - tie_embeddings=args.tie_embeddings, - tied_embed_init_std=args.tied_embed_init_std, - logit_softcap=args.logit_softcap, - rope_base=args.rope_base, - qk_gain_init=args.qk_gain_init, - bigram_vocab_size=args.bigram_vocab_size, - bigram_dim=args.bigram_dim, - ).to(device).bfloat16() - for module in base_model.modules(): - if isinstance(module, CastedLinear): - module.float() - restore_low_dim_params_to_fp32(base_model) - compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) - model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model - - block_named_params = list(base_model.blocks.named_parameters()) - matrix_params = [ - p for name, p in block_named_params - if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - scalar_params = [ - p for name, p in block_named_params - if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - if base_model.skip_weights.numel() > 0: - scalar_params.append(base_model.skip_weights) - scalar_params.append(base_model.smear.gate) - if base_model.bigram is not None: - scalar_params.append(base_model.bigram.scale) - - token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr - tok_params = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] - if base_model.bigram is not None: - tok_params.append({"params": [base_model.bigram.embed.weight], "lr": token_lr, "base_lr": token_lr}) - if base_model.bigram.proj is not None: - matrix_params.append(base_model.bigram.proj.weight) - - optimizer_tok = torch.optim.AdamW( - tok_params, - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - weight_decay=args.weight_decay, - fused=True, - ) - optimizer_muon = Muon( - matrix_params, - lr=args.matrix_lr, - momentum=args.muon_momentum, - backend_steps=args.muon_backend_steps, - weight_decay=0.04, - ) - for group in optimizer_muon.param_groups: - group["base_lr"] = args.matrix_lr - optimizer_scalar = torch.optim.AdamW( - [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - weight_decay=args.weight_decay, - fused=True, - ) - optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] - if base_model.lm_head is not None: - optimizer_head = torch.optim.Adam( - [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizers.insert(1, optimizer_head) - - n_params = sum(p.numel() for p in base_model.parameters()) - log0(f"model_params:{n_params}") - log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") - log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") - log0( - f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " - f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" - ) - log0( - f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " - f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " - f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" - ) - log0(f"seed:{args.seed}") - - # DATA LOADER & MODEL WARMUP - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - def zero_grad_all() -> None: - for opt in optimizers: - opt.zero_grad(set_to_none=True) - - max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None - - def lr_mul(step: int, elapsed_ms: float) -> float: - if args.warmdown_iters <= 0: - return 1.0 - if max_wallclock_ms is None: - warmdown_start = max(args.iterations - args.warmdown_iters, 0) - return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 - step_ms = elapsed_ms / max(step, 1) - warmdown_ms = args.warmdown_iters * step_ms - remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) - return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 - - if args.warmup_steps > 0: - initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} - initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] - model.train() - for warmup_step in range(args.warmup_steps): - zero_grad_all() - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - warmup_loss = model(x, y) - (warmup_loss * grad_scale).backward() - for opt in optimizers: - opt.step() - zero_grad_all() - if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: - log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") - base_model.load_state_dict(initial_model_state, strict=True) - for opt, state in zip(optimizers, initial_optimizer_states, strict=True): - opt.load_state_dict(state) - zero_grad_all() - if distributed: - model.require_backward_grad_sync = True - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - # MAIN TRAINING LOOP - training_time_ms = 0.0 - stop_after_step: int | None = None - swa_state: dict[str, Tensor] | None = None - swa_count = 0 - torch.cuda.synchronize() - t0 = time.perf_counter() - - step = 0 - while True: - last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) - - should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) - if should_validate: - torch.cuda.synchronize() - training_time_ms += 1000.0 * (time.perf_counter() - t0) - val_loss, val_bpb = eval_val( - args, model, rank, world_size, device, grad_accum_steps, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - ) - log0( - f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " - f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" - ) - torch.cuda.synchronize() - t0 = time.perf_counter() - - if last_step: - if stop_after_step is not None and step < args.iterations: - log0( - f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " - f"step:{step}/{args.iterations}" - ) - break - - elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - scale = lr_mul(step, elapsed_ms) - zero_grad_all() - train_loss = torch.zeros((), device=device) - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - loss = model(x, y) - train_loss += loss.detach() - (loss * grad_scale).backward() - train_loss /= grad_accum_steps - - frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 - muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum - for group in optimizer_muon.param_groups: - group["momentum"] = muon_momentum - - for opt in optimizers: - for group in opt.param_groups: - group["lr"] = group["base_lr"] * scale - - if args.grad_clip_norm > 0: - torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) - for opt in optimizers: - opt.step() - zero_grad_all() - - step += 1 - approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - - # SWA: collect checkpoints during warmdown - if args.swa_enabled and scale < args.swa_start_frac and step % args.swa_every == 0: - if swa_state is None: - swa_state = {name: t.detach().cpu().clone() for name, t in base_model.state_dict().items()} - swa_count = 1 - log0(f"swa:start step:{step}") - else: - for name, t in base_model.state_dict().items(): - swa_state[name] += t.detach().cpu() - swa_count += 1 - - should_log_train = ( - args.train_log_every > 0 - and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) - ) - if should_log_train: - log0( - f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " - f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" - ) - - reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms - if distributed and max_wallclock_ms is not None: - reached_cap_tensor = torch.tensor(int(reached_cap), device=device) - dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) - reached_cap = bool(reached_cap_tensor.item()) - if stop_after_step is None and reached_cap: - stop_after_step = step - - log0( - f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " - f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" - ) - - # Apply SWA if collected - if args.swa_enabled and swa_state is not None and swa_count > 1: - log0(f"swa:applying averaged {swa_count} checkpoints") - current_state = base_model.state_dict() - avg_state = { - name: (tensor / swa_count).to(dtype=current_state[name].dtype) - for name, tensor in swa_state.items() - } - base_model.load_state_dict(avg_state, strict=True) - - # SERIALIZATION + ROUNDTRIP VALIDATION - if master_process: - torch.save(base_model.state_dict(), "final_model.pt") - model_bytes = os.path.getsize("final_model.pt") - code_bytes = len(code.encode("utf-8")) - log0(f"Serialized model: {model_bytes} bytes") - log0(f"Code size: {code_bytes} bytes") - log0(f"Total submission size: {model_bytes + code_bytes} bytes") - - # Magnitude pruning: zero out smallest weights to improve compression - with torch.no_grad(): - for name, param in base_model.named_parameters(): - if param.ndim == 2 and param.numel() > 65536: - threshold = torch.quantile(param.abs().float().flatten(), 0.03) - mask = param.abs() < threshold - param.masked_fill_(mask, 0.0) - - # INT6 mixed quantization + zstd/zlib export - sd_cpu = {k: v.detach().cpu() for k, v in base_model.state_dict().items()} - quant_result, quant_meta = mixed_quantize_int6(sd_cpu, {"mlp", "attn", "bigram"}) - quant_buf = io.BytesIO() - torch.save({"w": quant_result, "m": quant_meta}, quant_buf) - quant_raw = quant_buf.getvalue() - if _COMPRESSOR == "zstd": - quant_blob = zstandard.ZstdCompressor(level=22).compress(quant_raw) - else: - quant_blob = zlib.compress(quant_raw, 9) - if master_process: - with open("final_model.int8.ptz", "wb") as f: - f.write(quant_blob) - quant_file_bytes = os.path.getsize("final_model.int8.ptz") - code_bytes = len(code.encode("utf-8")) - log0(f"Serialized model int6+{_COMPRESSOR}: {quant_file_bytes} bytes") - log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") - - if distributed: - dist.barrier() - with open("final_model.int8.ptz", "rb") as f: - quant_blob_disk = f.read() - if _COMPRESSOR == "zstd": - decompressed = zstandard.ZstdDecompressor().decompress(quant_blob_disk) - else: - decompressed = zlib.decompress(quant_blob_disk) - quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") - deq_state = dequantize_mixed_int6(quant_state["w"], quant_state["m"], sd_cpu) - base_model.load_state_dict(deq_state, strict=True) - - # Sliding window eval on int6-roundtripped weights - torch.cuda.synchronize() - t_qeval = time.perf_counter() - if args.eval_stride > 0 and args.eval_stride < args.train_seq_len: - log0(f"final_eval_mode:sliding_window stride:{args.eval_stride} batch_seqs:{args.eval_batch_seqs}") - q_val_loss, q_val_bpb = eval_val_sliding( - args, base_model, rank, world_size, device, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - stride=args.eval_stride, batch_seqs=args.eval_batch_seqs, - ) - else: - log0("final_eval_mode:standard") - q_val_loss, q_val_bpb = eval_val( - args, model, rank, world_size, device, grad_accum_steps, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - ) - torch.cuda.synchronize() - log0( - f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " - f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" - ) - log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") - - if distributed: - dist.destroy_process_group() - - -if __name__ == "__main__": - main() -# fixes applied -# tuned - -==================================================================================================== -Running Python 3.13.12 | packaged by Anaconda, Inc. | (main, Feb 24 2026, 16:13:31) [GCC 14.3.0] -Running PyTorch 2.10.0+cu128 -Fri Mar 20 08:45:54 2026 -+-----------------------------------------------------------------------------------------+ -| NVIDIA-SMI 590.48.01 Driver Version: 590.48.01 CUDA Version: 13.1 | -+-----------------------------------------+------------------------+----------------------+ -| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | -| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | -| | | MIG M. | -|=========================================+========================+======================| -| 0 NVIDIA H100 80GB HBM3 On | 00000000:19:00.0 Off | 0 | -| N/A 32C P0 117W / 700W | 1525MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 1 NVIDIA H100 80GB HBM3 On | 00000000:2D:00.0 Off | 0 | -| N/A 33C P0 122W / 700W | 1525MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 2 NVIDIA H100 80GB HBM3 On | 00000000:3F:00.0 Off | 0 | -| N/A 35C P0 122W / 700W | 1525MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 3 NVIDIA H100 80GB HBM3 On | 00000000:66:00.0 Off | 0 | -| N/A 32C P0 120W / 700W | 1525MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 4 NVIDIA H100 80GB HBM3 On | 00000000:9B:00.0 Off | 0 | -| N/A 32C P0 115W / 700W | 1525MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 5 NVIDIA H100 80GB HBM3 On | 00000000:AE:00.0 Off | 0 | -| N/A 35C P0 116W / 700W | 1525MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 6 NVIDIA H100 80GB HBM3 On | 00000000:BF:00.0 Off | 0 | -| N/A 34C P0 121W / 700W | 1525MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 7 NVIDIA H100 80GB HBM3 On | 00000000:E4:00.0 Off | 0 | -| N/A 32C P0 115W / 700W | 1525MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ - -+-----------------------------------------------------------------------------------------+ -| Processes: | -| GPU GI CI PID Type Process name GPU Memory | -| ID ID Usage | -|=========================================================================================| -| 0 N/A N/A 1390710 C ...hao/miniconda3/bin/python3.13 1516MiB | -| 1 N/A N/A 1390711 C ...hao/miniconda3/bin/python3.13 1516MiB | -| 2 N/A N/A 1390712 C ...hao/miniconda3/bin/python3.13 1516MiB | -| 3 N/A N/A 1390713 C ...hao/miniconda3/bin/python3.13 1516MiB | -| 4 N/A N/A 1390714 C ...hao/miniconda3/bin/python3.13 1516MiB | -| 5 N/A N/A 1390715 C ...hao/miniconda3/bin/python3.13 1516MiB | -| 6 N/A N/A 1390716 C ...hao/miniconda3/bin/python3.13 1516MiB | -| 7 N/A N/A 1390717 C ...hao/miniconda3/bin/python3.13 1516MiB | -+-----------------------------------------------------------------------------------------+ - -==================================================================================================== -val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model -train_loader:dataset:fineweb10B_sp1024 train_shards:80 -val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 -model_params:25517137 -world_size:8 grad_accum_steps:1 -attention_mode:gqa num_heads:8 num_kv_heads:4 -tie_embeddings:True embed_lr:0.03 matrix_lr:0.02 scalar_lr:0.02 -train_batch_tokens:786432 train_seq_len:2048 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 -seed:42 -warmup_step:1/20 -warmup_step:2/20 -warmup_step:3/20 -warmup_step:4/20 -warmup_step:5/20 -warmup_step:6/20 -warmup_step:7/20 -warmup_step:8/20 -warmup_step:9/20 -warmup_step:10/20 -warmup_step:11/20 -warmup_step:12/20 -warmup_step:13/20 -warmup_step:14/20 -warmup_step:15/20 -warmup_step:16/20 -warmup_step:17/20 -warmup_step:18/20 -warmup_step:19/20 -warmup_step:20/20 -step:0/20000 val_loss:6.9323 val_bpb:4.1057 train_time:0ms step_avg:0.02ms -step:1/20000 train_loss:6.9334 train_time:131ms step_avg:130.74ms -step:2/20000 train_loss:8.1444 train_time:196ms step_avg:98.20ms -step:3/20000 train_loss:7.6923 train_time:283ms step_avg:94.43ms -step:4/20000 train_loss:6.9898 train_time:371ms step_avg:92.74ms -step:5/20000 train_loss:6.8344 train_time:458ms step_avg:91.60ms -step:6/20000 train_loss:6.6323 train_time:545ms step_avg:90.81ms -step:7/20000 train_loss:6.5361 train_time:632ms step_avg:90.28ms -step:8/20000 train_loss:6.5762 train_time:719ms step_avg:89.89ms -step:9/20000 train_loss:6.3266 train_time:806ms step_avg:89.60ms -step:10/20000 train_loss:6.0560 train_time:894ms step_avg:89.39ms -step:100/20000 train_loss:3.1532 train_time:8816ms step_avg:88.16ms -step:200/20000 train_loss:2.3835 train_time:17744ms step_avg:88.72ms -step:300/20000 train_loss:2.5419 train_time:26679ms step_avg:88.93ms -step:400/20000 train_loss:2.4083 train_time:35610ms step_avg:89.03ms -step:500/20000 train_loss:2.3940 train_time:44456ms step_avg:88.91ms -step:500/20000 val_loss:2.3519 val_bpb:1.3929 train_time:44480ms step_avg:88.96ms -step:600/20000 train_loss:2.3335 train_time:53394ms step_avg:88.99ms -step:700/20000 train_loss:2.3419 train_time:62352ms step_avg:89.07ms -step:800/20000 train_loss:2.2378 train_time:71309ms step_avg:89.14ms -step:900/20000 train_loss:2.1268 train_time:80262ms step_avg:89.18ms -step:1000/20000 train_loss:2.2737 train_time:89118ms step_avg:89.12ms -step:1000/20000 val_loss:2.2258 val_bpb:1.3182 train_time:89142ms step_avg:89.14ms -step:1100/20000 train_loss:2.3197 train_time:98066ms step_avg:89.15ms -step:1200/20000 train_loss:2.3530 train_time:107021ms step_avg:89.18ms -step:1300/20000 train_loss:2.1035 train_time:115963ms step_avg:89.20ms -step:1400/20000 train_loss:2.1848 train_time:124918ms step_avg:89.23ms -step:1500/20000 train_loss:2.2191 train_time:133774ms step_avg:89.18ms -step:1500/20000 val_loss:2.1843 val_bpb:1.2937 train_time:133798ms step_avg:89.20ms -step:1600/20000 train_loss:2.0748 train_time:142729ms step_avg:89.21ms -step:1700/20000 train_loss:2.1419 train_time:151678ms step_avg:89.22ms -step:1800/20000 train_loss:2.1577 train_time:160626ms step_avg:89.24ms -step:1900/20000 train_loss:2.1327 train_time:169487ms step_avg:89.20ms -step:2000/20000 train_loss:2.0683 train_time:178437ms step_avg:89.22ms -step:2000/20000 val_loss:2.1331 val_bpb:1.2633 train_time:178461ms step_avg:89.23ms -step:2100/20000 train_loss:2.0465 train_time:187401ms step_avg:89.24ms -step:2200/20000 train_loss:2.1484 train_time:196360ms step_avg:89.25ms -step:2300/20000 train_loss:2.1132 train_time:205309ms step_avg:89.26ms -step:2400/20000 train_loss:2.0655 train_time:214159ms step_avg:89.23ms -step:2500/20000 train_loss:2.1696 train_time:223100ms step_avg:89.24ms -step:2500/20000 val_loss:2.1066 val_bpb:1.2476 train_time:223125ms step_avg:89.25ms -step:2600/20000 train_loss:2.1094 train_time:232061ms step_avg:89.25ms -step:2700/20000 train_loss:2.1006 train_time:241013ms step_avg:89.26ms -step:2800/20000 train_loss:2.1546 train_time:249980ms step_avg:89.28ms -step:2900/20000 train_loss:2.0250 train_time:258820ms step_avg:89.25ms -step:3000/20000 train_loss:2.1585 train_time:267773ms step_avg:89.26ms -step:3000/20000 val_loss:2.0922 val_bpb:1.2391 train_time:267797ms step_avg:89.27ms -step:3100/20000 train_loss:2.0360 train_time:276725ms step_avg:89.27ms -step:3200/20000 train_loss:2.1736 train_time:285688ms step_avg:89.28ms -step:3300/20000 train_loss:2.0739 train_time:294555ms step_avg:89.26ms -step:3400/20000 train_loss:2.0214 train_time:303521ms step_avg:89.27ms -step:3500/20000 train_loss:2.1841 train_time:312478ms step_avg:89.28ms -step:3500/20000 val_loss:2.0852 val_bpb:1.2349 train_time:312503ms step_avg:89.29ms -step:3600/20000 train_loss:2.1011 train_time:321439ms step_avg:89.29ms -step:3700/20000 train_loss:2.1038 train_time:330396ms step_avg:89.30ms -step:3800/20000 train_loss:2.0757 train_time:339245ms step_avg:89.27ms -step:3900/20000 train_loss:2.0864 train_time:348188ms step_avg:89.28ms -step:4000/20000 train_loss:1.9805 train_time:357142ms step_avg:89.29ms -step:4000/20000 val_loss:2.0722 val_bpb:1.2273 train_time:357167ms step_avg:89.29ms -step:4100/20000 train_loss:2.0233 train_time:366103ms step_avg:89.29ms -step:4200/20000 train_loss:2.1616 train_time:375075ms step_avg:89.30ms -step:4300/20000 train_loss:2.0600 train_time:383932ms step_avg:89.29ms -step:4400/20000 train_loss:2.0379 train_time:392892ms step_avg:89.29ms -step:4500/20000 train_loss:2.1254 train_time:401865ms step_avg:89.30ms -step:4500/20000 val_loss:2.0482 val_bpb:1.2130 train_time:401888ms step_avg:89.31ms -step:4600/20000 train_loss:1.8452 train_time:410814ms step_avg:89.31ms -step:4700/20000 train_loss:2.2389 train_time:419669ms step_avg:89.29ms -step:4800/20000 train_loss:2.4299 train_time:428641ms step_avg:89.30ms -step:4900/20000 train_loss:2.0543 train_time:437590ms step_avg:89.30ms -step:5000/20000 train_loss:2.1069 train_time:446548ms step_avg:89.31ms -step:5000/20000 val_loss:2.0272 val_bpb:1.2006 train_time:446571ms step_avg:89.31ms -step:5100/20000 train_loss:2.1283 train_time:455504ms step_avg:89.31ms -step:5200/20000 train_loss:2.0444 train_time:464368ms step_avg:89.30ms -step:5300/20000 train_loss:2.0105 train_time:473322ms step_avg:89.31ms -step:5400/20000 train_loss:2.0525 train_time:482286ms step_avg:89.31ms -step:5500/20000 train_loss:2.0201 train_time:491236ms step_avg:89.32ms -step:5500/20000 val_loss:2.0048 val_bpb:1.1874 train_time:491260ms step_avg:89.32ms -swa:start step:5550 -step:5600/20000 train_loss:1.9569 train_time:500257ms step_avg:89.33ms -step:5700/20000 train_loss:2.0141 train_time:509183ms step_avg:89.33ms -step:5800/20000 train_loss:1.9993 train_time:518202ms step_avg:89.35ms -step:5900/20000 train_loss:1.9068 train_time:527231ms step_avg:89.36ms -step:6000/20000 train_loss:1.9433 train_time:536240ms step_avg:89.37ms -step:6000/20000 val_loss:1.9814 val_bpb:1.1735 train_time:536303ms step_avg:89.38ms -step:6100/20000 train_loss:1.9155 train_time:545172ms step_avg:89.37ms -step:6200/20000 train_loss:1.9505 train_time:554198ms step_avg:89.39ms -step:6300/20000 train_loss:1.9472 train_time:563218ms step_avg:89.40ms -step:6400/20000 train_loss:2.0003 train_time:572221ms step_avg:89.41ms -step:6500/20000 train_loss:2.0806 train_time:581236ms step_avg:89.42ms -step:6500/20000 val_loss:1.9541 val_bpb:1.1573 train_time:581299ms step_avg:89.43ms -step:6600/20000 train_loss:1.8448 train_time:590169ms step_avg:89.42ms -step:6700/20000 train_loss:1.9475 train_time:599199ms step_avg:89.43ms -step:6709/20000 val_loss:1.9473 val_bpb:1.1533 train_time:600037ms step_avg:89.44ms -stopping_early: wallclock_cap train_time:600037ms step:6709/20000 -peak memory allocated: 18871 MiB reserved: 18976 MiB -swa:applying averaged 24 checkpoints -Serialized model: 98437419 bytes -Code size: 52930 bytes -Total submission size: 98490349 bytes -Serialized model int6+zstd: 15913048 bytes -Total submission size int8+zlib: 15965978 bytes -final_eval_mode:sliding_window stride:64 batch_seqs:32 -final_int8_zlib_roundtrip val_loss:1.9294 val_bpb:1.1427 eval_time:168421ms -final_int8_zlib_roundtrip_exact val_loss:1.92941613 val_bpb:1.14271194 diff --git a/records/track_10min_16mb/2026-03-20_11L_EfficientPartialXSA_FA3_SWA120/README.md b/records/track_10min_16mb/2026-03-20_11L_EfficientPartialXSA_FA3_SWA120/README.md deleted file mode 100644 index 22152b3936..0000000000 --- a/records/track_10min_16mb/2026-03-20_11L_EfficientPartialXSA_FA3_SWA120/README.md +++ /dev/null @@ -1,79 +0,0 @@ -# 11L + Efficient Partial XSA (val_bpb: 1.1307) - -## Results -- **val_bpb: 1.1307** (sliding window, stride=64) -- Pre-quantization BPB: 1.1437 -- Model parameters: 26,829,913 -- Artifact size: 15,892,986 bytes (under 16MB limit) -- Training: 6,976 steps in 600 seconds (~86ms/step) -- SWA: 13 checkpoint average during warmdown (every 120 steps) - -## Novel Contribution: Efficient Partial Exclusive Self Attention (XSA) - -Based on Exclusive Self Attention (arXiv:2603.09078), we introduce two key improvements: - -### 1. Efficient GQA-Aware Implementation -Standard XSA with Grouped Query Attention requires `repeat_interleave` to expand value vectors -from `num_kv_heads` to `num_heads`, doubling memory allocation per layer. Our implementation -uses a free reshape into KV head groups + broadcasting: - -```python -# OLD: expensive tensor duplication -v_expanded = v.repeat_interleave(group_size, dim=-2) # allocates 2x memory -vn = normalize(v_expanded) -y = y - dot(y, vn) * vn - -# NEW: free reshape + broadcast (zero allocation) -y_grouped = y.reshape(B, T, Hkv, group_size, D) # view, no copy -vn = normalize(v).unsqueeze(-2) # [B,T,Hkv,1,D] -y = (y_grouped - dot(y_grouped, vn) * vn).reshape(B, T, H, D) -``` - -This reduces XSA overhead from ~7ms/step to ~2ms/step at 11 layers with GQA (8 heads, 4 KV heads). - -### 2. Partial Application to Deepest Layers Only -The XSA paper shows self-attention bias (cosine similarity between output and self-value) -increases across layers. We apply XSA only to the **last 3 layers** (out of 11), targeting -the layers with highest self-attention bias while minimizing compute overhead. - -Combined, these give ~0.002 BPB improvement over the baseline at <2ms/step cost. - -## Architecture -- 11 transformer layers, 512-dim, 8 heads (4 KV heads via GQA) -- 3x MLP expansion (1536 hidden), relu-squared activation -- U-Net skip connections (encoder=5, decoder=6) -- SmearGate + BigramHash (2048 buckets, dim=128) -- Tied embeddings, logit softcap=30.0 -- NTK-aware RoPE (train_seq_len=1024, auto-scales at 2048) -- **XSA on layers 8, 9, 10** (deepest 3 of 11) - -## Training -- FlashAttention 3 (Hopper-optimized) -- Muon optimizer: lr=0.025, momentum=0.99 (warmup from 0.92 over 1500 steps) -- AdamW for embeddings/scalars: lr=0.035/0.025 -- Weight decay: 0.04 (both Muon and AdamW) -- Warmdown: 3000 iterations, grad clip 0.3 -- SWA every 120 steps (scale < 0.5), 13 checkpoint uniform average -- OrthoInit + muP-scaled output projections -- Seed: 1337 - -## Quantization -- Int6 per-row quantization on MLP + attention weights -- Int8 for embeddings -- zstd level 22 compression - -## Run Command -```bash -NUM_LAYERS=11 BIGRAM_VOCAB_SIZE=2048 \ -MUON_WD=0.04 ADAM_WD=0.04 \ -MATRIX_LR=0.025 SCALAR_LR=0.025 TIED_EMBED_LR=0.035 \ -MUON_MOMENTUM=0.99 MUON_MOMENTUM_WARMUP_START=0.92 \ -MUON_MOMENTUM_WARMUP_STEPS=1500 WARMDOWN_ITERS=3000 \ -ITERATIONS=9000 MAX_WALLCLOCK_SECONDS=600 EVAL_STRIDE=64 \ -SWA_EVERY=120 SWA_ENABLED=1 MTP_NUM_HEADS=0 SEED=1337 \ -WARMUP_STEPS=30 VAL_LOSS_EVERY=2000 XSA_LAST_N=3 \ -torchrun --nproc_per_node=8 train_gpt.py -``` - -## References -- Exclusive Self Attention: arXiv:2603.09078 (Shuangfei Zhai, 2026) diff --git a/records/track_10min_16mb/2026-03-20_11L_EfficientPartialXSA_FA3_SWA120/submission.json b/records/track_10min_16mb/2026-03-20_11L_EfficientPartialXSA_FA3_SWA120/submission.json deleted file mode 100644 index 69531e833d..0000000000 --- a/records/track_10min_16mb/2026-03-20_11L_EfficientPartialXSA_FA3_SWA120/submission.json +++ /dev/null @@ -1,17 +0,0 @@ -{ - "author": "vadim borisov (tabularis.ai)", - "github_id": "unnir", - "name": "11L + Efficient Partial XSA + FA3 + SWA/120 (val_bpb: 1.1307)", - "blurb": "11 layers, int6 quant, zstd-22. Novel contribution: Efficient Partial Exclusive Self Attention (XSA, arXiv:2603.09078) applied to deepest 3 layers only. GQA-aware reshape avoids tensor duplication, adding <2ms/step overhead. XSA subtracts self-value projection from attention output, forcing deeper layers to learn from context rather than self-reference. SWA every 120 steps (13 checkpoint avg). OrthoInit + muP scaling. SmearGate + BigramHash(2048x128). FlashAttention 3 + NTK RoPE. Weight decay 0.04 (Muon+AdamW).", - "date": "2026-03-20T20:15:00Z", - "val_loss": 1.90915845, - "val_bpb": 1.13071416, - "pre_quant_val_loss": 1.9311, - "pre_quant_val_bpb": 1.1437, - "int6_zstd_val_loss": 1.90915845, - "int6_zstd_val_bpb": 1.13071416, - "bytes_total": 15892986, - "bytes_model_int6_zstd": 15827986, - "bytes_code": 65000, - "seed": 1337 -} diff --git a/records/track_10min_16mb/2026-03-20_11L_EfficientPartialXSA_FA3_SWA120/train.log b/records/track_10min_16mb/2026-03-20_11L_EfficientPartialXSA_FA3_SWA120/train.log deleted file mode 100644 index e110fca22b..0000000000 --- a/records/track_10min_16mb/2026-03-20_11L_EfficientPartialXSA_FA3_SWA120/train.log +++ /dev/null @@ -1,1682 +0,0 @@ -""" -train_gpt_submit.py — Submission v2: wider MLP + STE int6 QAT + MTP + seq2048 + NTK RoPE + -fp16 embed + late-K passthrough + sliding window eval. -""" - -from __future__ import annotations - -import copy -import glob -import io -import math -import os -import random -import subprocess -import sys -import time -import uuid -import zlib -from pathlib import Path - -try: - import zstandard - _COMPRESSOR = "zstd" -except ImportError: - _COMPRESSOR = "zlib" - -import numpy as np -import sentencepiece as spm -import torch -import torch.distributed as dist -import torch.nn.functional as F -from torch import Tensor, nn -from torch.nn.parallel import DistributedDataParallel as DDP - -from flash_attn_interface import flash_attn_func as flash_attn_3_func - -# ----------------------------- -# HYPERPARAMETERS -# ----------------------------- -# Default Simple Baseline run: -# - 9 transformer blocks at width 512 -# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion -# - vocab size 1024, sequence length 1024, tied embeddings -# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap - -class Hyperparameters: - # Data paths are shard globs produced by the existing preprocessing pipeline. - data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") - train_files = os.path.join(data_path, "fineweb_train_*.bin") - val_files = os.path.join(data_path, "fineweb_val_*.bin") - tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") - run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) - seed = int(os.environ.get("SEED", 1337)) - - # Validation cadence and batch size. Validation always uses the full fineweb_val split. - val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) - val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) - train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) - - # Training length. - iterations = int(os.environ.get("ITERATIONS", 20000)) - warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) - warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) - train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 786_432)) - train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) - eval_seq_len = int(os.environ.get("EVAL_SEQ_LEN", 2048)) - max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) - qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) - - # Model shape. - vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) - num_layers = int(os.environ.get("NUM_LAYERS", 9)) - num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) - model_dim = int(os.environ.get("MODEL_DIM", 512)) - num_heads = int(os.environ.get("NUM_HEADS", 8)) - mlp_mult = float(os.environ.get("MLP_MULT", 3.0)) - tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) - rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) - logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) - - # Optimizer hyperparameters. - embed_lr = float(os.environ.get("EMBED_LR", 0.6)) - head_lr = float(os.environ.get("HEAD_LR", 0.008)) - tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) - tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) - matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) - scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) - muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) - muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) - muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) - muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) - beta1 = float(os.environ.get("BETA1", 0.9)) - beta2 = float(os.environ.get("BETA2", 0.95)) - adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) - grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) - eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) - mtp_num_heads = int(os.environ.get("MTP_NUM_HEADS", 0)) - mtp_loss_weight = float(os.environ.get("MTP_LOSS_WEIGHT", 0.2)) - muon_beta2 = float(os.environ.get("MUON_BETA2", 0.95)) - swa_enabled = bool(int(os.environ.get("SWA_ENABLED", "1"))) - swa_every = int(os.environ.get("SWA_EVERY", 200)) - muon_wd = float(os.environ.get("MUON_WD", 0.02)) - adam_wd = float(os.environ.get("ADAM_WD", 0.01)) - qat_enabled = bool(int(os.environ.get("QAT_ENABLED", "0"))) - bigram_vocab_size = int(os.environ.get("BIGRAM_VOCAB_SIZE", 4096)) - bigram_dim = int(os.environ.get("BIGRAM_DIM", 128)) - - # Efficient partial XSA: apply to last N layers only (deep layers have highest self-attention bias) - xsa_last_n = int(os.environ.get("XSA_LAST_N", 4)) # XSA on last 4 layers (0 = disabled) - -# ----------------------------- -# MUON OPTIMIZER -# ----------------------------- -# -# As borrowed from modded-nanogpt -# Background on Muon: https://kellerjordan.github.io/posts/muon/ - -def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: - a, b, c = (3.4445, -4.7750, 2.0315) - X = G.bfloat16() - X /= X.norm() + eps - transposed = G.size(0) > G.size(1) - if transposed: - X = X.T - for _ in range(steps): - A = X @ X.T - B = b * A + c * A @ A - X = a * X + B @ X - return X.T if transposed else X - - -class Muon(torch.optim.Optimizer): - def __init__(self, params, lr: float, momentum: float, backend_steps: int, - nesterov: bool = True, weight_decay: float = 0.0): - super().__init__( - params, - dict(lr=lr, momentum=momentum, backend_steps=backend_steps, - nesterov=nesterov, weight_decay=weight_decay), - ) - - @torch.no_grad() - def step(self, closure=None): - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - distributed = dist.is_available() and dist.is_initialized() - world_size = dist.get_world_size() if distributed else 1 - rank = dist.get_rank() if distributed else 0 - - for group in self.param_groups: - params = group["params"] - if not params: - continue - lr = group["lr"] - momentum = group["momentum"] - backend_steps = group["backend_steps"] - nesterov = group["nesterov"] - - total_params = sum(int(p.numel()) for p in params) - updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) - - curr = 0 - for i, p in enumerate(params): - if i % world_size == rank and p.grad is not None: - g = p.grad - state = self.state[p] - if "momentum_buffer" not in state: - state["momentum_buffer"] = torch.zeros_like(g) - buf = state["momentum_buffer"] - buf.mul_(momentum).add_(g) - if nesterov: - g = g.add(buf, alpha=momentum) - g = zeropower_via_newtonschulz5(g, steps=backend_steps) - g *= max(1, g.size(0) / g.size(1)) ** 0.5 - updates_flat[curr : curr + p.numel()] = g.reshape(-1) - curr += p.numel() - - if distributed: - dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) - - wd = group.get("weight_decay", 0.0) - curr = 0 - for p in params: - if wd > 0.0: - p.data.mul_(1.0 - lr * wd) - g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) - p.add_(g, alpha=-lr) - curr += p.numel() - - return loss - - -# ----------------------------- -# TOKENIZER-AGNOSTIC EVALUATION SETUP -# ----------------------------- -# -# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. -# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. -# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. -# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. - -def build_sentencepiece_luts( - sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device -) -> tuple[Tensor, Tensor, Tensor]: - sp_vocab_size = int(sp.vocab_size()) - table_size = max(sp_vocab_size, vocab_size) - base_bytes_np = np.zeros((table_size,), dtype=np.int16) - has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) - is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) - for token_id in range(sp_vocab_size): - if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): - continue - is_boundary_token_np[token_id] = False - if sp.is_byte(token_id): - base_bytes_np[token_id] = 1 - continue - piece = sp.id_to_piece(token_id) - if piece.startswith("▁"): - has_leading_space_np[token_id] = True - piece = piece[1:] - base_bytes_np[token_id] = len(piece.encode("utf-8")) - return ( - torch.tensor(base_bytes_np, dtype=torch.int16, device=device), - torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), - torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), - ) - - -def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: - files = [Path(p) for p in sorted(glob.glob(pattern))] - if not files: - raise FileNotFoundError(f"No files found for pattern: {pattern}") - # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. - tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() - usable = ((tokens.numel() - 1) // seq_len) * seq_len - if usable <= 0: - raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") - return tokens[: usable + 1] - - -def eval_val( - args: Hyperparameters, - model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - grad_accum_steps: int, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, - eval_seq_len: int | None = None, -) -> tuple[float, float]: - seq_len = eval_seq_len or args.train_seq_len - local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) - if local_batch_tokens < seq_len: - raise ValueError( - "VAL_BATCH_SIZE must provide at least one sequence per rank; " - f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " - f"GRAD_ACCUM_STEPS={grad_accum_steps}, seq_len={seq_len}" - ) - local_batch_seqs = local_batch_tokens // seq_len - total_seqs = (val_tokens.numel() - 1) // seq_len - seq_start = (total_seqs * rank) // world_size - seq_end = (total_seqs * (rank + 1)) // world_size - val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) - val_token_count = torch.zeros((), device=device, dtype=torch.float64) - val_byte_count = torch.zeros((), device=device, dtype=torch.float64) - - model.eval() - with torch.inference_mode(): - for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): - batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) - raw_start = batch_seq_start * seq_len - raw_end = batch_seq_end * seq_len + 1 - local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) - x = local[:-1].reshape(-1, seq_len) - y = local[1:].reshape(-1, seq_len) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - batch_loss = model(x, y).detach() - batch_token_count = float(y.numel()) - val_loss_sum += batch_loss.to(torch.float64) * batch_token_count - val_token_count += batch_token_count - prev_ids = x.reshape(-1) - tgt_ids = y.reshape(-1) - token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) - token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) - val_byte_count += token_bytes.to(torch.float64).sum() - - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) - - val_loss = val_loss_sum / val_token_count - bits_per_token = val_loss.item() / math.log(2.0) - tokens_per_byte = val_token_count.item() / val_byte_count.item() - model.train() - return float(val_loss.item()), float(bits_per_token * tokens_per_byte) - -# ----------------------------- -# POST-TRAINING QUANTIZATION -# ----------------------------- -# -# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. -# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. -# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. - -CONTROL_TENSOR_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "CONTROL_TENSOR_NAME_PATTERNS", - "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,smear", - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", - ",".join(CONTROL_TENSOR_NAME_PATTERNS), - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 -INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 -INT8_PER_ROW_SCALE_DTYPE = torch.float16 -INT8_CLIP_PERCENTILE = 99.99984 -INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 - -def tensor_nbytes(t: Tensor) -> int: - return int(t.numel()) * int(t.element_size()) - -def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: - if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): - return t.float().contiguous() - if t.dtype in {torch.float32, torch.bfloat16}: - passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") - return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() - return t - -def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: - t32 = t.float() - if t32.ndim == 2: - # Matrices get one scale per row, which usually tracks output-channel - # ranges much better than a single tensor-wide scale. - clip_abs = ( - torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) - if t32.numel() - else torch.empty((t32.shape[0],), dtype=torch.float32) - ) - clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) - scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) - q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() - return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() - - # Vectors / scalars use a simpler per-tensor scale. - clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 - scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) - q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() - return q, scale - -def quantize_state_dict_int8(state_dict: dict[str, Tensor]): - # Single supported clean-script export format: - # - per-row int8 for 2D float tensors - # - per-tensor int8 for other float tensors - # - exact passthrough for non-floats - # - passthrough for small float tensors, stored as fp16 to save bytes - quantized: dict[str, Tensor] = {} - scales: dict[str, Tensor] = {} - dtypes: dict[str, str] = {} - passthrough: dict[str, Tensor] = {} - passthrough_orig_dtypes: dict[str, str] = {} - qmeta: dict[str, dict[str, object]] = {} - stats = dict.fromkeys( - ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), - 0, - ) - - for name, tensor in state_dict.items(): - t = tensor.detach().to("cpu").contiguous() - stats["param_count"] += int(t.numel()) - stats["num_tensors"] += 1 - stats["baseline_tensor_bytes"] += tensor_nbytes(t) - - if not t.is_floating_point(): - stats["num_nonfloat_tensors"] += 1 - passthrough[name] = t - stats["int8_payload_bytes"] += tensor_nbytes(t) - continue - - # Small float tensors are cheap enough to keep directly. We still downcast - # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. - if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: - kept = keep_float_tensor(name, t, passthrough_orig_dtypes) - passthrough[name] = kept - stats["int8_payload_bytes"] += tensor_nbytes(kept) - continue - - stats["num_float_tensors"] += 1 - q, s = quantize_float_tensor(t) - if s.ndim > 0: - qmeta[name] = {"scheme": "per_row", "axis": 0} - quantized[name] = q - scales[name] = s - dtypes[name] = str(t.dtype).removeprefix("torch.") - stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) - - obj: dict[str, object] = { - "__quant_format__": "int8_clean_per_row_v1", - "quantized": quantized, - "scales": scales, - "dtypes": dtypes, - "passthrough": passthrough, - } - if qmeta: - obj["qmeta"] = qmeta - if passthrough_orig_dtypes: - obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes - return obj, stats - -def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: - out: dict[str, Tensor] = {} - qmeta = obj.get("qmeta", {}) - passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) - for name, q in obj["quantized"].items(): - dtype = getattr(torch, obj["dtypes"][name]) - s = obj["scales"][name] - if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: - s = s.to(dtype=torch.float32) - # Broadcast the saved row scale back across trailing dimensions. - out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() - else: - scale = float(s.item()) - out[name] = (q.float() * scale).to(dtype=dtype).contiguous() - for name, t in obj["passthrough"].items(): - # Restore small tensors, undoing the temporary fp16 storage cast if needed. - out_t = t.detach().to("cpu").contiguous() - orig_dtype = passthrough_orig_dtypes.get(name) - if isinstance(orig_dtype, str): - out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() - out[name] = out_t - return out - - -# ----------------------------- -# DATA LOADING -# ----------------------------- - -def load_data_shard(file: Path) -> Tensor: - header_bytes = 256 * np.dtype(" None: - self.file_idx = (self.file_idx + 1) % len(self.files) - self.tokens = load_data_shard(self.files[self.file_idx]) - self.pos = 0 - - def take(self, n: int) -> Tensor: - chunks: list[Tensor] = [] - remaining = n - while remaining > 0: - avail = self.tokens.numel() - self.pos - if avail <= 0: - self._advance_file() - continue - k = min(remaining, avail) - chunks.append(self.tokens[self.pos : self.pos + k]) - self.pos += k - remaining -= k - return chunks[0] if len(chunks) == 1 else torch.cat(chunks) - - -class DistributedTokenLoader: - # Each call consumes a contiguous chunk from the shared token stream, then slices out - # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. - def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): - self.rank = rank - self.world_size = world_size - self.device = device - self.stream = TokenStream(pattern) - - def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: - local_tokens = global_tokens // (self.world_size * grad_accum_steps) - per_rank_span = local_tokens + 1 - chunk = self.stream.take(per_rank_span * self.world_size) - start = self.rank * per_rank_span - local = chunk[start : start + per_rank_span].to(dtype=torch.int64) - x = local[:-1].reshape(-1, seq_len) - y = local[1:].reshape(-1, seq_len) - return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) - -# ----------------------------- -# TRANSFORMER MODULES -# ----------------------------- - -class RMSNorm(nn.Module): - def __init__(self, eps: float | None = None): - super().__init__() - self.eps = eps - - def forward(self, x: Tensor) -> Tensor: - return F.rms_norm(x, (x.size(-1),), eps=self.eps) - - -class CastedLinear(nn.Linear): - _qat_enabled: bool = False - - def forward(self, x: Tensor) -> Tensor: - w = self.weight.to(x.dtype) - if CastedLinear._qat_enabled and self.training and w.ndim == 2: - with torch.no_grad(): - w32 = self.weight.float() - row_max = w32.abs().amax(dim=1) - scale = (row_max / 31.0).clamp_min(1.0 / 31.0) - w_q = (torch.clamp(torch.round(w32 / scale[:, None]), -32, 31) * scale[:, None]).to(x.dtype) - w = w + (w_q - w).detach() - bias = self.bias.to(x.dtype) if self.bias is not None else None - return F.linear(x, w, bias) - - -def restore_low_dim_params_to_fp32(module: nn.Module) -> None: - # Keep small/control parameters in fp32 even when the model body runs in bf16. - with torch.no_grad(): - for name, param in module.named_parameters(): - if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: - param.data = param.data.float() - - -class Rotary(nn.Module): - # NTK-aware RoPE: auto-scales base frequency when seq_len exceeds train_seq_len. - def __init__(self, dim: int, base: float = 10000.0, train_seq_len: int = 1024): - super().__init__() - self.dim = dim - self.base = base - self.train_seq_len = train_seq_len - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self._seq_len_cached = 0 - self._cos_cached: Tensor | None = None - self._sin_cached: Tensor | None = None - - def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: - if ( - self._cos_cached is None - or self._sin_cached is None - or self._seq_len_cached != seq_len - or self._cos_cached.device != device - ): - if seq_len > self.train_seq_len: - scale = seq_len / self.train_seq_len - new_base = self.base * (scale ** (self.dim / (self.dim - 2))) - inv_freq = 1.0 / (new_base ** (torch.arange(0, self.dim, 2, dtype=torch.float32, device=device) / self.dim)) - else: - inv_freq = self.inv_freq.to(device) - t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) - freqs = torch.outer(t, inv_freq) - self._cos_cached = freqs.cos()[None, :, None, :] - self._sin_cached = freqs.sin()[None, :, None, :] - self._seq_len_cached = seq_len - return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) - - -def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: - half = x.size(-1) // 2 - x1, x2 = x[..., :half], x[..., half:] - return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) - - -class CausalSelfAttention(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - rope_base: float, - qk_gain_init: float, - ): - super().__init__() - if dim % num_heads != 0: - raise ValueError("model_dim must be divisible by num_heads") - if num_heads % num_kv_heads != 0: - raise ValueError("num_heads must be divisible by num_kv_heads") - self.num_heads = num_heads - self.num_kv_heads = num_kv_heads - self.head_dim = dim // num_heads - if self.head_dim % 2 != 0: - raise ValueError("head_dim must be even for RoPE") - kv_dim = self.num_kv_heads * self.head_dim - self.c_q = CastedLinear(dim, dim, bias=False) - self.c_k = CastedLinear(dim, kv_dim, bias=False) - self.c_v = CastedLinear(dim, kv_dim, bias=False) - self.proj = CastedLinear(dim, dim, bias=False) - self.proj._zero_init = True - self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) - self.rotary = Rotary(self.head_dim, base=rope_base, train_seq_len=1024) - self.use_xsa = False # set by GPT.__init__ for deep layers only - - def _xsa_efficient(self, y: Tensor, v: Tensor) -> Tensor: - """Efficient XSA: subtract self-value projection via GQA-aware reshape (no repeat_interleave). - y: [B, T, H, D], v: [B, T, Hkv, D]. H must be divisible by Hkv.""" - B, T, H, D = y.shape - Hkv = v.size(-2) - group = H // Hkv - # Reshape y into KV head groups — free view, no memory alloc - y_g = y.reshape(B, T, Hkv, group, D) # [B, T, Hkv, group, D] - vn = F.normalize(v, dim=-1).unsqueeze(-2) # [B, T, Hkv, 1, D] — broadcast ready - # Project out self-value component per KV head group - proj = (y_g * vn).sum(dim=-1, keepdim=True) * vn - return (y_g - proj).reshape(B, T, H, D) - - def forward(self, x: Tensor) -> Tensor: - bsz, seqlen, dim = x.shape - q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim) - k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) - v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) - q = F.rms_norm(q, (q.size(-1),)) - k = F.rms_norm(k, (k.size(-1),)) - cos, sin = self.rotary(seqlen, x.device, q.dtype) - q = apply_rotary_emb(q, cos, sin) - k = apply_rotary_emb(k, cos, sin) - q = q * self.q_gain.to(dtype=q.dtype)[None, None, :, None] - y = flash_attn_3_func(q, k, v, causal=True) - # XSA: subtract self-value projection (deep layers only) - if self.use_xsa: - y = self._xsa_efficient(y, v) - y = y.reshape(bsz, seqlen, dim) - return self.proj(y) - - -class SmearGate(nn.Module): - def __init__(self, dim: int): - super().__init__() - self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) - - def forward(self, x: Tensor) -> Tensor: - g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] - x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) - return (1 - g) * x + g * x_prev - - -class BigramHashEmbedding(nn.Module): - def __init__(self, bigram_vocab_size: int, bigram_dim: int, model_dim: int): - super().__init__() - self.bigram_vocab_size = bigram_vocab_size - self.embed = nn.Embedding(bigram_vocab_size, bigram_dim) - nn.init.zeros_(self.embed.weight) - self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None - if self.proj is not None: - nn.init.zeros_(self.proj.weight) - self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) - - def bigram_hash(self, tokens: Tensor) -> Tensor: - t = tokens.to(torch.int32) - mod = self.bigram_vocab_size - 1 - out = torch.empty_like(t) - out[..., 0] = mod - out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod - return out.long() - - def forward(self, token_ids: Tensor) -> Tensor: - h = self.embed(self.bigram_hash(token_ids)) - if self.proj is not None: - h = self.proj(h) - return h * self.scale.to(dtype=h.dtype) - - -class MLP(nn.Module): - def __init__(self, dim: int, mlp_mult: int): - super().__init__() - hidden = int(mlp_mult * dim) - self.fc = CastedLinear(dim, hidden, bias=False) - self.proj = CastedLinear(hidden, dim, bias=False) - self.proj._zero_init = True - - def forward(self, x: Tensor) -> Tensor: - x = torch.relu(self.fc(x)) - return self.proj(x.square()) - - -class Block(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: int, - rope_base: float, - qk_gain_init: float, - ): - super().__init__() - self.attn_norm = RMSNorm() - self.mlp_norm = RMSNorm() - self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) - self.mlp = MLP(dim, mlp_mult) - self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) - - def forward(self, x: Tensor, x0: Tensor) -> Tensor: - mix = self.resid_mix.to(dtype=x.dtype) - x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 - attn_out = self.attn(self.attn_norm(x)) - x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out - x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) - return x - - -class GPT(nn.Module): - def __init__( - self, - vocab_size: int, - num_layers: int, - model_dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: int, - tie_embeddings: bool, - tied_embed_init_std: float, - logit_softcap: float, - rope_base: float, - qk_gain_init: float, - mtp_num_heads: int = 0, - mtp_loss_weight: float = 0.1, - bigram_vocab_size: int = 0, - bigram_dim: int = 128, - xsa_last_n: int = 0, - ): - super().__init__() - if logit_softcap <= 0.0: - raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") - self.tie_embeddings = tie_embeddings - self.tied_embed_init_std = tied_embed_init_std - self.logit_softcap = logit_softcap - self.mtp_num_heads = mtp_num_heads - self.mtp_loss_weight = mtp_loss_weight - self.tok_emb = nn.Embedding(vocab_size, model_dim) - self.bigram = BigramHashEmbedding(bigram_vocab_size, bigram_dim, model_dim) if bigram_vocab_size > 0 else None - self.smear = SmearGate(model_dim) - self.num_encoder_layers = num_layers // 2 - self.num_decoder_layers = num_layers - self.num_encoder_layers - self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) - self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) - self.blocks = nn.ModuleList( - [ - Block( - model_dim, - num_heads, - num_kv_heads, - mlp_mult, - rope_base, - qk_gain_init, - ) - for i in range(num_layers) - ] - ) - self.final_norm = RMSNorm() - self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) - if self.lm_head is not None: - self.lm_head._zero_init = True - self.mtp_heads = nn.ModuleList( - [CastedLinear(model_dim, vocab_size, bias=False) for _ in range(mtp_num_heads)] - ) - for head in self.mtp_heads: - head._zero_init = True - # Enable efficient XSA on the deepest layers (highest self-attention bias) - if xsa_last_n > 0: - for i in range(max(0, num_layers - xsa_last_n), num_layers): - self.blocks[i].attn.use_xsa = True - self._init_weights() - - def _init_weights(self) -> None: - if self.tie_embeddings: - nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) - num_layers = len(self.blocks) - for name, module in self.named_modules(): - if isinstance(module, nn.Linear): - if getattr(module, "_zero_init", False): - nn.init.zeros_(module.weight) - elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: - nn.init.orthogonal_(module.weight, gain=1.0) - if ".proj." in name or name.endswith(".proj"): - with torch.no_grad(): - module.weight.mul_(1.0 / math.sqrt(2 * num_layers)) - - def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: - x = self.tok_emb(input_ids) - if self.bigram is not None: - x = x + self.bigram(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - x = self.smear(x) - x0 = x - skips: list[Tensor] = [] - - for i in range(self.num_encoder_layers): - x = self.blocks[i](x, x0) - skips.append(x) - for i in range(self.num_decoder_layers): - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - x = self.blocks[self.num_encoder_layers + i](x, x0) - - x = self.final_norm(x) - x_flat = x.reshape(-1, x.size(-1)) - targets = target_ids.reshape(-1) - if self.tie_embeddings: - logits_proj = F.linear(x_flat, self.tok_emb.weight) - else: - if self.lm_head is None: - raise RuntimeError("lm_head is required when tie_embeddings=False") - logits_proj = self.lm_head(x_flat) - logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - main_loss = F.cross_entropy(logits.float(), targets, reduction="mean") - - if self.training and self.mtp_num_heads > 0 and self.mtp_loss_weight > 0.0: - _, seqlen, dim = x.shape - mtp_loss_sum = x.new_zeros(()) - mtp_loss_count = 0 - for k, mtp_head in enumerate(self.mtp_heads): - valid_t = seqlen - (k + 1) - if valid_t <= 0: - continue - mtp_hidden = x[:, :valid_t, :].reshape(-1, dim) - mtp_targets = target_ids[:, k + 1 :].reshape(-1) - mtp_logits_proj = mtp_head(mtp_hidden) - mtp_logits = self.logit_softcap * torch.tanh(mtp_logits_proj / self.logit_softcap) - mtp_loss_sum = mtp_loss_sum + F.cross_entropy(mtp_logits.float(), mtp_targets, reduction="mean") - mtp_loss_count += 1 - if mtp_loss_count > 0: - main_loss = main_loss + self.mtp_loss_weight * (mtp_loss_sum / mtp_loss_count) - - return main_loss - - def forward_logits(self, input_ids: Tensor) -> Tensor: - """Return logits (bsz, seq_len, vocab) without computing loss.""" - x = self.tok_emb(input_ids) - if self.bigram is not None: - x = x + self.bigram(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - x = self.smear(x) - x0 = x - skips: list[Tensor] = [] - for i in range(self.num_encoder_layers): - x = self.blocks[i](x, x0) - skips.append(x) - for i in range(self.num_decoder_layers): - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - x = self.blocks[self.num_encoder_layers + i](x, x0) - x = self.final_norm(x) - if self.tie_embeddings: - logits_proj = F.linear(x, self.tok_emb.weight) - else: - logits_proj = self.lm_head(x) - return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - - -# ----------------------------- -# SLIDING WINDOW EVALUATION -# ----------------------------- - -def eval_val_sliding( - args: Hyperparameters, - base_model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, - stride: int, - batch_seqs: int = 32, - eval_seq_len: int | None = None, -) -> tuple[float, float]: - """Sliding window evaluation: each token scored with maximum context.""" - seq_len = eval_seq_len or args.train_seq_len - total_tokens = val_tokens.numel() - 1 - - window_starts = [ws for ws in range(0, total_tokens, stride) - if min(ws + seq_len, total_tokens) - ws >= 1] - total_windows = len(window_starts) - - my_s = (total_windows * rank) // world_size - my_e = (total_windows * (rank + 1)) // world_size - my_windows = window_starts[my_s:my_e] - - loss_sum = torch.zeros((), device=device, dtype=torch.float64) - token_count = torch.zeros((), device=device, dtype=torch.float64) - byte_count = torch.zeros((), device=device, dtype=torch.float64) - - base_model.eval() - compiled_logits = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) - - with torch.inference_mode(): - for bi in range(0, len(my_windows), batch_seqs): - batch_ws = my_windows[bi:bi + batch_seqs] - bsz = len(batch_ws) - - x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) - y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) - wlens: list[int] = [] - - for i, ws in enumerate(batch_ws): - end = min(ws + seq_len, total_tokens) - wlen = end - ws - wlens.append(wlen) - chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) - x_batch[i, :wlen] = chunk[:-1] - y_batch[i, :wlen] = chunk[1:] - - with torch.autocast(device_type="cuda", dtype=torch.bfloat16): - logits = compiled_logits(x_batch) - - nll = F.cross_entropy( - logits.reshape(-1, logits.size(-1)).float(), - y_batch.reshape(-1), - reduction="none", - ).reshape(bsz, seq_len) - - for i, ws in enumerate(batch_ws): - wlen = wlens[i] - s = 0 if ws == 0 else max(wlen - stride, 0) - scored_nll = nll[i, s:wlen].to(torch.float64) - loss_sum += scored_nll.sum() - token_count += float(wlen - s) - tgt = y_batch[i, s:wlen] - prev = x_batch[i, s:wlen] - tb = base_bytes_lut[tgt].to(torch.float64) - tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) - byte_count += tb.sum() - - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) - - val_loss = (loss_sum / token_count).item() - bits_per_token = val_loss / math.log(2.0) - tokens_per_byte = token_count.item() / byte_count.item() - base_model.train() - return val_loss, bits_per_token * tokens_per_byte - - -# ----------------------------- -# INT6 MIXED QUANTIZATION (transplanted from working diagnostic scripts) -# ----------------------------- - -def _classify_param(name: str) -> str: - if "tok_emb" in name or "lm_head" in name: - return "embed" - if ".mlp." in name: - return "mlp" - if ".attn." in name or (".proj." in name and ".mlp." not in name): - return "attn" - return "other" - -def quantize_int6_per_row(t: Tensor) -> tuple[Tensor, Tensor]: - t32 = t.float() - if t32.ndim == 2: - row_max = t32.abs().amax(dim=1) - scale = (row_max / 31.0).clamp_min(1.0 / 31.0).to(torch.float16) - q = torch.clamp(torch.round(t32 / scale.float()[:, None]), -32, 31).to(torch.int8) - return q, scale - amax = t32.abs().max().item() - scale = torch.tensor(amax / 31.0 if amax > 0 else 1.0, dtype=torch.float16) - q = torch.clamp(torch.round(t32 / scale.float()), -32, 31).to(torch.int8) - return q, scale - -def mixed_quantize_int6(state_dict: dict[str, Tensor], int6_cats: set[str]): - num_layers_total = max( - (int(k.split(".")[1]) for k in state_dict if k.startswith("blocks.")), - default=0, - ) + 1 - late_k_layers = set(range(num_layers_total - 2, num_layers_total)) - - result: dict[str, Tensor] = {} - meta: dict[str, object] = {} - for name, tensor in state_dict.items(): - t = tensor.detach().cpu().contiguous() - cat = _classify_param(name) - if not t.is_floating_point() or t.numel() <= 65536: - result[name] = t.to(torch.float16) if t.is_floating_point() else t - meta[name] = "passthrough" - continue - if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): - result[name] = t.float() - meta[name] = "passthrough_ctrl" - continue - # tok_emb.weight falls through to int8 via "embed" category - if cat in int6_cats and t.ndim >= 1: - q, s = quantize_int6_per_row(t) - result[name + ".q"] = q - result[name + ".scale"] = s - meta[name] = {"type": "int6"} - else: - q, s = quantize_float_tensor(t) - result[name + ".q"] = q - result[name + ".scale"] = s - meta[name] = {"type": "int8"} - return result, meta - -def dequantize_mixed_int6(result: dict[str, Tensor], meta: dict[str, object], - template_sd: dict[str, Tensor]) -> dict[str, Tensor]: - out: dict[str, Tensor] = {} - for name, orig in template_sd.items(): - info = meta.get(name) - if info is None: - continue - orig_dtype = orig.dtype - if info in ("passthrough", "passthrough_ctrl", "passthrough_fp16"): - t = result[name] - if t.dtype == torch.float16 and orig_dtype in (torch.float32, torch.bfloat16): - t = t.to(orig_dtype) - out[name] = t - continue - q, s = result[name + ".q"], result[name + ".scale"] - if s.ndim > 0: - out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig_dtype) - else: - out[name] = (q.float() * float(s.item())).to(orig_dtype) - return out - - -# ----------------------------- -# TRAINING -# ----------------------------- - -def main() -> None: - global zeropower_via_newtonschulz5 - - code = Path(__file__).read_text(encoding="utf-8") - args = Hyperparameters() - zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) - - # ----------------------------- - # DISTRIBUTED + CUDA SETUP - # ----------------------------- - - distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ - rank = int(os.environ.get("RANK", "0")) - world_size = int(os.environ.get("WORLD_SIZE", "1")) - local_rank = int(os.environ.get("LOCAL_RANK", "0")) - if world_size <= 0: - raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") - if 8 % world_size != 0: - raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") - grad_accum_steps = 8 // world_size - grad_scale = 1.0 / grad_accum_steps - if not torch.cuda.is_available(): - raise RuntimeError("CUDA is required") - device = torch.device("cuda", local_rank) - torch.cuda.set_device(device) - if distributed: - dist.init_process_group(backend="nccl", device_id=device) - dist.barrier() - master_process = rank == 0 - - # Fast math knobs - torch.backends.cuda.matmul.allow_tf32 = True - torch.backends.cudnn.allow_tf32 = True - from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp - - enable_cudnn_sdp(False) - enable_flash_sdp(True) - enable_mem_efficient_sdp(False) - enable_math_sdp(False) - - logfile = None - if master_process: - os.makedirs("logs", exist_ok=True) - logfile = f"logs/{args.run_id}.txt" - print(logfile) - - def log0(msg: str, console: bool = True) -> None: - if not master_process: - return - if console: - print(msg) - if logfile is not None: - with open(logfile, "a", encoding="utf-8") as f: - print(msg, file=f) - - log0(code, console=False) - log0("=" * 100, console=False) - log0(f"Running Python {sys.version}", console=False) - log0(f"Running PyTorch {torch.__version__}", console=False) - log0( - subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, - console=False, - ) - log0("=" * 100, console=False) - - # ----------------------------- - # TOKENIZER + VALIDATION METRIC SETUP - # ----------------------------- - - random.seed(args.seed) - np.random.seed(args.seed) - torch.manual_seed(args.seed) - torch.cuda.manual_seed_all(args.seed) - - if not args.tokenizer_path.endswith(".model"): - raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") - sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) - if int(sp.vocab_size()) != args.vocab_size: - raise ValueError( - f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" - ) - dataset_dir = Path(args.data_path).resolve() - actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) - effective_eval_seq_len = args.eval_seq_len if args.eval_seq_len > 0 else args.train_seq_len - val_seq_len = max(args.train_seq_len, effective_eval_seq_len) - val_tokens = load_validation_tokens(args.val_files, val_seq_len) - base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( - sp, args.vocab_size, device - ) - log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") - log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") - log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") - - # ----------------------------- - # MODEL + OPTIMIZER SETUP - # ----------------------------- - - CastedLinear._qat_enabled = args.qat_enabled - - base_model = GPT( - vocab_size=args.vocab_size, - num_layers=args.num_layers, - model_dim=args.model_dim, - num_heads=args.num_heads, - num_kv_heads=args.num_kv_heads, - mlp_mult=args.mlp_mult, - tie_embeddings=args.tie_embeddings, - tied_embed_init_std=args.tied_embed_init_std, - logit_softcap=args.logit_softcap, - rope_base=args.rope_base, - qk_gain_init=args.qk_gain_init, - mtp_num_heads=args.mtp_num_heads, - mtp_loss_weight=args.mtp_loss_weight, - bigram_vocab_size=args.bigram_vocab_size, - bigram_dim=args.bigram_dim, - xsa_last_n=args.xsa_last_n, - ).to(device).bfloat16() - for module in base_model.modules(): - if isinstance(module, CastedLinear): - module.float() - restore_low_dim_params_to_fp32(base_model) - compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) - model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model - - # Optimizer split: - # - token embedding (Adam) uses EMBED_LR - # - untied lm_head (Adam) uses HEAD_LR - # - matrix params in transformer blocks use MATRIX_LR via Muon - # - vectors/scalars use SCALAR_LR via Adam - block_named_params = list(base_model.blocks.named_parameters()) - matrix_params = [ - p - for name, p in block_named_params - if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - if base_model.mtp_num_heads > 0: - matrix_params.extend([p for p in base_model.mtp_heads.parameters() if p.ndim == 2]) - scalar_params = [ - p - for name, p in block_named_params - if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - if base_model.skip_weights.numel() > 0: - scalar_params.append(base_model.skip_weights) - scalar_params.append(base_model.smear.gate) - if base_model.bigram is not None: - scalar_params.append(base_model.bigram.scale) - token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr - tok_params = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] - if base_model.bigram is not None: - tok_params.append({"params": [base_model.bigram.embed.weight], "lr": token_lr, "base_lr": token_lr}) - if base_model.bigram.proj is not None: - matrix_params.append(base_model.bigram.proj.weight) - optimizer_tok = torch.optim.AdamW( - tok_params, - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - weight_decay=args.adam_wd, - fused=True, - ) - optimizer_muon = Muon( - matrix_params, - lr=args.matrix_lr, - momentum=args.muon_momentum, - backend_steps=args.muon_backend_steps, - weight_decay=args.muon_wd, - ) - for group in optimizer_muon.param_groups: - group["base_lr"] = args.matrix_lr - optimizer_scalar = torch.optim.AdamW( - [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - weight_decay=args.adam_wd, - fused=True, - ) - optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] - if base_model.lm_head is not None: - optimizer_head = torch.optim.Adam( - [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizers.insert(1, optimizer_head) - - n_params = sum(p.numel() for p in base_model.parameters()) - mtp_params = sum(p.numel() for p in base_model.mtp_heads.parameters()) - log0(f"model_params:{n_params}") - log0(f"mtp_num_heads:{args.mtp_num_heads} mtp_loss_weight:{args.mtp_loss_weight} mtp_params:{mtp_params}") - xsa_layers = [i for i, b in enumerate(base_model.blocks) if b.attn.use_xsa] - log0(f"XSA:last_{args.xsa_last_n} active_layers:{xsa_layers}") - log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") - log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") - log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") - log0( - f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " - f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " - f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" - ) - log0( - f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " - f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " - f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" - ) - log0(f"seed:{args.seed}") - - # ----------------------------- - # DATA LOADER & MODEL WARMUP - # ----------------------------- - - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - def zero_grad_all() -> None: - for opt in optimizers: - opt.zero_grad(set_to_none=True) - - max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None - - def lr_mul(step: int, elapsed_ms: float) -> float: - if args.warmdown_iters <= 0: - return 1.0 - if max_wallclock_ms is None: - warmdown_start = max(args.iterations - args.warmdown_iters, 0) - return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 - step_ms = elapsed_ms / max(step, 1) - warmdown_ms = args.warmdown_iters * step_ms - remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) - return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 - - # Warmup primes the compiled forward/backward/optimizer paths, then we restore the - # initial weights/optimizer state so measured training starts from the true init. - if args.warmup_steps > 0: - initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} - initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] - model.train() - for warmup_step in range(args.warmup_steps): - zero_grad_all() - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - warmup_loss = model(x, y) - (warmup_loss * grad_scale).backward() - for opt in optimizers: - opt.step() - zero_grad_all() - if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: - log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") - base_model.load_state_dict(initial_model_state, strict=True) - for opt, state in zip(optimizers, initial_optimizer_states, strict=True): - opt.load_state_dict(state) - zero_grad_all() - if distributed: - model.require_backward_grad_sync = True - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - # ----------------------------- - # MAIN TRAINING LOOP - # ----------------------------- - - swa_state: dict[str, Tensor] | None = None - swa_count = 0 - - training_time_ms = 0.0 - stop_after_step: int | None = None - torch.cuda.synchronize() - t0 = time.perf_counter() - - step = 0 - while True: - last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) - - should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) - if should_validate: - torch.cuda.synchronize() - training_time_ms += 1000.0 * (time.perf_counter() - t0) - val_loss, val_bpb = eval_val( - args, - model, - rank, - world_size, - device, - grad_accum_steps, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, - ) - log0( - f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " - f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" - ) - torch.cuda.synchronize() - t0 = time.perf_counter() - - if last_step: - if stop_after_step is not None and step < args.iterations: - log0( - f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " - f"step:{step}/{args.iterations}" - ) - break - - elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - scale = lr_mul(step, elapsed_ms) - zero_grad_all() - train_loss = torch.zeros((), device=device) - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - loss = model(x, y) - train_loss += loss.detach() - (loss * grad_scale).backward() - train_loss /= grad_accum_steps - - frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 - muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum - for group in optimizer_muon.param_groups: - group["momentum"] = muon_momentum - - for opt in optimizers: - for group in opt.param_groups: - group["lr"] = group["base_lr"] * scale - - if args.grad_clip_norm > 0: - torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) - for opt in optimizers: - opt.step() - zero_grad_all() - - step += 1 - approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - - if args.swa_enabled and scale < 0.5 and step % args.swa_every == 0: - if swa_state is None: - swa_state = {name: t.detach().cpu().clone() for name, t in base_model.state_dict().items()} - swa_count = 1 - log0(f"swa:start step:{step}") - else: - for name, t in base_model.state_dict().items(): - swa_state[name] += t.detach().cpu() - swa_count += 1 - - should_log_train = ( - args.train_log_every > 0 - and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) - ) - if should_log_train: - log0( - f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " - f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" - ) - - # Needed to sync whether we've reached the wallclock cap. - reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms - if distributed and max_wallclock_ms is not None: - reached_cap_tensor = torch.tensor(int(reached_cap), device=device) - dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) - reached_cap = bool(reached_cap_tensor.item()) - if stop_after_step is None and reached_cap: - stop_after_step = step - - log0( - f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " - f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" - ) - - if args.swa_enabled and swa_state is not None and swa_count > 1: - log0(f"swa:applying averaged {swa_count} checkpoints") - avg_state = {name: (t / swa_count).to(dtype=base_model.state_dict()[name].dtype) - for name, t in swa_state.items()} - base_model.load_state_dict(avg_state, strict=True) - - # ----------------------------- - # SERIALIZATION + ROUNDTRIP VALIDATION - # ----------------------------- - - full_state_dict = base_model.state_dict() - export_sd = {k: v for k, v in full_state_dict.items() if "mtp_heads" not in k} - excluded_mtp = sum(int(t.numel()) for k, t in full_state_dict.items() if "mtp_heads" in k) - if excluded_mtp > 0: - log0(f"export_excluding_mtp_params:{excluded_mtp}") - - if master_process: - torch.save(export_sd, "final_model.pt") - model_bytes = os.path.getsize("final_model.pt") - code_bytes = len(code.encode("utf-8")) - log0(f"Serialized model: {model_bytes} bytes") - log0(f"Code size: {code_bytes} bytes") - - sd_cpu = {k: v.detach().cpu() for k, v in export_sd.items()} - quant_result, quant_meta = mixed_quantize_int6(sd_cpu, {"mlp", "attn"}) - quant_buf = io.BytesIO() - torch.save({"w": quant_result, "m": quant_meta}, quant_buf) - quant_raw = quant_buf.getvalue() - quant_blob = zstandard.ZstdCompressor(level=22).compress(quant_raw) if _COMPRESSOR == "zstd" else zlib.compress(quant_raw, 9) - if master_process: - with open("final_model.int6.ptz", "wb") as f: - f.write(quant_blob) - quant_file_bytes = len(quant_blob) - code_bytes = len(code.encode("utf-8")) - log0(f"Serialized model int6+{_COMPRESSOR}: {quant_file_bytes} bytes") - log0(f"Total submission size int6+{_COMPRESSOR}: {quant_file_bytes + code_bytes} bytes") - - # Roundtrip: decompress + dequantize into fresh model + eval - if distributed: - dist.barrier() - with open("final_model.int6.ptz", "rb") as f: - quant_blob_disk = f.read() - quant_state = torch.load( - io.BytesIO(zstandard.ZstdDecompressor().decompress(quant_blob_disk) if _COMPRESSOR == "zstd" else zlib.decompress(quant_blob_disk)), - map_location="cpu", - ) - deq_state = dequantize_mixed_int6(quant_state["w"], quant_state["m"], sd_cpu) - - eval_model = GPT( - vocab_size=args.vocab_size, num_layers=args.num_layers, model_dim=args.model_dim, - num_heads=args.num_heads, num_kv_heads=args.num_kv_heads, mlp_mult=args.mlp_mult, - tie_embeddings=args.tie_embeddings, tied_embed_init_std=args.tied_embed_init_std, - logit_softcap=args.logit_softcap, rope_base=args.rope_base, qk_gain_init=args.qk_gain_init, - mtp_num_heads=0, mtp_loss_weight=0.0, - bigram_vocab_size=args.bigram_vocab_size, bigram_dim=args.bigram_dim, - xsa_last_n=args.xsa_last_n, # must match training model - ).to(device).bfloat16() - for m in eval_model.modules(): - if isinstance(m, CastedLinear): - m.float() - restore_low_dim_params_to_fp32(eval_model) - eval_model.load_state_dict(deq_state, strict=True) - compiled_eval = torch.compile(eval_model, dynamic=False, fullgraph=True) - - # Standard non-overlapping eval (sanity check) - torch.cuda.synchronize() - t_qeval = time.perf_counter() - q_val_loss, q_val_bpb = eval_val( - args, compiled_eval, rank, world_size, device, grad_accum_steps, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - eval_seq_len=effective_eval_seq_len, - ) - torch.cuda.synchronize() - log0( - f"final_int6_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " - f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" - ) - log0(f"final_int6_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") - - # Sliding window eval (submission score) - sw_seq_len = effective_eval_seq_len - if args.eval_stride > 0 and args.eval_stride < sw_seq_len: - torch.cuda.synchronize() - t_slide = time.perf_counter() - sw_val_loss, sw_val_bpb = eval_val_sliding( - args, eval_model, rank, world_size, device, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - stride=args.eval_stride, - eval_seq_len=sw_seq_len, - ) - torch.cuda.synchronize() - log0( - f"final_int6_sliding_window val_loss:{sw_val_loss:.4f} val_bpb:{sw_val_bpb:.4f} " - f"stride:{args.eval_stride} eval_time:{1000.0 * (time.perf_counter() - t_slide):.0f}ms" - ) - log0(f"final_int6_sliding_window_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") - - # Second sliding window eval at stride=64 for submission comparison - if args.eval_stride != 64 and 64 < sw_seq_len: - torch.cuda.synchronize() - t_slide64 = time.perf_counter() - sw64_val_loss, sw64_val_bpb = eval_val_sliding( - args, eval_model, rank, world_size, device, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - stride=64, - eval_seq_len=sw_seq_len, - ) - torch.cuda.synchronize() - log0( - f"final_int6_sliding_window_s64 val_loss:{sw64_val_loss:.4f} val_bpb:{sw64_val_bpb:.4f} " - f"stride:64 eval_time:{1000.0 * (time.perf_counter() - t_slide64):.0f}ms" - ) - log0(f"final_int6_sliding_window_s64_exact val_loss:{sw64_val_loss:.8f} val_bpb:{sw64_val_bpb:.8f}") - - if distributed: - dist.destroy_process_group() - - -if __name__ == "__main__": - main() - -==================================================================================================== -Running Python 3.12.3 (main, Nov 6 2025, 13:44:16) [GCC 13.3.0] -Running PyTorch 2.9.1+cu128 -Fri Mar 20 20:11:49 2026 -+-----------------------------------------------------------------------------------------+ -| NVIDIA-SMI 580.126.09 Driver Version: 580.126.09 CUDA Version: 13.0 | -+-----------------------------------------+------------------------+----------------------+ -| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | -| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | -| | | MIG M. | -|=========================================+========================+======================| -| 0 NVIDIA H100 80GB HBM3 On | 00000000:0A:00.0 Off | 0 | -| N/A 37C P0 119W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 1 NVIDIA H100 80GB HBM3 On | 00000000:18:00.0 Off | 0 | -| N/A 32C P0 120W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 2 NVIDIA H100 80GB HBM3 On | 00000000:2F:00.0 Off | 0 | -| N/A 33C P0 120W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 3 NVIDIA H100 80GB HBM3 On | 00000000:38:00.0 Off | 0 | -| N/A 56C P0 141W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 4 NVIDIA H100 80GB HBM3 On | 00000000:87:00.0 Off | 0 | -| N/A 36C P0 120W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 5 NVIDIA H100 80GB HBM3 On | 00000000:90:00.0 Off | 0 | -| N/A 32C P0 118W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 6 NVIDIA H100 80GB HBM3 On | 00000000:BE:00.0 Off | 0 | -| N/A 32C P0 118W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 7 NVIDIA H100 80GB HBM3 On | 00000000:C7:00.0 Off | 0 | -| N/A 36C P0 120W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ - -+-----------------------------------------------------------------------------------------+ -| Processes: | -| GPU GI CI PID Type Process name GPU Memory | -| ID ID Usage | -|=========================================================================================| -| 0 N/A N/A 229341 C /usr/local/bin/python 1512MiB | -| 1 N/A N/A 229342 C /usr/local/bin/python 1512MiB | -| 2 N/A N/A 229343 C /usr/local/bin/python 1512MiB | -| 3 N/A N/A 229344 C /usr/local/bin/python 1512MiB | -| 4 N/A N/A 229345 C /usr/local/bin/python 1512MiB | -| 5 N/A N/A 229346 C /usr/local/bin/python 1512MiB | -| 6 N/A N/A 229347 C /usr/local/bin/python 1512MiB | -| 7 N/A N/A 229348 C /usr/local/bin/python 1512MiB | -+-----------------------------------------------------------------------------------------+ - -==================================================================================================== -val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model -train_loader:dataset:fineweb10B_sp1024 train_shards:80 -val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 -model_params:26829913 -mtp_num_heads:0 mtp_loss_weight:0.2 mtp_params:0 -XSA:last_3 active_layers:[8, 9, 10] -world_size:8 grad_accum_steps:1 -sdp_backends:cudnn=False flash=True mem_efficient=False math=False -attention_mode:gqa num_heads:8 num_kv_heads:4 -tie_embeddings:True embed_lr:0.035 head_lr:0.0 matrix_lr:0.025 scalar_lr:0.025 -train_batch_tokens:786432 train_seq_len:2048 iterations:9000 warmup_steps:30 max_wallclock_seconds:600.000 -seed:1337 -warmup_step:10/30 -warmup_step:20/30 -warmup_step:30/30 -step:0/9000 val_loss:6.9303 val_bpb:4.1045 train_time:0ms step_avg:0.03ms -step:1/9000 train_loss:6.9326 train_time:126ms step_avg:126.40ms -step:2/9000 train_loss:8.5996 train_time:192ms step_avg:95.78ms -step:3/9000 train_loss:7.9279 train_time:272ms step_avg:90.65ms -step:4/9000 train_loss:7.2203 train_time:353ms step_avg:88.15ms -step:5/9000 train_loss:6.9910 train_time:433ms step_avg:86.65ms -step:6/9000 train_loss:6.8674 train_time:514ms step_avg:85.67ms -step:7/9000 train_loss:6.8833 train_time:595ms step_avg:85.06ms -step:8/9000 train_loss:6.8975 train_time:677ms step_avg:84.57ms -step:9/9000 train_loss:6.4748 train_time:757ms step_avg:84.15ms -step:10/9000 train_loss:6.1165 train_time:839ms step_avg:83.90ms -step:200/9000 train_loss:2.4053 train_time:16761ms step_avg:83.81ms -step:400/9000 train_loss:2.4085 train_time:33652ms step_avg:84.13ms -step:600/9000 train_loss:2.3292 train_time:50616ms step_avg:84.36ms -step:800/9000 train_loss:2.2284 train_time:67709ms step_avg:84.64ms -step:1000/9000 train_loss:2.2674 train_time:84768ms step_avg:84.77ms -step:1200/9000 train_loss:2.3457 train_time:101972ms step_avg:84.98ms -step:1400/9000 train_loss:2.1750 train_time:119163ms step_avg:85.12ms -step:1600/9000 train_loss:2.0696 train_time:136286ms step_avg:85.18ms -step:1800/9000 train_loss:2.1577 train_time:153518ms step_avg:85.29ms -step:2000/9000 train_loss:2.0630 train_time:170649ms step_avg:85.32ms -step:2000/9000 val_loss:2.1268 val_bpb:1.2596 train_time:170666ms step_avg:85.33ms -step:2200/9000 train_loss:2.1336 train_time:187829ms step_avg:85.38ms -step:2400/9000 train_loss:2.0606 train_time:205001ms step_avg:85.42ms -step:2600/9000 train_loss:2.1047 train_time:222237ms step_avg:85.48ms -step:2800/9000 train_loss:2.1514 train_time:239532ms step_avg:85.55ms -step:3000/9000 train_loss:2.1596 train_time:256746ms step_avg:85.58ms -step:3200/9000 train_loss:2.1729 train_time:273986ms step_avg:85.62ms -step:3400/9000 train_loss:2.0217 train_time:291191ms step_avg:85.64ms -step:3600/9000 train_loss:2.0941 train_time:308459ms step_avg:85.68ms -step:3800/9000 train_loss:2.0762 train_time:325650ms step_avg:85.70ms -step:4000/9000 train_loss:1.9846 train_time:342912ms step_avg:85.73ms -step:4000/9000 val_loss:2.0752 val_bpb:1.2291 train_time:342932ms step_avg:85.73ms -step:4200/9000 train_loss:2.1623 train_time:360201ms step_avg:85.76ms -step:4400/9000 train_loss:2.0410 train_time:377440ms step_avg:85.78ms -step:4600/9000 train_loss:1.8509 train_time:394678ms step_avg:85.80ms -step:4800/9000 train_loss:2.4358 train_time:411907ms step_avg:85.81ms -step:5000/9000 train_loss:2.1116 train_time:429219ms step_avg:85.84ms -step:5200/9000 train_loss:2.0480 train_time:446432ms step_avg:85.85ms -step:5400/9000 train_loss:2.0537 train_time:463738ms step_avg:85.88ms -swa:start step:5520 -step:5600/9000 train_loss:1.9593 train_time:481071ms step_avg:85.91ms -step:5800/9000 train_loss:2.0039 train_time:498364ms step_avg:85.92ms -step:6000/9000 train_loss:1.9465 train_time:515670ms step_avg:85.95ms -step:6000/9000 val_loss:1.9851 val_bpb:1.1757 train_time:515715ms step_avg:85.95ms -step:6200/9000 train_loss:1.9572 train_time:532916ms step_avg:85.95ms -step:6400/9000 train_loss:2.0003 train_time:550278ms step_avg:85.98ms -step:6600/9000 train_loss:1.8465 train_time:567513ms step_avg:85.99ms -step:6800/9000 train_loss:2.0251 train_time:584832ms step_avg:86.00ms -step:6976/9000 val_loss:1.9311 val_bpb:1.1437 train_time:600040ms step_avg:86.01ms -stopping_early: wallclock_cap train_time:600040ms step:6976/9000 -peak memory allocated: 20300 MiB reserved: 20412 MiB -swa:applying averaged 13 checkpoints -Serialized model: 105783807 bytes -Code size: 66055 bytes -Serialized model int6+zstd: 15826931 bytes -Total submission size int6+zstd: 15892986 bytes -final_int6_roundtrip val_loss:1.9466 val_bpb:1.1529 eval_time:29134ms -final_int6_roundtrip_exact val_loss:1.94662702 val_bpb:1.15290217 -final_int6_sliding_window val_loss:1.9092 val_bpb:1.1307 stride:64 eval_time:85703ms -final_int6_sliding_window_exact val_loss:1.90915845 val_bpb:1.13071416 diff --git a/records/track_10min_16mb/2026-03-20_11L_EfficientPartialXSA_FA3_SWA120/train_gpt.py b/records/track_10min_16mb/2026-03-20_11L_EfficientPartialXSA_FA3_SWA120/train_gpt.py deleted file mode 100644 index 2ab6785add..0000000000 --- a/records/track_10min_16mb/2026-03-20_11L_EfficientPartialXSA_FA3_SWA120/train_gpt.py +++ /dev/null @@ -1,1545 +0,0 @@ -""" -train_gpt_submit.py — Submission v2: wider MLP + STE int6 QAT + MTP + seq2048 + NTK RoPE + -fp16 embed + late-K passthrough + sliding window eval. -""" - -from __future__ import annotations - -import copy -import glob -import io -import math -import os -import random -import subprocess -import sys -import time -import uuid -import zlib -from pathlib import Path - -try: - import zstandard - _COMPRESSOR = "zstd" -except ImportError: - _COMPRESSOR = "zlib" - -import numpy as np -import sentencepiece as spm -import torch -import torch.distributed as dist -import torch.nn.functional as F -from torch import Tensor, nn -from torch.nn.parallel import DistributedDataParallel as DDP - -from flash_attn_interface import flash_attn_func as flash_attn_3_func - -# ----------------------------- -# HYPERPARAMETERS -# ----------------------------- -# Default Simple Baseline run: -# - 9 transformer blocks at width 512 -# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion -# - vocab size 1024, sequence length 1024, tied embeddings -# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap - -class Hyperparameters: - # Data paths are shard globs produced by the existing preprocessing pipeline. - data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") - train_files = os.path.join(data_path, "fineweb_train_*.bin") - val_files = os.path.join(data_path, "fineweb_val_*.bin") - tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") - run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) - seed = int(os.environ.get("SEED", 1337)) - - # Validation cadence and batch size. Validation always uses the full fineweb_val split. - val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) - val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) - train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) - - # Training length. - iterations = int(os.environ.get("ITERATIONS", 20000)) - warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) - warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) - train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 786_432)) - train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) - eval_seq_len = int(os.environ.get("EVAL_SEQ_LEN", 2048)) - max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) - qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) - - # Model shape. - vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) - num_layers = int(os.environ.get("NUM_LAYERS", 9)) - num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) - model_dim = int(os.environ.get("MODEL_DIM", 512)) - num_heads = int(os.environ.get("NUM_HEADS", 8)) - mlp_mult = float(os.environ.get("MLP_MULT", 3.0)) - tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) - rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) - logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) - - # Optimizer hyperparameters. - embed_lr = float(os.environ.get("EMBED_LR", 0.6)) - head_lr = float(os.environ.get("HEAD_LR", 0.008)) - tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) - tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) - matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) - scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) - muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) - muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) - muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) - muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) - beta1 = float(os.environ.get("BETA1", 0.9)) - beta2 = float(os.environ.get("BETA2", 0.95)) - adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) - grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) - eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) - mtp_num_heads = int(os.environ.get("MTP_NUM_HEADS", 0)) - mtp_loss_weight = float(os.environ.get("MTP_LOSS_WEIGHT", 0.2)) - muon_beta2 = float(os.environ.get("MUON_BETA2", 0.95)) - swa_enabled = bool(int(os.environ.get("SWA_ENABLED", "1"))) - swa_every = int(os.environ.get("SWA_EVERY", 200)) - muon_wd = float(os.environ.get("MUON_WD", 0.02)) - adam_wd = float(os.environ.get("ADAM_WD", 0.01)) - qat_enabled = bool(int(os.environ.get("QAT_ENABLED", "0"))) - bigram_vocab_size = int(os.environ.get("BIGRAM_VOCAB_SIZE", 4096)) - bigram_dim = int(os.environ.get("BIGRAM_DIM", 128)) - - # Efficient partial XSA: apply to last N layers only (deep layers have highest self-attention bias) - xsa_last_n = int(os.environ.get("XSA_LAST_N", 4)) # XSA on last 4 layers (0 = disabled) - -# ----------------------------- -# MUON OPTIMIZER -# ----------------------------- -# -# As borrowed from modded-nanogpt -# Background on Muon: https://kellerjordan.github.io/posts/muon/ - -def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: - a, b, c = (3.4445, -4.7750, 2.0315) - X = G.bfloat16() - X /= X.norm() + eps - transposed = G.size(0) > G.size(1) - if transposed: - X = X.T - for _ in range(steps): - A = X @ X.T - B = b * A + c * A @ A - X = a * X + B @ X - return X.T if transposed else X - - -class Muon(torch.optim.Optimizer): - def __init__(self, params, lr: float, momentum: float, backend_steps: int, - nesterov: bool = True, weight_decay: float = 0.0): - super().__init__( - params, - dict(lr=lr, momentum=momentum, backend_steps=backend_steps, - nesterov=nesterov, weight_decay=weight_decay), - ) - - @torch.no_grad() - def step(self, closure=None): - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - distributed = dist.is_available() and dist.is_initialized() - world_size = dist.get_world_size() if distributed else 1 - rank = dist.get_rank() if distributed else 0 - - for group in self.param_groups: - params = group["params"] - if not params: - continue - lr = group["lr"] - momentum = group["momentum"] - backend_steps = group["backend_steps"] - nesterov = group["nesterov"] - - total_params = sum(int(p.numel()) for p in params) - updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) - - curr = 0 - for i, p in enumerate(params): - if i % world_size == rank and p.grad is not None: - g = p.grad - state = self.state[p] - if "momentum_buffer" not in state: - state["momentum_buffer"] = torch.zeros_like(g) - buf = state["momentum_buffer"] - buf.mul_(momentum).add_(g) - if nesterov: - g = g.add(buf, alpha=momentum) - g = zeropower_via_newtonschulz5(g, steps=backend_steps) - g *= max(1, g.size(0) / g.size(1)) ** 0.5 - updates_flat[curr : curr + p.numel()] = g.reshape(-1) - curr += p.numel() - - if distributed: - dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) - - wd = group.get("weight_decay", 0.0) - curr = 0 - for p in params: - if wd > 0.0: - p.data.mul_(1.0 - lr * wd) - g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) - p.add_(g, alpha=-lr) - curr += p.numel() - - return loss - - -# ----------------------------- -# TOKENIZER-AGNOSTIC EVALUATION SETUP -# ----------------------------- -# -# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. -# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. -# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. -# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. - -def build_sentencepiece_luts( - sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device -) -> tuple[Tensor, Tensor, Tensor]: - sp_vocab_size = int(sp.vocab_size()) - table_size = max(sp_vocab_size, vocab_size) - base_bytes_np = np.zeros((table_size,), dtype=np.int16) - has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) - is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) - for token_id in range(sp_vocab_size): - if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): - continue - is_boundary_token_np[token_id] = False - if sp.is_byte(token_id): - base_bytes_np[token_id] = 1 - continue - piece = sp.id_to_piece(token_id) - if piece.startswith("▁"): - has_leading_space_np[token_id] = True - piece = piece[1:] - base_bytes_np[token_id] = len(piece.encode("utf-8")) - return ( - torch.tensor(base_bytes_np, dtype=torch.int16, device=device), - torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), - torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), - ) - - -def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: - files = [Path(p) for p in sorted(glob.glob(pattern))] - if not files: - raise FileNotFoundError(f"No files found for pattern: {pattern}") - # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. - tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() - usable = ((tokens.numel() - 1) // seq_len) * seq_len - if usable <= 0: - raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") - return tokens[: usable + 1] - - -def eval_val( - args: Hyperparameters, - model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - grad_accum_steps: int, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, - eval_seq_len: int | None = None, -) -> tuple[float, float]: - seq_len = eval_seq_len or args.train_seq_len - local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) - if local_batch_tokens < seq_len: - raise ValueError( - "VAL_BATCH_SIZE must provide at least one sequence per rank; " - f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " - f"GRAD_ACCUM_STEPS={grad_accum_steps}, seq_len={seq_len}" - ) - local_batch_seqs = local_batch_tokens // seq_len - total_seqs = (val_tokens.numel() - 1) // seq_len - seq_start = (total_seqs * rank) // world_size - seq_end = (total_seqs * (rank + 1)) // world_size - val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) - val_token_count = torch.zeros((), device=device, dtype=torch.float64) - val_byte_count = torch.zeros((), device=device, dtype=torch.float64) - - model.eval() - with torch.inference_mode(): - for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): - batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) - raw_start = batch_seq_start * seq_len - raw_end = batch_seq_end * seq_len + 1 - local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) - x = local[:-1].reshape(-1, seq_len) - y = local[1:].reshape(-1, seq_len) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - batch_loss = model(x, y).detach() - batch_token_count = float(y.numel()) - val_loss_sum += batch_loss.to(torch.float64) * batch_token_count - val_token_count += batch_token_count - prev_ids = x.reshape(-1) - tgt_ids = y.reshape(-1) - token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) - token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) - val_byte_count += token_bytes.to(torch.float64).sum() - - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) - - val_loss = val_loss_sum / val_token_count - bits_per_token = val_loss.item() / math.log(2.0) - tokens_per_byte = val_token_count.item() / val_byte_count.item() - model.train() - return float(val_loss.item()), float(bits_per_token * tokens_per_byte) - -# ----------------------------- -# POST-TRAINING QUANTIZATION -# ----------------------------- -# -# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. -# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. -# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. - -CONTROL_TENSOR_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "CONTROL_TENSOR_NAME_PATTERNS", - "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,smear", - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", - ",".join(CONTROL_TENSOR_NAME_PATTERNS), - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 -INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 -INT8_PER_ROW_SCALE_DTYPE = torch.float16 -INT8_CLIP_PERCENTILE = 99.99984 -INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 - -def tensor_nbytes(t: Tensor) -> int: - return int(t.numel()) * int(t.element_size()) - -def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: - if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): - return t.float().contiguous() - if t.dtype in {torch.float32, torch.bfloat16}: - passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") - return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() - return t - -def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: - t32 = t.float() - if t32.ndim == 2: - # Matrices get one scale per row, which usually tracks output-channel - # ranges much better than a single tensor-wide scale. - clip_abs = ( - torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) - if t32.numel() - else torch.empty((t32.shape[0],), dtype=torch.float32) - ) - clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) - scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) - q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() - return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() - - # Vectors / scalars use a simpler per-tensor scale. - clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 - scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) - q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() - return q, scale - -def quantize_state_dict_int8(state_dict: dict[str, Tensor]): - # Single supported clean-script export format: - # - per-row int8 for 2D float tensors - # - per-tensor int8 for other float tensors - # - exact passthrough for non-floats - # - passthrough for small float tensors, stored as fp16 to save bytes - quantized: dict[str, Tensor] = {} - scales: dict[str, Tensor] = {} - dtypes: dict[str, str] = {} - passthrough: dict[str, Tensor] = {} - passthrough_orig_dtypes: dict[str, str] = {} - qmeta: dict[str, dict[str, object]] = {} - stats = dict.fromkeys( - ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), - 0, - ) - - for name, tensor in state_dict.items(): - t = tensor.detach().to("cpu").contiguous() - stats["param_count"] += int(t.numel()) - stats["num_tensors"] += 1 - stats["baseline_tensor_bytes"] += tensor_nbytes(t) - - if not t.is_floating_point(): - stats["num_nonfloat_tensors"] += 1 - passthrough[name] = t - stats["int8_payload_bytes"] += tensor_nbytes(t) - continue - - # Small float tensors are cheap enough to keep directly. We still downcast - # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. - if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: - kept = keep_float_tensor(name, t, passthrough_orig_dtypes) - passthrough[name] = kept - stats["int8_payload_bytes"] += tensor_nbytes(kept) - continue - - stats["num_float_tensors"] += 1 - q, s = quantize_float_tensor(t) - if s.ndim > 0: - qmeta[name] = {"scheme": "per_row", "axis": 0} - quantized[name] = q - scales[name] = s - dtypes[name] = str(t.dtype).removeprefix("torch.") - stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) - - obj: dict[str, object] = { - "__quant_format__": "int8_clean_per_row_v1", - "quantized": quantized, - "scales": scales, - "dtypes": dtypes, - "passthrough": passthrough, - } - if qmeta: - obj["qmeta"] = qmeta - if passthrough_orig_dtypes: - obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes - return obj, stats - -def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: - out: dict[str, Tensor] = {} - qmeta = obj.get("qmeta", {}) - passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) - for name, q in obj["quantized"].items(): - dtype = getattr(torch, obj["dtypes"][name]) - s = obj["scales"][name] - if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: - s = s.to(dtype=torch.float32) - # Broadcast the saved row scale back across trailing dimensions. - out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() - else: - scale = float(s.item()) - out[name] = (q.float() * scale).to(dtype=dtype).contiguous() - for name, t in obj["passthrough"].items(): - # Restore small tensors, undoing the temporary fp16 storage cast if needed. - out_t = t.detach().to("cpu").contiguous() - orig_dtype = passthrough_orig_dtypes.get(name) - if isinstance(orig_dtype, str): - out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() - out[name] = out_t - return out - - -# ----------------------------- -# DATA LOADING -# ----------------------------- - -def load_data_shard(file: Path) -> Tensor: - header_bytes = 256 * np.dtype(" None: - self.file_idx = (self.file_idx + 1) % len(self.files) - self.tokens = load_data_shard(self.files[self.file_idx]) - self.pos = 0 - - def take(self, n: int) -> Tensor: - chunks: list[Tensor] = [] - remaining = n - while remaining > 0: - avail = self.tokens.numel() - self.pos - if avail <= 0: - self._advance_file() - continue - k = min(remaining, avail) - chunks.append(self.tokens[self.pos : self.pos + k]) - self.pos += k - remaining -= k - return chunks[0] if len(chunks) == 1 else torch.cat(chunks) - - -class DistributedTokenLoader: - # Each call consumes a contiguous chunk from the shared token stream, then slices out - # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. - def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): - self.rank = rank - self.world_size = world_size - self.device = device - self.stream = TokenStream(pattern) - - def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: - local_tokens = global_tokens // (self.world_size * grad_accum_steps) - per_rank_span = local_tokens + 1 - chunk = self.stream.take(per_rank_span * self.world_size) - start = self.rank * per_rank_span - local = chunk[start : start + per_rank_span].to(dtype=torch.int64) - x = local[:-1].reshape(-1, seq_len) - y = local[1:].reshape(-1, seq_len) - return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) - -# ----------------------------- -# TRANSFORMER MODULES -# ----------------------------- - -class RMSNorm(nn.Module): - def __init__(self, eps: float | None = None): - super().__init__() - self.eps = eps - - def forward(self, x: Tensor) -> Tensor: - return F.rms_norm(x, (x.size(-1),), eps=self.eps) - - -class CastedLinear(nn.Linear): - _qat_enabled: bool = False - - def forward(self, x: Tensor) -> Tensor: - w = self.weight.to(x.dtype) - if CastedLinear._qat_enabled and self.training and w.ndim == 2: - with torch.no_grad(): - w32 = self.weight.float() - row_max = w32.abs().amax(dim=1) - scale = (row_max / 31.0).clamp_min(1.0 / 31.0) - w_q = (torch.clamp(torch.round(w32 / scale[:, None]), -32, 31) * scale[:, None]).to(x.dtype) - w = w + (w_q - w).detach() - bias = self.bias.to(x.dtype) if self.bias is not None else None - return F.linear(x, w, bias) - - -def restore_low_dim_params_to_fp32(module: nn.Module) -> None: - # Keep small/control parameters in fp32 even when the model body runs in bf16. - with torch.no_grad(): - for name, param in module.named_parameters(): - if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: - param.data = param.data.float() - - -class Rotary(nn.Module): - # NTK-aware RoPE: auto-scales base frequency when seq_len exceeds train_seq_len. - def __init__(self, dim: int, base: float = 10000.0, train_seq_len: int = 1024): - super().__init__() - self.dim = dim - self.base = base - self.train_seq_len = train_seq_len - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self._seq_len_cached = 0 - self._cos_cached: Tensor | None = None - self._sin_cached: Tensor | None = None - - def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: - if ( - self._cos_cached is None - or self._sin_cached is None - or self._seq_len_cached != seq_len - or self._cos_cached.device != device - ): - if seq_len > self.train_seq_len: - scale = seq_len / self.train_seq_len - new_base = self.base * (scale ** (self.dim / (self.dim - 2))) - inv_freq = 1.0 / (new_base ** (torch.arange(0, self.dim, 2, dtype=torch.float32, device=device) / self.dim)) - else: - inv_freq = self.inv_freq.to(device) - t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) - freqs = torch.outer(t, inv_freq) - self._cos_cached = freqs.cos()[None, :, None, :] - self._sin_cached = freqs.sin()[None, :, None, :] - self._seq_len_cached = seq_len - return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) - - -def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: - half = x.size(-1) // 2 - x1, x2 = x[..., :half], x[..., half:] - return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) - - -class CausalSelfAttention(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - rope_base: float, - qk_gain_init: float, - ): - super().__init__() - if dim % num_heads != 0: - raise ValueError("model_dim must be divisible by num_heads") - if num_heads % num_kv_heads != 0: - raise ValueError("num_heads must be divisible by num_kv_heads") - self.num_heads = num_heads - self.num_kv_heads = num_kv_heads - self.head_dim = dim // num_heads - if self.head_dim % 2 != 0: - raise ValueError("head_dim must be even for RoPE") - kv_dim = self.num_kv_heads * self.head_dim - self.c_q = CastedLinear(dim, dim, bias=False) - self.c_k = CastedLinear(dim, kv_dim, bias=False) - self.c_v = CastedLinear(dim, kv_dim, bias=False) - self.proj = CastedLinear(dim, dim, bias=False) - self.proj._zero_init = True - self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) - self.rotary = Rotary(self.head_dim, base=rope_base, train_seq_len=1024) - self.use_xsa = False # set by GPT.__init__ for deep layers only - - def _xsa_efficient(self, y: Tensor, v: Tensor) -> Tensor: - """Efficient XSA: subtract self-value projection via GQA-aware reshape (no repeat_interleave). - y: [B, T, H, D], v: [B, T, Hkv, D]. H must be divisible by Hkv.""" - B, T, H, D = y.shape - Hkv = v.size(-2) - group = H // Hkv - # Reshape y into KV head groups — free view, no memory alloc - y_g = y.reshape(B, T, Hkv, group, D) # [B, T, Hkv, group, D] - vn = F.normalize(v, dim=-1).unsqueeze(-2) # [B, T, Hkv, 1, D] — broadcast ready - # Project out self-value component per KV head group - proj = (y_g * vn).sum(dim=-1, keepdim=True) * vn - return (y_g - proj).reshape(B, T, H, D) - - def forward(self, x: Tensor) -> Tensor: - bsz, seqlen, dim = x.shape - q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim) - k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) - v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) - q = F.rms_norm(q, (q.size(-1),)) - k = F.rms_norm(k, (k.size(-1),)) - cos, sin = self.rotary(seqlen, x.device, q.dtype) - q = apply_rotary_emb(q, cos, sin) - k = apply_rotary_emb(k, cos, sin) - q = q * self.q_gain.to(dtype=q.dtype)[None, None, :, None] - y = flash_attn_3_func(q, k, v, causal=True) - # XSA: subtract self-value projection (deep layers only) - if self.use_xsa: - y = self._xsa_efficient(y, v) - y = y.reshape(bsz, seqlen, dim) - return self.proj(y) - - -class SmearGate(nn.Module): - def __init__(self, dim: int): - super().__init__() - self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) - - def forward(self, x: Tensor) -> Tensor: - g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] - x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) - return (1 - g) * x + g * x_prev - - -class BigramHashEmbedding(nn.Module): - def __init__(self, bigram_vocab_size: int, bigram_dim: int, model_dim: int): - super().__init__() - self.bigram_vocab_size = bigram_vocab_size - self.embed = nn.Embedding(bigram_vocab_size, bigram_dim) - nn.init.zeros_(self.embed.weight) - self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None - if self.proj is not None: - nn.init.zeros_(self.proj.weight) - self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) - - def bigram_hash(self, tokens: Tensor) -> Tensor: - t = tokens.to(torch.int32) - mod = self.bigram_vocab_size - 1 - out = torch.empty_like(t) - out[..., 0] = mod - out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod - return out.long() - - def forward(self, token_ids: Tensor) -> Tensor: - h = self.embed(self.bigram_hash(token_ids)) - if self.proj is not None: - h = self.proj(h) - return h * self.scale.to(dtype=h.dtype) - - -class MLP(nn.Module): - def __init__(self, dim: int, mlp_mult: int): - super().__init__() - hidden = int(mlp_mult * dim) - self.fc = CastedLinear(dim, hidden, bias=False) - self.proj = CastedLinear(hidden, dim, bias=False) - self.proj._zero_init = True - - def forward(self, x: Tensor) -> Tensor: - x = torch.relu(self.fc(x)) - return self.proj(x.square()) - - -class Block(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: int, - rope_base: float, - qk_gain_init: float, - ): - super().__init__() - self.attn_norm = RMSNorm() - self.mlp_norm = RMSNorm() - self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) - self.mlp = MLP(dim, mlp_mult) - self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) - - def forward(self, x: Tensor, x0: Tensor) -> Tensor: - mix = self.resid_mix.to(dtype=x.dtype) - x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 - attn_out = self.attn(self.attn_norm(x)) - x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out - x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) - return x - - -class GPT(nn.Module): - def __init__( - self, - vocab_size: int, - num_layers: int, - model_dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: int, - tie_embeddings: bool, - tied_embed_init_std: float, - logit_softcap: float, - rope_base: float, - qk_gain_init: float, - mtp_num_heads: int = 0, - mtp_loss_weight: float = 0.1, - bigram_vocab_size: int = 0, - bigram_dim: int = 128, - xsa_last_n: int = 0, - ): - super().__init__() - if logit_softcap <= 0.0: - raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") - self.tie_embeddings = tie_embeddings - self.tied_embed_init_std = tied_embed_init_std - self.logit_softcap = logit_softcap - self.mtp_num_heads = mtp_num_heads - self.mtp_loss_weight = mtp_loss_weight - self.tok_emb = nn.Embedding(vocab_size, model_dim) - self.bigram = BigramHashEmbedding(bigram_vocab_size, bigram_dim, model_dim) if bigram_vocab_size > 0 else None - self.smear = SmearGate(model_dim) - self.num_encoder_layers = num_layers // 2 - self.num_decoder_layers = num_layers - self.num_encoder_layers - self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) - self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) - self.blocks = nn.ModuleList( - [ - Block( - model_dim, - num_heads, - num_kv_heads, - mlp_mult, - rope_base, - qk_gain_init, - ) - for i in range(num_layers) - ] - ) - self.final_norm = RMSNorm() - self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) - if self.lm_head is not None: - self.lm_head._zero_init = True - self.mtp_heads = nn.ModuleList( - [CastedLinear(model_dim, vocab_size, bias=False) for _ in range(mtp_num_heads)] - ) - for head in self.mtp_heads: - head._zero_init = True - # Enable efficient XSA on the deepest layers (highest self-attention bias) - if xsa_last_n > 0: - for i in range(max(0, num_layers - xsa_last_n), num_layers): - self.blocks[i].attn.use_xsa = True - self._init_weights() - - def _init_weights(self) -> None: - if self.tie_embeddings: - nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) - num_layers = len(self.blocks) - for name, module in self.named_modules(): - if isinstance(module, nn.Linear): - if getattr(module, "_zero_init", False): - nn.init.zeros_(module.weight) - elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: - nn.init.orthogonal_(module.weight, gain=1.0) - if ".proj." in name or name.endswith(".proj"): - with torch.no_grad(): - module.weight.mul_(1.0 / math.sqrt(2 * num_layers)) - - def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: - x = self.tok_emb(input_ids) - if self.bigram is not None: - x = x + self.bigram(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - x = self.smear(x) - x0 = x - skips: list[Tensor] = [] - - for i in range(self.num_encoder_layers): - x = self.blocks[i](x, x0) - skips.append(x) - for i in range(self.num_decoder_layers): - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - x = self.blocks[self.num_encoder_layers + i](x, x0) - - x = self.final_norm(x) - x_flat = x.reshape(-1, x.size(-1)) - targets = target_ids.reshape(-1) - if self.tie_embeddings: - logits_proj = F.linear(x_flat, self.tok_emb.weight) - else: - if self.lm_head is None: - raise RuntimeError("lm_head is required when tie_embeddings=False") - logits_proj = self.lm_head(x_flat) - logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - main_loss = F.cross_entropy(logits.float(), targets, reduction="mean") - - if self.training and self.mtp_num_heads > 0 and self.mtp_loss_weight > 0.0: - _, seqlen, dim = x.shape - mtp_loss_sum = x.new_zeros(()) - mtp_loss_count = 0 - for k, mtp_head in enumerate(self.mtp_heads): - valid_t = seqlen - (k + 1) - if valid_t <= 0: - continue - mtp_hidden = x[:, :valid_t, :].reshape(-1, dim) - mtp_targets = target_ids[:, k + 1 :].reshape(-1) - mtp_logits_proj = mtp_head(mtp_hidden) - mtp_logits = self.logit_softcap * torch.tanh(mtp_logits_proj / self.logit_softcap) - mtp_loss_sum = mtp_loss_sum + F.cross_entropy(mtp_logits.float(), mtp_targets, reduction="mean") - mtp_loss_count += 1 - if mtp_loss_count > 0: - main_loss = main_loss + self.mtp_loss_weight * (mtp_loss_sum / mtp_loss_count) - - return main_loss - - def forward_logits(self, input_ids: Tensor) -> Tensor: - """Return logits (bsz, seq_len, vocab) without computing loss.""" - x = self.tok_emb(input_ids) - if self.bigram is not None: - x = x + self.bigram(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - x = self.smear(x) - x0 = x - skips: list[Tensor] = [] - for i in range(self.num_encoder_layers): - x = self.blocks[i](x, x0) - skips.append(x) - for i in range(self.num_decoder_layers): - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - x = self.blocks[self.num_encoder_layers + i](x, x0) - x = self.final_norm(x) - if self.tie_embeddings: - logits_proj = F.linear(x, self.tok_emb.weight) - else: - logits_proj = self.lm_head(x) - return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - - -# ----------------------------- -# SLIDING WINDOW EVALUATION -# ----------------------------- - -def eval_val_sliding( - args: Hyperparameters, - base_model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, - stride: int, - batch_seqs: int = 32, - eval_seq_len: int | None = None, -) -> tuple[float, float]: - """Sliding window evaluation: each token scored with maximum context.""" - seq_len = eval_seq_len or args.train_seq_len - total_tokens = val_tokens.numel() - 1 - - window_starts = [ws for ws in range(0, total_tokens, stride) - if min(ws + seq_len, total_tokens) - ws >= 1] - total_windows = len(window_starts) - - my_s = (total_windows * rank) // world_size - my_e = (total_windows * (rank + 1)) // world_size - my_windows = window_starts[my_s:my_e] - - loss_sum = torch.zeros((), device=device, dtype=torch.float64) - token_count = torch.zeros((), device=device, dtype=torch.float64) - byte_count = torch.zeros((), device=device, dtype=torch.float64) - - base_model.eval() - compiled_logits = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) - - with torch.inference_mode(): - for bi in range(0, len(my_windows), batch_seqs): - batch_ws = my_windows[bi:bi + batch_seqs] - bsz = len(batch_ws) - - x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) - y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) - wlens: list[int] = [] - - for i, ws in enumerate(batch_ws): - end = min(ws + seq_len, total_tokens) - wlen = end - ws - wlens.append(wlen) - chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) - x_batch[i, :wlen] = chunk[:-1] - y_batch[i, :wlen] = chunk[1:] - - with torch.autocast(device_type="cuda", dtype=torch.bfloat16): - logits = compiled_logits(x_batch) - - nll = F.cross_entropy( - logits.reshape(-1, logits.size(-1)).float(), - y_batch.reshape(-1), - reduction="none", - ).reshape(bsz, seq_len) - - for i, ws in enumerate(batch_ws): - wlen = wlens[i] - s = 0 if ws == 0 else max(wlen - stride, 0) - scored_nll = nll[i, s:wlen].to(torch.float64) - loss_sum += scored_nll.sum() - token_count += float(wlen - s) - tgt = y_batch[i, s:wlen] - prev = x_batch[i, s:wlen] - tb = base_bytes_lut[tgt].to(torch.float64) - tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) - byte_count += tb.sum() - - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) - - val_loss = (loss_sum / token_count).item() - bits_per_token = val_loss / math.log(2.0) - tokens_per_byte = token_count.item() / byte_count.item() - base_model.train() - return val_loss, bits_per_token * tokens_per_byte - - -# ----------------------------- -# INT6 MIXED QUANTIZATION (transplanted from working diagnostic scripts) -# ----------------------------- - -def _classify_param(name: str) -> str: - if "tok_emb" in name or "lm_head" in name: - return "embed" - if ".mlp." in name: - return "mlp" - if ".attn." in name or (".proj." in name and ".mlp." not in name): - return "attn" - return "other" - -def quantize_int6_per_row(t: Tensor) -> tuple[Tensor, Tensor]: - t32 = t.float() - if t32.ndim == 2: - row_max = t32.abs().amax(dim=1) - scale = (row_max / 31.0).clamp_min(1.0 / 31.0).to(torch.float16) - q = torch.clamp(torch.round(t32 / scale.float()[:, None]), -32, 31).to(torch.int8) - return q, scale - amax = t32.abs().max().item() - scale = torch.tensor(amax / 31.0 if amax > 0 else 1.0, dtype=torch.float16) - q = torch.clamp(torch.round(t32 / scale.float()), -32, 31).to(torch.int8) - return q, scale - -def mixed_quantize_int6(state_dict: dict[str, Tensor], int6_cats: set[str]): - num_layers_total = max( - (int(k.split(".")[1]) for k in state_dict if k.startswith("blocks.")), - default=0, - ) + 1 - late_k_layers = set(range(num_layers_total - 2, num_layers_total)) - - result: dict[str, Tensor] = {} - meta: dict[str, object] = {} - for name, tensor in state_dict.items(): - t = tensor.detach().cpu().contiguous() - cat = _classify_param(name) - if not t.is_floating_point() or t.numel() <= 65536: - result[name] = t.to(torch.float16) if t.is_floating_point() else t - meta[name] = "passthrough" - continue - if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): - result[name] = t.float() - meta[name] = "passthrough_ctrl" - continue - # tok_emb.weight falls through to int8 via "embed" category - if cat in int6_cats and t.ndim >= 1: - q, s = quantize_int6_per_row(t) - result[name + ".q"] = q - result[name + ".scale"] = s - meta[name] = {"type": "int6"} - else: - q, s = quantize_float_tensor(t) - result[name + ".q"] = q - result[name + ".scale"] = s - meta[name] = {"type": "int8"} - return result, meta - -def dequantize_mixed_int6(result: dict[str, Tensor], meta: dict[str, object], - template_sd: dict[str, Tensor]) -> dict[str, Tensor]: - out: dict[str, Tensor] = {} - for name, orig in template_sd.items(): - info = meta.get(name) - if info is None: - continue - orig_dtype = orig.dtype - if info in ("passthrough", "passthrough_ctrl", "passthrough_fp16"): - t = result[name] - if t.dtype == torch.float16 and orig_dtype in (torch.float32, torch.bfloat16): - t = t.to(orig_dtype) - out[name] = t - continue - q, s = result[name + ".q"], result[name + ".scale"] - if s.ndim > 0: - out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig_dtype) - else: - out[name] = (q.float() * float(s.item())).to(orig_dtype) - return out - - -# ----------------------------- -# TRAINING -# ----------------------------- - -def main() -> None: - global zeropower_via_newtonschulz5 - - code = Path(__file__).read_text(encoding="utf-8") - args = Hyperparameters() - zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) - - # ----------------------------- - # DISTRIBUTED + CUDA SETUP - # ----------------------------- - - distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ - rank = int(os.environ.get("RANK", "0")) - world_size = int(os.environ.get("WORLD_SIZE", "1")) - local_rank = int(os.environ.get("LOCAL_RANK", "0")) - if world_size <= 0: - raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") - if 8 % world_size != 0: - raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") - grad_accum_steps = 8 // world_size - grad_scale = 1.0 / grad_accum_steps - if not torch.cuda.is_available(): - raise RuntimeError("CUDA is required") - device = torch.device("cuda", local_rank) - torch.cuda.set_device(device) - if distributed: - dist.init_process_group(backend="nccl", device_id=device) - dist.barrier() - master_process = rank == 0 - - # Fast math knobs - torch.backends.cuda.matmul.allow_tf32 = True - torch.backends.cudnn.allow_tf32 = True - from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp - - enable_cudnn_sdp(False) - enable_flash_sdp(True) - enable_mem_efficient_sdp(False) - enable_math_sdp(False) - - logfile = None - if master_process: - os.makedirs("logs", exist_ok=True) - logfile = f"logs/{args.run_id}.txt" - print(logfile) - - def log0(msg: str, console: bool = True) -> None: - if not master_process: - return - if console: - print(msg) - if logfile is not None: - with open(logfile, "a", encoding="utf-8") as f: - print(msg, file=f) - - log0(code, console=False) - log0("=" * 100, console=False) - log0(f"Running Python {sys.version}", console=False) - log0(f"Running PyTorch {torch.__version__}", console=False) - log0( - subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, - console=False, - ) - log0("=" * 100, console=False) - - # ----------------------------- - # TOKENIZER + VALIDATION METRIC SETUP - # ----------------------------- - - random.seed(args.seed) - np.random.seed(args.seed) - torch.manual_seed(args.seed) - torch.cuda.manual_seed_all(args.seed) - - if not args.tokenizer_path.endswith(".model"): - raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") - sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) - if int(sp.vocab_size()) != args.vocab_size: - raise ValueError( - f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" - ) - dataset_dir = Path(args.data_path).resolve() - actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) - effective_eval_seq_len = args.eval_seq_len if args.eval_seq_len > 0 else args.train_seq_len - val_seq_len = max(args.train_seq_len, effective_eval_seq_len) - val_tokens = load_validation_tokens(args.val_files, val_seq_len) - base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( - sp, args.vocab_size, device - ) - log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") - log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") - log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") - - # ----------------------------- - # MODEL + OPTIMIZER SETUP - # ----------------------------- - - CastedLinear._qat_enabled = args.qat_enabled - - base_model = GPT( - vocab_size=args.vocab_size, - num_layers=args.num_layers, - model_dim=args.model_dim, - num_heads=args.num_heads, - num_kv_heads=args.num_kv_heads, - mlp_mult=args.mlp_mult, - tie_embeddings=args.tie_embeddings, - tied_embed_init_std=args.tied_embed_init_std, - logit_softcap=args.logit_softcap, - rope_base=args.rope_base, - qk_gain_init=args.qk_gain_init, - mtp_num_heads=args.mtp_num_heads, - mtp_loss_weight=args.mtp_loss_weight, - bigram_vocab_size=args.bigram_vocab_size, - bigram_dim=args.bigram_dim, - xsa_last_n=args.xsa_last_n, - ).to(device).bfloat16() - for module in base_model.modules(): - if isinstance(module, CastedLinear): - module.float() - restore_low_dim_params_to_fp32(base_model) - compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) - model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model - - # Optimizer split: - # - token embedding (Adam) uses EMBED_LR - # - untied lm_head (Adam) uses HEAD_LR - # - matrix params in transformer blocks use MATRIX_LR via Muon - # - vectors/scalars use SCALAR_LR via Adam - block_named_params = list(base_model.blocks.named_parameters()) - matrix_params = [ - p - for name, p in block_named_params - if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - if base_model.mtp_num_heads > 0: - matrix_params.extend([p for p in base_model.mtp_heads.parameters() if p.ndim == 2]) - scalar_params = [ - p - for name, p in block_named_params - if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - if base_model.skip_weights.numel() > 0: - scalar_params.append(base_model.skip_weights) - scalar_params.append(base_model.smear.gate) - if base_model.bigram is not None: - scalar_params.append(base_model.bigram.scale) - token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr - tok_params = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] - if base_model.bigram is not None: - tok_params.append({"params": [base_model.bigram.embed.weight], "lr": token_lr, "base_lr": token_lr}) - if base_model.bigram.proj is not None: - matrix_params.append(base_model.bigram.proj.weight) - optimizer_tok = torch.optim.AdamW( - tok_params, - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - weight_decay=args.adam_wd, - fused=True, - ) - optimizer_muon = Muon( - matrix_params, - lr=args.matrix_lr, - momentum=args.muon_momentum, - backend_steps=args.muon_backend_steps, - weight_decay=args.muon_wd, - ) - for group in optimizer_muon.param_groups: - group["base_lr"] = args.matrix_lr - optimizer_scalar = torch.optim.AdamW( - [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - weight_decay=args.adam_wd, - fused=True, - ) - optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] - if base_model.lm_head is not None: - optimizer_head = torch.optim.Adam( - [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizers.insert(1, optimizer_head) - - n_params = sum(p.numel() for p in base_model.parameters()) - mtp_params = sum(p.numel() for p in base_model.mtp_heads.parameters()) - log0(f"model_params:{n_params}") - log0(f"mtp_num_heads:{args.mtp_num_heads} mtp_loss_weight:{args.mtp_loss_weight} mtp_params:{mtp_params}") - xsa_layers = [i for i, b in enumerate(base_model.blocks) if b.attn.use_xsa] - log0(f"XSA:last_{args.xsa_last_n} active_layers:{xsa_layers}") - log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") - log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") - log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") - log0( - f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " - f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " - f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" - ) - log0( - f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " - f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " - f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" - ) - log0(f"seed:{args.seed}") - - # ----------------------------- - # DATA LOADER & MODEL WARMUP - # ----------------------------- - - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - def zero_grad_all() -> None: - for opt in optimizers: - opt.zero_grad(set_to_none=True) - - max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None - - def lr_mul(step: int, elapsed_ms: float) -> float: - if args.warmdown_iters <= 0: - return 1.0 - if max_wallclock_ms is None: - warmdown_start = max(args.iterations - args.warmdown_iters, 0) - return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 - step_ms = elapsed_ms / max(step, 1) - warmdown_ms = args.warmdown_iters * step_ms - remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) - return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 - - # Warmup primes the compiled forward/backward/optimizer paths, then we restore the - # initial weights/optimizer state so measured training starts from the true init. - if args.warmup_steps > 0: - initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} - initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] - model.train() - for warmup_step in range(args.warmup_steps): - zero_grad_all() - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - warmup_loss = model(x, y) - (warmup_loss * grad_scale).backward() - for opt in optimizers: - opt.step() - zero_grad_all() - if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: - log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") - base_model.load_state_dict(initial_model_state, strict=True) - for opt, state in zip(optimizers, initial_optimizer_states, strict=True): - opt.load_state_dict(state) - zero_grad_all() - if distributed: - model.require_backward_grad_sync = True - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - # ----------------------------- - # MAIN TRAINING LOOP - # ----------------------------- - - swa_state: dict[str, Tensor] | None = None - swa_count = 0 - - training_time_ms = 0.0 - stop_after_step: int | None = None - torch.cuda.synchronize() - t0 = time.perf_counter() - - step = 0 - while True: - last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) - - should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) - if should_validate: - torch.cuda.synchronize() - training_time_ms += 1000.0 * (time.perf_counter() - t0) - val_loss, val_bpb = eval_val( - args, - model, - rank, - world_size, - device, - grad_accum_steps, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, - ) - log0( - f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " - f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" - ) - torch.cuda.synchronize() - t0 = time.perf_counter() - - if last_step: - if stop_after_step is not None and step < args.iterations: - log0( - f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " - f"step:{step}/{args.iterations}" - ) - break - - elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - scale = lr_mul(step, elapsed_ms) - zero_grad_all() - train_loss = torch.zeros((), device=device) - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - loss = model(x, y) - train_loss += loss.detach() - (loss * grad_scale).backward() - train_loss /= grad_accum_steps - - frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 - muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum - for group in optimizer_muon.param_groups: - group["momentum"] = muon_momentum - - for opt in optimizers: - for group in opt.param_groups: - group["lr"] = group["base_lr"] * scale - - if args.grad_clip_norm > 0: - torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) - for opt in optimizers: - opt.step() - zero_grad_all() - - step += 1 - approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - - if args.swa_enabled and scale < 0.5 and step % args.swa_every == 0: - if swa_state is None: - swa_state = {name: t.detach().cpu().clone() for name, t in base_model.state_dict().items()} - swa_count = 1 - log0(f"swa:start step:{step}") - else: - for name, t in base_model.state_dict().items(): - swa_state[name] += t.detach().cpu() - swa_count += 1 - - should_log_train = ( - args.train_log_every > 0 - and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) - ) - if should_log_train: - log0( - f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " - f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" - ) - - # Needed to sync whether we've reached the wallclock cap. - reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms - if distributed and max_wallclock_ms is not None: - reached_cap_tensor = torch.tensor(int(reached_cap), device=device) - dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) - reached_cap = bool(reached_cap_tensor.item()) - if stop_after_step is None and reached_cap: - stop_after_step = step - - log0( - f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " - f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" - ) - - if args.swa_enabled and swa_state is not None and swa_count > 1: - log0(f"swa:applying averaged {swa_count} checkpoints") - avg_state = {name: (t / swa_count).to(dtype=base_model.state_dict()[name].dtype) - for name, t in swa_state.items()} - base_model.load_state_dict(avg_state, strict=True) - - # ----------------------------- - # SERIALIZATION + ROUNDTRIP VALIDATION - # ----------------------------- - - full_state_dict = base_model.state_dict() - export_sd = {k: v for k, v in full_state_dict.items() if "mtp_heads" not in k} - excluded_mtp = sum(int(t.numel()) for k, t in full_state_dict.items() if "mtp_heads" in k) - if excluded_mtp > 0: - log0(f"export_excluding_mtp_params:{excluded_mtp}") - - if master_process: - torch.save(export_sd, "final_model.pt") - model_bytes = os.path.getsize("final_model.pt") - code_bytes = len(code.encode("utf-8")) - log0(f"Serialized model: {model_bytes} bytes") - log0(f"Code size: {code_bytes} bytes") - - sd_cpu = {k: v.detach().cpu() for k, v in export_sd.items()} - quant_result, quant_meta = mixed_quantize_int6(sd_cpu, {"mlp", "attn"}) - quant_buf = io.BytesIO() - torch.save({"w": quant_result, "m": quant_meta}, quant_buf) - quant_raw = quant_buf.getvalue() - quant_blob = zstandard.ZstdCompressor(level=22).compress(quant_raw) if _COMPRESSOR == "zstd" else zlib.compress(quant_raw, 9) - if master_process: - with open("final_model.int6.ptz", "wb") as f: - f.write(quant_blob) - quant_file_bytes = len(quant_blob) - code_bytes = len(code.encode("utf-8")) - log0(f"Serialized model int6+{_COMPRESSOR}: {quant_file_bytes} bytes") - log0(f"Total submission size int6+{_COMPRESSOR}: {quant_file_bytes + code_bytes} bytes") - - # Roundtrip: decompress + dequantize into fresh model + eval - if distributed: - dist.barrier() - with open("final_model.int6.ptz", "rb") as f: - quant_blob_disk = f.read() - quant_state = torch.load( - io.BytesIO(zstandard.ZstdDecompressor().decompress(quant_blob_disk) if _COMPRESSOR == "zstd" else zlib.decompress(quant_blob_disk)), - map_location="cpu", - ) - deq_state = dequantize_mixed_int6(quant_state["w"], quant_state["m"], sd_cpu) - - eval_model = GPT( - vocab_size=args.vocab_size, num_layers=args.num_layers, model_dim=args.model_dim, - num_heads=args.num_heads, num_kv_heads=args.num_kv_heads, mlp_mult=args.mlp_mult, - tie_embeddings=args.tie_embeddings, tied_embed_init_std=args.tied_embed_init_std, - logit_softcap=args.logit_softcap, rope_base=args.rope_base, qk_gain_init=args.qk_gain_init, - mtp_num_heads=0, mtp_loss_weight=0.0, - bigram_vocab_size=args.bigram_vocab_size, bigram_dim=args.bigram_dim, - xsa_last_n=args.xsa_last_n, # must match training model - ).to(device).bfloat16() - for m in eval_model.modules(): - if isinstance(m, CastedLinear): - m.float() - restore_low_dim_params_to_fp32(eval_model) - eval_model.load_state_dict(deq_state, strict=True) - compiled_eval = torch.compile(eval_model, dynamic=False, fullgraph=True) - - # Standard non-overlapping eval (sanity check) - torch.cuda.synchronize() - t_qeval = time.perf_counter() - q_val_loss, q_val_bpb = eval_val( - args, compiled_eval, rank, world_size, device, grad_accum_steps, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - eval_seq_len=effective_eval_seq_len, - ) - torch.cuda.synchronize() - log0( - f"final_int6_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " - f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" - ) - log0(f"final_int6_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") - - # Sliding window eval (submission score) - sw_seq_len = effective_eval_seq_len - if args.eval_stride > 0 and args.eval_stride < sw_seq_len: - torch.cuda.synchronize() - t_slide = time.perf_counter() - sw_val_loss, sw_val_bpb = eval_val_sliding( - args, eval_model, rank, world_size, device, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - stride=args.eval_stride, - eval_seq_len=sw_seq_len, - ) - torch.cuda.synchronize() - log0( - f"final_int6_sliding_window val_loss:{sw_val_loss:.4f} val_bpb:{sw_val_bpb:.4f} " - f"stride:{args.eval_stride} eval_time:{1000.0 * (time.perf_counter() - t_slide):.0f}ms" - ) - log0(f"final_int6_sliding_window_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") - - # Second sliding window eval at stride=64 for submission comparison - if args.eval_stride != 64 and 64 < sw_seq_len: - torch.cuda.synchronize() - t_slide64 = time.perf_counter() - sw64_val_loss, sw64_val_bpb = eval_val_sliding( - args, eval_model, rank, world_size, device, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - stride=64, - eval_seq_len=sw_seq_len, - ) - torch.cuda.synchronize() - log0( - f"final_int6_sliding_window_s64 val_loss:{sw64_val_loss:.4f} val_bpb:{sw64_val_bpb:.4f} " - f"stride:64 eval_time:{1000.0 * (time.perf_counter() - t_slide64):.0f}ms" - ) - log0(f"final_int6_sliding_window_s64_exact val_loss:{sw64_val_loss:.8f} val_bpb:{sw64_val_bpb:.8f}") - - if distributed: - dist.destroy_process_group() - - -if __name__ == "__main__": - main() diff --git a/records/track_10min_16mb/2026-03-20_11L_XSA4_EMA_Int6_MLP3x_WD04_1.1271/README.md b/records/track_10min_16mb/2026-03-20_11L_XSA4_EMA_Int6_MLP3x_WD04_1.1271/README.md deleted file mode 100644 index 8a1b15b311..0000000000 --- a/records/track_10min_16mb/2026-03-20_11L_XSA4_EMA_Int6_MLP3x_WD04_1.1271/README.md +++ /dev/null @@ -1,79 +0,0 @@ -## Record: 11L XSA + EMA + Int6 MLP3x + WD=0.04 (val_bpb: 1.1271) - -**val_bpb = 1.1271** (sliding window, stride=64) | **15.5 MB** artifact | 8xH100 SXM, 600s - -Previous: [PR #70](https://github.com/openai/parameter-golf/pull/70) (9L, 1.1659) → [PR #164](https://github.com/openai/parameter-golf/pull/164) (9L, 1.1524) → [PR #198](https://github.com/openai/parameter-golf/pull/198) (11L, 1.1318) → this - -### Changes from PR #198 - -| | [PR #198](https://github.com/openai/parameter-golf/pull/198) | This | -|---|---|---| -| val_bpb (sliding s64) | 1.1318 | **1.1271** | -| XSA | None | Last 4 layers | -| Weight averaging | SWA (~8 checkpoints) | EMA (decay=0.997, every step) | -| Artifact | 15.7 MB | 15.5 MB | -| Everything else | Same | Same | - -### What's new - -1. **Exclusive Self Attention (XSA)** on last 4 layers. After the standard attention output, XSA subtracts the component aligned with each token's own value vector using an efficient GQA-aware reshape (no repeat_interleave). This encourages attention to capture only information orthogonal to what the token already knows, improving context modeling. Zero new parameters, ~2ms/step overhead. - -2. **EMA replacing SWA**. Instead of collecting periodic SWA checkpoints during warmdown, we maintain an exponential moving average shadow model on GPU that updates every step: `ema = 0.997 * ema + 0.003 * param`. The EMA weights are used for quantization and eval. Smoother averaging than periodic SWA, better generalization and artifact compression. - -### Carried from PR #198 - -- 11 transformer layers with U-Net skip connections -- Orthogonal + muP-scaled init on all large matrices -- 3x MLP (hidden=1536), relu² activation -- Int6 mixed quantization + zstd-22 (int6 on MLP+attention, int8 on embeddings) -- Weight decay 0.04 (Muon + AdamW) -- SmearGate (learned token blending gate, ~512 params) -- Bigram Hash Embedding (2048-bucket, dim=128, projected to 512) -- FlashAttention 3 (direct flash_attn_func calls) -- Sequence length 2048 with NTK-aware RoPE -- Muon optimizer, momentum 0.99 with warmup, warmdown 3000 iters, grad clip 0.3 - -### Configuration - -```bash -NUM_LAYERS=11 BIGRAM_VOCAB_SIZE=2048 XSA_LAST_N=4 \ -EMA_ENABLED=1 EMA_DECAY=0.997 SWA_ENABLED=0 \ -MUON_WD=0.04 ADAM_WD=0.04 \ -MATRIX_LR=0.025 SCALAR_LR=0.025 TIED_EMBED_LR=0.035 \ -MUON_MOMENTUM=0.99 MUON_MOMENTUM_WARMUP_START=0.92 \ -MUON_MOMENTUM_WARMUP_STEPS=1500 WARMDOWN_ITERS=3000 \ -ITERATIONS=9000 MAX_WALLCLOCK_SECONDS=600 EVAL_STRIDE=64 \ -torchrun --standalone --nproc_per_node=8 train_gpt.py -``` - -### Key Metrics - -- 7,103 steps in 600s (84ms/step) -- ~5.6B train tokens (7,103 steps x 786,432 tokens/step) -- Peak memory: ~20,400 MiB per GPU - -| Metric | Value | -|--------|-------| -| Pre-quant val_bpb | 1.1427 | -| Int6 roundtrip val_bpb | 1.1494 | -| **Int6 sliding val_bpb (s64)** | **1.1271** | -| Compressed artifact (int6+zstd) | 15,468,512 bytes | -| Code size | 66,133 bytes | -| **Total submission size** | **15,534,645 bytes** | - -### Reproducibility - -| Seed | Steps | Sliding s64 | Artifact | -|------|-------|-------------|----------| -| **1337** | **7,103** | **1.1271** | **15,534,645** | -| 42 | 7,094 | 1.1286 | 15,745,973 | -| 2025 | 7,107 | 1.1284 | 15,649,516 | - -Mean val_bpb: **1.1280**. Submitted: seed 1337 (best). Inter-seed variance: 0.0015. - -### Included files - -- `train_gpt.py` — full training + quantization + evaluation script -- `train.log` — training log from best seed (1337) -- `train_seed1337.log`, `train_seed42.log`, `train_seed2025.log` — all seed logs -- `submission.json` — leaderboard metadata diff --git a/records/track_10min_16mb/2026-03-20_11L_XSA4_EMA_Int6_MLP3x_WD04_1.1271/submission.json b/records/track_10min_16mb/2026-03-20_11L_XSA4_EMA_Int6_MLP3x_WD04_1.1271/submission.json deleted file mode 100644 index 732dee313e..0000000000 --- a/records/track_10min_16mb/2026-03-20_11L_XSA4_EMA_Int6_MLP3x_WD04_1.1271/submission.json +++ /dev/null @@ -1,16 +0,0 @@ -{ - "author": "Jack Princz", - "github_id": "jfprincz", - "name": "Record: 11L XSA + EMA + Int6 MLP3x + WD=0.04", - "blurb": "11 layers with Exclusive Self Attention (XSA) on last 4 layers, EMA weight averaging (decay=0.997), int6 per-row on all MLP+attention weights, int8 tok_emb, zstd-22. Weight decay 0.04 (Muon+AdamW). OrthoInit + muP scaling. SmearGate + BigramHash(2048x128). FA3. Sliding window eval stride=64, seq=2048.", - "date": "2026-03-21T00:00:00Z", - "val_loss": 1.90301335, - "val_bpb": 1.12707468, - "pre_quant_val_loss": 1.9293, - "pre_quant_val_bpb": 1.1427, - "int6_zstd_val_loss": 1.94077647, - "int6_zstd_val_bpb": 1.14943715, - "bytes_total": 15534645, - "bytes_model_int6_zstd": 15468512, - "bytes_code": 66133 -} diff --git a/records/track_10min_16mb/2026-03-20_11L_XSA4_EMA_Int6_MLP3x_WD04_1.1271/train.log b/records/track_10min_16mb/2026-03-20_11L_XSA4_EMA_Int6_MLP3x_WD04_1.1271/train.log deleted file mode 100644 index f2c48a5ae8..0000000000 --- a/records/track_10min_16mb/2026-03-20_11L_XSA4_EMA_Int6_MLP3x_WD04_1.1271/train.log +++ /dev/null @@ -1,1704 +0,0 @@ -""" -train_gpt_submit.py — Submission v2: wider MLP + STE int6 QAT + MTP + seq2048 + NTK RoPE + -fp16 embed + late-K passthrough + sliding window eval. -""" - -from __future__ import annotations - -import copy -import glob -import io -import math -import os -import random -import subprocess -import sys -import time -import uuid -import zlib -from pathlib import Path - -try: - import zstandard - _COMPRESSOR = "zstd" -except ImportError: - _COMPRESSOR = "zlib" - -import numpy as np -import sentencepiece as spm -import torch -import torch.distributed as dist -import torch.nn.functional as F -from torch import Tensor, nn -from torch.nn.parallel import DistributedDataParallel as DDP - -from flash_attn_interface import flash_attn_func as flash_attn_3_func - -# ----------------------------- -# HYPERPARAMETERS -# ----------------------------- -# Default Simple Baseline run: -# - 9 transformer blocks at width 512 -# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion -# - vocab size 1024, sequence length 1024, tied embeddings -# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap - -class Hyperparameters: - # Data paths are shard globs produced by the existing preprocessing pipeline. - data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") - train_files = os.path.join(data_path, "fineweb_train_*.bin") - val_files = os.path.join(data_path, "fineweb_val_*.bin") - tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") - run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) - seed = int(os.environ.get("SEED", 1337)) - - # Validation cadence and batch size. Validation always uses the full fineweb_val split. - val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) - val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) - train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) - - # Training length. - iterations = int(os.environ.get("ITERATIONS", 20000)) - warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) - warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) - train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 786_432)) - train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) - eval_seq_len = int(os.environ.get("EVAL_SEQ_LEN", 2048)) - max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) - qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) - - # Model shape. - vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) - num_layers = int(os.environ.get("NUM_LAYERS", 9)) - num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) - model_dim = int(os.environ.get("MODEL_DIM", 512)) - num_heads = int(os.environ.get("NUM_HEADS", 8)) - mlp_mult = float(os.environ.get("MLP_MULT", 3.0)) - tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) - rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) - logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) - - # Optimizer hyperparameters. - embed_lr = float(os.environ.get("EMBED_LR", 0.6)) - head_lr = float(os.environ.get("HEAD_LR", 0.008)) - tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) - tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) - matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) - scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) - muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) - muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) - muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) - muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) - beta1 = float(os.environ.get("BETA1", 0.9)) - beta2 = float(os.environ.get("BETA2", 0.95)) - adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) - grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) - eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) - mtp_num_heads = int(os.environ.get("MTP_NUM_HEADS", 0)) - mtp_loss_weight = float(os.environ.get("MTP_LOSS_WEIGHT", 0.2)) - muon_beta2 = float(os.environ.get("MUON_BETA2", 0.95)) - swa_enabled = bool(int(os.environ.get("SWA_ENABLED", "1"))) - swa_every = int(os.environ.get("SWA_EVERY", 200)) - muon_wd = float(os.environ.get("MUON_WD", 0.02)) - adam_wd = float(os.environ.get("ADAM_WD", 0.01)) - qat_enabled = bool(int(os.environ.get("QAT_ENABLED", "0"))) - xsa_last_n = int(os.environ.get("XSA_LAST_N", 0)) - ema_enabled = bool(int(os.environ.get("EMA_ENABLED", "0"))) - ema_decay = float(os.environ.get("EMA_DECAY", 0.997)) - bigram_vocab_size = int(os.environ.get("BIGRAM_VOCAB_SIZE", 4096)) - bigram_dim = int(os.environ.get("BIGRAM_DIM", 128)) - -# ----------------------------- -# MUON OPTIMIZER -# ----------------------------- -# -# As borrowed from modded-nanogpt -# Background on Muon: https://kellerjordan.github.io/posts/muon/ - -def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: - a, b, c = (3.4445, -4.7750, 2.0315) - X = G.bfloat16() - X /= X.norm() + eps - transposed = G.size(0) > G.size(1) - if transposed: - X = X.T - for _ in range(steps): - A = X @ X.T - B = b * A + c * A @ A - X = a * X + B @ X - return X.T if transposed else X - - -class Muon(torch.optim.Optimizer): - def __init__(self, params, lr: float, momentum: float, backend_steps: int, - nesterov: bool = True, weight_decay: float = 0.0): - super().__init__( - params, - dict(lr=lr, momentum=momentum, backend_steps=backend_steps, - nesterov=nesterov, weight_decay=weight_decay), - ) - - @torch.no_grad() - def step(self, closure=None): - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - distributed = dist.is_available() and dist.is_initialized() - world_size = dist.get_world_size() if distributed else 1 - rank = dist.get_rank() if distributed else 0 - - for group in self.param_groups: - params = group["params"] - if not params: - continue - lr = group["lr"] - momentum = group["momentum"] - backend_steps = group["backend_steps"] - nesterov = group["nesterov"] - - total_params = sum(int(p.numel()) for p in params) - updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) - - curr = 0 - for i, p in enumerate(params): - if i % world_size == rank and p.grad is not None: - g = p.grad - state = self.state[p] - if "momentum_buffer" not in state: - state["momentum_buffer"] = torch.zeros_like(g) - buf = state["momentum_buffer"] - buf.mul_(momentum).add_(g) - if nesterov: - g = g.add(buf, alpha=momentum) - g = zeropower_via_newtonschulz5(g, steps=backend_steps) - g *= max(1, g.size(0) / g.size(1)) ** 0.5 - updates_flat[curr : curr + p.numel()] = g.reshape(-1) - curr += p.numel() - - if distributed: - dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) - - wd = group.get("weight_decay", 0.0) - curr = 0 - for p in params: - if wd > 0.0: - p.data.mul_(1.0 - lr * wd) - g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) - p.add_(g, alpha=-lr) - curr += p.numel() - - return loss - - -# ----------------------------- -# TOKENIZER-AGNOSTIC EVALUATION SETUP -# ----------------------------- -# -# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. -# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. -# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. -# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. - -def build_sentencepiece_luts( - sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device -) -> tuple[Tensor, Tensor, Tensor]: - sp_vocab_size = int(sp.vocab_size()) - table_size = max(sp_vocab_size, vocab_size) - base_bytes_np = np.zeros((table_size,), dtype=np.int16) - has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) - is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) - for token_id in range(sp_vocab_size): - if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): - continue - is_boundary_token_np[token_id] = False - if sp.is_byte(token_id): - base_bytes_np[token_id] = 1 - continue - piece = sp.id_to_piece(token_id) - if piece.startswith("▁"): - has_leading_space_np[token_id] = True - piece = piece[1:] - base_bytes_np[token_id] = len(piece.encode("utf-8")) - return ( - torch.tensor(base_bytes_np, dtype=torch.int16, device=device), - torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), - torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), - ) - - -def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: - files = [Path(p) for p in sorted(glob.glob(pattern))] - if not files: - raise FileNotFoundError(f"No files found for pattern: {pattern}") - # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. - tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() - usable = ((tokens.numel() - 1) // seq_len) * seq_len - if usable <= 0: - raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") - return tokens[: usable + 1] - - -def eval_val( - args: Hyperparameters, - model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - grad_accum_steps: int, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, - eval_seq_len: int | None = None, -) -> tuple[float, float]: - seq_len = eval_seq_len or args.train_seq_len - local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) - if local_batch_tokens < seq_len: - raise ValueError( - "VAL_BATCH_SIZE must provide at least one sequence per rank; " - f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " - f"GRAD_ACCUM_STEPS={grad_accum_steps}, seq_len={seq_len}" - ) - local_batch_seqs = local_batch_tokens // seq_len - total_seqs = (val_tokens.numel() - 1) // seq_len - seq_start = (total_seqs * rank) // world_size - seq_end = (total_seqs * (rank + 1)) // world_size - val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) - val_token_count = torch.zeros((), device=device, dtype=torch.float64) - val_byte_count = torch.zeros((), device=device, dtype=torch.float64) - - model.eval() - with torch.inference_mode(): - for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): - batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) - raw_start = batch_seq_start * seq_len - raw_end = batch_seq_end * seq_len + 1 - local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) - x = local[:-1].reshape(-1, seq_len) - y = local[1:].reshape(-1, seq_len) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - batch_loss = model(x, y).detach() - batch_token_count = float(y.numel()) - val_loss_sum += batch_loss.to(torch.float64) * batch_token_count - val_token_count += batch_token_count - prev_ids = x.reshape(-1) - tgt_ids = y.reshape(-1) - token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) - token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) - val_byte_count += token_bytes.to(torch.float64).sum() - - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) - - val_loss = val_loss_sum / val_token_count - bits_per_token = val_loss.item() / math.log(2.0) - tokens_per_byte = val_token_count.item() / val_byte_count.item() - model.train() - return float(val_loss.item()), float(bits_per_token * tokens_per_byte) - -# ----------------------------- -# POST-TRAINING QUANTIZATION -# ----------------------------- -# -# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. -# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. -# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. - -CONTROL_TENSOR_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "CONTROL_TENSOR_NAME_PATTERNS", - "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,smear", - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", - ",".join(CONTROL_TENSOR_NAME_PATTERNS), - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 -INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 -INT8_PER_ROW_SCALE_DTYPE = torch.float16 -INT8_CLIP_PERCENTILE = 99.99984 -INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 - -def tensor_nbytes(t: Tensor) -> int: - return int(t.numel()) * int(t.element_size()) - -def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: - if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): - return t.float().contiguous() - if t.dtype in {torch.float32, torch.bfloat16}: - passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") - return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() - return t - -def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: - t32 = t.float() - if t32.ndim == 2: - # Matrices get one scale per row, which usually tracks output-channel - # ranges much better than a single tensor-wide scale. - clip_abs = ( - torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) - if t32.numel() - else torch.empty((t32.shape[0],), dtype=torch.float32) - ) - clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) - scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) - q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() - return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() - - # Vectors / scalars use a simpler per-tensor scale. - clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 - scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) - q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() - return q, scale - -def quantize_state_dict_int8(state_dict: dict[str, Tensor]): - # Single supported clean-script export format: - # - per-row int8 for 2D float tensors - # - per-tensor int8 for other float tensors - # - exact passthrough for non-floats - # - passthrough for small float tensors, stored as fp16 to save bytes - quantized: dict[str, Tensor] = {} - scales: dict[str, Tensor] = {} - dtypes: dict[str, str] = {} - passthrough: dict[str, Tensor] = {} - passthrough_orig_dtypes: dict[str, str] = {} - qmeta: dict[str, dict[str, object]] = {} - stats = dict.fromkeys( - ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), - 0, - ) - - for name, tensor in state_dict.items(): - t = tensor.detach().to("cpu").contiguous() - stats["param_count"] += int(t.numel()) - stats["num_tensors"] += 1 - stats["baseline_tensor_bytes"] += tensor_nbytes(t) - - if not t.is_floating_point(): - stats["num_nonfloat_tensors"] += 1 - passthrough[name] = t - stats["int8_payload_bytes"] += tensor_nbytes(t) - continue - - # Small float tensors are cheap enough to keep directly. We still downcast - # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. - if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: - kept = keep_float_tensor(name, t, passthrough_orig_dtypes) - passthrough[name] = kept - stats["int8_payload_bytes"] += tensor_nbytes(kept) - continue - - stats["num_float_tensors"] += 1 - q, s = quantize_float_tensor(t) - if s.ndim > 0: - qmeta[name] = {"scheme": "per_row", "axis": 0} - quantized[name] = q - scales[name] = s - dtypes[name] = str(t.dtype).removeprefix("torch.") - stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) - - obj: dict[str, object] = { - "__quant_format__": "int8_clean_per_row_v1", - "quantized": quantized, - "scales": scales, - "dtypes": dtypes, - "passthrough": passthrough, - } - if qmeta: - obj["qmeta"] = qmeta - if passthrough_orig_dtypes: - obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes - return obj, stats - -def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: - out: dict[str, Tensor] = {} - qmeta = obj.get("qmeta", {}) - passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) - for name, q in obj["quantized"].items(): - dtype = getattr(torch, obj["dtypes"][name]) - s = obj["scales"][name] - if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: - s = s.to(dtype=torch.float32) - # Broadcast the saved row scale back across trailing dimensions. - out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() - else: - scale = float(s.item()) - out[name] = (q.float() * scale).to(dtype=dtype).contiguous() - for name, t in obj["passthrough"].items(): - # Restore small tensors, undoing the temporary fp16 storage cast if needed. - out_t = t.detach().to("cpu").contiguous() - orig_dtype = passthrough_orig_dtypes.get(name) - if isinstance(orig_dtype, str): - out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() - out[name] = out_t - return out - - -# ----------------------------- -# DATA LOADING -# ----------------------------- - -def load_data_shard(file: Path) -> Tensor: - header_bytes = 256 * np.dtype(" None: - self.file_idx = (self.file_idx + 1) % len(self.files) - self.tokens = load_data_shard(self.files[self.file_idx]) - self.pos = 0 - - def take(self, n: int) -> Tensor: - chunks: list[Tensor] = [] - remaining = n - while remaining > 0: - avail = self.tokens.numel() - self.pos - if avail <= 0: - self._advance_file() - continue - k = min(remaining, avail) - chunks.append(self.tokens[self.pos : self.pos + k]) - self.pos += k - remaining -= k - return chunks[0] if len(chunks) == 1 else torch.cat(chunks) - - -class DistributedTokenLoader: - # Each call consumes a contiguous chunk from the shared token stream, then slices out - # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. - def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): - self.rank = rank - self.world_size = world_size - self.device = device - self.stream = TokenStream(pattern) - - def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: - local_tokens = global_tokens // (self.world_size * grad_accum_steps) - per_rank_span = local_tokens + 1 - chunk = self.stream.take(per_rank_span * self.world_size) - start = self.rank * per_rank_span - local = chunk[start : start + per_rank_span].to(dtype=torch.int64) - x = local[:-1].reshape(-1, seq_len) - y = local[1:].reshape(-1, seq_len) - return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) - -# ----------------------------- -# TRANSFORMER MODULES -# ----------------------------- - -class RMSNorm(nn.Module): - def __init__(self, eps: float | None = None): - super().__init__() - self.eps = eps - - def forward(self, x: Tensor) -> Tensor: - return F.rms_norm(x, (x.size(-1),), eps=self.eps) - - -class CastedLinear(nn.Linear): - _qat_enabled: bool = False - - def forward(self, x: Tensor) -> Tensor: - w = self.weight.to(x.dtype) - if CastedLinear._qat_enabled and self.training and w.ndim == 2: - with torch.no_grad(): - w32 = self.weight.float() - row_max = w32.abs().amax(dim=1) - scale = (row_max / 31.0).clamp_min(1.0 / 31.0) - w_q = (torch.clamp(torch.round(w32 / scale[:, None]), -32, 31) * scale[:, None]).to(x.dtype) - w = w + (w_q - w).detach() - bias = self.bias.to(x.dtype) if self.bias is not None else None - return F.linear(x, w, bias) - - -def restore_low_dim_params_to_fp32(module: nn.Module) -> None: - # Keep small/control parameters in fp32 even when the model body runs in bf16. - with torch.no_grad(): - for name, param in module.named_parameters(): - if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: - param.data = param.data.float() - - -class Rotary(nn.Module): - # NTK-aware RoPE: auto-scales base frequency when seq_len exceeds train_seq_len. - def __init__(self, dim: int, base: float = 10000.0, train_seq_len: int = 1024): - super().__init__() - self.dim = dim - self.base = base - self.train_seq_len = train_seq_len - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self._seq_len_cached = 0 - self._cos_cached: Tensor | None = None - self._sin_cached: Tensor | None = None - - def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: - if ( - self._cos_cached is None - or self._sin_cached is None - or self._seq_len_cached != seq_len - or self._cos_cached.device != device - ): - if seq_len > self.train_seq_len: - scale = seq_len / self.train_seq_len - new_base = self.base * (scale ** (self.dim / (self.dim - 2))) - inv_freq = 1.0 / (new_base ** (torch.arange(0, self.dim, 2, dtype=torch.float32, device=device) / self.dim)) - else: - inv_freq = self.inv_freq.to(device) - t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) - freqs = torch.outer(t, inv_freq) - self._cos_cached = freqs.cos()[None, :, None, :] - self._sin_cached = freqs.sin()[None, :, None, :] - self._seq_len_cached = seq_len - return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) - - -def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: - half = x.size(-1) // 2 - x1, x2 = x[..., :half], x[..., half:] - return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) - - -class CausalSelfAttention(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - rope_base: float, - qk_gain_init: float, - ): - super().__init__() - if dim % num_heads != 0: - raise ValueError("model_dim must be divisible by num_heads") - if num_heads % num_kv_heads != 0: - raise ValueError("num_heads must be divisible by num_kv_heads") - self.num_heads = num_heads - self.num_kv_heads = num_kv_heads - self.head_dim = dim // num_heads - if self.head_dim % 2 != 0: - raise ValueError("head_dim must be even for RoPE") - kv_dim = self.num_kv_heads * self.head_dim - self.c_q = CastedLinear(dim, dim, bias=False) - self.c_k = CastedLinear(dim, kv_dim, bias=False) - self.c_v = CastedLinear(dim, kv_dim, bias=False) - self.proj = CastedLinear(dim, dim, bias=False) - self.proj._zero_init = True - self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) - self.rotary = Rotary(self.head_dim, base=rope_base, train_seq_len=1024) - self.use_xsa = False - - def _xsa_efficient(self, y: Tensor, v: Tensor) -> Tensor: - """Subtract self-value projection via GQA-aware reshape (no repeat_interleave).""" - B, T, H, D = y.shape - Hkv = v.size(-2) - group = H // Hkv - y_g = y.reshape(B, T, Hkv, group, D) - vn = F.normalize(v, dim=-1).unsqueeze(-2) - proj = (y_g * vn).sum(dim=-1, keepdim=True) * vn - return (y_g - proj).reshape(B, T, H, D) - - def forward(self, x: Tensor) -> Tensor: - bsz, seqlen, dim = x.shape - q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim) - k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) - v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) - q = F.rms_norm(q, (q.size(-1),)) - k = F.rms_norm(k, (k.size(-1),)) - cos, sin = self.rotary(seqlen, x.device, q.dtype) - q = apply_rotary_emb(q, cos, sin) - k = apply_rotary_emb(k, cos, sin) - q = q * self.q_gain.to(dtype=q.dtype)[None, None, :, None] - y = flash_attn_3_func(q, k, v, causal=True) - if self.use_xsa: - y = self._xsa_efficient(y, v) - y = y.reshape(bsz, seqlen, dim) - return self.proj(y) - - -class SmearGate(nn.Module): - def __init__(self, dim: int): - super().__init__() - self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) - - def forward(self, x: Tensor) -> Tensor: - g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] - x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) - return (1 - g) * x + g * x_prev - - -class BigramHashEmbedding(nn.Module): - def __init__(self, bigram_vocab_size: int, bigram_dim: int, model_dim: int): - super().__init__() - self.bigram_vocab_size = bigram_vocab_size - self.embed = nn.Embedding(bigram_vocab_size, bigram_dim) - nn.init.zeros_(self.embed.weight) - self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None - if self.proj is not None: - nn.init.zeros_(self.proj.weight) - self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) - - def bigram_hash(self, tokens: Tensor) -> Tensor: - t = tokens.to(torch.int32) - mod = self.bigram_vocab_size - 1 - out = torch.empty_like(t) - out[..., 0] = mod - out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod - return out.long() - - def forward(self, token_ids: Tensor) -> Tensor: - h = self.embed(self.bigram_hash(token_ids)) - if self.proj is not None: - h = self.proj(h) - return h * self.scale.to(dtype=h.dtype) - - -class MLP(nn.Module): - def __init__(self, dim: int, mlp_mult: int): - super().__init__() - hidden = int(mlp_mult * dim) - self.fc = CastedLinear(dim, hidden, bias=False) - self.proj = CastedLinear(hidden, dim, bias=False) - self.proj._zero_init = True - - def forward(self, x: Tensor) -> Tensor: - x = torch.relu(self.fc(x)) - return self.proj(x.square()) - - -class Block(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: int, - rope_base: float, - qk_gain_init: float, - ): - super().__init__() - self.attn_norm = RMSNorm() - self.mlp_norm = RMSNorm() - self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) - self.mlp = MLP(dim, mlp_mult) - self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) - - def forward(self, x: Tensor, x0: Tensor) -> Tensor: - mix = self.resid_mix.to(dtype=x.dtype) - x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 - attn_out = self.attn(self.attn_norm(x)) - x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out - x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) - return x - - -class GPT(nn.Module): - def __init__( - self, - vocab_size: int, - num_layers: int, - model_dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: int, - tie_embeddings: bool, - tied_embed_init_std: float, - logit_softcap: float, - rope_base: float, - qk_gain_init: float, - mtp_num_heads: int = 0, - mtp_loss_weight: float = 0.1, - bigram_vocab_size: int = 0, - bigram_dim: int = 128, - xsa_last_n: int = 0, - ): - super().__init__() - if logit_softcap <= 0.0: - raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") - self.tie_embeddings = tie_embeddings - self.tied_embed_init_std = tied_embed_init_std - self.logit_softcap = logit_softcap - self.mtp_num_heads = mtp_num_heads - self.mtp_loss_weight = mtp_loss_weight - self.tok_emb = nn.Embedding(vocab_size, model_dim) - self.bigram = BigramHashEmbedding(bigram_vocab_size, bigram_dim, model_dim) if bigram_vocab_size > 0 else None - self.smear = SmearGate(model_dim) - self.num_encoder_layers = num_layers // 2 - self.num_decoder_layers = num_layers - self.num_encoder_layers - self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) - self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) - self.blocks = nn.ModuleList( - [ - Block( - model_dim, - num_heads, - num_kv_heads, - mlp_mult, - rope_base, - qk_gain_init, - ) - for i in range(num_layers) - ] - ) - self.final_norm = RMSNorm() - self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) - if self.lm_head is not None: - self.lm_head._zero_init = True - self.mtp_heads = nn.ModuleList( - [CastedLinear(model_dim, vocab_size, bias=False) for _ in range(mtp_num_heads)] - ) - for head in self.mtp_heads: - head._zero_init = True - if xsa_last_n > 0: - for i in range(max(0, num_layers - xsa_last_n), num_layers): - self.blocks[i].attn.use_xsa = True - self._init_weights() - - def _init_weights(self) -> None: - if self.tie_embeddings: - nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) - num_layers = len(self.blocks) - for name, module in self.named_modules(): - if isinstance(module, nn.Linear): - if getattr(module, "_zero_init", False): - nn.init.zeros_(module.weight) - elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: - nn.init.orthogonal_(module.weight, gain=1.0) - if ".proj." in name or name.endswith(".proj"): - with torch.no_grad(): - module.weight.mul_(1.0 / math.sqrt(2 * num_layers)) - - def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: - x = self.tok_emb(input_ids) - if self.bigram is not None: - x = x + self.bigram(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - x = self.smear(x) - x0 = x - skips: list[Tensor] = [] - - for i in range(self.num_encoder_layers): - x = self.blocks[i](x, x0) - skips.append(x) - for i in range(self.num_decoder_layers): - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - x = self.blocks[self.num_encoder_layers + i](x, x0) - - x = self.final_norm(x) - x_flat = x.reshape(-1, x.size(-1)) - targets = target_ids.reshape(-1) - if self.tie_embeddings: - logits_proj = F.linear(x_flat, self.tok_emb.weight) - else: - if self.lm_head is None: - raise RuntimeError("lm_head is required when tie_embeddings=False") - logits_proj = self.lm_head(x_flat) - logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - main_loss = F.cross_entropy(logits.float(), targets, reduction="mean") - - if self.training and self.mtp_num_heads > 0 and self.mtp_loss_weight > 0.0: - _, seqlen, dim = x.shape - mtp_loss_sum = x.new_zeros(()) - mtp_loss_count = 0 - for k, mtp_head in enumerate(self.mtp_heads): - valid_t = seqlen - (k + 1) - if valid_t <= 0: - continue - mtp_hidden = x[:, :valid_t, :].reshape(-1, dim) - mtp_targets = target_ids[:, k + 1 :].reshape(-1) - mtp_logits_proj = mtp_head(mtp_hidden) - mtp_logits = self.logit_softcap * torch.tanh(mtp_logits_proj / self.logit_softcap) - mtp_loss_sum = mtp_loss_sum + F.cross_entropy(mtp_logits.float(), mtp_targets, reduction="mean") - mtp_loss_count += 1 - if mtp_loss_count > 0: - main_loss = main_loss + self.mtp_loss_weight * (mtp_loss_sum / mtp_loss_count) - - return main_loss - - def forward_logits(self, input_ids: Tensor) -> Tensor: - """Return logits (bsz, seq_len, vocab) without computing loss.""" - x = self.tok_emb(input_ids) - if self.bigram is not None: - x = x + self.bigram(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - x = self.smear(x) - x0 = x - skips: list[Tensor] = [] - for i in range(self.num_encoder_layers): - x = self.blocks[i](x, x0) - skips.append(x) - for i in range(self.num_decoder_layers): - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - x = self.blocks[self.num_encoder_layers + i](x, x0) - x = self.final_norm(x) - if self.tie_embeddings: - logits_proj = F.linear(x, self.tok_emb.weight) - else: - logits_proj = self.lm_head(x) - return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - - -# ----------------------------- -# SLIDING WINDOW EVALUATION -# ----------------------------- - -def eval_val_sliding( - args: Hyperparameters, - base_model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, - stride: int, - batch_seqs: int = 32, - eval_seq_len: int | None = None, -) -> tuple[float, float]: - """Sliding window evaluation: each token scored with maximum context.""" - seq_len = eval_seq_len or args.train_seq_len - total_tokens = val_tokens.numel() - 1 - - window_starts = [ws for ws in range(0, total_tokens, stride) - if min(ws + seq_len, total_tokens) - ws >= 1] - total_windows = len(window_starts) - - my_s = (total_windows * rank) // world_size - my_e = (total_windows * (rank + 1)) // world_size - my_windows = window_starts[my_s:my_e] - - loss_sum = torch.zeros((), device=device, dtype=torch.float64) - token_count = torch.zeros((), device=device, dtype=torch.float64) - byte_count = torch.zeros((), device=device, dtype=torch.float64) - - base_model.eval() - compiled_logits = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) - - with torch.inference_mode(): - for bi in range(0, len(my_windows), batch_seqs): - batch_ws = my_windows[bi:bi + batch_seqs] - bsz = len(batch_ws) - - x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) - y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) - wlens: list[int] = [] - - for i, ws in enumerate(batch_ws): - end = min(ws + seq_len, total_tokens) - wlen = end - ws - wlens.append(wlen) - chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) - x_batch[i, :wlen] = chunk[:-1] - y_batch[i, :wlen] = chunk[1:] - - with torch.autocast(device_type="cuda", dtype=torch.bfloat16): - logits = compiled_logits(x_batch) - - nll = F.cross_entropy( - logits.reshape(-1, logits.size(-1)).float(), - y_batch.reshape(-1), - reduction="none", - ).reshape(bsz, seq_len) - - for i, ws in enumerate(batch_ws): - wlen = wlens[i] - s = 0 if ws == 0 else max(wlen - stride, 0) - scored_nll = nll[i, s:wlen].to(torch.float64) - loss_sum += scored_nll.sum() - token_count += float(wlen - s) - tgt = y_batch[i, s:wlen] - prev = x_batch[i, s:wlen] - tb = base_bytes_lut[tgt].to(torch.float64) - tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) - byte_count += tb.sum() - - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) - - val_loss = (loss_sum / token_count).item() - bits_per_token = val_loss / math.log(2.0) - tokens_per_byte = token_count.item() / byte_count.item() - base_model.train() - return val_loss, bits_per_token * tokens_per_byte - - -# ----------------------------- -# INT6 MIXED QUANTIZATION (transplanted from working diagnostic scripts) -# ----------------------------- - -def _classify_param(name: str) -> str: - if "tok_emb" in name or "lm_head" in name: - return "embed" - if ".mlp." in name: - return "mlp" - if ".attn." in name or (".proj." in name and ".mlp." not in name): - return "attn" - return "other" - -def quantize_int6_per_row(t: Tensor) -> tuple[Tensor, Tensor]: - t32 = t.float() - if t32.ndim == 2: - row_max = t32.abs().amax(dim=1) - scale = (row_max / 31.0).clamp_min(1.0 / 31.0).to(torch.float16) - q = torch.clamp(torch.round(t32 / scale.float()[:, None]), -32, 31).to(torch.int8) - return q, scale - amax = t32.abs().max().item() - scale = torch.tensor(amax / 31.0 if amax > 0 else 1.0, dtype=torch.float16) - q = torch.clamp(torch.round(t32 / scale.float()), -32, 31).to(torch.int8) - return q, scale - -def mixed_quantize_int6(state_dict: dict[str, Tensor], int6_cats: set[str]): - num_layers_total = max( - (int(k.split(".")[1]) for k in state_dict if k.startswith("blocks.")), - default=0, - ) + 1 - late_k_layers = set(range(num_layers_total - 2, num_layers_total)) - - result: dict[str, Tensor] = {} - meta: dict[str, object] = {} - for name, tensor in state_dict.items(): - t = tensor.detach().cpu().contiguous() - cat = _classify_param(name) - if not t.is_floating_point() or t.numel() <= 65536: - result[name] = t.to(torch.float16) if t.is_floating_point() else t - meta[name] = "passthrough" - continue - if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): - result[name] = t.float() - meta[name] = "passthrough_ctrl" - continue - # tok_emb.weight falls through to int8 via "embed" category - if cat in int6_cats and t.ndim >= 1: - q, s = quantize_int6_per_row(t) - result[name + ".q"] = q - result[name + ".scale"] = s - meta[name] = {"type": "int6"} - else: - q, s = quantize_float_tensor(t) - result[name + ".q"] = q - result[name + ".scale"] = s - meta[name] = {"type": "int8"} - return result, meta - -def dequantize_mixed_int6(result: dict[str, Tensor], meta: dict[str, object], - template_sd: dict[str, Tensor]) -> dict[str, Tensor]: - out: dict[str, Tensor] = {} - for name, orig in template_sd.items(): - info = meta.get(name) - if info is None: - continue - orig_dtype = orig.dtype - if info in ("passthrough", "passthrough_ctrl", "passthrough_fp16"): - t = result[name] - if t.dtype == torch.float16 and orig_dtype in (torch.float32, torch.bfloat16): - t = t.to(orig_dtype) - out[name] = t - continue - q, s = result[name + ".q"], result[name + ".scale"] - if s.ndim > 0: - out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig_dtype) - else: - out[name] = (q.float() * float(s.item())).to(orig_dtype) - return out - - -# ----------------------------- -# TRAINING -# ----------------------------- - -def main() -> None: - global zeropower_via_newtonschulz5 - - code = Path(__file__).read_text(encoding="utf-8") - args = Hyperparameters() - zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) - - # ----------------------------- - # DISTRIBUTED + CUDA SETUP - # ----------------------------- - - distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ - rank = int(os.environ.get("RANK", "0")) - world_size = int(os.environ.get("WORLD_SIZE", "1")) - local_rank = int(os.environ.get("LOCAL_RANK", "0")) - if world_size <= 0: - raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") - if 8 % world_size != 0: - raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") - grad_accum_steps = 8 // world_size - grad_scale = 1.0 / grad_accum_steps - if not torch.cuda.is_available(): - raise RuntimeError("CUDA is required") - device = torch.device("cuda", local_rank) - torch.cuda.set_device(device) - if distributed: - dist.init_process_group(backend="nccl", device_id=device) - dist.barrier() - master_process = rank == 0 - - # Fast math knobs - torch.backends.cuda.matmul.allow_tf32 = True - torch.backends.cudnn.allow_tf32 = True - from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp - - enable_cudnn_sdp(False) - enable_flash_sdp(True) - enable_mem_efficient_sdp(False) - enable_math_sdp(False) - - logfile = None - if master_process: - os.makedirs("logs", exist_ok=True) - logfile = f"logs/{args.run_id}.txt" - print(logfile) - - def log0(msg: str, console: bool = True) -> None: - if not master_process: - return - if console: - print(msg) - if logfile is not None: - with open(logfile, "a", encoding="utf-8") as f: - print(msg, file=f) - - log0(code, console=False) - log0("=" * 100, console=False) - log0(f"Running Python {sys.version}", console=False) - log0(f"Running PyTorch {torch.__version__}", console=False) - log0( - subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, - console=False, - ) - log0("=" * 100, console=False) - - # ----------------------------- - # TOKENIZER + VALIDATION METRIC SETUP - # ----------------------------- - - random.seed(args.seed) - np.random.seed(args.seed) - torch.manual_seed(args.seed) - torch.cuda.manual_seed_all(args.seed) - - if not args.tokenizer_path.endswith(".model"): - raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") - sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) - if int(sp.vocab_size()) != args.vocab_size: - raise ValueError( - f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" - ) - dataset_dir = Path(args.data_path).resolve() - actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) - effective_eval_seq_len = args.eval_seq_len if args.eval_seq_len > 0 else args.train_seq_len - val_seq_len = max(args.train_seq_len, effective_eval_seq_len) - val_tokens = load_validation_tokens(args.val_files, val_seq_len) - base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( - sp, args.vocab_size, device - ) - log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") - log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") - log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") - - # ----------------------------- - # MODEL + OPTIMIZER SETUP - # ----------------------------- - - CastedLinear._qat_enabled = args.qat_enabled - - base_model = GPT( - vocab_size=args.vocab_size, - num_layers=args.num_layers, - model_dim=args.model_dim, - num_heads=args.num_heads, - num_kv_heads=args.num_kv_heads, - mlp_mult=args.mlp_mult, - tie_embeddings=args.tie_embeddings, - tied_embed_init_std=args.tied_embed_init_std, - logit_softcap=args.logit_softcap, - rope_base=args.rope_base, - qk_gain_init=args.qk_gain_init, - mtp_num_heads=args.mtp_num_heads, - mtp_loss_weight=args.mtp_loss_weight, - bigram_vocab_size=args.bigram_vocab_size, - bigram_dim=args.bigram_dim, - xsa_last_n=args.xsa_last_n, - ).to(device).bfloat16() - for module in base_model.modules(): - if isinstance(module, CastedLinear): - module.float() - restore_low_dim_params_to_fp32(base_model) - compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) - model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model - - # Optimizer split: - # - token embedding (Adam) uses EMBED_LR - # - untied lm_head (Adam) uses HEAD_LR - # - matrix params in transformer blocks use MATRIX_LR via Muon - # - vectors/scalars use SCALAR_LR via Adam - block_named_params = list(base_model.blocks.named_parameters()) - matrix_params = [ - p - for name, p in block_named_params - if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - if base_model.mtp_num_heads > 0: - matrix_params.extend([p for p in base_model.mtp_heads.parameters() if p.ndim == 2]) - scalar_params = [ - p - for name, p in block_named_params - if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - if base_model.skip_weights.numel() > 0: - scalar_params.append(base_model.skip_weights) - scalar_params.append(base_model.smear.gate) - if base_model.bigram is not None: - scalar_params.append(base_model.bigram.scale) - token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr - tok_params = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] - if base_model.bigram is not None: - tok_params.append({"params": [base_model.bigram.embed.weight], "lr": token_lr, "base_lr": token_lr}) - if base_model.bigram.proj is not None: - matrix_params.append(base_model.bigram.proj.weight) - optimizer_tok = torch.optim.AdamW( - tok_params, - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - weight_decay=args.adam_wd, - fused=True, - ) - optimizer_muon = Muon( - matrix_params, - lr=args.matrix_lr, - momentum=args.muon_momentum, - backend_steps=args.muon_backend_steps, - weight_decay=args.muon_wd, - ) - for group in optimizer_muon.param_groups: - group["base_lr"] = args.matrix_lr - optimizer_scalar = torch.optim.AdamW( - [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - weight_decay=args.adam_wd, - fused=True, - ) - optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] - if base_model.lm_head is not None: - optimizer_head = torch.optim.Adam( - [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizers.insert(1, optimizer_head) - - n_params = sum(p.numel() for p in base_model.parameters()) - mtp_params = sum(p.numel() for p in base_model.mtp_heads.parameters()) - log0(f"model_params:{n_params}") - log0(f"mtp_num_heads:{args.mtp_num_heads} mtp_loss_weight:{args.mtp_loss_weight} mtp_params:{mtp_params}") - log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") - log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") - log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") - log0( - f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " - f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " - f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" - ) - log0( - f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " - f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " - f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" - ) - log0(f"seed:{args.seed}") - - # ----------------------------- - # DATA LOADER & MODEL WARMUP - # ----------------------------- - - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - def zero_grad_all() -> None: - for opt in optimizers: - opt.zero_grad(set_to_none=True) - - max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None - - def lr_mul(step: int, elapsed_ms: float) -> float: - if args.warmdown_iters <= 0: - return 1.0 - if max_wallclock_ms is None: - warmdown_start = max(args.iterations - args.warmdown_iters, 0) - return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 - step_ms = elapsed_ms / max(step, 1) - warmdown_ms = args.warmdown_iters * step_ms - remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) - return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 - - # Warmup primes the compiled forward/backward/optimizer paths, then we restore the - # initial weights/optimizer state so measured training starts from the true init. - if args.warmup_steps > 0: - initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} - initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] - model.train() - for warmup_step in range(args.warmup_steps): - zero_grad_all() - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - warmup_loss = model(x, y) - (warmup_loss * grad_scale).backward() - for opt in optimizers: - opt.step() - zero_grad_all() - if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: - log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") - base_model.load_state_dict(initial_model_state, strict=True) - for opt, state in zip(optimizers, initial_optimizer_states, strict=True): - opt.load_state_dict(state) - zero_grad_all() - if distributed: - model.require_backward_grad_sync = True - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - # ----------------------------- - # MAIN TRAINING LOOP - # ----------------------------- - - swa_state: dict[str, Tensor] | None = None - swa_count = 0 - - ema_state: dict[str, Tensor] | None = None - if args.ema_enabled: - ema_state = {name: t.detach().float().clone() for name, t in base_model.state_dict().items()} - - training_time_ms = 0.0 - stop_after_step: int | None = None - torch.cuda.synchronize() - t0 = time.perf_counter() - - step = 0 - while True: - last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) - - should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) - if should_validate: - torch.cuda.synchronize() - training_time_ms += 1000.0 * (time.perf_counter() - t0) - val_loss, val_bpb = eval_val( - args, - model, - rank, - world_size, - device, - grad_accum_steps, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, - ) - log0( - f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " - f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" - ) - torch.cuda.synchronize() - t0 = time.perf_counter() - - if last_step: - if stop_after_step is not None and step < args.iterations: - log0( - f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " - f"step:{step}/{args.iterations}" - ) - break - - elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - scale = lr_mul(step, elapsed_ms) - zero_grad_all() - train_loss = torch.zeros((), device=device) - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - loss = model(x, y) - train_loss += loss.detach() - (loss * grad_scale).backward() - train_loss /= grad_accum_steps - - frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 - muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum - for group in optimizer_muon.param_groups: - group["momentum"] = muon_momentum - - for opt in optimizers: - for group in opt.param_groups: - group["lr"] = group["base_lr"] * scale - - if args.grad_clip_norm > 0: - torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) - for opt in optimizers: - opt.step() - zero_grad_all() - - step += 1 - approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - - if ema_state is not None: - d = args.ema_decay - with torch.no_grad(): - for name, t in base_model.state_dict().items(): - ema_state[name].mul_(d).add_(t.detach().float(), alpha=1.0 - d) - - if args.swa_enabled and not args.ema_enabled and scale < 0.5 and step % args.swa_every == 0: - if swa_state is None: - swa_state = {name: t.detach().float().clone() for name, t in base_model.state_dict().items()} - swa_count = 1 - log0(f"swa:start step:{step}") - else: - for name, t in base_model.state_dict().items(): - swa_state[name].add_(t.detach().float()) - swa_count += 1 - - should_log_train = ( - args.train_log_every > 0 - and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) - ) - if should_log_train: - log0( - f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " - f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" - ) - - # Needed to sync whether we've reached the wallclock cap. - reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms - if distributed and max_wallclock_ms is not None: - reached_cap_tensor = torch.tensor(int(reached_cap), device=device) - dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) - reached_cap = bool(reached_cap_tensor.item()) - if stop_after_step is None and reached_cap: - stop_after_step = step - - log0( - f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " - f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" - ) - - if ema_state is not None: - log0("ema:applying EMA weights") - avg_state = {name: t.to(dtype=base_model.state_dict()[name].dtype) - for name, t in ema_state.items()} - del ema_state - base_model.load_state_dict(avg_state, strict=True) - elif args.swa_enabled and swa_state is not None and swa_count > 1: - log0(f"swa:applying averaged {swa_count} checkpoints") - avg_state = {name: (t / swa_count).to(dtype=base_model.state_dict()[name].dtype) - for name, t in swa_state.items()} - del swa_state - base_model.load_state_dict(avg_state, strict=True) - - # ----------------------------- - # SERIALIZATION + ROUNDTRIP VALIDATION - # ----------------------------- - - full_state_dict = base_model.state_dict() - export_sd = {k: v for k, v in full_state_dict.items() if "mtp_heads" not in k} - excluded_mtp = sum(int(t.numel()) for k, t in full_state_dict.items() if "mtp_heads" in k) - if excluded_mtp > 0: - log0(f"export_excluding_mtp_params:{excluded_mtp}") - - if master_process: - torch.save(export_sd, "final_model.pt") - model_bytes = os.path.getsize("final_model.pt") - code_bytes = len(code.encode("utf-8")) - log0(f"Serialized model: {model_bytes} bytes") - log0(f"Code size: {code_bytes} bytes") - - sd_cpu = {k: v.detach().cpu() for k, v in export_sd.items()} - quant_result, quant_meta = mixed_quantize_int6(sd_cpu, {"mlp", "attn"}) - quant_buf = io.BytesIO() - torch.save({"w": quant_result, "m": quant_meta}, quant_buf) - quant_raw = quant_buf.getvalue() - quant_blob = zstandard.ZstdCompressor(level=22).compress(quant_raw) if _COMPRESSOR == "zstd" else zlib.compress(quant_raw, 9) - if master_process: - with open("final_model.int6.ptz", "wb") as f: - f.write(quant_blob) - quant_file_bytes = len(quant_blob) - code_bytes = len(code.encode("utf-8")) - log0(f"Serialized model int6+{_COMPRESSOR}: {quant_file_bytes} bytes") - log0(f"Total submission size int6+{_COMPRESSOR}: {quant_file_bytes + code_bytes} bytes") - - # Roundtrip: decompress + dequantize into fresh model + eval - if distributed: - dist.barrier() - with open("final_model.int6.ptz", "rb") as f: - quant_blob_disk = f.read() - quant_state = torch.load( - io.BytesIO(zstandard.ZstdDecompressor().decompress(quant_blob_disk) if _COMPRESSOR == "zstd" else zlib.decompress(quant_blob_disk)), - map_location="cpu", - ) - deq_state = dequantize_mixed_int6(quant_state["w"], quant_state["m"], sd_cpu) - - eval_model = GPT( - vocab_size=args.vocab_size, num_layers=args.num_layers, model_dim=args.model_dim, - num_heads=args.num_heads, num_kv_heads=args.num_kv_heads, mlp_mult=args.mlp_mult, - tie_embeddings=args.tie_embeddings, tied_embed_init_std=args.tied_embed_init_std, - logit_softcap=args.logit_softcap, rope_base=args.rope_base, qk_gain_init=args.qk_gain_init, - mtp_num_heads=0, mtp_loss_weight=0.0, - bigram_vocab_size=args.bigram_vocab_size, bigram_dim=args.bigram_dim, - xsa_last_n=args.xsa_last_n, - ).to(device).bfloat16() - for m in eval_model.modules(): - if isinstance(m, CastedLinear): - m.float() - restore_low_dim_params_to_fp32(eval_model) - eval_model.load_state_dict(deq_state, strict=True) - compiled_eval = torch.compile(eval_model, dynamic=False, fullgraph=True) - - # Standard non-overlapping eval (sanity check) - torch.cuda.synchronize() - t_qeval = time.perf_counter() - q_val_loss, q_val_bpb = eval_val( - args, compiled_eval, rank, world_size, device, grad_accum_steps, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - eval_seq_len=effective_eval_seq_len, - ) - torch.cuda.synchronize() - log0( - f"final_int6_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " - f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" - ) - log0(f"final_int6_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") - - # Sliding window eval (submission score) - sw_seq_len = effective_eval_seq_len - if args.eval_stride > 0 and args.eval_stride < sw_seq_len: - torch.cuda.synchronize() - t_slide = time.perf_counter() - sw_val_loss, sw_val_bpb = eval_val_sliding( - args, eval_model, rank, world_size, device, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - stride=args.eval_stride, - eval_seq_len=sw_seq_len, - ) - torch.cuda.synchronize() - log0( - f"final_int6_sliding_window val_loss:{sw_val_loss:.4f} val_bpb:{sw_val_bpb:.4f} " - f"stride:{args.eval_stride} eval_time:{1000.0 * (time.perf_counter() - t_slide):.0f}ms" - ) - log0(f"final_int6_sliding_window_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") - - # Second sliding window eval at stride=64 for submission comparison - if args.eval_stride != 64 and 64 < sw_seq_len: - torch.cuda.synchronize() - t_slide64 = time.perf_counter() - sw64_val_loss, sw64_val_bpb = eval_val_sliding( - args, eval_model, rank, world_size, device, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - stride=64, - eval_seq_len=sw_seq_len, - ) - torch.cuda.synchronize() - log0( - f"final_int6_sliding_window_s64 val_loss:{sw64_val_loss:.4f} val_bpb:{sw64_val_bpb:.4f} " - f"stride:64 eval_time:{1000.0 * (time.perf_counter() - t_slide64):.0f}ms" - ) - log0(f"final_int6_sliding_window_s64_exact val_loss:{sw64_val_loss:.8f} val_bpb:{sw64_val_bpb:.8f}") - - if distributed: - dist.destroy_process_group() - - -if __name__ == "__main__": - main() - -==================================================================================================== -Running Python 3.12.3 (main, Nov 6 2025, 13:44:16) [GCC 13.3.0] -Running PyTorch 2.9.1+cu128 -Fri Mar 20 22:49:34 2026 -+-----------------------------------------------------------------------------------------+ -| NVIDIA-SMI 580.126.09 Driver Version: 580.126.09 CUDA Version: 13.0 | -+-----------------------------------------+------------------------+----------------------+ -| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | -| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | -| | | MIG M. | -|=========================================+========================+======================| -| 0 NVIDIA H100 80GB HBM3 On | 00000000:18:00.0 Off | 0 | -| N/A 27C P0 116W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 1 NVIDIA H100 80GB HBM3 On | 00000000:2A:00.0 Off | 0 | -| N/A 27C P0 119W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 2 NVIDIA H100 80GB HBM3 On | 00000000:3A:00.0 Off | 0 | -| N/A 27C P0 115W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 3 NVIDIA H100 80GB HBM3 On | 00000000:5D:00.0 Off | 0 | -| N/A 25C P0 114W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 4 NVIDIA H100 80GB HBM3 On | 00000000:9A:00.0 Off | 0 | -| N/A 24C P0 114W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 5 NVIDIA H100 80GB HBM3 On | 00000000:AB:00.0 Off | 0 | -| N/A 26C P0 113W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 6 NVIDIA H100 80GB HBM3 On | 00000000:BA:00.0 Off | 0 | -| N/A 25C P0 112W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 7 NVIDIA H100 80GB HBM3 On | 00000000:DB:00.0 Off | 0 | -| N/A 25C P0 113W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ - -+-----------------------------------------------------------------------------------------+ -| Processes: | -| GPU GI CI PID Type Process name GPU Memory | -| ID ID Usage | -|=========================================================================================| -| 0 N/A N/A 112549 C /usr/local/bin/python 1512MiB | -| 1 N/A N/A 112550 C /usr/local/bin/python 1512MiB | -| 2 N/A N/A 112551 C /usr/local/bin/python 1512MiB | -| 3 N/A N/A 112552 C /usr/local/bin/python 1512MiB | -| 4 N/A N/A 112553 C /usr/local/bin/python 1512MiB | -| 5 N/A N/A 112554 C /usr/local/bin/python 1512MiB | -| 6 N/A N/A 112555 C /usr/local/bin/python 1512MiB | -| 7 N/A N/A 112556 C /usr/local/bin/python 1512MiB | -+-----------------------------------------------------------------------------------------+ - -==================================================================================================== -val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model -train_loader:dataset:fineweb10B_sp1024 train_shards:80 -val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 -model_params:26829913 -mtp_num_heads:0 mtp_loss_weight:0.2 mtp_params:0 -world_size:8 grad_accum_steps:1 -sdp_backends:cudnn=False flash=True mem_efficient=False math=False -attention_mode:gqa num_heads:8 num_kv_heads:4 -tie_embeddings:True embed_lr:0.035 head_lr:0.0 matrix_lr:0.025 scalar_lr:0.025 -train_batch_tokens:786432 train_seq_len:2048 iterations:9000 warmup_steps:20 max_wallclock_seconds:600.000 -seed:1337 -warmup_step:1/20 -warmup_step:2/20 -warmup_step:3/20 -warmup_step:4/20 -warmup_step:5/20 -warmup_step:6/20 -warmup_step:7/20 -warmup_step:8/20 -warmup_step:9/20 -warmup_step:10/20 -warmup_step:11/20 -warmup_step:12/20 -warmup_step:13/20 -warmup_step:14/20 -warmup_step:15/20 -warmup_step:16/20 -warmup_step:17/20 -warmup_step:18/20 -warmup_step:19/20 -warmup_step:20/20 -step:1/9000 train_loss:6.9326 train_time:137ms step_avg:136.88ms -step:2/9000 train_loss:8.6010 train_time:215ms step_avg:107.64ms -step:3/9000 train_loss:7.9282 train_time:306ms step_avg:101.93ms -step:4/9000 train_loss:7.2181 train_time:395ms step_avg:98.68ms -step:5/9000 train_loss:6.9899 train_time:484ms step_avg:96.71ms -step:6/9000 train_loss:6.8653 train_time:573ms step_avg:95.48ms -step:7/9000 train_loss:6.8744 train_time:663ms step_avg:94.76ms -step:8/9000 train_loss:6.8852 train_time:754ms step_avg:94.27ms -step:9/9000 train_loss:6.4661 train_time:844ms step_avg:93.76ms -step:10/9000 train_loss:6.1087 train_time:933ms step_avg:93.32ms -step:200/9000 train_loss:2.4028 train_time:16961ms step_avg:84.81ms -step:400/9000 train_loss:2.4090 train_time:33863ms step_avg:84.66ms -step:600/9000 train_loss:2.3253 train_time:50649ms step_avg:84.42ms -step:800/9000 train_loss:2.2314 train_time:67584ms step_avg:84.48ms -step:1000/9000 train_loss:2.2670 train_time:84427ms step_avg:84.43ms -step:1200/9000 train_loss:2.3422 train_time:101372ms step_avg:84.48ms -step:1400/9000 train_loss:2.1767 train_time:118297ms step_avg:84.50ms -step:1600/9000 train_loss:2.0660 train_time:135093ms step_avg:84.43ms -step:1800/9000 train_loss:2.1574 train_time:152012ms step_avg:84.45ms -step:2000/9000 train_loss:2.0607 train_time:168821ms step_avg:84.41ms -step:2200/9000 train_loss:2.1364 train_time:185761ms step_avg:84.44ms -step:2400/9000 train_loss:2.0617 train_time:202583ms step_avg:84.41ms -step:2600/9000 train_loss:2.1061 train_time:219469ms step_avg:84.41ms -step:2800/9000 train_loss:2.1514 train_time:236421ms step_avg:84.44ms -step:3000/9000 train_loss:2.1556 train_time:253214ms step_avg:84.40ms -step:3200/9000 train_loss:2.1708 train_time:270185ms step_avg:84.43ms -step:3400/9000 train_loss:2.0192 train_time:286995ms step_avg:84.41ms -step:3600/9000 train_loss:2.0973 train_time:303900ms step_avg:84.42ms -step:3800/9000 train_loss:2.0745 train_time:320702ms step_avg:84.40ms -step:4000/9000 train_loss:1.9845 train_time:337615ms step_avg:84.40ms -step:4200/9000 train_loss:2.1611 train_time:354577ms step_avg:84.42ms -step:4400/9000 train_loss:2.0459 train_time:371382ms step_avg:84.41ms -step:4600/9000 train_loss:1.8517 train_time:388295ms step_avg:84.41ms -step:4800/9000 train_loss:2.4364 train_time:405091ms step_avg:84.39ms -step:5000/9000 train_loss:2.1143 train_time:422062ms step_avg:84.41ms -step:5200/9000 train_loss:2.0497 train_time:438873ms step_avg:84.40ms -step:5400/9000 train_loss:2.0554 train_time:455815ms step_avg:84.41ms -step:5600/9000 train_loss:1.9594 train_time:472895ms step_avg:84.45ms -step:5800/9000 train_loss:2.0093 train_time:489724ms step_avg:84.44ms -step:6000/9000 train_loss:1.9491 train_time:506734ms step_avg:84.46ms -step:6200/9000 train_loss:1.9584 train_time:523630ms step_avg:84.46ms -step:6400/9000 train_loss:2.0061 train_time:540633ms step_avg:84.47ms -step:6600/9000 train_loss:1.8496 train_time:557451ms step_avg:84.46ms -step:6800/9000 train_loss:2.0281 train_time:574468ms step_avg:84.48ms -step:7000/9000 train_loss:1.7921 train_time:591382ms step_avg:84.48ms -step:7103/9000 val_loss:1.9293 val_bpb:1.1427 train_time:600021ms step_avg:84.47ms -stopping_early: wallclock_cap train_time:600021ms step:7103/9000 -peak memory allocated: 20600 MiB reserved: 20658 MiB -ema:applying EMA weights -Serialized model: 105783807 bytes -Code size: 66133 bytes -Serialized model int6+zstd: 15468512 bytes -Total submission size int6+zstd: 15534645 bytes -final_int6_roundtrip val_loss:1.9408 val_bpb:1.1494 eval_time:5637ms -final_int6_roundtrip_exact val_loss:1.94077647 val_bpb:1.14943715 -final_int6_sliding_window val_loss:1.9030 val_bpb:1.1271 stride:64 eval_time:69809ms -final_int6_sliding_window_exact val_loss:1.90301335 val_bpb:1.12707468 diff --git a/records/track_10min_16mb/2026-03-20_11L_XSA4_EMA_Int6_MLP3x_WD04_1.1271/train_gpt.py b/records/track_10min_16mb/2026-03-20_11L_XSA4_EMA_Int6_MLP3x_WD04_1.1271/train_gpt.py deleted file mode 100644 index 7b83fa4d3e..0000000000 --- a/records/track_10min_16mb/2026-03-20_11L_XSA4_EMA_Int6_MLP3x_WD04_1.1271/train_gpt.py +++ /dev/null @@ -1,1555 +0,0 @@ -""" -train_gpt_submit.py — Submission v2: wider MLP + STE int6 QAT + MTP + seq2048 + NTK RoPE + -fp16 embed + late-K passthrough + sliding window eval. -""" - -from __future__ import annotations - -import copy -import glob -import io -import math -import os -import random -import subprocess -import sys -import time -import uuid -import zlib -from pathlib import Path - -try: - import zstandard - _COMPRESSOR = "zstd" -except ImportError: - _COMPRESSOR = "zlib" - -import numpy as np -import sentencepiece as spm -import torch -import torch.distributed as dist -import torch.nn.functional as F -from torch import Tensor, nn -from torch.nn.parallel import DistributedDataParallel as DDP - -from flash_attn_interface import flash_attn_func as flash_attn_3_func - -# ----------------------------- -# HYPERPARAMETERS -# ----------------------------- -# Default Simple Baseline run: -# - 9 transformer blocks at width 512 -# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion -# - vocab size 1024, sequence length 1024, tied embeddings -# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap - -class Hyperparameters: - # Data paths are shard globs produced by the existing preprocessing pipeline. - data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") - train_files = os.path.join(data_path, "fineweb_train_*.bin") - val_files = os.path.join(data_path, "fineweb_val_*.bin") - tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") - run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) - seed = int(os.environ.get("SEED", 1337)) - - # Validation cadence and batch size. Validation always uses the full fineweb_val split. - val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) - val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) - train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) - - # Training length. - iterations = int(os.environ.get("ITERATIONS", 20000)) - warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) - warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) - train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 786_432)) - train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) - eval_seq_len = int(os.environ.get("EVAL_SEQ_LEN", 2048)) - max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) - qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) - - # Model shape. - vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) - num_layers = int(os.environ.get("NUM_LAYERS", 9)) - num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) - model_dim = int(os.environ.get("MODEL_DIM", 512)) - num_heads = int(os.environ.get("NUM_HEADS", 8)) - mlp_mult = float(os.environ.get("MLP_MULT", 3.0)) - tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) - rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) - logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) - - # Optimizer hyperparameters. - embed_lr = float(os.environ.get("EMBED_LR", 0.6)) - head_lr = float(os.environ.get("HEAD_LR", 0.008)) - tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) - tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) - matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) - scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) - muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) - muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) - muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) - muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) - beta1 = float(os.environ.get("BETA1", 0.9)) - beta2 = float(os.environ.get("BETA2", 0.95)) - adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) - grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) - eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) - mtp_num_heads = int(os.environ.get("MTP_NUM_HEADS", 0)) - mtp_loss_weight = float(os.environ.get("MTP_LOSS_WEIGHT", 0.2)) - muon_beta2 = float(os.environ.get("MUON_BETA2", 0.95)) - swa_enabled = bool(int(os.environ.get("SWA_ENABLED", "1"))) - swa_every = int(os.environ.get("SWA_EVERY", 200)) - muon_wd = float(os.environ.get("MUON_WD", 0.02)) - adam_wd = float(os.environ.get("ADAM_WD", 0.01)) - qat_enabled = bool(int(os.environ.get("QAT_ENABLED", "0"))) - xsa_last_n = int(os.environ.get("XSA_LAST_N", 0)) - ema_enabled = bool(int(os.environ.get("EMA_ENABLED", "0"))) - ema_decay = float(os.environ.get("EMA_DECAY", 0.997)) - bigram_vocab_size = int(os.environ.get("BIGRAM_VOCAB_SIZE", 4096)) - bigram_dim = int(os.environ.get("BIGRAM_DIM", 128)) - -# ----------------------------- -# MUON OPTIMIZER -# ----------------------------- -# -# As borrowed from modded-nanogpt -# Background on Muon: https://kellerjordan.github.io/posts/muon/ - -def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: - a, b, c = (3.4445, -4.7750, 2.0315) - X = G.bfloat16() - X /= X.norm() + eps - transposed = G.size(0) > G.size(1) - if transposed: - X = X.T - for _ in range(steps): - A = X @ X.T - B = b * A + c * A @ A - X = a * X + B @ X - return X.T if transposed else X - - -class Muon(torch.optim.Optimizer): - def __init__(self, params, lr: float, momentum: float, backend_steps: int, - nesterov: bool = True, weight_decay: float = 0.0): - super().__init__( - params, - dict(lr=lr, momentum=momentum, backend_steps=backend_steps, - nesterov=nesterov, weight_decay=weight_decay), - ) - - @torch.no_grad() - def step(self, closure=None): - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - distributed = dist.is_available() and dist.is_initialized() - world_size = dist.get_world_size() if distributed else 1 - rank = dist.get_rank() if distributed else 0 - - for group in self.param_groups: - params = group["params"] - if not params: - continue - lr = group["lr"] - momentum = group["momentum"] - backend_steps = group["backend_steps"] - nesterov = group["nesterov"] - - total_params = sum(int(p.numel()) for p in params) - updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) - - curr = 0 - for i, p in enumerate(params): - if i % world_size == rank and p.grad is not None: - g = p.grad - state = self.state[p] - if "momentum_buffer" not in state: - state["momentum_buffer"] = torch.zeros_like(g) - buf = state["momentum_buffer"] - buf.mul_(momentum).add_(g) - if nesterov: - g = g.add(buf, alpha=momentum) - g = zeropower_via_newtonschulz5(g, steps=backend_steps) - g *= max(1, g.size(0) / g.size(1)) ** 0.5 - updates_flat[curr : curr + p.numel()] = g.reshape(-1) - curr += p.numel() - - if distributed: - dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) - - wd = group.get("weight_decay", 0.0) - curr = 0 - for p in params: - if wd > 0.0: - p.data.mul_(1.0 - lr * wd) - g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) - p.add_(g, alpha=-lr) - curr += p.numel() - - return loss - - -# ----------------------------- -# TOKENIZER-AGNOSTIC EVALUATION SETUP -# ----------------------------- -# -# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. -# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. -# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. -# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. - -def build_sentencepiece_luts( - sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device -) -> tuple[Tensor, Tensor, Tensor]: - sp_vocab_size = int(sp.vocab_size()) - table_size = max(sp_vocab_size, vocab_size) - base_bytes_np = np.zeros((table_size,), dtype=np.int16) - has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) - is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) - for token_id in range(sp_vocab_size): - if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): - continue - is_boundary_token_np[token_id] = False - if sp.is_byte(token_id): - base_bytes_np[token_id] = 1 - continue - piece = sp.id_to_piece(token_id) - if piece.startswith("▁"): - has_leading_space_np[token_id] = True - piece = piece[1:] - base_bytes_np[token_id] = len(piece.encode("utf-8")) - return ( - torch.tensor(base_bytes_np, dtype=torch.int16, device=device), - torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), - torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), - ) - - -def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: - files = [Path(p) for p in sorted(glob.glob(pattern))] - if not files: - raise FileNotFoundError(f"No files found for pattern: {pattern}") - # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. - tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() - usable = ((tokens.numel() - 1) // seq_len) * seq_len - if usable <= 0: - raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") - return tokens[: usable + 1] - - -def eval_val( - args: Hyperparameters, - model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - grad_accum_steps: int, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, - eval_seq_len: int | None = None, -) -> tuple[float, float]: - seq_len = eval_seq_len or args.train_seq_len - local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) - if local_batch_tokens < seq_len: - raise ValueError( - "VAL_BATCH_SIZE must provide at least one sequence per rank; " - f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " - f"GRAD_ACCUM_STEPS={grad_accum_steps}, seq_len={seq_len}" - ) - local_batch_seqs = local_batch_tokens // seq_len - total_seqs = (val_tokens.numel() - 1) // seq_len - seq_start = (total_seqs * rank) // world_size - seq_end = (total_seqs * (rank + 1)) // world_size - val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) - val_token_count = torch.zeros((), device=device, dtype=torch.float64) - val_byte_count = torch.zeros((), device=device, dtype=torch.float64) - - model.eval() - with torch.inference_mode(): - for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): - batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) - raw_start = batch_seq_start * seq_len - raw_end = batch_seq_end * seq_len + 1 - local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) - x = local[:-1].reshape(-1, seq_len) - y = local[1:].reshape(-1, seq_len) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - batch_loss = model(x, y).detach() - batch_token_count = float(y.numel()) - val_loss_sum += batch_loss.to(torch.float64) * batch_token_count - val_token_count += batch_token_count - prev_ids = x.reshape(-1) - tgt_ids = y.reshape(-1) - token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) - token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) - val_byte_count += token_bytes.to(torch.float64).sum() - - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) - - val_loss = val_loss_sum / val_token_count - bits_per_token = val_loss.item() / math.log(2.0) - tokens_per_byte = val_token_count.item() / val_byte_count.item() - model.train() - return float(val_loss.item()), float(bits_per_token * tokens_per_byte) - -# ----------------------------- -# POST-TRAINING QUANTIZATION -# ----------------------------- -# -# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. -# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. -# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. - -CONTROL_TENSOR_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "CONTROL_TENSOR_NAME_PATTERNS", - "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,smear", - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", - ",".join(CONTROL_TENSOR_NAME_PATTERNS), - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 -INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 -INT8_PER_ROW_SCALE_DTYPE = torch.float16 -INT8_CLIP_PERCENTILE = 99.99984 -INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 - -def tensor_nbytes(t: Tensor) -> int: - return int(t.numel()) * int(t.element_size()) - -def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: - if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): - return t.float().contiguous() - if t.dtype in {torch.float32, torch.bfloat16}: - passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") - return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() - return t - -def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: - t32 = t.float() - if t32.ndim == 2: - # Matrices get one scale per row, which usually tracks output-channel - # ranges much better than a single tensor-wide scale. - clip_abs = ( - torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) - if t32.numel() - else torch.empty((t32.shape[0],), dtype=torch.float32) - ) - clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) - scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) - q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() - return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() - - # Vectors / scalars use a simpler per-tensor scale. - clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 - scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) - q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() - return q, scale - -def quantize_state_dict_int8(state_dict: dict[str, Tensor]): - # Single supported clean-script export format: - # - per-row int8 for 2D float tensors - # - per-tensor int8 for other float tensors - # - exact passthrough for non-floats - # - passthrough for small float tensors, stored as fp16 to save bytes - quantized: dict[str, Tensor] = {} - scales: dict[str, Tensor] = {} - dtypes: dict[str, str] = {} - passthrough: dict[str, Tensor] = {} - passthrough_orig_dtypes: dict[str, str] = {} - qmeta: dict[str, dict[str, object]] = {} - stats = dict.fromkeys( - ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), - 0, - ) - - for name, tensor in state_dict.items(): - t = tensor.detach().to("cpu").contiguous() - stats["param_count"] += int(t.numel()) - stats["num_tensors"] += 1 - stats["baseline_tensor_bytes"] += tensor_nbytes(t) - - if not t.is_floating_point(): - stats["num_nonfloat_tensors"] += 1 - passthrough[name] = t - stats["int8_payload_bytes"] += tensor_nbytes(t) - continue - - # Small float tensors are cheap enough to keep directly. We still downcast - # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. - if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: - kept = keep_float_tensor(name, t, passthrough_orig_dtypes) - passthrough[name] = kept - stats["int8_payload_bytes"] += tensor_nbytes(kept) - continue - - stats["num_float_tensors"] += 1 - q, s = quantize_float_tensor(t) - if s.ndim > 0: - qmeta[name] = {"scheme": "per_row", "axis": 0} - quantized[name] = q - scales[name] = s - dtypes[name] = str(t.dtype).removeprefix("torch.") - stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) - - obj: dict[str, object] = { - "__quant_format__": "int8_clean_per_row_v1", - "quantized": quantized, - "scales": scales, - "dtypes": dtypes, - "passthrough": passthrough, - } - if qmeta: - obj["qmeta"] = qmeta - if passthrough_orig_dtypes: - obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes - return obj, stats - -def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: - out: dict[str, Tensor] = {} - qmeta = obj.get("qmeta", {}) - passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) - for name, q in obj["quantized"].items(): - dtype = getattr(torch, obj["dtypes"][name]) - s = obj["scales"][name] - if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: - s = s.to(dtype=torch.float32) - # Broadcast the saved row scale back across trailing dimensions. - out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() - else: - scale = float(s.item()) - out[name] = (q.float() * scale).to(dtype=dtype).contiguous() - for name, t in obj["passthrough"].items(): - # Restore small tensors, undoing the temporary fp16 storage cast if needed. - out_t = t.detach().to("cpu").contiguous() - orig_dtype = passthrough_orig_dtypes.get(name) - if isinstance(orig_dtype, str): - out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() - out[name] = out_t - return out - - -# ----------------------------- -# DATA LOADING -# ----------------------------- - -def load_data_shard(file: Path) -> Tensor: - header_bytes = 256 * np.dtype(" None: - self.file_idx = (self.file_idx + 1) % len(self.files) - self.tokens = load_data_shard(self.files[self.file_idx]) - self.pos = 0 - - def take(self, n: int) -> Tensor: - chunks: list[Tensor] = [] - remaining = n - while remaining > 0: - avail = self.tokens.numel() - self.pos - if avail <= 0: - self._advance_file() - continue - k = min(remaining, avail) - chunks.append(self.tokens[self.pos : self.pos + k]) - self.pos += k - remaining -= k - return chunks[0] if len(chunks) == 1 else torch.cat(chunks) - - -class DistributedTokenLoader: - # Each call consumes a contiguous chunk from the shared token stream, then slices out - # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. - def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): - self.rank = rank - self.world_size = world_size - self.device = device - self.stream = TokenStream(pattern) - - def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: - local_tokens = global_tokens // (self.world_size * grad_accum_steps) - per_rank_span = local_tokens + 1 - chunk = self.stream.take(per_rank_span * self.world_size) - start = self.rank * per_rank_span - local = chunk[start : start + per_rank_span].to(dtype=torch.int64) - x = local[:-1].reshape(-1, seq_len) - y = local[1:].reshape(-1, seq_len) - return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) - -# ----------------------------- -# TRANSFORMER MODULES -# ----------------------------- - -class RMSNorm(nn.Module): - def __init__(self, eps: float | None = None): - super().__init__() - self.eps = eps - - def forward(self, x: Tensor) -> Tensor: - return F.rms_norm(x, (x.size(-1),), eps=self.eps) - - -class CastedLinear(nn.Linear): - _qat_enabled: bool = False - - def forward(self, x: Tensor) -> Tensor: - w = self.weight.to(x.dtype) - if CastedLinear._qat_enabled and self.training and w.ndim == 2: - with torch.no_grad(): - w32 = self.weight.float() - row_max = w32.abs().amax(dim=1) - scale = (row_max / 31.0).clamp_min(1.0 / 31.0) - w_q = (torch.clamp(torch.round(w32 / scale[:, None]), -32, 31) * scale[:, None]).to(x.dtype) - w = w + (w_q - w).detach() - bias = self.bias.to(x.dtype) if self.bias is not None else None - return F.linear(x, w, bias) - - -def restore_low_dim_params_to_fp32(module: nn.Module) -> None: - # Keep small/control parameters in fp32 even when the model body runs in bf16. - with torch.no_grad(): - for name, param in module.named_parameters(): - if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: - param.data = param.data.float() - - -class Rotary(nn.Module): - # NTK-aware RoPE: auto-scales base frequency when seq_len exceeds train_seq_len. - def __init__(self, dim: int, base: float = 10000.0, train_seq_len: int = 1024): - super().__init__() - self.dim = dim - self.base = base - self.train_seq_len = train_seq_len - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self._seq_len_cached = 0 - self._cos_cached: Tensor | None = None - self._sin_cached: Tensor | None = None - - def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: - if ( - self._cos_cached is None - or self._sin_cached is None - or self._seq_len_cached != seq_len - or self._cos_cached.device != device - ): - if seq_len > self.train_seq_len: - scale = seq_len / self.train_seq_len - new_base = self.base * (scale ** (self.dim / (self.dim - 2))) - inv_freq = 1.0 / (new_base ** (torch.arange(0, self.dim, 2, dtype=torch.float32, device=device) / self.dim)) - else: - inv_freq = self.inv_freq.to(device) - t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) - freqs = torch.outer(t, inv_freq) - self._cos_cached = freqs.cos()[None, :, None, :] - self._sin_cached = freqs.sin()[None, :, None, :] - self._seq_len_cached = seq_len - return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) - - -def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: - half = x.size(-1) // 2 - x1, x2 = x[..., :half], x[..., half:] - return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) - - -class CausalSelfAttention(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - rope_base: float, - qk_gain_init: float, - ): - super().__init__() - if dim % num_heads != 0: - raise ValueError("model_dim must be divisible by num_heads") - if num_heads % num_kv_heads != 0: - raise ValueError("num_heads must be divisible by num_kv_heads") - self.num_heads = num_heads - self.num_kv_heads = num_kv_heads - self.head_dim = dim // num_heads - if self.head_dim % 2 != 0: - raise ValueError("head_dim must be even for RoPE") - kv_dim = self.num_kv_heads * self.head_dim - self.c_q = CastedLinear(dim, dim, bias=False) - self.c_k = CastedLinear(dim, kv_dim, bias=False) - self.c_v = CastedLinear(dim, kv_dim, bias=False) - self.proj = CastedLinear(dim, dim, bias=False) - self.proj._zero_init = True - self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) - self.rotary = Rotary(self.head_dim, base=rope_base, train_seq_len=1024) - self.use_xsa = False - - def _xsa_efficient(self, y: Tensor, v: Tensor) -> Tensor: - """Subtract self-value projection via GQA-aware reshape (no repeat_interleave).""" - B, T, H, D = y.shape - Hkv = v.size(-2) - group = H // Hkv - y_g = y.reshape(B, T, Hkv, group, D) - vn = F.normalize(v, dim=-1).unsqueeze(-2) - proj = (y_g * vn).sum(dim=-1, keepdim=True) * vn - return (y_g - proj).reshape(B, T, H, D) - - def forward(self, x: Tensor) -> Tensor: - bsz, seqlen, dim = x.shape - q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim) - k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) - v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) - q = F.rms_norm(q, (q.size(-1),)) - k = F.rms_norm(k, (k.size(-1),)) - cos, sin = self.rotary(seqlen, x.device, q.dtype) - q = apply_rotary_emb(q, cos, sin) - k = apply_rotary_emb(k, cos, sin) - q = q * self.q_gain.to(dtype=q.dtype)[None, None, :, None] - y = flash_attn_3_func(q, k, v, causal=True) - if self.use_xsa: - y = self._xsa_efficient(y, v) - y = y.reshape(bsz, seqlen, dim) - return self.proj(y) - - -class SmearGate(nn.Module): - def __init__(self, dim: int): - super().__init__() - self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) - - def forward(self, x: Tensor) -> Tensor: - g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] - x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) - return (1 - g) * x + g * x_prev - - -class BigramHashEmbedding(nn.Module): - def __init__(self, bigram_vocab_size: int, bigram_dim: int, model_dim: int): - super().__init__() - self.bigram_vocab_size = bigram_vocab_size - self.embed = nn.Embedding(bigram_vocab_size, bigram_dim) - nn.init.zeros_(self.embed.weight) - self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None - if self.proj is not None: - nn.init.zeros_(self.proj.weight) - self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) - - def bigram_hash(self, tokens: Tensor) -> Tensor: - t = tokens.to(torch.int32) - mod = self.bigram_vocab_size - 1 - out = torch.empty_like(t) - out[..., 0] = mod - out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod - return out.long() - - def forward(self, token_ids: Tensor) -> Tensor: - h = self.embed(self.bigram_hash(token_ids)) - if self.proj is not None: - h = self.proj(h) - return h * self.scale.to(dtype=h.dtype) - - -class MLP(nn.Module): - def __init__(self, dim: int, mlp_mult: int): - super().__init__() - hidden = int(mlp_mult * dim) - self.fc = CastedLinear(dim, hidden, bias=False) - self.proj = CastedLinear(hidden, dim, bias=False) - self.proj._zero_init = True - - def forward(self, x: Tensor) -> Tensor: - x = torch.relu(self.fc(x)) - return self.proj(x.square()) - - -class Block(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: int, - rope_base: float, - qk_gain_init: float, - ): - super().__init__() - self.attn_norm = RMSNorm() - self.mlp_norm = RMSNorm() - self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) - self.mlp = MLP(dim, mlp_mult) - self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) - - def forward(self, x: Tensor, x0: Tensor) -> Tensor: - mix = self.resid_mix.to(dtype=x.dtype) - x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 - attn_out = self.attn(self.attn_norm(x)) - x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out - x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) - return x - - -class GPT(nn.Module): - def __init__( - self, - vocab_size: int, - num_layers: int, - model_dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: int, - tie_embeddings: bool, - tied_embed_init_std: float, - logit_softcap: float, - rope_base: float, - qk_gain_init: float, - mtp_num_heads: int = 0, - mtp_loss_weight: float = 0.1, - bigram_vocab_size: int = 0, - bigram_dim: int = 128, - xsa_last_n: int = 0, - ): - super().__init__() - if logit_softcap <= 0.0: - raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") - self.tie_embeddings = tie_embeddings - self.tied_embed_init_std = tied_embed_init_std - self.logit_softcap = logit_softcap - self.mtp_num_heads = mtp_num_heads - self.mtp_loss_weight = mtp_loss_weight - self.tok_emb = nn.Embedding(vocab_size, model_dim) - self.bigram = BigramHashEmbedding(bigram_vocab_size, bigram_dim, model_dim) if bigram_vocab_size > 0 else None - self.smear = SmearGate(model_dim) - self.num_encoder_layers = num_layers // 2 - self.num_decoder_layers = num_layers - self.num_encoder_layers - self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) - self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) - self.blocks = nn.ModuleList( - [ - Block( - model_dim, - num_heads, - num_kv_heads, - mlp_mult, - rope_base, - qk_gain_init, - ) - for i in range(num_layers) - ] - ) - self.final_norm = RMSNorm() - self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) - if self.lm_head is not None: - self.lm_head._zero_init = True - self.mtp_heads = nn.ModuleList( - [CastedLinear(model_dim, vocab_size, bias=False) for _ in range(mtp_num_heads)] - ) - for head in self.mtp_heads: - head._zero_init = True - if xsa_last_n > 0: - for i in range(max(0, num_layers - xsa_last_n), num_layers): - self.blocks[i].attn.use_xsa = True - self._init_weights() - - def _init_weights(self) -> None: - if self.tie_embeddings: - nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) - num_layers = len(self.blocks) - for name, module in self.named_modules(): - if isinstance(module, nn.Linear): - if getattr(module, "_zero_init", False): - nn.init.zeros_(module.weight) - elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: - nn.init.orthogonal_(module.weight, gain=1.0) - if ".proj." in name or name.endswith(".proj"): - with torch.no_grad(): - module.weight.mul_(1.0 / math.sqrt(2 * num_layers)) - - def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: - x = self.tok_emb(input_ids) - if self.bigram is not None: - x = x + self.bigram(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - x = self.smear(x) - x0 = x - skips: list[Tensor] = [] - - for i in range(self.num_encoder_layers): - x = self.blocks[i](x, x0) - skips.append(x) - for i in range(self.num_decoder_layers): - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - x = self.blocks[self.num_encoder_layers + i](x, x0) - - x = self.final_norm(x) - x_flat = x.reshape(-1, x.size(-1)) - targets = target_ids.reshape(-1) - if self.tie_embeddings: - logits_proj = F.linear(x_flat, self.tok_emb.weight) - else: - if self.lm_head is None: - raise RuntimeError("lm_head is required when tie_embeddings=False") - logits_proj = self.lm_head(x_flat) - logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - main_loss = F.cross_entropy(logits.float(), targets, reduction="mean") - - if self.training and self.mtp_num_heads > 0 and self.mtp_loss_weight > 0.0: - _, seqlen, dim = x.shape - mtp_loss_sum = x.new_zeros(()) - mtp_loss_count = 0 - for k, mtp_head in enumerate(self.mtp_heads): - valid_t = seqlen - (k + 1) - if valid_t <= 0: - continue - mtp_hidden = x[:, :valid_t, :].reshape(-1, dim) - mtp_targets = target_ids[:, k + 1 :].reshape(-1) - mtp_logits_proj = mtp_head(mtp_hidden) - mtp_logits = self.logit_softcap * torch.tanh(mtp_logits_proj / self.logit_softcap) - mtp_loss_sum = mtp_loss_sum + F.cross_entropy(mtp_logits.float(), mtp_targets, reduction="mean") - mtp_loss_count += 1 - if mtp_loss_count > 0: - main_loss = main_loss + self.mtp_loss_weight * (mtp_loss_sum / mtp_loss_count) - - return main_loss - - def forward_logits(self, input_ids: Tensor) -> Tensor: - """Return logits (bsz, seq_len, vocab) without computing loss.""" - x = self.tok_emb(input_ids) - if self.bigram is not None: - x = x + self.bigram(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - x = self.smear(x) - x0 = x - skips: list[Tensor] = [] - for i in range(self.num_encoder_layers): - x = self.blocks[i](x, x0) - skips.append(x) - for i in range(self.num_decoder_layers): - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - x = self.blocks[self.num_encoder_layers + i](x, x0) - x = self.final_norm(x) - if self.tie_embeddings: - logits_proj = F.linear(x, self.tok_emb.weight) - else: - logits_proj = self.lm_head(x) - return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - - -# ----------------------------- -# SLIDING WINDOW EVALUATION -# ----------------------------- - -def eval_val_sliding( - args: Hyperparameters, - base_model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, - stride: int, - batch_seqs: int = 32, - eval_seq_len: int | None = None, -) -> tuple[float, float]: - """Sliding window evaluation: each token scored with maximum context.""" - seq_len = eval_seq_len or args.train_seq_len - total_tokens = val_tokens.numel() - 1 - - window_starts = [ws for ws in range(0, total_tokens, stride) - if min(ws + seq_len, total_tokens) - ws >= 1] - total_windows = len(window_starts) - - my_s = (total_windows * rank) // world_size - my_e = (total_windows * (rank + 1)) // world_size - my_windows = window_starts[my_s:my_e] - - loss_sum = torch.zeros((), device=device, dtype=torch.float64) - token_count = torch.zeros((), device=device, dtype=torch.float64) - byte_count = torch.zeros((), device=device, dtype=torch.float64) - - base_model.eval() - compiled_logits = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) - - with torch.inference_mode(): - for bi in range(0, len(my_windows), batch_seqs): - batch_ws = my_windows[bi:bi + batch_seqs] - bsz = len(batch_ws) - - x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) - y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) - wlens: list[int] = [] - - for i, ws in enumerate(batch_ws): - end = min(ws + seq_len, total_tokens) - wlen = end - ws - wlens.append(wlen) - chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) - x_batch[i, :wlen] = chunk[:-1] - y_batch[i, :wlen] = chunk[1:] - - with torch.autocast(device_type="cuda", dtype=torch.bfloat16): - logits = compiled_logits(x_batch) - - nll = F.cross_entropy( - logits.reshape(-1, logits.size(-1)).float(), - y_batch.reshape(-1), - reduction="none", - ).reshape(bsz, seq_len) - - for i, ws in enumerate(batch_ws): - wlen = wlens[i] - s = 0 if ws == 0 else max(wlen - stride, 0) - scored_nll = nll[i, s:wlen].to(torch.float64) - loss_sum += scored_nll.sum() - token_count += float(wlen - s) - tgt = y_batch[i, s:wlen] - prev = x_batch[i, s:wlen] - tb = base_bytes_lut[tgt].to(torch.float64) - tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) - byte_count += tb.sum() - - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) - - val_loss = (loss_sum / token_count).item() - bits_per_token = val_loss / math.log(2.0) - tokens_per_byte = token_count.item() / byte_count.item() - base_model.train() - return val_loss, bits_per_token * tokens_per_byte - - -# ----------------------------- -# INT6 MIXED QUANTIZATION (transplanted from working diagnostic scripts) -# ----------------------------- - -def _classify_param(name: str) -> str: - if "tok_emb" in name or "lm_head" in name: - return "embed" - if ".mlp." in name: - return "mlp" - if ".attn." in name or (".proj." in name and ".mlp." not in name): - return "attn" - return "other" - -def quantize_int6_per_row(t: Tensor) -> tuple[Tensor, Tensor]: - t32 = t.float() - if t32.ndim == 2: - row_max = t32.abs().amax(dim=1) - scale = (row_max / 31.0).clamp_min(1.0 / 31.0).to(torch.float16) - q = torch.clamp(torch.round(t32 / scale.float()[:, None]), -32, 31).to(torch.int8) - return q, scale - amax = t32.abs().max().item() - scale = torch.tensor(amax / 31.0 if amax > 0 else 1.0, dtype=torch.float16) - q = torch.clamp(torch.round(t32 / scale.float()), -32, 31).to(torch.int8) - return q, scale - -def mixed_quantize_int6(state_dict: dict[str, Tensor], int6_cats: set[str]): - num_layers_total = max( - (int(k.split(".")[1]) for k in state_dict if k.startswith("blocks.")), - default=0, - ) + 1 - late_k_layers = set(range(num_layers_total - 2, num_layers_total)) - - result: dict[str, Tensor] = {} - meta: dict[str, object] = {} - for name, tensor in state_dict.items(): - t = tensor.detach().cpu().contiguous() - cat = _classify_param(name) - if not t.is_floating_point() or t.numel() <= 65536: - result[name] = t.to(torch.float16) if t.is_floating_point() else t - meta[name] = "passthrough" - continue - if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): - result[name] = t.float() - meta[name] = "passthrough_ctrl" - continue - # tok_emb.weight falls through to int8 via "embed" category - if cat in int6_cats and t.ndim >= 1: - q, s = quantize_int6_per_row(t) - result[name + ".q"] = q - result[name + ".scale"] = s - meta[name] = {"type": "int6"} - else: - q, s = quantize_float_tensor(t) - result[name + ".q"] = q - result[name + ".scale"] = s - meta[name] = {"type": "int8"} - return result, meta - -def dequantize_mixed_int6(result: dict[str, Tensor], meta: dict[str, object], - template_sd: dict[str, Tensor]) -> dict[str, Tensor]: - out: dict[str, Tensor] = {} - for name, orig in template_sd.items(): - info = meta.get(name) - if info is None: - continue - orig_dtype = orig.dtype - if info in ("passthrough", "passthrough_ctrl", "passthrough_fp16"): - t = result[name] - if t.dtype == torch.float16 and orig_dtype in (torch.float32, torch.bfloat16): - t = t.to(orig_dtype) - out[name] = t - continue - q, s = result[name + ".q"], result[name + ".scale"] - if s.ndim > 0: - out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig_dtype) - else: - out[name] = (q.float() * float(s.item())).to(orig_dtype) - return out - - -# ----------------------------- -# TRAINING -# ----------------------------- - -def main() -> None: - global zeropower_via_newtonschulz5 - - code = Path(__file__).read_text(encoding="utf-8") - args = Hyperparameters() - zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) - - # ----------------------------- - # DISTRIBUTED + CUDA SETUP - # ----------------------------- - - distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ - rank = int(os.environ.get("RANK", "0")) - world_size = int(os.environ.get("WORLD_SIZE", "1")) - local_rank = int(os.environ.get("LOCAL_RANK", "0")) - if world_size <= 0: - raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") - if 8 % world_size != 0: - raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") - grad_accum_steps = 8 // world_size - grad_scale = 1.0 / grad_accum_steps - if not torch.cuda.is_available(): - raise RuntimeError("CUDA is required") - device = torch.device("cuda", local_rank) - torch.cuda.set_device(device) - if distributed: - dist.init_process_group(backend="nccl", device_id=device) - dist.barrier() - master_process = rank == 0 - - # Fast math knobs - torch.backends.cuda.matmul.allow_tf32 = True - torch.backends.cudnn.allow_tf32 = True - from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp - - enable_cudnn_sdp(False) - enable_flash_sdp(True) - enable_mem_efficient_sdp(False) - enable_math_sdp(False) - - logfile = None - if master_process: - os.makedirs("logs", exist_ok=True) - logfile = f"logs/{args.run_id}.txt" - print(logfile) - - def log0(msg: str, console: bool = True) -> None: - if not master_process: - return - if console: - print(msg) - if logfile is not None: - with open(logfile, "a", encoding="utf-8") as f: - print(msg, file=f) - - log0(code, console=False) - log0("=" * 100, console=False) - log0(f"Running Python {sys.version}", console=False) - log0(f"Running PyTorch {torch.__version__}", console=False) - log0( - subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, - console=False, - ) - log0("=" * 100, console=False) - - # ----------------------------- - # TOKENIZER + VALIDATION METRIC SETUP - # ----------------------------- - - random.seed(args.seed) - np.random.seed(args.seed) - torch.manual_seed(args.seed) - torch.cuda.manual_seed_all(args.seed) - - if not args.tokenizer_path.endswith(".model"): - raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") - sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) - if int(sp.vocab_size()) != args.vocab_size: - raise ValueError( - f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" - ) - dataset_dir = Path(args.data_path).resolve() - actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) - effective_eval_seq_len = args.eval_seq_len if args.eval_seq_len > 0 else args.train_seq_len - val_seq_len = max(args.train_seq_len, effective_eval_seq_len) - val_tokens = load_validation_tokens(args.val_files, val_seq_len) - base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( - sp, args.vocab_size, device - ) - log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") - log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") - log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") - - # ----------------------------- - # MODEL + OPTIMIZER SETUP - # ----------------------------- - - CastedLinear._qat_enabled = args.qat_enabled - - base_model = GPT( - vocab_size=args.vocab_size, - num_layers=args.num_layers, - model_dim=args.model_dim, - num_heads=args.num_heads, - num_kv_heads=args.num_kv_heads, - mlp_mult=args.mlp_mult, - tie_embeddings=args.tie_embeddings, - tied_embed_init_std=args.tied_embed_init_std, - logit_softcap=args.logit_softcap, - rope_base=args.rope_base, - qk_gain_init=args.qk_gain_init, - mtp_num_heads=args.mtp_num_heads, - mtp_loss_weight=args.mtp_loss_weight, - bigram_vocab_size=args.bigram_vocab_size, - bigram_dim=args.bigram_dim, - xsa_last_n=args.xsa_last_n, - ).to(device).bfloat16() - for module in base_model.modules(): - if isinstance(module, CastedLinear): - module.float() - restore_low_dim_params_to_fp32(base_model) - compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) - model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model - - # Optimizer split: - # - token embedding (Adam) uses EMBED_LR - # - untied lm_head (Adam) uses HEAD_LR - # - matrix params in transformer blocks use MATRIX_LR via Muon - # - vectors/scalars use SCALAR_LR via Adam - block_named_params = list(base_model.blocks.named_parameters()) - matrix_params = [ - p - for name, p in block_named_params - if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - if base_model.mtp_num_heads > 0: - matrix_params.extend([p for p in base_model.mtp_heads.parameters() if p.ndim == 2]) - scalar_params = [ - p - for name, p in block_named_params - if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - if base_model.skip_weights.numel() > 0: - scalar_params.append(base_model.skip_weights) - scalar_params.append(base_model.smear.gate) - if base_model.bigram is not None: - scalar_params.append(base_model.bigram.scale) - token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr - tok_params = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] - if base_model.bigram is not None: - tok_params.append({"params": [base_model.bigram.embed.weight], "lr": token_lr, "base_lr": token_lr}) - if base_model.bigram.proj is not None: - matrix_params.append(base_model.bigram.proj.weight) - optimizer_tok = torch.optim.AdamW( - tok_params, - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - weight_decay=args.adam_wd, - fused=True, - ) - optimizer_muon = Muon( - matrix_params, - lr=args.matrix_lr, - momentum=args.muon_momentum, - backend_steps=args.muon_backend_steps, - weight_decay=args.muon_wd, - ) - for group in optimizer_muon.param_groups: - group["base_lr"] = args.matrix_lr - optimizer_scalar = torch.optim.AdamW( - [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - weight_decay=args.adam_wd, - fused=True, - ) - optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] - if base_model.lm_head is not None: - optimizer_head = torch.optim.Adam( - [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizers.insert(1, optimizer_head) - - n_params = sum(p.numel() for p in base_model.parameters()) - mtp_params = sum(p.numel() for p in base_model.mtp_heads.parameters()) - log0(f"model_params:{n_params}") - log0(f"mtp_num_heads:{args.mtp_num_heads} mtp_loss_weight:{args.mtp_loss_weight} mtp_params:{mtp_params}") - log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") - log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") - log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") - log0( - f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " - f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " - f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" - ) - log0( - f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " - f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " - f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" - ) - log0(f"seed:{args.seed}") - - # ----------------------------- - # DATA LOADER & MODEL WARMUP - # ----------------------------- - - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - def zero_grad_all() -> None: - for opt in optimizers: - opt.zero_grad(set_to_none=True) - - max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None - - def lr_mul(step: int, elapsed_ms: float) -> float: - if args.warmdown_iters <= 0: - return 1.0 - if max_wallclock_ms is None: - warmdown_start = max(args.iterations - args.warmdown_iters, 0) - return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 - step_ms = elapsed_ms / max(step, 1) - warmdown_ms = args.warmdown_iters * step_ms - remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) - return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 - - # Warmup primes the compiled forward/backward/optimizer paths, then we restore the - # initial weights/optimizer state so measured training starts from the true init. - if args.warmup_steps > 0: - initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} - initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] - model.train() - for warmup_step in range(args.warmup_steps): - zero_grad_all() - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - warmup_loss = model(x, y) - (warmup_loss * grad_scale).backward() - for opt in optimizers: - opt.step() - zero_grad_all() - if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: - log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") - base_model.load_state_dict(initial_model_state, strict=True) - for opt, state in zip(optimizers, initial_optimizer_states, strict=True): - opt.load_state_dict(state) - zero_grad_all() - if distributed: - model.require_backward_grad_sync = True - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - # ----------------------------- - # MAIN TRAINING LOOP - # ----------------------------- - - swa_state: dict[str, Tensor] | None = None - swa_count = 0 - - ema_state: dict[str, Tensor] | None = None - if args.ema_enabled: - ema_state = {name: t.detach().float().clone() for name, t in base_model.state_dict().items()} - - training_time_ms = 0.0 - stop_after_step: int | None = None - torch.cuda.synchronize() - t0 = time.perf_counter() - - step = 0 - while True: - last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) - - should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) - if should_validate: - torch.cuda.synchronize() - training_time_ms += 1000.0 * (time.perf_counter() - t0) - val_loss, val_bpb = eval_val( - args, - model, - rank, - world_size, - device, - grad_accum_steps, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, - ) - log0( - f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " - f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" - ) - torch.cuda.synchronize() - t0 = time.perf_counter() - - if last_step: - if stop_after_step is not None and step < args.iterations: - log0( - f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " - f"step:{step}/{args.iterations}" - ) - break - - elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - scale = lr_mul(step, elapsed_ms) - zero_grad_all() - train_loss = torch.zeros((), device=device) - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - loss = model(x, y) - train_loss += loss.detach() - (loss * grad_scale).backward() - train_loss /= grad_accum_steps - - frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 - muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum - for group in optimizer_muon.param_groups: - group["momentum"] = muon_momentum - - for opt in optimizers: - for group in opt.param_groups: - group["lr"] = group["base_lr"] * scale - - if args.grad_clip_norm > 0: - torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) - for opt in optimizers: - opt.step() - zero_grad_all() - - step += 1 - approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - - if ema_state is not None: - d = args.ema_decay - with torch.no_grad(): - for name, t in base_model.state_dict().items(): - ema_state[name].mul_(d).add_(t.detach().float(), alpha=1.0 - d) - - if args.swa_enabled and not args.ema_enabled and scale < 0.5 and step % args.swa_every == 0: - if swa_state is None: - swa_state = {name: t.detach().float().clone() for name, t in base_model.state_dict().items()} - swa_count = 1 - log0(f"swa:start step:{step}") - else: - for name, t in base_model.state_dict().items(): - swa_state[name].add_(t.detach().float()) - swa_count += 1 - - should_log_train = ( - args.train_log_every > 0 - and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) - ) - if should_log_train: - log0( - f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " - f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" - ) - - # Needed to sync whether we've reached the wallclock cap. - reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms - if distributed and max_wallclock_ms is not None: - reached_cap_tensor = torch.tensor(int(reached_cap), device=device) - dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) - reached_cap = bool(reached_cap_tensor.item()) - if stop_after_step is None and reached_cap: - stop_after_step = step - - log0( - f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " - f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" - ) - - if ema_state is not None: - log0("ema:applying EMA weights") - avg_state = {name: t.to(dtype=base_model.state_dict()[name].dtype) - for name, t in ema_state.items()} - del ema_state - base_model.load_state_dict(avg_state, strict=True) - elif args.swa_enabled and swa_state is not None and swa_count > 1: - log0(f"swa:applying averaged {swa_count} checkpoints") - avg_state = {name: (t / swa_count).to(dtype=base_model.state_dict()[name].dtype) - for name, t in swa_state.items()} - del swa_state - base_model.load_state_dict(avg_state, strict=True) - - # ----------------------------- - # SERIALIZATION + ROUNDTRIP VALIDATION - # ----------------------------- - - full_state_dict = base_model.state_dict() - export_sd = {k: v for k, v in full_state_dict.items() if "mtp_heads" not in k} - excluded_mtp = sum(int(t.numel()) for k, t in full_state_dict.items() if "mtp_heads" in k) - if excluded_mtp > 0: - log0(f"export_excluding_mtp_params:{excluded_mtp}") - - if master_process: - torch.save(export_sd, "final_model.pt") - model_bytes = os.path.getsize("final_model.pt") - code_bytes = len(code.encode("utf-8")) - log0(f"Serialized model: {model_bytes} bytes") - log0(f"Code size: {code_bytes} bytes") - - sd_cpu = {k: v.detach().cpu() for k, v in export_sd.items()} - quant_result, quant_meta = mixed_quantize_int6(sd_cpu, {"mlp", "attn"}) - quant_buf = io.BytesIO() - torch.save({"w": quant_result, "m": quant_meta}, quant_buf) - quant_raw = quant_buf.getvalue() - quant_blob = zstandard.ZstdCompressor(level=22).compress(quant_raw) if _COMPRESSOR == "zstd" else zlib.compress(quant_raw, 9) - if master_process: - with open("final_model.int6.ptz", "wb") as f: - f.write(quant_blob) - quant_file_bytes = len(quant_blob) - code_bytes = len(code.encode("utf-8")) - log0(f"Serialized model int6+{_COMPRESSOR}: {quant_file_bytes} bytes") - log0(f"Total submission size int6+{_COMPRESSOR}: {quant_file_bytes + code_bytes} bytes") - - # Roundtrip: decompress + dequantize into fresh model + eval - if distributed: - dist.barrier() - with open("final_model.int6.ptz", "rb") as f: - quant_blob_disk = f.read() - quant_state = torch.load( - io.BytesIO(zstandard.ZstdDecompressor().decompress(quant_blob_disk) if _COMPRESSOR == "zstd" else zlib.decompress(quant_blob_disk)), - map_location="cpu", - ) - deq_state = dequantize_mixed_int6(quant_state["w"], quant_state["m"], sd_cpu) - - eval_model = GPT( - vocab_size=args.vocab_size, num_layers=args.num_layers, model_dim=args.model_dim, - num_heads=args.num_heads, num_kv_heads=args.num_kv_heads, mlp_mult=args.mlp_mult, - tie_embeddings=args.tie_embeddings, tied_embed_init_std=args.tied_embed_init_std, - logit_softcap=args.logit_softcap, rope_base=args.rope_base, qk_gain_init=args.qk_gain_init, - mtp_num_heads=0, mtp_loss_weight=0.0, - bigram_vocab_size=args.bigram_vocab_size, bigram_dim=args.bigram_dim, - xsa_last_n=args.xsa_last_n, - ).to(device).bfloat16() - for m in eval_model.modules(): - if isinstance(m, CastedLinear): - m.float() - restore_low_dim_params_to_fp32(eval_model) - eval_model.load_state_dict(deq_state, strict=True) - compiled_eval = torch.compile(eval_model, dynamic=False, fullgraph=True) - - # Standard non-overlapping eval (sanity check) - torch.cuda.synchronize() - t_qeval = time.perf_counter() - q_val_loss, q_val_bpb = eval_val( - args, compiled_eval, rank, world_size, device, grad_accum_steps, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - eval_seq_len=effective_eval_seq_len, - ) - torch.cuda.synchronize() - log0( - f"final_int6_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " - f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" - ) - log0(f"final_int6_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") - - # Sliding window eval (submission score) - sw_seq_len = effective_eval_seq_len - if args.eval_stride > 0 and args.eval_stride < sw_seq_len: - torch.cuda.synchronize() - t_slide = time.perf_counter() - sw_val_loss, sw_val_bpb = eval_val_sliding( - args, eval_model, rank, world_size, device, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - stride=args.eval_stride, - eval_seq_len=sw_seq_len, - ) - torch.cuda.synchronize() - log0( - f"final_int6_sliding_window val_loss:{sw_val_loss:.4f} val_bpb:{sw_val_bpb:.4f} " - f"stride:{args.eval_stride} eval_time:{1000.0 * (time.perf_counter() - t_slide):.0f}ms" - ) - log0(f"final_int6_sliding_window_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") - - # Second sliding window eval at stride=64 for submission comparison - if args.eval_stride != 64 and 64 < sw_seq_len: - torch.cuda.synchronize() - t_slide64 = time.perf_counter() - sw64_val_loss, sw64_val_bpb = eval_val_sliding( - args, eval_model, rank, world_size, device, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - stride=64, - eval_seq_len=sw_seq_len, - ) - torch.cuda.synchronize() - log0( - f"final_int6_sliding_window_s64 val_loss:{sw64_val_loss:.4f} val_bpb:{sw64_val_bpb:.4f} " - f"stride:64 eval_time:{1000.0 * (time.perf_counter() - t_slide64):.0f}ms" - ) - log0(f"final_int6_sliding_window_s64_exact val_loss:{sw64_val_loss:.8f} val_bpb:{sw64_val_bpb:.8f}") - - if distributed: - dist.destroy_process_group() - - -if __name__ == "__main__": - main() diff --git a/records/track_10min_16mb/2026-03-20_11L_XSA4_EMA_Int6_MLP3x_WD04_1.1271/train_seed1337.log b/records/track_10min_16mb/2026-03-20_11L_XSA4_EMA_Int6_MLP3x_WD04_1.1271/train_seed1337.log deleted file mode 100644 index f2c48a5ae8..0000000000 --- a/records/track_10min_16mb/2026-03-20_11L_XSA4_EMA_Int6_MLP3x_WD04_1.1271/train_seed1337.log +++ /dev/null @@ -1,1704 +0,0 @@ -""" -train_gpt_submit.py — Submission v2: wider MLP + STE int6 QAT + MTP + seq2048 + NTK RoPE + -fp16 embed + late-K passthrough + sliding window eval. -""" - -from __future__ import annotations - -import copy -import glob -import io -import math -import os -import random -import subprocess -import sys -import time -import uuid -import zlib -from pathlib import Path - -try: - import zstandard - _COMPRESSOR = "zstd" -except ImportError: - _COMPRESSOR = "zlib" - -import numpy as np -import sentencepiece as spm -import torch -import torch.distributed as dist -import torch.nn.functional as F -from torch import Tensor, nn -from torch.nn.parallel import DistributedDataParallel as DDP - -from flash_attn_interface import flash_attn_func as flash_attn_3_func - -# ----------------------------- -# HYPERPARAMETERS -# ----------------------------- -# Default Simple Baseline run: -# - 9 transformer blocks at width 512 -# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion -# - vocab size 1024, sequence length 1024, tied embeddings -# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap - -class Hyperparameters: - # Data paths are shard globs produced by the existing preprocessing pipeline. - data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") - train_files = os.path.join(data_path, "fineweb_train_*.bin") - val_files = os.path.join(data_path, "fineweb_val_*.bin") - tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") - run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) - seed = int(os.environ.get("SEED", 1337)) - - # Validation cadence and batch size. Validation always uses the full fineweb_val split. - val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) - val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) - train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) - - # Training length. - iterations = int(os.environ.get("ITERATIONS", 20000)) - warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) - warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) - train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 786_432)) - train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) - eval_seq_len = int(os.environ.get("EVAL_SEQ_LEN", 2048)) - max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) - qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) - - # Model shape. - vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) - num_layers = int(os.environ.get("NUM_LAYERS", 9)) - num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) - model_dim = int(os.environ.get("MODEL_DIM", 512)) - num_heads = int(os.environ.get("NUM_HEADS", 8)) - mlp_mult = float(os.environ.get("MLP_MULT", 3.0)) - tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) - rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) - logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) - - # Optimizer hyperparameters. - embed_lr = float(os.environ.get("EMBED_LR", 0.6)) - head_lr = float(os.environ.get("HEAD_LR", 0.008)) - tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) - tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) - matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) - scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) - muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) - muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) - muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) - muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) - beta1 = float(os.environ.get("BETA1", 0.9)) - beta2 = float(os.environ.get("BETA2", 0.95)) - adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) - grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) - eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) - mtp_num_heads = int(os.environ.get("MTP_NUM_HEADS", 0)) - mtp_loss_weight = float(os.environ.get("MTP_LOSS_WEIGHT", 0.2)) - muon_beta2 = float(os.environ.get("MUON_BETA2", 0.95)) - swa_enabled = bool(int(os.environ.get("SWA_ENABLED", "1"))) - swa_every = int(os.environ.get("SWA_EVERY", 200)) - muon_wd = float(os.environ.get("MUON_WD", 0.02)) - adam_wd = float(os.environ.get("ADAM_WD", 0.01)) - qat_enabled = bool(int(os.environ.get("QAT_ENABLED", "0"))) - xsa_last_n = int(os.environ.get("XSA_LAST_N", 0)) - ema_enabled = bool(int(os.environ.get("EMA_ENABLED", "0"))) - ema_decay = float(os.environ.get("EMA_DECAY", 0.997)) - bigram_vocab_size = int(os.environ.get("BIGRAM_VOCAB_SIZE", 4096)) - bigram_dim = int(os.environ.get("BIGRAM_DIM", 128)) - -# ----------------------------- -# MUON OPTIMIZER -# ----------------------------- -# -# As borrowed from modded-nanogpt -# Background on Muon: https://kellerjordan.github.io/posts/muon/ - -def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: - a, b, c = (3.4445, -4.7750, 2.0315) - X = G.bfloat16() - X /= X.norm() + eps - transposed = G.size(0) > G.size(1) - if transposed: - X = X.T - for _ in range(steps): - A = X @ X.T - B = b * A + c * A @ A - X = a * X + B @ X - return X.T if transposed else X - - -class Muon(torch.optim.Optimizer): - def __init__(self, params, lr: float, momentum: float, backend_steps: int, - nesterov: bool = True, weight_decay: float = 0.0): - super().__init__( - params, - dict(lr=lr, momentum=momentum, backend_steps=backend_steps, - nesterov=nesterov, weight_decay=weight_decay), - ) - - @torch.no_grad() - def step(self, closure=None): - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - distributed = dist.is_available() and dist.is_initialized() - world_size = dist.get_world_size() if distributed else 1 - rank = dist.get_rank() if distributed else 0 - - for group in self.param_groups: - params = group["params"] - if not params: - continue - lr = group["lr"] - momentum = group["momentum"] - backend_steps = group["backend_steps"] - nesterov = group["nesterov"] - - total_params = sum(int(p.numel()) for p in params) - updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) - - curr = 0 - for i, p in enumerate(params): - if i % world_size == rank and p.grad is not None: - g = p.grad - state = self.state[p] - if "momentum_buffer" not in state: - state["momentum_buffer"] = torch.zeros_like(g) - buf = state["momentum_buffer"] - buf.mul_(momentum).add_(g) - if nesterov: - g = g.add(buf, alpha=momentum) - g = zeropower_via_newtonschulz5(g, steps=backend_steps) - g *= max(1, g.size(0) / g.size(1)) ** 0.5 - updates_flat[curr : curr + p.numel()] = g.reshape(-1) - curr += p.numel() - - if distributed: - dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) - - wd = group.get("weight_decay", 0.0) - curr = 0 - for p in params: - if wd > 0.0: - p.data.mul_(1.0 - lr * wd) - g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) - p.add_(g, alpha=-lr) - curr += p.numel() - - return loss - - -# ----------------------------- -# TOKENIZER-AGNOSTIC EVALUATION SETUP -# ----------------------------- -# -# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. -# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. -# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. -# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. - -def build_sentencepiece_luts( - sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device -) -> tuple[Tensor, Tensor, Tensor]: - sp_vocab_size = int(sp.vocab_size()) - table_size = max(sp_vocab_size, vocab_size) - base_bytes_np = np.zeros((table_size,), dtype=np.int16) - has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) - is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) - for token_id in range(sp_vocab_size): - if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): - continue - is_boundary_token_np[token_id] = False - if sp.is_byte(token_id): - base_bytes_np[token_id] = 1 - continue - piece = sp.id_to_piece(token_id) - if piece.startswith("▁"): - has_leading_space_np[token_id] = True - piece = piece[1:] - base_bytes_np[token_id] = len(piece.encode("utf-8")) - return ( - torch.tensor(base_bytes_np, dtype=torch.int16, device=device), - torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), - torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), - ) - - -def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: - files = [Path(p) for p in sorted(glob.glob(pattern))] - if not files: - raise FileNotFoundError(f"No files found for pattern: {pattern}") - # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. - tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() - usable = ((tokens.numel() - 1) // seq_len) * seq_len - if usable <= 0: - raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") - return tokens[: usable + 1] - - -def eval_val( - args: Hyperparameters, - model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - grad_accum_steps: int, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, - eval_seq_len: int | None = None, -) -> tuple[float, float]: - seq_len = eval_seq_len or args.train_seq_len - local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) - if local_batch_tokens < seq_len: - raise ValueError( - "VAL_BATCH_SIZE must provide at least one sequence per rank; " - f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " - f"GRAD_ACCUM_STEPS={grad_accum_steps}, seq_len={seq_len}" - ) - local_batch_seqs = local_batch_tokens // seq_len - total_seqs = (val_tokens.numel() - 1) // seq_len - seq_start = (total_seqs * rank) // world_size - seq_end = (total_seqs * (rank + 1)) // world_size - val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) - val_token_count = torch.zeros((), device=device, dtype=torch.float64) - val_byte_count = torch.zeros((), device=device, dtype=torch.float64) - - model.eval() - with torch.inference_mode(): - for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): - batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) - raw_start = batch_seq_start * seq_len - raw_end = batch_seq_end * seq_len + 1 - local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) - x = local[:-1].reshape(-1, seq_len) - y = local[1:].reshape(-1, seq_len) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - batch_loss = model(x, y).detach() - batch_token_count = float(y.numel()) - val_loss_sum += batch_loss.to(torch.float64) * batch_token_count - val_token_count += batch_token_count - prev_ids = x.reshape(-1) - tgt_ids = y.reshape(-1) - token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) - token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) - val_byte_count += token_bytes.to(torch.float64).sum() - - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) - - val_loss = val_loss_sum / val_token_count - bits_per_token = val_loss.item() / math.log(2.0) - tokens_per_byte = val_token_count.item() / val_byte_count.item() - model.train() - return float(val_loss.item()), float(bits_per_token * tokens_per_byte) - -# ----------------------------- -# POST-TRAINING QUANTIZATION -# ----------------------------- -# -# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. -# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. -# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. - -CONTROL_TENSOR_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "CONTROL_TENSOR_NAME_PATTERNS", - "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,smear", - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", - ",".join(CONTROL_TENSOR_NAME_PATTERNS), - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 -INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 -INT8_PER_ROW_SCALE_DTYPE = torch.float16 -INT8_CLIP_PERCENTILE = 99.99984 -INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 - -def tensor_nbytes(t: Tensor) -> int: - return int(t.numel()) * int(t.element_size()) - -def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: - if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): - return t.float().contiguous() - if t.dtype in {torch.float32, torch.bfloat16}: - passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") - return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() - return t - -def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: - t32 = t.float() - if t32.ndim == 2: - # Matrices get one scale per row, which usually tracks output-channel - # ranges much better than a single tensor-wide scale. - clip_abs = ( - torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) - if t32.numel() - else torch.empty((t32.shape[0],), dtype=torch.float32) - ) - clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) - scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) - q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() - return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() - - # Vectors / scalars use a simpler per-tensor scale. - clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 - scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) - q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() - return q, scale - -def quantize_state_dict_int8(state_dict: dict[str, Tensor]): - # Single supported clean-script export format: - # - per-row int8 for 2D float tensors - # - per-tensor int8 for other float tensors - # - exact passthrough for non-floats - # - passthrough for small float tensors, stored as fp16 to save bytes - quantized: dict[str, Tensor] = {} - scales: dict[str, Tensor] = {} - dtypes: dict[str, str] = {} - passthrough: dict[str, Tensor] = {} - passthrough_orig_dtypes: dict[str, str] = {} - qmeta: dict[str, dict[str, object]] = {} - stats = dict.fromkeys( - ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), - 0, - ) - - for name, tensor in state_dict.items(): - t = tensor.detach().to("cpu").contiguous() - stats["param_count"] += int(t.numel()) - stats["num_tensors"] += 1 - stats["baseline_tensor_bytes"] += tensor_nbytes(t) - - if not t.is_floating_point(): - stats["num_nonfloat_tensors"] += 1 - passthrough[name] = t - stats["int8_payload_bytes"] += tensor_nbytes(t) - continue - - # Small float tensors are cheap enough to keep directly. We still downcast - # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. - if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: - kept = keep_float_tensor(name, t, passthrough_orig_dtypes) - passthrough[name] = kept - stats["int8_payload_bytes"] += tensor_nbytes(kept) - continue - - stats["num_float_tensors"] += 1 - q, s = quantize_float_tensor(t) - if s.ndim > 0: - qmeta[name] = {"scheme": "per_row", "axis": 0} - quantized[name] = q - scales[name] = s - dtypes[name] = str(t.dtype).removeprefix("torch.") - stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) - - obj: dict[str, object] = { - "__quant_format__": "int8_clean_per_row_v1", - "quantized": quantized, - "scales": scales, - "dtypes": dtypes, - "passthrough": passthrough, - } - if qmeta: - obj["qmeta"] = qmeta - if passthrough_orig_dtypes: - obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes - return obj, stats - -def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: - out: dict[str, Tensor] = {} - qmeta = obj.get("qmeta", {}) - passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) - for name, q in obj["quantized"].items(): - dtype = getattr(torch, obj["dtypes"][name]) - s = obj["scales"][name] - if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: - s = s.to(dtype=torch.float32) - # Broadcast the saved row scale back across trailing dimensions. - out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() - else: - scale = float(s.item()) - out[name] = (q.float() * scale).to(dtype=dtype).contiguous() - for name, t in obj["passthrough"].items(): - # Restore small tensors, undoing the temporary fp16 storage cast if needed. - out_t = t.detach().to("cpu").contiguous() - orig_dtype = passthrough_orig_dtypes.get(name) - if isinstance(orig_dtype, str): - out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() - out[name] = out_t - return out - - -# ----------------------------- -# DATA LOADING -# ----------------------------- - -def load_data_shard(file: Path) -> Tensor: - header_bytes = 256 * np.dtype(" None: - self.file_idx = (self.file_idx + 1) % len(self.files) - self.tokens = load_data_shard(self.files[self.file_idx]) - self.pos = 0 - - def take(self, n: int) -> Tensor: - chunks: list[Tensor] = [] - remaining = n - while remaining > 0: - avail = self.tokens.numel() - self.pos - if avail <= 0: - self._advance_file() - continue - k = min(remaining, avail) - chunks.append(self.tokens[self.pos : self.pos + k]) - self.pos += k - remaining -= k - return chunks[0] if len(chunks) == 1 else torch.cat(chunks) - - -class DistributedTokenLoader: - # Each call consumes a contiguous chunk from the shared token stream, then slices out - # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. - def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): - self.rank = rank - self.world_size = world_size - self.device = device - self.stream = TokenStream(pattern) - - def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: - local_tokens = global_tokens // (self.world_size * grad_accum_steps) - per_rank_span = local_tokens + 1 - chunk = self.stream.take(per_rank_span * self.world_size) - start = self.rank * per_rank_span - local = chunk[start : start + per_rank_span].to(dtype=torch.int64) - x = local[:-1].reshape(-1, seq_len) - y = local[1:].reshape(-1, seq_len) - return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) - -# ----------------------------- -# TRANSFORMER MODULES -# ----------------------------- - -class RMSNorm(nn.Module): - def __init__(self, eps: float | None = None): - super().__init__() - self.eps = eps - - def forward(self, x: Tensor) -> Tensor: - return F.rms_norm(x, (x.size(-1),), eps=self.eps) - - -class CastedLinear(nn.Linear): - _qat_enabled: bool = False - - def forward(self, x: Tensor) -> Tensor: - w = self.weight.to(x.dtype) - if CastedLinear._qat_enabled and self.training and w.ndim == 2: - with torch.no_grad(): - w32 = self.weight.float() - row_max = w32.abs().amax(dim=1) - scale = (row_max / 31.0).clamp_min(1.0 / 31.0) - w_q = (torch.clamp(torch.round(w32 / scale[:, None]), -32, 31) * scale[:, None]).to(x.dtype) - w = w + (w_q - w).detach() - bias = self.bias.to(x.dtype) if self.bias is not None else None - return F.linear(x, w, bias) - - -def restore_low_dim_params_to_fp32(module: nn.Module) -> None: - # Keep small/control parameters in fp32 even when the model body runs in bf16. - with torch.no_grad(): - for name, param in module.named_parameters(): - if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: - param.data = param.data.float() - - -class Rotary(nn.Module): - # NTK-aware RoPE: auto-scales base frequency when seq_len exceeds train_seq_len. - def __init__(self, dim: int, base: float = 10000.0, train_seq_len: int = 1024): - super().__init__() - self.dim = dim - self.base = base - self.train_seq_len = train_seq_len - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self._seq_len_cached = 0 - self._cos_cached: Tensor | None = None - self._sin_cached: Tensor | None = None - - def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: - if ( - self._cos_cached is None - or self._sin_cached is None - or self._seq_len_cached != seq_len - or self._cos_cached.device != device - ): - if seq_len > self.train_seq_len: - scale = seq_len / self.train_seq_len - new_base = self.base * (scale ** (self.dim / (self.dim - 2))) - inv_freq = 1.0 / (new_base ** (torch.arange(0, self.dim, 2, dtype=torch.float32, device=device) / self.dim)) - else: - inv_freq = self.inv_freq.to(device) - t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) - freqs = torch.outer(t, inv_freq) - self._cos_cached = freqs.cos()[None, :, None, :] - self._sin_cached = freqs.sin()[None, :, None, :] - self._seq_len_cached = seq_len - return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) - - -def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: - half = x.size(-1) // 2 - x1, x2 = x[..., :half], x[..., half:] - return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) - - -class CausalSelfAttention(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - rope_base: float, - qk_gain_init: float, - ): - super().__init__() - if dim % num_heads != 0: - raise ValueError("model_dim must be divisible by num_heads") - if num_heads % num_kv_heads != 0: - raise ValueError("num_heads must be divisible by num_kv_heads") - self.num_heads = num_heads - self.num_kv_heads = num_kv_heads - self.head_dim = dim // num_heads - if self.head_dim % 2 != 0: - raise ValueError("head_dim must be even for RoPE") - kv_dim = self.num_kv_heads * self.head_dim - self.c_q = CastedLinear(dim, dim, bias=False) - self.c_k = CastedLinear(dim, kv_dim, bias=False) - self.c_v = CastedLinear(dim, kv_dim, bias=False) - self.proj = CastedLinear(dim, dim, bias=False) - self.proj._zero_init = True - self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) - self.rotary = Rotary(self.head_dim, base=rope_base, train_seq_len=1024) - self.use_xsa = False - - def _xsa_efficient(self, y: Tensor, v: Tensor) -> Tensor: - """Subtract self-value projection via GQA-aware reshape (no repeat_interleave).""" - B, T, H, D = y.shape - Hkv = v.size(-2) - group = H // Hkv - y_g = y.reshape(B, T, Hkv, group, D) - vn = F.normalize(v, dim=-1).unsqueeze(-2) - proj = (y_g * vn).sum(dim=-1, keepdim=True) * vn - return (y_g - proj).reshape(B, T, H, D) - - def forward(self, x: Tensor) -> Tensor: - bsz, seqlen, dim = x.shape - q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim) - k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) - v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) - q = F.rms_norm(q, (q.size(-1),)) - k = F.rms_norm(k, (k.size(-1),)) - cos, sin = self.rotary(seqlen, x.device, q.dtype) - q = apply_rotary_emb(q, cos, sin) - k = apply_rotary_emb(k, cos, sin) - q = q * self.q_gain.to(dtype=q.dtype)[None, None, :, None] - y = flash_attn_3_func(q, k, v, causal=True) - if self.use_xsa: - y = self._xsa_efficient(y, v) - y = y.reshape(bsz, seqlen, dim) - return self.proj(y) - - -class SmearGate(nn.Module): - def __init__(self, dim: int): - super().__init__() - self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) - - def forward(self, x: Tensor) -> Tensor: - g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] - x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) - return (1 - g) * x + g * x_prev - - -class BigramHashEmbedding(nn.Module): - def __init__(self, bigram_vocab_size: int, bigram_dim: int, model_dim: int): - super().__init__() - self.bigram_vocab_size = bigram_vocab_size - self.embed = nn.Embedding(bigram_vocab_size, bigram_dim) - nn.init.zeros_(self.embed.weight) - self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None - if self.proj is not None: - nn.init.zeros_(self.proj.weight) - self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) - - def bigram_hash(self, tokens: Tensor) -> Tensor: - t = tokens.to(torch.int32) - mod = self.bigram_vocab_size - 1 - out = torch.empty_like(t) - out[..., 0] = mod - out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod - return out.long() - - def forward(self, token_ids: Tensor) -> Tensor: - h = self.embed(self.bigram_hash(token_ids)) - if self.proj is not None: - h = self.proj(h) - return h * self.scale.to(dtype=h.dtype) - - -class MLP(nn.Module): - def __init__(self, dim: int, mlp_mult: int): - super().__init__() - hidden = int(mlp_mult * dim) - self.fc = CastedLinear(dim, hidden, bias=False) - self.proj = CastedLinear(hidden, dim, bias=False) - self.proj._zero_init = True - - def forward(self, x: Tensor) -> Tensor: - x = torch.relu(self.fc(x)) - return self.proj(x.square()) - - -class Block(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: int, - rope_base: float, - qk_gain_init: float, - ): - super().__init__() - self.attn_norm = RMSNorm() - self.mlp_norm = RMSNorm() - self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) - self.mlp = MLP(dim, mlp_mult) - self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) - - def forward(self, x: Tensor, x0: Tensor) -> Tensor: - mix = self.resid_mix.to(dtype=x.dtype) - x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 - attn_out = self.attn(self.attn_norm(x)) - x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out - x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) - return x - - -class GPT(nn.Module): - def __init__( - self, - vocab_size: int, - num_layers: int, - model_dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: int, - tie_embeddings: bool, - tied_embed_init_std: float, - logit_softcap: float, - rope_base: float, - qk_gain_init: float, - mtp_num_heads: int = 0, - mtp_loss_weight: float = 0.1, - bigram_vocab_size: int = 0, - bigram_dim: int = 128, - xsa_last_n: int = 0, - ): - super().__init__() - if logit_softcap <= 0.0: - raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") - self.tie_embeddings = tie_embeddings - self.tied_embed_init_std = tied_embed_init_std - self.logit_softcap = logit_softcap - self.mtp_num_heads = mtp_num_heads - self.mtp_loss_weight = mtp_loss_weight - self.tok_emb = nn.Embedding(vocab_size, model_dim) - self.bigram = BigramHashEmbedding(bigram_vocab_size, bigram_dim, model_dim) if bigram_vocab_size > 0 else None - self.smear = SmearGate(model_dim) - self.num_encoder_layers = num_layers // 2 - self.num_decoder_layers = num_layers - self.num_encoder_layers - self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) - self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) - self.blocks = nn.ModuleList( - [ - Block( - model_dim, - num_heads, - num_kv_heads, - mlp_mult, - rope_base, - qk_gain_init, - ) - for i in range(num_layers) - ] - ) - self.final_norm = RMSNorm() - self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) - if self.lm_head is not None: - self.lm_head._zero_init = True - self.mtp_heads = nn.ModuleList( - [CastedLinear(model_dim, vocab_size, bias=False) for _ in range(mtp_num_heads)] - ) - for head in self.mtp_heads: - head._zero_init = True - if xsa_last_n > 0: - for i in range(max(0, num_layers - xsa_last_n), num_layers): - self.blocks[i].attn.use_xsa = True - self._init_weights() - - def _init_weights(self) -> None: - if self.tie_embeddings: - nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) - num_layers = len(self.blocks) - for name, module in self.named_modules(): - if isinstance(module, nn.Linear): - if getattr(module, "_zero_init", False): - nn.init.zeros_(module.weight) - elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: - nn.init.orthogonal_(module.weight, gain=1.0) - if ".proj." in name or name.endswith(".proj"): - with torch.no_grad(): - module.weight.mul_(1.0 / math.sqrt(2 * num_layers)) - - def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: - x = self.tok_emb(input_ids) - if self.bigram is not None: - x = x + self.bigram(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - x = self.smear(x) - x0 = x - skips: list[Tensor] = [] - - for i in range(self.num_encoder_layers): - x = self.blocks[i](x, x0) - skips.append(x) - for i in range(self.num_decoder_layers): - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - x = self.blocks[self.num_encoder_layers + i](x, x0) - - x = self.final_norm(x) - x_flat = x.reshape(-1, x.size(-1)) - targets = target_ids.reshape(-1) - if self.tie_embeddings: - logits_proj = F.linear(x_flat, self.tok_emb.weight) - else: - if self.lm_head is None: - raise RuntimeError("lm_head is required when tie_embeddings=False") - logits_proj = self.lm_head(x_flat) - logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - main_loss = F.cross_entropy(logits.float(), targets, reduction="mean") - - if self.training and self.mtp_num_heads > 0 and self.mtp_loss_weight > 0.0: - _, seqlen, dim = x.shape - mtp_loss_sum = x.new_zeros(()) - mtp_loss_count = 0 - for k, mtp_head in enumerate(self.mtp_heads): - valid_t = seqlen - (k + 1) - if valid_t <= 0: - continue - mtp_hidden = x[:, :valid_t, :].reshape(-1, dim) - mtp_targets = target_ids[:, k + 1 :].reshape(-1) - mtp_logits_proj = mtp_head(mtp_hidden) - mtp_logits = self.logit_softcap * torch.tanh(mtp_logits_proj / self.logit_softcap) - mtp_loss_sum = mtp_loss_sum + F.cross_entropy(mtp_logits.float(), mtp_targets, reduction="mean") - mtp_loss_count += 1 - if mtp_loss_count > 0: - main_loss = main_loss + self.mtp_loss_weight * (mtp_loss_sum / mtp_loss_count) - - return main_loss - - def forward_logits(self, input_ids: Tensor) -> Tensor: - """Return logits (bsz, seq_len, vocab) without computing loss.""" - x = self.tok_emb(input_ids) - if self.bigram is not None: - x = x + self.bigram(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - x = self.smear(x) - x0 = x - skips: list[Tensor] = [] - for i in range(self.num_encoder_layers): - x = self.blocks[i](x, x0) - skips.append(x) - for i in range(self.num_decoder_layers): - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - x = self.blocks[self.num_encoder_layers + i](x, x0) - x = self.final_norm(x) - if self.tie_embeddings: - logits_proj = F.linear(x, self.tok_emb.weight) - else: - logits_proj = self.lm_head(x) - return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - - -# ----------------------------- -# SLIDING WINDOW EVALUATION -# ----------------------------- - -def eval_val_sliding( - args: Hyperparameters, - base_model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, - stride: int, - batch_seqs: int = 32, - eval_seq_len: int | None = None, -) -> tuple[float, float]: - """Sliding window evaluation: each token scored with maximum context.""" - seq_len = eval_seq_len or args.train_seq_len - total_tokens = val_tokens.numel() - 1 - - window_starts = [ws for ws in range(0, total_tokens, stride) - if min(ws + seq_len, total_tokens) - ws >= 1] - total_windows = len(window_starts) - - my_s = (total_windows * rank) // world_size - my_e = (total_windows * (rank + 1)) // world_size - my_windows = window_starts[my_s:my_e] - - loss_sum = torch.zeros((), device=device, dtype=torch.float64) - token_count = torch.zeros((), device=device, dtype=torch.float64) - byte_count = torch.zeros((), device=device, dtype=torch.float64) - - base_model.eval() - compiled_logits = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) - - with torch.inference_mode(): - for bi in range(0, len(my_windows), batch_seqs): - batch_ws = my_windows[bi:bi + batch_seqs] - bsz = len(batch_ws) - - x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) - y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) - wlens: list[int] = [] - - for i, ws in enumerate(batch_ws): - end = min(ws + seq_len, total_tokens) - wlen = end - ws - wlens.append(wlen) - chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) - x_batch[i, :wlen] = chunk[:-1] - y_batch[i, :wlen] = chunk[1:] - - with torch.autocast(device_type="cuda", dtype=torch.bfloat16): - logits = compiled_logits(x_batch) - - nll = F.cross_entropy( - logits.reshape(-1, logits.size(-1)).float(), - y_batch.reshape(-1), - reduction="none", - ).reshape(bsz, seq_len) - - for i, ws in enumerate(batch_ws): - wlen = wlens[i] - s = 0 if ws == 0 else max(wlen - stride, 0) - scored_nll = nll[i, s:wlen].to(torch.float64) - loss_sum += scored_nll.sum() - token_count += float(wlen - s) - tgt = y_batch[i, s:wlen] - prev = x_batch[i, s:wlen] - tb = base_bytes_lut[tgt].to(torch.float64) - tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) - byte_count += tb.sum() - - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) - - val_loss = (loss_sum / token_count).item() - bits_per_token = val_loss / math.log(2.0) - tokens_per_byte = token_count.item() / byte_count.item() - base_model.train() - return val_loss, bits_per_token * tokens_per_byte - - -# ----------------------------- -# INT6 MIXED QUANTIZATION (transplanted from working diagnostic scripts) -# ----------------------------- - -def _classify_param(name: str) -> str: - if "tok_emb" in name or "lm_head" in name: - return "embed" - if ".mlp." in name: - return "mlp" - if ".attn." in name or (".proj." in name and ".mlp." not in name): - return "attn" - return "other" - -def quantize_int6_per_row(t: Tensor) -> tuple[Tensor, Tensor]: - t32 = t.float() - if t32.ndim == 2: - row_max = t32.abs().amax(dim=1) - scale = (row_max / 31.0).clamp_min(1.0 / 31.0).to(torch.float16) - q = torch.clamp(torch.round(t32 / scale.float()[:, None]), -32, 31).to(torch.int8) - return q, scale - amax = t32.abs().max().item() - scale = torch.tensor(amax / 31.0 if amax > 0 else 1.0, dtype=torch.float16) - q = torch.clamp(torch.round(t32 / scale.float()), -32, 31).to(torch.int8) - return q, scale - -def mixed_quantize_int6(state_dict: dict[str, Tensor], int6_cats: set[str]): - num_layers_total = max( - (int(k.split(".")[1]) for k in state_dict if k.startswith("blocks.")), - default=0, - ) + 1 - late_k_layers = set(range(num_layers_total - 2, num_layers_total)) - - result: dict[str, Tensor] = {} - meta: dict[str, object] = {} - for name, tensor in state_dict.items(): - t = tensor.detach().cpu().contiguous() - cat = _classify_param(name) - if not t.is_floating_point() or t.numel() <= 65536: - result[name] = t.to(torch.float16) if t.is_floating_point() else t - meta[name] = "passthrough" - continue - if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): - result[name] = t.float() - meta[name] = "passthrough_ctrl" - continue - # tok_emb.weight falls through to int8 via "embed" category - if cat in int6_cats and t.ndim >= 1: - q, s = quantize_int6_per_row(t) - result[name + ".q"] = q - result[name + ".scale"] = s - meta[name] = {"type": "int6"} - else: - q, s = quantize_float_tensor(t) - result[name + ".q"] = q - result[name + ".scale"] = s - meta[name] = {"type": "int8"} - return result, meta - -def dequantize_mixed_int6(result: dict[str, Tensor], meta: dict[str, object], - template_sd: dict[str, Tensor]) -> dict[str, Tensor]: - out: dict[str, Tensor] = {} - for name, orig in template_sd.items(): - info = meta.get(name) - if info is None: - continue - orig_dtype = orig.dtype - if info in ("passthrough", "passthrough_ctrl", "passthrough_fp16"): - t = result[name] - if t.dtype == torch.float16 and orig_dtype in (torch.float32, torch.bfloat16): - t = t.to(orig_dtype) - out[name] = t - continue - q, s = result[name + ".q"], result[name + ".scale"] - if s.ndim > 0: - out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig_dtype) - else: - out[name] = (q.float() * float(s.item())).to(orig_dtype) - return out - - -# ----------------------------- -# TRAINING -# ----------------------------- - -def main() -> None: - global zeropower_via_newtonschulz5 - - code = Path(__file__).read_text(encoding="utf-8") - args = Hyperparameters() - zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) - - # ----------------------------- - # DISTRIBUTED + CUDA SETUP - # ----------------------------- - - distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ - rank = int(os.environ.get("RANK", "0")) - world_size = int(os.environ.get("WORLD_SIZE", "1")) - local_rank = int(os.environ.get("LOCAL_RANK", "0")) - if world_size <= 0: - raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") - if 8 % world_size != 0: - raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") - grad_accum_steps = 8 // world_size - grad_scale = 1.0 / grad_accum_steps - if not torch.cuda.is_available(): - raise RuntimeError("CUDA is required") - device = torch.device("cuda", local_rank) - torch.cuda.set_device(device) - if distributed: - dist.init_process_group(backend="nccl", device_id=device) - dist.barrier() - master_process = rank == 0 - - # Fast math knobs - torch.backends.cuda.matmul.allow_tf32 = True - torch.backends.cudnn.allow_tf32 = True - from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp - - enable_cudnn_sdp(False) - enable_flash_sdp(True) - enable_mem_efficient_sdp(False) - enable_math_sdp(False) - - logfile = None - if master_process: - os.makedirs("logs", exist_ok=True) - logfile = f"logs/{args.run_id}.txt" - print(logfile) - - def log0(msg: str, console: bool = True) -> None: - if not master_process: - return - if console: - print(msg) - if logfile is not None: - with open(logfile, "a", encoding="utf-8") as f: - print(msg, file=f) - - log0(code, console=False) - log0("=" * 100, console=False) - log0(f"Running Python {sys.version}", console=False) - log0(f"Running PyTorch {torch.__version__}", console=False) - log0( - subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, - console=False, - ) - log0("=" * 100, console=False) - - # ----------------------------- - # TOKENIZER + VALIDATION METRIC SETUP - # ----------------------------- - - random.seed(args.seed) - np.random.seed(args.seed) - torch.manual_seed(args.seed) - torch.cuda.manual_seed_all(args.seed) - - if not args.tokenizer_path.endswith(".model"): - raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") - sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) - if int(sp.vocab_size()) != args.vocab_size: - raise ValueError( - f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" - ) - dataset_dir = Path(args.data_path).resolve() - actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) - effective_eval_seq_len = args.eval_seq_len if args.eval_seq_len > 0 else args.train_seq_len - val_seq_len = max(args.train_seq_len, effective_eval_seq_len) - val_tokens = load_validation_tokens(args.val_files, val_seq_len) - base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( - sp, args.vocab_size, device - ) - log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") - log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") - log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") - - # ----------------------------- - # MODEL + OPTIMIZER SETUP - # ----------------------------- - - CastedLinear._qat_enabled = args.qat_enabled - - base_model = GPT( - vocab_size=args.vocab_size, - num_layers=args.num_layers, - model_dim=args.model_dim, - num_heads=args.num_heads, - num_kv_heads=args.num_kv_heads, - mlp_mult=args.mlp_mult, - tie_embeddings=args.tie_embeddings, - tied_embed_init_std=args.tied_embed_init_std, - logit_softcap=args.logit_softcap, - rope_base=args.rope_base, - qk_gain_init=args.qk_gain_init, - mtp_num_heads=args.mtp_num_heads, - mtp_loss_weight=args.mtp_loss_weight, - bigram_vocab_size=args.bigram_vocab_size, - bigram_dim=args.bigram_dim, - xsa_last_n=args.xsa_last_n, - ).to(device).bfloat16() - for module in base_model.modules(): - if isinstance(module, CastedLinear): - module.float() - restore_low_dim_params_to_fp32(base_model) - compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) - model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model - - # Optimizer split: - # - token embedding (Adam) uses EMBED_LR - # - untied lm_head (Adam) uses HEAD_LR - # - matrix params in transformer blocks use MATRIX_LR via Muon - # - vectors/scalars use SCALAR_LR via Adam - block_named_params = list(base_model.blocks.named_parameters()) - matrix_params = [ - p - for name, p in block_named_params - if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - if base_model.mtp_num_heads > 0: - matrix_params.extend([p for p in base_model.mtp_heads.parameters() if p.ndim == 2]) - scalar_params = [ - p - for name, p in block_named_params - if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - if base_model.skip_weights.numel() > 0: - scalar_params.append(base_model.skip_weights) - scalar_params.append(base_model.smear.gate) - if base_model.bigram is not None: - scalar_params.append(base_model.bigram.scale) - token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr - tok_params = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] - if base_model.bigram is not None: - tok_params.append({"params": [base_model.bigram.embed.weight], "lr": token_lr, "base_lr": token_lr}) - if base_model.bigram.proj is not None: - matrix_params.append(base_model.bigram.proj.weight) - optimizer_tok = torch.optim.AdamW( - tok_params, - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - weight_decay=args.adam_wd, - fused=True, - ) - optimizer_muon = Muon( - matrix_params, - lr=args.matrix_lr, - momentum=args.muon_momentum, - backend_steps=args.muon_backend_steps, - weight_decay=args.muon_wd, - ) - for group in optimizer_muon.param_groups: - group["base_lr"] = args.matrix_lr - optimizer_scalar = torch.optim.AdamW( - [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - weight_decay=args.adam_wd, - fused=True, - ) - optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] - if base_model.lm_head is not None: - optimizer_head = torch.optim.Adam( - [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizers.insert(1, optimizer_head) - - n_params = sum(p.numel() for p in base_model.parameters()) - mtp_params = sum(p.numel() for p in base_model.mtp_heads.parameters()) - log0(f"model_params:{n_params}") - log0(f"mtp_num_heads:{args.mtp_num_heads} mtp_loss_weight:{args.mtp_loss_weight} mtp_params:{mtp_params}") - log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") - log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") - log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") - log0( - f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " - f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " - f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" - ) - log0( - f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " - f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " - f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" - ) - log0(f"seed:{args.seed}") - - # ----------------------------- - # DATA LOADER & MODEL WARMUP - # ----------------------------- - - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - def zero_grad_all() -> None: - for opt in optimizers: - opt.zero_grad(set_to_none=True) - - max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None - - def lr_mul(step: int, elapsed_ms: float) -> float: - if args.warmdown_iters <= 0: - return 1.0 - if max_wallclock_ms is None: - warmdown_start = max(args.iterations - args.warmdown_iters, 0) - return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 - step_ms = elapsed_ms / max(step, 1) - warmdown_ms = args.warmdown_iters * step_ms - remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) - return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 - - # Warmup primes the compiled forward/backward/optimizer paths, then we restore the - # initial weights/optimizer state so measured training starts from the true init. - if args.warmup_steps > 0: - initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} - initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] - model.train() - for warmup_step in range(args.warmup_steps): - zero_grad_all() - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - warmup_loss = model(x, y) - (warmup_loss * grad_scale).backward() - for opt in optimizers: - opt.step() - zero_grad_all() - if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: - log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") - base_model.load_state_dict(initial_model_state, strict=True) - for opt, state in zip(optimizers, initial_optimizer_states, strict=True): - opt.load_state_dict(state) - zero_grad_all() - if distributed: - model.require_backward_grad_sync = True - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - # ----------------------------- - # MAIN TRAINING LOOP - # ----------------------------- - - swa_state: dict[str, Tensor] | None = None - swa_count = 0 - - ema_state: dict[str, Tensor] | None = None - if args.ema_enabled: - ema_state = {name: t.detach().float().clone() for name, t in base_model.state_dict().items()} - - training_time_ms = 0.0 - stop_after_step: int | None = None - torch.cuda.synchronize() - t0 = time.perf_counter() - - step = 0 - while True: - last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) - - should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) - if should_validate: - torch.cuda.synchronize() - training_time_ms += 1000.0 * (time.perf_counter() - t0) - val_loss, val_bpb = eval_val( - args, - model, - rank, - world_size, - device, - grad_accum_steps, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, - ) - log0( - f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " - f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" - ) - torch.cuda.synchronize() - t0 = time.perf_counter() - - if last_step: - if stop_after_step is not None and step < args.iterations: - log0( - f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " - f"step:{step}/{args.iterations}" - ) - break - - elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - scale = lr_mul(step, elapsed_ms) - zero_grad_all() - train_loss = torch.zeros((), device=device) - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - loss = model(x, y) - train_loss += loss.detach() - (loss * grad_scale).backward() - train_loss /= grad_accum_steps - - frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 - muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum - for group in optimizer_muon.param_groups: - group["momentum"] = muon_momentum - - for opt in optimizers: - for group in opt.param_groups: - group["lr"] = group["base_lr"] * scale - - if args.grad_clip_norm > 0: - torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) - for opt in optimizers: - opt.step() - zero_grad_all() - - step += 1 - approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - - if ema_state is not None: - d = args.ema_decay - with torch.no_grad(): - for name, t in base_model.state_dict().items(): - ema_state[name].mul_(d).add_(t.detach().float(), alpha=1.0 - d) - - if args.swa_enabled and not args.ema_enabled and scale < 0.5 and step % args.swa_every == 0: - if swa_state is None: - swa_state = {name: t.detach().float().clone() for name, t in base_model.state_dict().items()} - swa_count = 1 - log0(f"swa:start step:{step}") - else: - for name, t in base_model.state_dict().items(): - swa_state[name].add_(t.detach().float()) - swa_count += 1 - - should_log_train = ( - args.train_log_every > 0 - and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) - ) - if should_log_train: - log0( - f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " - f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" - ) - - # Needed to sync whether we've reached the wallclock cap. - reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms - if distributed and max_wallclock_ms is not None: - reached_cap_tensor = torch.tensor(int(reached_cap), device=device) - dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) - reached_cap = bool(reached_cap_tensor.item()) - if stop_after_step is None and reached_cap: - stop_after_step = step - - log0( - f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " - f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" - ) - - if ema_state is not None: - log0("ema:applying EMA weights") - avg_state = {name: t.to(dtype=base_model.state_dict()[name].dtype) - for name, t in ema_state.items()} - del ema_state - base_model.load_state_dict(avg_state, strict=True) - elif args.swa_enabled and swa_state is not None and swa_count > 1: - log0(f"swa:applying averaged {swa_count} checkpoints") - avg_state = {name: (t / swa_count).to(dtype=base_model.state_dict()[name].dtype) - for name, t in swa_state.items()} - del swa_state - base_model.load_state_dict(avg_state, strict=True) - - # ----------------------------- - # SERIALIZATION + ROUNDTRIP VALIDATION - # ----------------------------- - - full_state_dict = base_model.state_dict() - export_sd = {k: v for k, v in full_state_dict.items() if "mtp_heads" not in k} - excluded_mtp = sum(int(t.numel()) for k, t in full_state_dict.items() if "mtp_heads" in k) - if excluded_mtp > 0: - log0(f"export_excluding_mtp_params:{excluded_mtp}") - - if master_process: - torch.save(export_sd, "final_model.pt") - model_bytes = os.path.getsize("final_model.pt") - code_bytes = len(code.encode("utf-8")) - log0(f"Serialized model: {model_bytes} bytes") - log0(f"Code size: {code_bytes} bytes") - - sd_cpu = {k: v.detach().cpu() for k, v in export_sd.items()} - quant_result, quant_meta = mixed_quantize_int6(sd_cpu, {"mlp", "attn"}) - quant_buf = io.BytesIO() - torch.save({"w": quant_result, "m": quant_meta}, quant_buf) - quant_raw = quant_buf.getvalue() - quant_blob = zstandard.ZstdCompressor(level=22).compress(quant_raw) if _COMPRESSOR == "zstd" else zlib.compress(quant_raw, 9) - if master_process: - with open("final_model.int6.ptz", "wb") as f: - f.write(quant_blob) - quant_file_bytes = len(quant_blob) - code_bytes = len(code.encode("utf-8")) - log0(f"Serialized model int6+{_COMPRESSOR}: {quant_file_bytes} bytes") - log0(f"Total submission size int6+{_COMPRESSOR}: {quant_file_bytes + code_bytes} bytes") - - # Roundtrip: decompress + dequantize into fresh model + eval - if distributed: - dist.barrier() - with open("final_model.int6.ptz", "rb") as f: - quant_blob_disk = f.read() - quant_state = torch.load( - io.BytesIO(zstandard.ZstdDecompressor().decompress(quant_blob_disk) if _COMPRESSOR == "zstd" else zlib.decompress(quant_blob_disk)), - map_location="cpu", - ) - deq_state = dequantize_mixed_int6(quant_state["w"], quant_state["m"], sd_cpu) - - eval_model = GPT( - vocab_size=args.vocab_size, num_layers=args.num_layers, model_dim=args.model_dim, - num_heads=args.num_heads, num_kv_heads=args.num_kv_heads, mlp_mult=args.mlp_mult, - tie_embeddings=args.tie_embeddings, tied_embed_init_std=args.tied_embed_init_std, - logit_softcap=args.logit_softcap, rope_base=args.rope_base, qk_gain_init=args.qk_gain_init, - mtp_num_heads=0, mtp_loss_weight=0.0, - bigram_vocab_size=args.bigram_vocab_size, bigram_dim=args.bigram_dim, - xsa_last_n=args.xsa_last_n, - ).to(device).bfloat16() - for m in eval_model.modules(): - if isinstance(m, CastedLinear): - m.float() - restore_low_dim_params_to_fp32(eval_model) - eval_model.load_state_dict(deq_state, strict=True) - compiled_eval = torch.compile(eval_model, dynamic=False, fullgraph=True) - - # Standard non-overlapping eval (sanity check) - torch.cuda.synchronize() - t_qeval = time.perf_counter() - q_val_loss, q_val_bpb = eval_val( - args, compiled_eval, rank, world_size, device, grad_accum_steps, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - eval_seq_len=effective_eval_seq_len, - ) - torch.cuda.synchronize() - log0( - f"final_int6_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " - f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" - ) - log0(f"final_int6_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") - - # Sliding window eval (submission score) - sw_seq_len = effective_eval_seq_len - if args.eval_stride > 0 and args.eval_stride < sw_seq_len: - torch.cuda.synchronize() - t_slide = time.perf_counter() - sw_val_loss, sw_val_bpb = eval_val_sliding( - args, eval_model, rank, world_size, device, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - stride=args.eval_stride, - eval_seq_len=sw_seq_len, - ) - torch.cuda.synchronize() - log0( - f"final_int6_sliding_window val_loss:{sw_val_loss:.4f} val_bpb:{sw_val_bpb:.4f} " - f"stride:{args.eval_stride} eval_time:{1000.0 * (time.perf_counter() - t_slide):.0f}ms" - ) - log0(f"final_int6_sliding_window_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") - - # Second sliding window eval at stride=64 for submission comparison - if args.eval_stride != 64 and 64 < sw_seq_len: - torch.cuda.synchronize() - t_slide64 = time.perf_counter() - sw64_val_loss, sw64_val_bpb = eval_val_sliding( - args, eval_model, rank, world_size, device, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - stride=64, - eval_seq_len=sw_seq_len, - ) - torch.cuda.synchronize() - log0( - f"final_int6_sliding_window_s64 val_loss:{sw64_val_loss:.4f} val_bpb:{sw64_val_bpb:.4f} " - f"stride:64 eval_time:{1000.0 * (time.perf_counter() - t_slide64):.0f}ms" - ) - log0(f"final_int6_sliding_window_s64_exact val_loss:{sw64_val_loss:.8f} val_bpb:{sw64_val_bpb:.8f}") - - if distributed: - dist.destroy_process_group() - - -if __name__ == "__main__": - main() - -==================================================================================================== -Running Python 3.12.3 (main, Nov 6 2025, 13:44:16) [GCC 13.3.0] -Running PyTorch 2.9.1+cu128 -Fri Mar 20 22:49:34 2026 -+-----------------------------------------------------------------------------------------+ -| NVIDIA-SMI 580.126.09 Driver Version: 580.126.09 CUDA Version: 13.0 | -+-----------------------------------------+------------------------+----------------------+ -| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | -| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | -| | | MIG M. | -|=========================================+========================+======================| -| 0 NVIDIA H100 80GB HBM3 On | 00000000:18:00.0 Off | 0 | -| N/A 27C P0 116W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 1 NVIDIA H100 80GB HBM3 On | 00000000:2A:00.0 Off | 0 | -| N/A 27C P0 119W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 2 NVIDIA H100 80GB HBM3 On | 00000000:3A:00.0 Off | 0 | -| N/A 27C P0 115W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 3 NVIDIA H100 80GB HBM3 On | 00000000:5D:00.0 Off | 0 | -| N/A 25C P0 114W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 4 NVIDIA H100 80GB HBM3 On | 00000000:9A:00.0 Off | 0 | -| N/A 24C P0 114W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 5 NVIDIA H100 80GB HBM3 On | 00000000:AB:00.0 Off | 0 | -| N/A 26C P0 113W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 6 NVIDIA H100 80GB HBM3 On | 00000000:BA:00.0 Off | 0 | -| N/A 25C P0 112W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 7 NVIDIA H100 80GB HBM3 On | 00000000:DB:00.0 Off | 0 | -| N/A 25C P0 113W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ - -+-----------------------------------------------------------------------------------------+ -| Processes: | -| GPU GI CI PID Type Process name GPU Memory | -| ID ID Usage | -|=========================================================================================| -| 0 N/A N/A 112549 C /usr/local/bin/python 1512MiB | -| 1 N/A N/A 112550 C /usr/local/bin/python 1512MiB | -| 2 N/A N/A 112551 C /usr/local/bin/python 1512MiB | -| 3 N/A N/A 112552 C /usr/local/bin/python 1512MiB | -| 4 N/A N/A 112553 C /usr/local/bin/python 1512MiB | -| 5 N/A N/A 112554 C /usr/local/bin/python 1512MiB | -| 6 N/A N/A 112555 C /usr/local/bin/python 1512MiB | -| 7 N/A N/A 112556 C /usr/local/bin/python 1512MiB | -+-----------------------------------------------------------------------------------------+ - -==================================================================================================== -val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model -train_loader:dataset:fineweb10B_sp1024 train_shards:80 -val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 -model_params:26829913 -mtp_num_heads:0 mtp_loss_weight:0.2 mtp_params:0 -world_size:8 grad_accum_steps:1 -sdp_backends:cudnn=False flash=True mem_efficient=False math=False -attention_mode:gqa num_heads:8 num_kv_heads:4 -tie_embeddings:True embed_lr:0.035 head_lr:0.0 matrix_lr:0.025 scalar_lr:0.025 -train_batch_tokens:786432 train_seq_len:2048 iterations:9000 warmup_steps:20 max_wallclock_seconds:600.000 -seed:1337 -warmup_step:1/20 -warmup_step:2/20 -warmup_step:3/20 -warmup_step:4/20 -warmup_step:5/20 -warmup_step:6/20 -warmup_step:7/20 -warmup_step:8/20 -warmup_step:9/20 -warmup_step:10/20 -warmup_step:11/20 -warmup_step:12/20 -warmup_step:13/20 -warmup_step:14/20 -warmup_step:15/20 -warmup_step:16/20 -warmup_step:17/20 -warmup_step:18/20 -warmup_step:19/20 -warmup_step:20/20 -step:1/9000 train_loss:6.9326 train_time:137ms step_avg:136.88ms -step:2/9000 train_loss:8.6010 train_time:215ms step_avg:107.64ms -step:3/9000 train_loss:7.9282 train_time:306ms step_avg:101.93ms -step:4/9000 train_loss:7.2181 train_time:395ms step_avg:98.68ms -step:5/9000 train_loss:6.9899 train_time:484ms step_avg:96.71ms -step:6/9000 train_loss:6.8653 train_time:573ms step_avg:95.48ms -step:7/9000 train_loss:6.8744 train_time:663ms step_avg:94.76ms -step:8/9000 train_loss:6.8852 train_time:754ms step_avg:94.27ms -step:9/9000 train_loss:6.4661 train_time:844ms step_avg:93.76ms -step:10/9000 train_loss:6.1087 train_time:933ms step_avg:93.32ms -step:200/9000 train_loss:2.4028 train_time:16961ms step_avg:84.81ms -step:400/9000 train_loss:2.4090 train_time:33863ms step_avg:84.66ms -step:600/9000 train_loss:2.3253 train_time:50649ms step_avg:84.42ms -step:800/9000 train_loss:2.2314 train_time:67584ms step_avg:84.48ms -step:1000/9000 train_loss:2.2670 train_time:84427ms step_avg:84.43ms -step:1200/9000 train_loss:2.3422 train_time:101372ms step_avg:84.48ms -step:1400/9000 train_loss:2.1767 train_time:118297ms step_avg:84.50ms -step:1600/9000 train_loss:2.0660 train_time:135093ms step_avg:84.43ms -step:1800/9000 train_loss:2.1574 train_time:152012ms step_avg:84.45ms -step:2000/9000 train_loss:2.0607 train_time:168821ms step_avg:84.41ms -step:2200/9000 train_loss:2.1364 train_time:185761ms step_avg:84.44ms -step:2400/9000 train_loss:2.0617 train_time:202583ms step_avg:84.41ms -step:2600/9000 train_loss:2.1061 train_time:219469ms step_avg:84.41ms -step:2800/9000 train_loss:2.1514 train_time:236421ms step_avg:84.44ms -step:3000/9000 train_loss:2.1556 train_time:253214ms step_avg:84.40ms -step:3200/9000 train_loss:2.1708 train_time:270185ms step_avg:84.43ms -step:3400/9000 train_loss:2.0192 train_time:286995ms step_avg:84.41ms -step:3600/9000 train_loss:2.0973 train_time:303900ms step_avg:84.42ms -step:3800/9000 train_loss:2.0745 train_time:320702ms step_avg:84.40ms -step:4000/9000 train_loss:1.9845 train_time:337615ms step_avg:84.40ms -step:4200/9000 train_loss:2.1611 train_time:354577ms step_avg:84.42ms -step:4400/9000 train_loss:2.0459 train_time:371382ms step_avg:84.41ms -step:4600/9000 train_loss:1.8517 train_time:388295ms step_avg:84.41ms -step:4800/9000 train_loss:2.4364 train_time:405091ms step_avg:84.39ms -step:5000/9000 train_loss:2.1143 train_time:422062ms step_avg:84.41ms -step:5200/9000 train_loss:2.0497 train_time:438873ms step_avg:84.40ms -step:5400/9000 train_loss:2.0554 train_time:455815ms step_avg:84.41ms -step:5600/9000 train_loss:1.9594 train_time:472895ms step_avg:84.45ms -step:5800/9000 train_loss:2.0093 train_time:489724ms step_avg:84.44ms -step:6000/9000 train_loss:1.9491 train_time:506734ms step_avg:84.46ms -step:6200/9000 train_loss:1.9584 train_time:523630ms step_avg:84.46ms -step:6400/9000 train_loss:2.0061 train_time:540633ms step_avg:84.47ms -step:6600/9000 train_loss:1.8496 train_time:557451ms step_avg:84.46ms -step:6800/9000 train_loss:2.0281 train_time:574468ms step_avg:84.48ms -step:7000/9000 train_loss:1.7921 train_time:591382ms step_avg:84.48ms -step:7103/9000 val_loss:1.9293 val_bpb:1.1427 train_time:600021ms step_avg:84.47ms -stopping_early: wallclock_cap train_time:600021ms step:7103/9000 -peak memory allocated: 20600 MiB reserved: 20658 MiB -ema:applying EMA weights -Serialized model: 105783807 bytes -Code size: 66133 bytes -Serialized model int6+zstd: 15468512 bytes -Total submission size int6+zstd: 15534645 bytes -final_int6_roundtrip val_loss:1.9408 val_bpb:1.1494 eval_time:5637ms -final_int6_roundtrip_exact val_loss:1.94077647 val_bpb:1.14943715 -final_int6_sliding_window val_loss:1.9030 val_bpb:1.1271 stride:64 eval_time:69809ms -final_int6_sliding_window_exact val_loss:1.90301335 val_bpb:1.12707468 diff --git a/records/track_10min_16mb/2026-03-20_11L_XSA4_EMA_Int6_MLP3x_WD04_1.1271/train_seed2025.log b/records/track_10min_16mb/2026-03-20_11L_XSA4_EMA_Int6_MLP3x_WD04_1.1271/train_seed2025.log deleted file mode 100644 index 48f738fd35..0000000000 --- a/records/track_10min_16mb/2026-03-20_11L_XSA4_EMA_Int6_MLP3x_WD04_1.1271/train_seed2025.log +++ /dev/null @@ -1,1704 +0,0 @@ -""" -train_gpt_submit.py — Submission v2: wider MLP + STE int6 QAT + MTP + seq2048 + NTK RoPE + -fp16 embed + late-K passthrough + sliding window eval. -""" - -from __future__ import annotations - -import copy -import glob -import io -import math -import os -import random -import subprocess -import sys -import time -import uuid -import zlib -from pathlib import Path - -try: - import zstandard - _COMPRESSOR = "zstd" -except ImportError: - _COMPRESSOR = "zlib" - -import numpy as np -import sentencepiece as spm -import torch -import torch.distributed as dist -import torch.nn.functional as F -from torch import Tensor, nn -from torch.nn.parallel import DistributedDataParallel as DDP - -from flash_attn_interface import flash_attn_func as flash_attn_3_func - -# ----------------------------- -# HYPERPARAMETERS -# ----------------------------- -# Default Simple Baseline run: -# - 9 transformer blocks at width 512 -# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion -# - vocab size 1024, sequence length 1024, tied embeddings -# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap - -class Hyperparameters: - # Data paths are shard globs produced by the existing preprocessing pipeline. - data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") - train_files = os.path.join(data_path, "fineweb_train_*.bin") - val_files = os.path.join(data_path, "fineweb_val_*.bin") - tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") - run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) - seed = int(os.environ.get("SEED", 1337)) - - # Validation cadence and batch size. Validation always uses the full fineweb_val split. - val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) - val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) - train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) - - # Training length. - iterations = int(os.environ.get("ITERATIONS", 20000)) - warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) - warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) - train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 786_432)) - train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) - eval_seq_len = int(os.environ.get("EVAL_SEQ_LEN", 2048)) - max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) - qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) - - # Model shape. - vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) - num_layers = int(os.environ.get("NUM_LAYERS", 9)) - num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) - model_dim = int(os.environ.get("MODEL_DIM", 512)) - num_heads = int(os.environ.get("NUM_HEADS", 8)) - mlp_mult = float(os.environ.get("MLP_MULT", 3.0)) - tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) - rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) - logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) - - # Optimizer hyperparameters. - embed_lr = float(os.environ.get("EMBED_LR", 0.6)) - head_lr = float(os.environ.get("HEAD_LR", 0.008)) - tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) - tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) - matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) - scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) - muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) - muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) - muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) - muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) - beta1 = float(os.environ.get("BETA1", 0.9)) - beta2 = float(os.environ.get("BETA2", 0.95)) - adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) - grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) - eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) - mtp_num_heads = int(os.environ.get("MTP_NUM_HEADS", 0)) - mtp_loss_weight = float(os.environ.get("MTP_LOSS_WEIGHT", 0.2)) - muon_beta2 = float(os.environ.get("MUON_BETA2", 0.95)) - swa_enabled = bool(int(os.environ.get("SWA_ENABLED", "1"))) - swa_every = int(os.environ.get("SWA_EVERY", 200)) - muon_wd = float(os.environ.get("MUON_WD", 0.02)) - adam_wd = float(os.environ.get("ADAM_WD", 0.01)) - qat_enabled = bool(int(os.environ.get("QAT_ENABLED", "0"))) - xsa_last_n = int(os.environ.get("XSA_LAST_N", 0)) - ema_enabled = bool(int(os.environ.get("EMA_ENABLED", "0"))) - ema_decay = float(os.environ.get("EMA_DECAY", 0.997)) - bigram_vocab_size = int(os.environ.get("BIGRAM_VOCAB_SIZE", 4096)) - bigram_dim = int(os.environ.get("BIGRAM_DIM", 128)) - -# ----------------------------- -# MUON OPTIMIZER -# ----------------------------- -# -# As borrowed from modded-nanogpt -# Background on Muon: https://kellerjordan.github.io/posts/muon/ - -def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: - a, b, c = (3.4445, -4.7750, 2.0315) - X = G.bfloat16() - X /= X.norm() + eps - transposed = G.size(0) > G.size(1) - if transposed: - X = X.T - for _ in range(steps): - A = X @ X.T - B = b * A + c * A @ A - X = a * X + B @ X - return X.T if transposed else X - - -class Muon(torch.optim.Optimizer): - def __init__(self, params, lr: float, momentum: float, backend_steps: int, - nesterov: bool = True, weight_decay: float = 0.0): - super().__init__( - params, - dict(lr=lr, momentum=momentum, backend_steps=backend_steps, - nesterov=nesterov, weight_decay=weight_decay), - ) - - @torch.no_grad() - def step(self, closure=None): - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - distributed = dist.is_available() and dist.is_initialized() - world_size = dist.get_world_size() if distributed else 1 - rank = dist.get_rank() if distributed else 0 - - for group in self.param_groups: - params = group["params"] - if not params: - continue - lr = group["lr"] - momentum = group["momentum"] - backend_steps = group["backend_steps"] - nesterov = group["nesterov"] - - total_params = sum(int(p.numel()) for p in params) - updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) - - curr = 0 - for i, p in enumerate(params): - if i % world_size == rank and p.grad is not None: - g = p.grad - state = self.state[p] - if "momentum_buffer" not in state: - state["momentum_buffer"] = torch.zeros_like(g) - buf = state["momentum_buffer"] - buf.mul_(momentum).add_(g) - if nesterov: - g = g.add(buf, alpha=momentum) - g = zeropower_via_newtonschulz5(g, steps=backend_steps) - g *= max(1, g.size(0) / g.size(1)) ** 0.5 - updates_flat[curr : curr + p.numel()] = g.reshape(-1) - curr += p.numel() - - if distributed: - dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) - - wd = group.get("weight_decay", 0.0) - curr = 0 - for p in params: - if wd > 0.0: - p.data.mul_(1.0 - lr * wd) - g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) - p.add_(g, alpha=-lr) - curr += p.numel() - - return loss - - -# ----------------------------- -# TOKENIZER-AGNOSTIC EVALUATION SETUP -# ----------------------------- -# -# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. -# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. -# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. -# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. - -def build_sentencepiece_luts( - sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device -) -> tuple[Tensor, Tensor, Tensor]: - sp_vocab_size = int(sp.vocab_size()) - table_size = max(sp_vocab_size, vocab_size) - base_bytes_np = np.zeros((table_size,), dtype=np.int16) - has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) - is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) - for token_id in range(sp_vocab_size): - if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): - continue - is_boundary_token_np[token_id] = False - if sp.is_byte(token_id): - base_bytes_np[token_id] = 1 - continue - piece = sp.id_to_piece(token_id) - if piece.startswith("▁"): - has_leading_space_np[token_id] = True - piece = piece[1:] - base_bytes_np[token_id] = len(piece.encode("utf-8")) - return ( - torch.tensor(base_bytes_np, dtype=torch.int16, device=device), - torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), - torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), - ) - - -def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: - files = [Path(p) for p in sorted(glob.glob(pattern))] - if not files: - raise FileNotFoundError(f"No files found for pattern: {pattern}") - # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. - tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() - usable = ((tokens.numel() - 1) // seq_len) * seq_len - if usable <= 0: - raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") - return tokens[: usable + 1] - - -def eval_val( - args: Hyperparameters, - model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - grad_accum_steps: int, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, - eval_seq_len: int | None = None, -) -> tuple[float, float]: - seq_len = eval_seq_len or args.train_seq_len - local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) - if local_batch_tokens < seq_len: - raise ValueError( - "VAL_BATCH_SIZE must provide at least one sequence per rank; " - f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " - f"GRAD_ACCUM_STEPS={grad_accum_steps}, seq_len={seq_len}" - ) - local_batch_seqs = local_batch_tokens // seq_len - total_seqs = (val_tokens.numel() - 1) // seq_len - seq_start = (total_seqs * rank) // world_size - seq_end = (total_seqs * (rank + 1)) // world_size - val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) - val_token_count = torch.zeros((), device=device, dtype=torch.float64) - val_byte_count = torch.zeros((), device=device, dtype=torch.float64) - - model.eval() - with torch.inference_mode(): - for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): - batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) - raw_start = batch_seq_start * seq_len - raw_end = batch_seq_end * seq_len + 1 - local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) - x = local[:-1].reshape(-1, seq_len) - y = local[1:].reshape(-1, seq_len) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - batch_loss = model(x, y).detach() - batch_token_count = float(y.numel()) - val_loss_sum += batch_loss.to(torch.float64) * batch_token_count - val_token_count += batch_token_count - prev_ids = x.reshape(-1) - tgt_ids = y.reshape(-1) - token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) - token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) - val_byte_count += token_bytes.to(torch.float64).sum() - - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) - - val_loss = val_loss_sum / val_token_count - bits_per_token = val_loss.item() / math.log(2.0) - tokens_per_byte = val_token_count.item() / val_byte_count.item() - model.train() - return float(val_loss.item()), float(bits_per_token * tokens_per_byte) - -# ----------------------------- -# POST-TRAINING QUANTIZATION -# ----------------------------- -# -# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. -# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. -# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. - -CONTROL_TENSOR_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "CONTROL_TENSOR_NAME_PATTERNS", - "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,smear", - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", - ",".join(CONTROL_TENSOR_NAME_PATTERNS), - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 -INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 -INT8_PER_ROW_SCALE_DTYPE = torch.float16 -INT8_CLIP_PERCENTILE = 99.99984 -INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 - -def tensor_nbytes(t: Tensor) -> int: - return int(t.numel()) * int(t.element_size()) - -def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: - if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): - return t.float().contiguous() - if t.dtype in {torch.float32, torch.bfloat16}: - passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") - return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() - return t - -def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: - t32 = t.float() - if t32.ndim == 2: - # Matrices get one scale per row, which usually tracks output-channel - # ranges much better than a single tensor-wide scale. - clip_abs = ( - torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) - if t32.numel() - else torch.empty((t32.shape[0],), dtype=torch.float32) - ) - clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) - scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) - q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() - return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() - - # Vectors / scalars use a simpler per-tensor scale. - clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 - scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) - q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() - return q, scale - -def quantize_state_dict_int8(state_dict: dict[str, Tensor]): - # Single supported clean-script export format: - # - per-row int8 for 2D float tensors - # - per-tensor int8 for other float tensors - # - exact passthrough for non-floats - # - passthrough for small float tensors, stored as fp16 to save bytes - quantized: dict[str, Tensor] = {} - scales: dict[str, Tensor] = {} - dtypes: dict[str, str] = {} - passthrough: dict[str, Tensor] = {} - passthrough_orig_dtypes: dict[str, str] = {} - qmeta: dict[str, dict[str, object]] = {} - stats = dict.fromkeys( - ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), - 0, - ) - - for name, tensor in state_dict.items(): - t = tensor.detach().to("cpu").contiguous() - stats["param_count"] += int(t.numel()) - stats["num_tensors"] += 1 - stats["baseline_tensor_bytes"] += tensor_nbytes(t) - - if not t.is_floating_point(): - stats["num_nonfloat_tensors"] += 1 - passthrough[name] = t - stats["int8_payload_bytes"] += tensor_nbytes(t) - continue - - # Small float tensors are cheap enough to keep directly. We still downcast - # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. - if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: - kept = keep_float_tensor(name, t, passthrough_orig_dtypes) - passthrough[name] = kept - stats["int8_payload_bytes"] += tensor_nbytes(kept) - continue - - stats["num_float_tensors"] += 1 - q, s = quantize_float_tensor(t) - if s.ndim > 0: - qmeta[name] = {"scheme": "per_row", "axis": 0} - quantized[name] = q - scales[name] = s - dtypes[name] = str(t.dtype).removeprefix("torch.") - stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) - - obj: dict[str, object] = { - "__quant_format__": "int8_clean_per_row_v1", - "quantized": quantized, - "scales": scales, - "dtypes": dtypes, - "passthrough": passthrough, - } - if qmeta: - obj["qmeta"] = qmeta - if passthrough_orig_dtypes: - obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes - return obj, stats - -def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: - out: dict[str, Tensor] = {} - qmeta = obj.get("qmeta", {}) - passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) - for name, q in obj["quantized"].items(): - dtype = getattr(torch, obj["dtypes"][name]) - s = obj["scales"][name] - if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: - s = s.to(dtype=torch.float32) - # Broadcast the saved row scale back across trailing dimensions. - out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() - else: - scale = float(s.item()) - out[name] = (q.float() * scale).to(dtype=dtype).contiguous() - for name, t in obj["passthrough"].items(): - # Restore small tensors, undoing the temporary fp16 storage cast if needed. - out_t = t.detach().to("cpu").contiguous() - orig_dtype = passthrough_orig_dtypes.get(name) - if isinstance(orig_dtype, str): - out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() - out[name] = out_t - return out - - -# ----------------------------- -# DATA LOADING -# ----------------------------- - -def load_data_shard(file: Path) -> Tensor: - header_bytes = 256 * np.dtype(" None: - self.file_idx = (self.file_idx + 1) % len(self.files) - self.tokens = load_data_shard(self.files[self.file_idx]) - self.pos = 0 - - def take(self, n: int) -> Tensor: - chunks: list[Tensor] = [] - remaining = n - while remaining > 0: - avail = self.tokens.numel() - self.pos - if avail <= 0: - self._advance_file() - continue - k = min(remaining, avail) - chunks.append(self.tokens[self.pos : self.pos + k]) - self.pos += k - remaining -= k - return chunks[0] if len(chunks) == 1 else torch.cat(chunks) - - -class DistributedTokenLoader: - # Each call consumes a contiguous chunk from the shared token stream, then slices out - # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. - def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): - self.rank = rank - self.world_size = world_size - self.device = device - self.stream = TokenStream(pattern) - - def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: - local_tokens = global_tokens // (self.world_size * grad_accum_steps) - per_rank_span = local_tokens + 1 - chunk = self.stream.take(per_rank_span * self.world_size) - start = self.rank * per_rank_span - local = chunk[start : start + per_rank_span].to(dtype=torch.int64) - x = local[:-1].reshape(-1, seq_len) - y = local[1:].reshape(-1, seq_len) - return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) - -# ----------------------------- -# TRANSFORMER MODULES -# ----------------------------- - -class RMSNorm(nn.Module): - def __init__(self, eps: float | None = None): - super().__init__() - self.eps = eps - - def forward(self, x: Tensor) -> Tensor: - return F.rms_norm(x, (x.size(-1),), eps=self.eps) - - -class CastedLinear(nn.Linear): - _qat_enabled: bool = False - - def forward(self, x: Tensor) -> Tensor: - w = self.weight.to(x.dtype) - if CastedLinear._qat_enabled and self.training and w.ndim == 2: - with torch.no_grad(): - w32 = self.weight.float() - row_max = w32.abs().amax(dim=1) - scale = (row_max / 31.0).clamp_min(1.0 / 31.0) - w_q = (torch.clamp(torch.round(w32 / scale[:, None]), -32, 31) * scale[:, None]).to(x.dtype) - w = w + (w_q - w).detach() - bias = self.bias.to(x.dtype) if self.bias is not None else None - return F.linear(x, w, bias) - - -def restore_low_dim_params_to_fp32(module: nn.Module) -> None: - # Keep small/control parameters in fp32 even when the model body runs in bf16. - with torch.no_grad(): - for name, param in module.named_parameters(): - if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: - param.data = param.data.float() - - -class Rotary(nn.Module): - # NTK-aware RoPE: auto-scales base frequency when seq_len exceeds train_seq_len. - def __init__(self, dim: int, base: float = 10000.0, train_seq_len: int = 1024): - super().__init__() - self.dim = dim - self.base = base - self.train_seq_len = train_seq_len - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self._seq_len_cached = 0 - self._cos_cached: Tensor | None = None - self._sin_cached: Tensor | None = None - - def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: - if ( - self._cos_cached is None - or self._sin_cached is None - or self._seq_len_cached != seq_len - or self._cos_cached.device != device - ): - if seq_len > self.train_seq_len: - scale = seq_len / self.train_seq_len - new_base = self.base * (scale ** (self.dim / (self.dim - 2))) - inv_freq = 1.0 / (new_base ** (torch.arange(0, self.dim, 2, dtype=torch.float32, device=device) / self.dim)) - else: - inv_freq = self.inv_freq.to(device) - t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) - freqs = torch.outer(t, inv_freq) - self._cos_cached = freqs.cos()[None, :, None, :] - self._sin_cached = freqs.sin()[None, :, None, :] - self._seq_len_cached = seq_len - return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) - - -def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: - half = x.size(-1) // 2 - x1, x2 = x[..., :half], x[..., half:] - return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) - - -class CausalSelfAttention(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - rope_base: float, - qk_gain_init: float, - ): - super().__init__() - if dim % num_heads != 0: - raise ValueError("model_dim must be divisible by num_heads") - if num_heads % num_kv_heads != 0: - raise ValueError("num_heads must be divisible by num_kv_heads") - self.num_heads = num_heads - self.num_kv_heads = num_kv_heads - self.head_dim = dim // num_heads - if self.head_dim % 2 != 0: - raise ValueError("head_dim must be even for RoPE") - kv_dim = self.num_kv_heads * self.head_dim - self.c_q = CastedLinear(dim, dim, bias=False) - self.c_k = CastedLinear(dim, kv_dim, bias=False) - self.c_v = CastedLinear(dim, kv_dim, bias=False) - self.proj = CastedLinear(dim, dim, bias=False) - self.proj._zero_init = True - self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) - self.rotary = Rotary(self.head_dim, base=rope_base, train_seq_len=1024) - self.use_xsa = False - - def _xsa_efficient(self, y: Tensor, v: Tensor) -> Tensor: - """Subtract self-value projection via GQA-aware reshape (no repeat_interleave).""" - B, T, H, D = y.shape - Hkv = v.size(-2) - group = H // Hkv - y_g = y.reshape(B, T, Hkv, group, D) - vn = F.normalize(v, dim=-1).unsqueeze(-2) - proj = (y_g * vn).sum(dim=-1, keepdim=True) * vn - return (y_g - proj).reshape(B, T, H, D) - - def forward(self, x: Tensor) -> Tensor: - bsz, seqlen, dim = x.shape - q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim) - k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) - v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) - q = F.rms_norm(q, (q.size(-1),)) - k = F.rms_norm(k, (k.size(-1),)) - cos, sin = self.rotary(seqlen, x.device, q.dtype) - q = apply_rotary_emb(q, cos, sin) - k = apply_rotary_emb(k, cos, sin) - q = q * self.q_gain.to(dtype=q.dtype)[None, None, :, None] - y = flash_attn_3_func(q, k, v, causal=True) - if self.use_xsa: - y = self._xsa_efficient(y, v) - y = y.reshape(bsz, seqlen, dim) - return self.proj(y) - - -class SmearGate(nn.Module): - def __init__(self, dim: int): - super().__init__() - self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) - - def forward(self, x: Tensor) -> Tensor: - g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] - x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) - return (1 - g) * x + g * x_prev - - -class BigramHashEmbedding(nn.Module): - def __init__(self, bigram_vocab_size: int, bigram_dim: int, model_dim: int): - super().__init__() - self.bigram_vocab_size = bigram_vocab_size - self.embed = nn.Embedding(bigram_vocab_size, bigram_dim) - nn.init.zeros_(self.embed.weight) - self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None - if self.proj is not None: - nn.init.zeros_(self.proj.weight) - self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) - - def bigram_hash(self, tokens: Tensor) -> Tensor: - t = tokens.to(torch.int32) - mod = self.bigram_vocab_size - 1 - out = torch.empty_like(t) - out[..., 0] = mod - out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod - return out.long() - - def forward(self, token_ids: Tensor) -> Tensor: - h = self.embed(self.bigram_hash(token_ids)) - if self.proj is not None: - h = self.proj(h) - return h * self.scale.to(dtype=h.dtype) - - -class MLP(nn.Module): - def __init__(self, dim: int, mlp_mult: int): - super().__init__() - hidden = int(mlp_mult * dim) - self.fc = CastedLinear(dim, hidden, bias=False) - self.proj = CastedLinear(hidden, dim, bias=False) - self.proj._zero_init = True - - def forward(self, x: Tensor) -> Tensor: - x = torch.relu(self.fc(x)) - return self.proj(x.square()) - - -class Block(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: int, - rope_base: float, - qk_gain_init: float, - ): - super().__init__() - self.attn_norm = RMSNorm() - self.mlp_norm = RMSNorm() - self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) - self.mlp = MLP(dim, mlp_mult) - self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) - - def forward(self, x: Tensor, x0: Tensor) -> Tensor: - mix = self.resid_mix.to(dtype=x.dtype) - x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 - attn_out = self.attn(self.attn_norm(x)) - x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out - x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) - return x - - -class GPT(nn.Module): - def __init__( - self, - vocab_size: int, - num_layers: int, - model_dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: int, - tie_embeddings: bool, - tied_embed_init_std: float, - logit_softcap: float, - rope_base: float, - qk_gain_init: float, - mtp_num_heads: int = 0, - mtp_loss_weight: float = 0.1, - bigram_vocab_size: int = 0, - bigram_dim: int = 128, - xsa_last_n: int = 0, - ): - super().__init__() - if logit_softcap <= 0.0: - raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") - self.tie_embeddings = tie_embeddings - self.tied_embed_init_std = tied_embed_init_std - self.logit_softcap = logit_softcap - self.mtp_num_heads = mtp_num_heads - self.mtp_loss_weight = mtp_loss_weight - self.tok_emb = nn.Embedding(vocab_size, model_dim) - self.bigram = BigramHashEmbedding(bigram_vocab_size, bigram_dim, model_dim) if bigram_vocab_size > 0 else None - self.smear = SmearGate(model_dim) - self.num_encoder_layers = num_layers // 2 - self.num_decoder_layers = num_layers - self.num_encoder_layers - self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) - self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) - self.blocks = nn.ModuleList( - [ - Block( - model_dim, - num_heads, - num_kv_heads, - mlp_mult, - rope_base, - qk_gain_init, - ) - for i in range(num_layers) - ] - ) - self.final_norm = RMSNorm() - self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) - if self.lm_head is not None: - self.lm_head._zero_init = True - self.mtp_heads = nn.ModuleList( - [CastedLinear(model_dim, vocab_size, bias=False) for _ in range(mtp_num_heads)] - ) - for head in self.mtp_heads: - head._zero_init = True - if xsa_last_n > 0: - for i in range(max(0, num_layers - xsa_last_n), num_layers): - self.blocks[i].attn.use_xsa = True - self._init_weights() - - def _init_weights(self) -> None: - if self.tie_embeddings: - nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) - num_layers = len(self.blocks) - for name, module in self.named_modules(): - if isinstance(module, nn.Linear): - if getattr(module, "_zero_init", False): - nn.init.zeros_(module.weight) - elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: - nn.init.orthogonal_(module.weight, gain=1.0) - if ".proj." in name or name.endswith(".proj"): - with torch.no_grad(): - module.weight.mul_(1.0 / math.sqrt(2 * num_layers)) - - def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: - x = self.tok_emb(input_ids) - if self.bigram is not None: - x = x + self.bigram(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - x = self.smear(x) - x0 = x - skips: list[Tensor] = [] - - for i in range(self.num_encoder_layers): - x = self.blocks[i](x, x0) - skips.append(x) - for i in range(self.num_decoder_layers): - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - x = self.blocks[self.num_encoder_layers + i](x, x0) - - x = self.final_norm(x) - x_flat = x.reshape(-1, x.size(-1)) - targets = target_ids.reshape(-1) - if self.tie_embeddings: - logits_proj = F.linear(x_flat, self.tok_emb.weight) - else: - if self.lm_head is None: - raise RuntimeError("lm_head is required when tie_embeddings=False") - logits_proj = self.lm_head(x_flat) - logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - main_loss = F.cross_entropy(logits.float(), targets, reduction="mean") - - if self.training and self.mtp_num_heads > 0 and self.mtp_loss_weight > 0.0: - _, seqlen, dim = x.shape - mtp_loss_sum = x.new_zeros(()) - mtp_loss_count = 0 - for k, mtp_head in enumerate(self.mtp_heads): - valid_t = seqlen - (k + 1) - if valid_t <= 0: - continue - mtp_hidden = x[:, :valid_t, :].reshape(-1, dim) - mtp_targets = target_ids[:, k + 1 :].reshape(-1) - mtp_logits_proj = mtp_head(mtp_hidden) - mtp_logits = self.logit_softcap * torch.tanh(mtp_logits_proj / self.logit_softcap) - mtp_loss_sum = mtp_loss_sum + F.cross_entropy(mtp_logits.float(), mtp_targets, reduction="mean") - mtp_loss_count += 1 - if mtp_loss_count > 0: - main_loss = main_loss + self.mtp_loss_weight * (mtp_loss_sum / mtp_loss_count) - - return main_loss - - def forward_logits(self, input_ids: Tensor) -> Tensor: - """Return logits (bsz, seq_len, vocab) without computing loss.""" - x = self.tok_emb(input_ids) - if self.bigram is not None: - x = x + self.bigram(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - x = self.smear(x) - x0 = x - skips: list[Tensor] = [] - for i in range(self.num_encoder_layers): - x = self.blocks[i](x, x0) - skips.append(x) - for i in range(self.num_decoder_layers): - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - x = self.blocks[self.num_encoder_layers + i](x, x0) - x = self.final_norm(x) - if self.tie_embeddings: - logits_proj = F.linear(x, self.tok_emb.weight) - else: - logits_proj = self.lm_head(x) - return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - - -# ----------------------------- -# SLIDING WINDOW EVALUATION -# ----------------------------- - -def eval_val_sliding( - args: Hyperparameters, - base_model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, - stride: int, - batch_seqs: int = 32, - eval_seq_len: int | None = None, -) -> tuple[float, float]: - """Sliding window evaluation: each token scored with maximum context.""" - seq_len = eval_seq_len or args.train_seq_len - total_tokens = val_tokens.numel() - 1 - - window_starts = [ws for ws in range(0, total_tokens, stride) - if min(ws + seq_len, total_tokens) - ws >= 1] - total_windows = len(window_starts) - - my_s = (total_windows * rank) // world_size - my_e = (total_windows * (rank + 1)) // world_size - my_windows = window_starts[my_s:my_e] - - loss_sum = torch.zeros((), device=device, dtype=torch.float64) - token_count = torch.zeros((), device=device, dtype=torch.float64) - byte_count = torch.zeros((), device=device, dtype=torch.float64) - - base_model.eval() - compiled_logits = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) - - with torch.inference_mode(): - for bi in range(0, len(my_windows), batch_seqs): - batch_ws = my_windows[bi:bi + batch_seqs] - bsz = len(batch_ws) - - x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) - y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) - wlens: list[int] = [] - - for i, ws in enumerate(batch_ws): - end = min(ws + seq_len, total_tokens) - wlen = end - ws - wlens.append(wlen) - chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) - x_batch[i, :wlen] = chunk[:-1] - y_batch[i, :wlen] = chunk[1:] - - with torch.autocast(device_type="cuda", dtype=torch.bfloat16): - logits = compiled_logits(x_batch) - - nll = F.cross_entropy( - logits.reshape(-1, logits.size(-1)).float(), - y_batch.reshape(-1), - reduction="none", - ).reshape(bsz, seq_len) - - for i, ws in enumerate(batch_ws): - wlen = wlens[i] - s = 0 if ws == 0 else max(wlen - stride, 0) - scored_nll = nll[i, s:wlen].to(torch.float64) - loss_sum += scored_nll.sum() - token_count += float(wlen - s) - tgt = y_batch[i, s:wlen] - prev = x_batch[i, s:wlen] - tb = base_bytes_lut[tgt].to(torch.float64) - tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) - byte_count += tb.sum() - - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) - - val_loss = (loss_sum / token_count).item() - bits_per_token = val_loss / math.log(2.0) - tokens_per_byte = token_count.item() / byte_count.item() - base_model.train() - return val_loss, bits_per_token * tokens_per_byte - - -# ----------------------------- -# INT6 MIXED QUANTIZATION (transplanted from working diagnostic scripts) -# ----------------------------- - -def _classify_param(name: str) -> str: - if "tok_emb" in name or "lm_head" in name: - return "embed" - if ".mlp." in name: - return "mlp" - if ".attn." in name or (".proj." in name and ".mlp." not in name): - return "attn" - return "other" - -def quantize_int6_per_row(t: Tensor) -> tuple[Tensor, Tensor]: - t32 = t.float() - if t32.ndim == 2: - row_max = t32.abs().amax(dim=1) - scale = (row_max / 31.0).clamp_min(1.0 / 31.0).to(torch.float16) - q = torch.clamp(torch.round(t32 / scale.float()[:, None]), -32, 31).to(torch.int8) - return q, scale - amax = t32.abs().max().item() - scale = torch.tensor(amax / 31.0 if amax > 0 else 1.0, dtype=torch.float16) - q = torch.clamp(torch.round(t32 / scale.float()), -32, 31).to(torch.int8) - return q, scale - -def mixed_quantize_int6(state_dict: dict[str, Tensor], int6_cats: set[str]): - num_layers_total = max( - (int(k.split(".")[1]) for k in state_dict if k.startswith("blocks.")), - default=0, - ) + 1 - late_k_layers = set(range(num_layers_total - 2, num_layers_total)) - - result: dict[str, Tensor] = {} - meta: dict[str, object] = {} - for name, tensor in state_dict.items(): - t = tensor.detach().cpu().contiguous() - cat = _classify_param(name) - if not t.is_floating_point() or t.numel() <= 65536: - result[name] = t.to(torch.float16) if t.is_floating_point() else t - meta[name] = "passthrough" - continue - if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): - result[name] = t.float() - meta[name] = "passthrough_ctrl" - continue - # tok_emb.weight falls through to int8 via "embed" category - if cat in int6_cats and t.ndim >= 1: - q, s = quantize_int6_per_row(t) - result[name + ".q"] = q - result[name + ".scale"] = s - meta[name] = {"type": "int6"} - else: - q, s = quantize_float_tensor(t) - result[name + ".q"] = q - result[name + ".scale"] = s - meta[name] = {"type": "int8"} - return result, meta - -def dequantize_mixed_int6(result: dict[str, Tensor], meta: dict[str, object], - template_sd: dict[str, Tensor]) -> dict[str, Tensor]: - out: dict[str, Tensor] = {} - for name, orig in template_sd.items(): - info = meta.get(name) - if info is None: - continue - orig_dtype = orig.dtype - if info in ("passthrough", "passthrough_ctrl", "passthrough_fp16"): - t = result[name] - if t.dtype == torch.float16 and orig_dtype in (torch.float32, torch.bfloat16): - t = t.to(orig_dtype) - out[name] = t - continue - q, s = result[name + ".q"], result[name + ".scale"] - if s.ndim > 0: - out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig_dtype) - else: - out[name] = (q.float() * float(s.item())).to(orig_dtype) - return out - - -# ----------------------------- -# TRAINING -# ----------------------------- - -def main() -> None: - global zeropower_via_newtonschulz5 - - code = Path(__file__).read_text(encoding="utf-8") - args = Hyperparameters() - zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) - - # ----------------------------- - # DISTRIBUTED + CUDA SETUP - # ----------------------------- - - distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ - rank = int(os.environ.get("RANK", "0")) - world_size = int(os.environ.get("WORLD_SIZE", "1")) - local_rank = int(os.environ.get("LOCAL_RANK", "0")) - if world_size <= 0: - raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") - if 8 % world_size != 0: - raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") - grad_accum_steps = 8 // world_size - grad_scale = 1.0 / grad_accum_steps - if not torch.cuda.is_available(): - raise RuntimeError("CUDA is required") - device = torch.device("cuda", local_rank) - torch.cuda.set_device(device) - if distributed: - dist.init_process_group(backend="nccl", device_id=device) - dist.barrier() - master_process = rank == 0 - - # Fast math knobs - torch.backends.cuda.matmul.allow_tf32 = True - torch.backends.cudnn.allow_tf32 = True - from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp - - enable_cudnn_sdp(False) - enable_flash_sdp(True) - enable_mem_efficient_sdp(False) - enable_math_sdp(False) - - logfile = None - if master_process: - os.makedirs("logs", exist_ok=True) - logfile = f"logs/{args.run_id}.txt" - print(logfile) - - def log0(msg: str, console: bool = True) -> None: - if not master_process: - return - if console: - print(msg) - if logfile is not None: - with open(logfile, "a", encoding="utf-8") as f: - print(msg, file=f) - - log0(code, console=False) - log0("=" * 100, console=False) - log0(f"Running Python {sys.version}", console=False) - log0(f"Running PyTorch {torch.__version__}", console=False) - log0( - subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, - console=False, - ) - log0("=" * 100, console=False) - - # ----------------------------- - # TOKENIZER + VALIDATION METRIC SETUP - # ----------------------------- - - random.seed(args.seed) - np.random.seed(args.seed) - torch.manual_seed(args.seed) - torch.cuda.manual_seed_all(args.seed) - - if not args.tokenizer_path.endswith(".model"): - raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") - sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) - if int(sp.vocab_size()) != args.vocab_size: - raise ValueError( - f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" - ) - dataset_dir = Path(args.data_path).resolve() - actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) - effective_eval_seq_len = args.eval_seq_len if args.eval_seq_len > 0 else args.train_seq_len - val_seq_len = max(args.train_seq_len, effective_eval_seq_len) - val_tokens = load_validation_tokens(args.val_files, val_seq_len) - base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( - sp, args.vocab_size, device - ) - log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") - log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") - log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") - - # ----------------------------- - # MODEL + OPTIMIZER SETUP - # ----------------------------- - - CastedLinear._qat_enabled = args.qat_enabled - - base_model = GPT( - vocab_size=args.vocab_size, - num_layers=args.num_layers, - model_dim=args.model_dim, - num_heads=args.num_heads, - num_kv_heads=args.num_kv_heads, - mlp_mult=args.mlp_mult, - tie_embeddings=args.tie_embeddings, - tied_embed_init_std=args.tied_embed_init_std, - logit_softcap=args.logit_softcap, - rope_base=args.rope_base, - qk_gain_init=args.qk_gain_init, - mtp_num_heads=args.mtp_num_heads, - mtp_loss_weight=args.mtp_loss_weight, - bigram_vocab_size=args.bigram_vocab_size, - bigram_dim=args.bigram_dim, - xsa_last_n=args.xsa_last_n, - ).to(device).bfloat16() - for module in base_model.modules(): - if isinstance(module, CastedLinear): - module.float() - restore_low_dim_params_to_fp32(base_model) - compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) - model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model - - # Optimizer split: - # - token embedding (Adam) uses EMBED_LR - # - untied lm_head (Adam) uses HEAD_LR - # - matrix params in transformer blocks use MATRIX_LR via Muon - # - vectors/scalars use SCALAR_LR via Adam - block_named_params = list(base_model.blocks.named_parameters()) - matrix_params = [ - p - for name, p in block_named_params - if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - if base_model.mtp_num_heads > 0: - matrix_params.extend([p for p in base_model.mtp_heads.parameters() if p.ndim == 2]) - scalar_params = [ - p - for name, p in block_named_params - if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - if base_model.skip_weights.numel() > 0: - scalar_params.append(base_model.skip_weights) - scalar_params.append(base_model.smear.gate) - if base_model.bigram is not None: - scalar_params.append(base_model.bigram.scale) - token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr - tok_params = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] - if base_model.bigram is not None: - tok_params.append({"params": [base_model.bigram.embed.weight], "lr": token_lr, "base_lr": token_lr}) - if base_model.bigram.proj is not None: - matrix_params.append(base_model.bigram.proj.weight) - optimizer_tok = torch.optim.AdamW( - tok_params, - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - weight_decay=args.adam_wd, - fused=True, - ) - optimizer_muon = Muon( - matrix_params, - lr=args.matrix_lr, - momentum=args.muon_momentum, - backend_steps=args.muon_backend_steps, - weight_decay=args.muon_wd, - ) - for group in optimizer_muon.param_groups: - group["base_lr"] = args.matrix_lr - optimizer_scalar = torch.optim.AdamW( - [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - weight_decay=args.adam_wd, - fused=True, - ) - optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] - if base_model.lm_head is not None: - optimizer_head = torch.optim.Adam( - [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizers.insert(1, optimizer_head) - - n_params = sum(p.numel() for p in base_model.parameters()) - mtp_params = sum(p.numel() for p in base_model.mtp_heads.parameters()) - log0(f"model_params:{n_params}") - log0(f"mtp_num_heads:{args.mtp_num_heads} mtp_loss_weight:{args.mtp_loss_weight} mtp_params:{mtp_params}") - log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") - log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") - log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") - log0( - f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " - f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " - f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" - ) - log0( - f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " - f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " - f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" - ) - log0(f"seed:{args.seed}") - - # ----------------------------- - # DATA LOADER & MODEL WARMUP - # ----------------------------- - - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - def zero_grad_all() -> None: - for opt in optimizers: - opt.zero_grad(set_to_none=True) - - max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None - - def lr_mul(step: int, elapsed_ms: float) -> float: - if args.warmdown_iters <= 0: - return 1.0 - if max_wallclock_ms is None: - warmdown_start = max(args.iterations - args.warmdown_iters, 0) - return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 - step_ms = elapsed_ms / max(step, 1) - warmdown_ms = args.warmdown_iters * step_ms - remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) - return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 - - # Warmup primes the compiled forward/backward/optimizer paths, then we restore the - # initial weights/optimizer state so measured training starts from the true init. - if args.warmup_steps > 0: - initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} - initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] - model.train() - for warmup_step in range(args.warmup_steps): - zero_grad_all() - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - warmup_loss = model(x, y) - (warmup_loss * grad_scale).backward() - for opt in optimizers: - opt.step() - zero_grad_all() - if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: - log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") - base_model.load_state_dict(initial_model_state, strict=True) - for opt, state in zip(optimizers, initial_optimizer_states, strict=True): - opt.load_state_dict(state) - zero_grad_all() - if distributed: - model.require_backward_grad_sync = True - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - # ----------------------------- - # MAIN TRAINING LOOP - # ----------------------------- - - swa_state: dict[str, Tensor] | None = None - swa_count = 0 - - ema_state: dict[str, Tensor] | None = None - if args.ema_enabled: - ema_state = {name: t.detach().float().clone() for name, t in base_model.state_dict().items()} - - training_time_ms = 0.0 - stop_after_step: int | None = None - torch.cuda.synchronize() - t0 = time.perf_counter() - - step = 0 - while True: - last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) - - should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) - if should_validate: - torch.cuda.synchronize() - training_time_ms += 1000.0 * (time.perf_counter() - t0) - val_loss, val_bpb = eval_val( - args, - model, - rank, - world_size, - device, - grad_accum_steps, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, - ) - log0( - f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " - f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" - ) - torch.cuda.synchronize() - t0 = time.perf_counter() - - if last_step: - if stop_after_step is not None and step < args.iterations: - log0( - f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " - f"step:{step}/{args.iterations}" - ) - break - - elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - scale = lr_mul(step, elapsed_ms) - zero_grad_all() - train_loss = torch.zeros((), device=device) - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - loss = model(x, y) - train_loss += loss.detach() - (loss * grad_scale).backward() - train_loss /= grad_accum_steps - - frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 - muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum - for group in optimizer_muon.param_groups: - group["momentum"] = muon_momentum - - for opt in optimizers: - for group in opt.param_groups: - group["lr"] = group["base_lr"] * scale - - if args.grad_clip_norm > 0: - torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) - for opt in optimizers: - opt.step() - zero_grad_all() - - step += 1 - approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - - if ema_state is not None: - d = args.ema_decay - with torch.no_grad(): - for name, t in base_model.state_dict().items(): - ema_state[name].mul_(d).add_(t.detach().float(), alpha=1.0 - d) - - if args.swa_enabled and not args.ema_enabled and scale < 0.5 and step % args.swa_every == 0: - if swa_state is None: - swa_state = {name: t.detach().float().clone() for name, t in base_model.state_dict().items()} - swa_count = 1 - log0(f"swa:start step:{step}") - else: - for name, t in base_model.state_dict().items(): - swa_state[name].add_(t.detach().float()) - swa_count += 1 - - should_log_train = ( - args.train_log_every > 0 - and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) - ) - if should_log_train: - log0( - f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " - f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" - ) - - # Needed to sync whether we've reached the wallclock cap. - reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms - if distributed and max_wallclock_ms is not None: - reached_cap_tensor = torch.tensor(int(reached_cap), device=device) - dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) - reached_cap = bool(reached_cap_tensor.item()) - if stop_after_step is None and reached_cap: - stop_after_step = step - - log0( - f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " - f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" - ) - - if ema_state is not None: - log0("ema:applying EMA weights") - avg_state = {name: t.to(dtype=base_model.state_dict()[name].dtype) - for name, t in ema_state.items()} - del ema_state - base_model.load_state_dict(avg_state, strict=True) - elif args.swa_enabled and swa_state is not None and swa_count > 1: - log0(f"swa:applying averaged {swa_count} checkpoints") - avg_state = {name: (t / swa_count).to(dtype=base_model.state_dict()[name].dtype) - for name, t in swa_state.items()} - del swa_state - base_model.load_state_dict(avg_state, strict=True) - - # ----------------------------- - # SERIALIZATION + ROUNDTRIP VALIDATION - # ----------------------------- - - full_state_dict = base_model.state_dict() - export_sd = {k: v for k, v in full_state_dict.items() if "mtp_heads" not in k} - excluded_mtp = sum(int(t.numel()) for k, t in full_state_dict.items() if "mtp_heads" in k) - if excluded_mtp > 0: - log0(f"export_excluding_mtp_params:{excluded_mtp}") - - if master_process: - torch.save(export_sd, "final_model.pt") - model_bytes = os.path.getsize("final_model.pt") - code_bytes = len(code.encode("utf-8")) - log0(f"Serialized model: {model_bytes} bytes") - log0(f"Code size: {code_bytes} bytes") - - sd_cpu = {k: v.detach().cpu() for k, v in export_sd.items()} - quant_result, quant_meta = mixed_quantize_int6(sd_cpu, {"mlp", "attn"}) - quant_buf = io.BytesIO() - torch.save({"w": quant_result, "m": quant_meta}, quant_buf) - quant_raw = quant_buf.getvalue() - quant_blob = zstandard.ZstdCompressor(level=22).compress(quant_raw) if _COMPRESSOR == "zstd" else zlib.compress(quant_raw, 9) - if master_process: - with open("final_model.int6.ptz", "wb") as f: - f.write(quant_blob) - quant_file_bytes = len(quant_blob) - code_bytes = len(code.encode("utf-8")) - log0(f"Serialized model int6+{_COMPRESSOR}: {quant_file_bytes} bytes") - log0(f"Total submission size int6+{_COMPRESSOR}: {quant_file_bytes + code_bytes} bytes") - - # Roundtrip: decompress + dequantize into fresh model + eval - if distributed: - dist.barrier() - with open("final_model.int6.ptz", "rb") as f: - quant_blob_disk = f.read() - quant_state = torch.load( - io.BytesIO(zstandard.ZstdDecompressor().decompress(quant_blob_disk) if _COMPRESSOR == "zstd" else zlib.decompress(quant_blob_disk)), - map_location="cpu", - ) - deq_state = dequantize_mixed_int6(quant_state["w"], quant_state["m"], sd_cpu) - - eval_model = GPT( - vocab_size=args.vocab_size, num_layers=args.num_layers, model_dim=args.model_dim, - num_heads=args.num_heads, num_kv_heads=args.num_kv_heads, mlp_mult=args.mlp_mult, - tie_embeddings=args.tie_embeddings, tied_embed_init_std=args.tied_embed_init_std, - logit_softcap=args.logit_softcap, rope_base=args.rope_base, qk_gain_init=args.qk_gain_init, - mtp_num_heads=0, mtp_loss_weight=0.0, - bigram_vocab_size=args.bigram_vocab_size, bigram_dim=args.bigram_dim, - xsa_last_n=args.xsa_last_n, - ).to(device).bfloat16() - for m in eval_model.modules(): - if isinstance(m, CastedLinear): - m.float() - restore_low_dim_params_to_fp32(eval_model) - eval_model.load_state_dict(deq_state, strict=True) - compiled_eval = torch.compile(eval_model, dynamic=False, fullgraph=True) - - # Standard non-overlapping eval (sanity check) - torch.cuda.synchronize() - t_qeval = time.perf_counter() - q_val_loss, q_val_bpb = eval_val( - args, compiled_eval, rank, world_size, device, grad_accum_steps, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - eval_seq_len=effective_eval_seq_len, - ) - torch.cuda.synchronize() - log0( - f"final_int6_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " - f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" - ) - log0(f"final_int6_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") - - # Sliding window eval (submission score) - sw_seq_len = effective_eval_seq_len - if args.eval_stride > 0 and args.eval_stride < sw_seq_len: - torch.cuda.synchronize() - t_slide = time.perf_counter() - sw_val_loss, sw_val_bpb = eval_val_sliding( - args, eval_model, rank, world_size, device, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - stride=args.eval_stride, - eval_seq_len=sw_seq_len, - ) - torch.cuda.synchronize() - log0( - f"final_int6_sliding_window val_loss:{sw_val_loss:.4f} val_bpb:{sw_val_bpb:.4f} " - f"stride:{args.eval_stride} eval_time:{1000.0 * (time.perf_counter() - t_slide):.0f}ms" - ) - log0(f"final_int6_sliding_window_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") - - # Second sliding window eval at stride=64 for submission comparison - if args.eval_stride != 64 and 64 < sw_seq_len: - torch.cuda.synchronize() - t_slide64 = time.perf_counter() - sw64_val_loss, sw64_val_bpb = eval_val_sliding( - args, eval_model, rank, world_size, device, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - stride=64, - eval_seq_len=sw_seq_len, - ) - torch.cuda.synchronize() - log0( - f"final_int6_sliding_window_s64 val_loss:{sw64_val_loss:.4f} val_bpb:{sw64_val_bpb:.4f} " - f"stride:64 eval_time:{1000.0 * (time.perf_counter() - t_slide64):.0f}ms" - ) - log0(f"final_int6_sliding_window_s64_exact val_loss:{sw64_val_loss:.8f} val_bpb:{sw64_val_bpb:.8f}") - - if distributed: - dist.destroy_process_group() - - -if __name__ == "__main__": - main() - -==================================================================================================== -Running Python 3.12.3 (main, Nov 6 2025, 13:44:16) [GCC 13.3.0] -Running PyTorch 2.9.1+cu128 -Fri Mar 20 23:02:06 2026 -+-----------------------------------------------------------------------------------------+ -| NVIDIA-SMI 580.126.09 Driver Version: 580.126.09 CUDA Version: 13.0 | -+-----------------------------------------+------------------------+----------------------+ -| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | -| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | -| | | MIG M. | -|=========================================+========================+======================| -| 0 NVIDIA H100 80GB HBM3 On | 00000000:18:00.0 Off | 0 | -| N/A 33C P0 119W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 1 NVIDIA H100 80GB HBM3 On | 00000000:2A:00.0 Off | 0 | -| N/A 37C P0 126W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 2 NVIDIA H100 80GB HBM3 On | 00000000:3A:00.0 Off | 0 | -| N/A 37C P0 121W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 3 NVIDIA H100 80GB HBM3 On | 00000000:5D:00.0 Off | 0 | -| N/A 30C P0 117W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 4 NVIDIA H100 80GB HBM3 On | 00000000:9A:00.0 Off | 0 | -| N/A 30C P0 120W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 5 NVIDIA H100 80GB HBM3 On | 00000000:AB:00.0 Off | 0 | -| N/A 37C P0 120W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 6 NVIDIA H100 80GB HBM3 On | 00000000:BA:00.0 Off | 0 | -| N/A 34C P0 118W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 7 NVIDIA H100 80GB HBM3 On | 00000000:DB:00.0 Off | 0 | -| N/A 31C P0 116W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ - -+-----------------------------------------------------------------------------------------+ -| Processes: | -| GPU GI CI PID Type Process name GPU Memory | -| ID ID Usage | -|=========================================================================================| -| 0 N/A N/A 113489 C /usr/local/bin/python 1512MiB | -| 1 N/A N/A 113490 C /usr/local/bin/python 1512MiB | -| 2 N/A N/A 113491 C /usr/local/bin/python 1512MiB | -| 3 N/A N/A 113492 C /usr/local/bin/python 1512MiB | -| 4 N/A N/A 113493 C /usr/local/bin/python 1512MiB | -| 5 N/A N/A 113494 C /usr/local/bin/python 1512MiB | -| 6 N/A N/A 113495 C /usr/local/bin/python 1512MiB | -| 7 N/A N/A 113496 C /usr/local/bin/python 1512MiB | -+-----------------------------------------------------------------------------------------+ - -==================================================================================================== -val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model -train_loader:dataset:fineweb10B_sp1024 train_shards:80 -val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 -model_params:26829913 -mtp_num_heads:0 mtp_loss_weight:0.2 mtp_params:0 -world_size:8 grad_accum_steps:1 -sdp_backends:cudnn=False flash=True mem_efficient=False math=False -attention_mode:gqa num_heads:8 num_kv_heads:4 -tie_embeddings:True embed_lr:0.035 head_lr:0.0 matrix_lr:0.025 scalar_lr:0.025 -train_batch_tokens:786432 train_seq_len:2048 iterations:9000 warmup_steps:20 max_wallclock_seconds:600.000 -seed:2025 -warmup_step:1/20 -warmup_step:2/20 -warmup_step:3/20 -warmup_step:4/20 -warmup_step:5/20 -warmup_step:6/20 -warmup_step:7/20 -warmup_step:8/20 -warmup_step:9/20 -warmup_step:10/20 -warmup_step:11/20 -warmup_step:12/20 -warmup_step:13/20 -warmup_step:14/20 -warmup_step:15/20 -warmup_step:16/20 -warmup_step:17/20 -warmup_step:18/20 -warmup_step:19/20 -warmup_step:20/20 -step:1/9000 train_loss:6.9287 train_time:143ms step_avg:143.48ms -step:2/9000 train_loss:8.4960 train_time:449ms step_avg:224.52ms -step:3/9000 train_loss:8.0184 train_time:556ms step_avg:185.21ms -step:4/9000 train_loss:7.1375 train_time:658ms step_avg:164.53ms -step:5/9000 train_loss:7.0775 train_time:757ms step_avg:151.45ms -step:6/9000 train_loss:6.9272 train_time:861ms step_avg:143.42ms -step:7/9000 train_loss:6.8288 train_time:963ms step_avg:137.58ms -step:8/9000 train_loss:6.7827 train_time:1056ms step_avg:131.94ms -step:9/9000 train_loss:6.5175 train_time:1147ms step_avg:127.42ms -step:10/9000 train_loss:6.3236 train_time:1243ms step_avg:124.27ms -step:200/9000 train_loss:2.4123 train_time:17233ms step_avg:86.17ms -step:400/9000 train_loss:2.4080 train_time:34153ms step_avg:85.38ms -step:600/9000 train_loss:2.3298 train_time:50959ms step_avg:84.93ms -step:800/9000 train_loss:2.2295 train_time:67925ms step_avg:84.91ms -step:1000/9000 train_loss:2.2680 train_time:84722ms step_avg:84.72ms -step:1200/9000 train_loss:2.3497 train_time:101720ms step_avg:84.77ms -step:1400/9000 train_loss:2.1779 train_time:118650ms step_avg:84.75ms -step:1600/9000 train_loss:2.0706 train_time:135486ms step_avg:84.68ms -step:1800/9000 train_loss:2.1519 train_time:152480ms step_avg:84.71ms -step:2000/9000 train_loss:2.0605 train_time:169320ms step_avg:84.66ms -step:2200/9000 train_loss:2.1334 train_time:186262ms step_avg:84.66ms -step:2400/9000 train_loss:2.0598 train_time:203098ms step_avg:84.62ms -step:2600/9000 train_loss:2.1041 train_time:220022ms step_avg:84.62ms -step:2800/9000 train_loss:2.1464 train_time:236943ms step_avg:84.62ms -step:3000/9000 train_loss:2.1567 train_time:253817ms step_avg:84.61ms -step:3200/9000 train_loss:2.1695 train_time:270754ms step_avg:84.61ms -step:3400/9000 train_loss:2.0205 train_time:287594ms step_avg:84.59ms -step:3600/9000 train_loss:2.0997 train_time:304462ms step_avg:84.57ms -step:3800/9000 train_loss:2.0767 train_time:321297ms step_avg:84.55ms -step:4000/9000 train_loss:1.9866 train_time:338222ms step_avg:84.56ms -step:4200/9000 train_loss:2.1642 train_time:355089ms step_avg:84.55ms -step:4400/9000 train_loss:2.0441 train_time:371876ms step_avg:84.52ms -step:4600/9000 train_loss:1.8538 train_time:388768ms step_avg:84.51ms -step:4800/9000 train_loss:2.4375 train_time:405564ms step_avg:84.49ms -step:5000/9000 train_loss:2.1141 train_time:422475ms step_avg:84.50ms -step:5200/9000 train_loss:2.0508 train_time:439276ms step_avg:84.48ms -step:5400/9000 train_loss:2.0568 train_time:456231ms step_avg:84.49ms -step:5600/9000 train_loss:1.9642 train_time:473111ms step_avg:84.48ms -step:5800/9000 train_loss:2.0081 train_time:489930ms step_avg:84.47ms -step:6000/9000 train_loss:1.9514 train_time:506830ms step_avg:84.47ms -step:6200/9000 train_loss:1.9608 train_time:523627ms step_avg:84.46ms -step:6400/9000 train_loss:2.0102 train_time:540513ms step_avg:84.46ms -step:6600/9000 train_loss:1.8526 train_time:557316ms step_avg:84.44ms -step:6800/9000 train_loss:2.0307 train_time:574198ms step_avg:84.44ms -step:7000/9000 train_loss:1.7989 train_time:591074ms step_avg:84.44ms -step:7107/9000 val_loss:1.9310 val_bpb:1.1436 train_time:600043ms step_avg:84.43ms -stopping_early: wallclock_cap train_time:600043ms step:7107/9000 -peak memory allocated: 20600 MiB reserved: 20658 MiB -ema:applying EMA weights -Serialized model: 105783807 bytes -Code size: 66133 bytes -Serialized model int6+zstd: 15583383 bytes -Total submission size int6+zstd: 15649516 bytes -final_int6_roundtrip val_loss:1.9427 val_bpb:1.1506 eval_time:5634ms -final_int6_roundtrip_exact val_loss:1.94270012 val_bpb:1.15057644 -final_int6_sliding_window val_loss:1.9052 val_bpb:1.1284 stride:64 eval_time:70072ms -final_int6_sliding_window_exact val_loss:1.90518030 val_bpb:1.12835808 diff --git a/records/track_10min_16mb/2026-03-20_11L_XSA4_EMA_Int6_MLP3x_WD04_1.1271/train_seed42.log b/records/track_10min_16mb/2026-03-20_11L_XSA4_EMA_Int6_MLP3x_WD04_1.1271/train_seed42.log deleted file mode 100644 index 91220e5613..0000000000 --- a/records/track_10min_16mb/2026-03-20_11L_XSA4_EMA_Int6_MLP3x_WD04_1.1271/train_seed42.log +++ /dev/null @@ -1,1704 +0,0 @@ -""" -train_gpt_submit.py — Submission v2: wider MLP + STE int6 QAT + MTP + seq2048 + NTK RoPE + -fp16 embed + late-K passthrough + sliding window eval. -""" - -from __future__ import annotations - -import copy -import glob -import io -import math -import os -import random -import subprocess -import sys -import time -import uuid -import zlib -from pathlib import Path - -try: - import zstandard - _COMPRESSOR = "zstd" -except ImportError: - _COMPRESSOR = "zlib" - -import numpy as np -import sentencepiece as spm -import torch -import torch.distributed as dist -import torch.nn.functional as F -from torch import Tensor, nn -from torch.nn.parallel import DistributedDataParallel as DDP - -from flash_attn_interface import flash_attn_func as flash_attn_3_func - -# ----------------------------- -# HYPERPARAMETERS -# ----------------------------- -# Default Simple Baseline run: -# - 9 transformer blocks at width 512 -# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion -# - vocab size 1024, sequence length 1024, tied embeddings -# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap - -class Hyperparameters: - # Data paths are shard globs produced by the existing preprocessing pipeline. - data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") - train_files = os.path.join(data_path, "fineweb_train_*.bin") - val_files = os.path.join(data_path, "fineweb_val_*.bin") - tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") - run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) - seed = int(os.environ.get("SEED", 1337)) - - # Validation cadence and batch size. Validation always uses the full fineweb_val split. - val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) - val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) - train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) - - # Training length. - iterations = int(os.environ.get("ITERATIONS", 20000)) - warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) - warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) - train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 786_432)) - train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) - eval_seq_len = int(os.environ.get("EVAL_SEQ_LEN", 2048)) - max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) - qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) - - # Model shape. - vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) - num_layers = int(os.environ.get("NUM_LAYERS", 9)) - num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) - model_dim = int(os.environ.get("MODEL_DIM", 512)) - num_heads = int(os.environ.get("NUM_HEADS", 8)) - mlp_mult = float(os.environ.get("MLP_MULT", 3.0)) - tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) - rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) - logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) - - # Optimizer hyperparameters. - embed_lr = float(os.environ.get("EMBED_LR", 0.6)) - head_lr = float(os.environ.get("HEAD_LR", 0.008)) - tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) - tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) - matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) - scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) - muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) - muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) - muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) - muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) - beta1 = float(os.environ.get("BETA1", 0.9)) - beta2 = float(os.environ.get("BETA2", 0.95)) - adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) - grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) - eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) - mtp_num_heads = int(os.environ.get("MTP_NUM_HEADS", 0)) - mtp_loss_weight = float(os.environ.get("MTP_LOSS_WEIGHT", 0.2)) - muon_beta2 = float(os.environ.get("MUON_BETA2", 0.95)) - swa_enabled = bool(int(os.environ.get("SWA_ENABLED", "1"))) - swa_every = int(os.environ.get("SWA_EVERY", 200)) - muon_wd = float(os.environ.get("MUON_WD", 0.02)) - adam_wd = float(os.environ.get("ADAM_WD", 0.01)) - qat_enabled = bool(int(os.environ.get("QAT_ENABLED", "0"))) - xsa_last_n = int(os.environ.get("XSA_LAST_N", 0)) - ema_enabled = bool(int(os.environ.get("EMA_ENABLED", "0"))) - ema_decay = float(os.environ.get("EMA_DECAY", 0.997)) - bigram_vocab_size = int(os.environ.get("BIGRAM_VOCAB_SIZE", 4096)) - bigram_dim = int(os.environ.get("BIGRAM_DIM", 128)) - -# ----------------------------- -# MUON OPTIMIZER -# ----------------------------- -# -# As borrowed from modded-nanogpt -# Background on Muon: https://kellerjordan.github.io/posts/muon/ - -def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: - a, b, c = (3.4445, -4.7750, 2.0315) - X = G.bfloat16() - X /= X.norm() + eps - transposed = G.size(0) > G.size(1) - if transposed: - X = X.T - for _ in range(steps): - A = X @ X.T - B = b * A + c * A @ A - X = a * X + B @ X - return X.T if transposed else X - - -class Muon(torch.optim.Optimizer): - def __init__(self, params, lr: float, momentum: float, backend_steps: int, - nesterov: bool = True, weight_decay: float = 0.0): - super().__init__( - params, - dict(lr=lr, momentum=momentum, backend_steps=backend_steps, - nesterov=nesterov, weight_decay=weight_decay), - ) - - @torch.no_grad() - def step(self, closure=None): - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - distributed = dist.is_available() and dist.is_initialized() - world_size = dist.get_world_size() if distributed else 1 - rank = dist.get_rank() if distributed else 0 - - for group in self.param_groups: - params = group["params"] - if not params: - continue - lr = group["lr"] - momentum = group["momentum"] - backend_steps = group["backend_steps"] - nesterov = group["nesterov"] - - total_params = sum(int(p.numel()) for p in params) - updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) - - curr = 0 - for i, p in enumerate(params): - if i % world_size == rank and p.grad is not None: - g = p.grad - state = self.state[p] - if "momentum_buffer" not in state: - state["momentum_buffer"] = torch.zeros_like(g) - buf = state["momentum_buffer"] - buf.mul_(momentum).add_(g) - if nesterov: - g = g.add(buf, alpha=momentum) - g = zeropower_via_newtonschulz5(g, steps=backend_steps) - g *= max(1, g.size(0) / g.size(1)) ** 0.5 - updates_flat[curr : curr + p.numel()] = g.reshape(-1) - curr += p.numel() - - if distributed: - dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) - - wd = group.get("weight_decay", 0.0) - curr = 0 - for p in params: - if wd > 0.0: - p.data.mul_(1.0 - lr * wd) - g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) - p.add_(g, alpha=-lr) - curr += p.numel() - - return loss - - -# ----------------------------- -# TOKENIZER-AGNOSTIC EVALUATION SETUP -# ----------------------------- -# -# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. -# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. -# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. -# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. - -def build_sentencepiece_luts( - sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device -) -> tuple[Tensor, Tensor, Tensor]: - sp_vocab_size = int(sp.vocab_size()) - table_size = max(sp_vocab_size, vocab_size) - base_bytes_np = np.zeros((table_size,), dtype=np.int16) - has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) - is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) - for token_id in range(sp_vocab_size): - if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): - continue - is_boundary_token_np[token_id] = False - if sp.is_byte(token_id): - base_bytes_np[token_id] = 1 - continue - piece = sp.id_to_piece(token_id) - if piece.startswith("▁"): - has_leading_space_np[token_id] = True - piece = piece[1:] - base_bytes_np[token_id] = len(piece.encode("utf-8")) - return ( - torch.tensor(base_bytes_np, dtype=torch.int16, device=device), - torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), - torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), - ) - - -def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: - files = [Path(p) for p in sorted(glob.glob(pattern))] - if not files: - raise FileNotFoundError(f"No files found for pattern: {pattern}") - # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. - tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() - usable = ((tokens.numel() - 1) // seq_len) * seq_len - if usable <= 0: - raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") - return tokens[: usable + 1] - - -def eval_val( - args: Hyperparameters, - model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - grad_accum_steps: int, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, - eval_seq_len: int | None = None, -) -> tuple[float, float]: - seq_len = eval_seq_len or args.train_seq_len - local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) - if local_batch_tokens < seq_len: - raise ValueError( - "VAL_BATCH_SIZE must provide at least one sequence per rank; " - f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " - f"GRAD_ACCUM_STEPS={grad_accum_steps}, seq_len={seq_len}" - ) - local_batch_seqs = local_batch_tokens // seq_len - total_seqs = (val_tokens.numel() - 1) // seq_len - seq_start = (total_seqs * rank) // world_size - seq_end = (total_seqs * (rank + 1)) // world_size - val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) - val_token_count = torch.zeros((), device=device, dtype=torch.float64) - val_byte_count = torch.zeros((), device=device, dtype=torch.float64) - - model.eval() - with torch.inference_mode(): - for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): - batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) - raw_start = batch_seq_start * seq_len - raw_end = batch_seq_end * seq_len + 1 - local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) - x = local[:-1].reshape(-1, seq_len) - y = local[1:].reshape(-1, seq_len) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - batch_loss = model(x, y).detach() - batch_token_count = float(y.numel()) - val_loss_sum += batch_loss.to(torch.float64) * batch_token_count - val_token_count += batch_token_count - prev_ids = x.reshape(-1) - tgt_ids = y.reshape(-1) - token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) - token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) - val_byte_count += token_bytes.to(torch.float64).sum() - - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) - - val_loss = val_loss_sum / val_token_count - bits_per_token = val_loss.item() / math.log(2.0) - tokens_per_byte = val_token_count.item() / val_byte_count.item() - model.train() - return float(val_loss.item()), float(bits_per_token * tokens_per_byte) - -# ----------------------------- -# POST-TRAINING QUANTIZATION -# ----------------------------- -# -# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. -# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. -# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. - -CONTROL_TENSOR_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "CONTROL_TENSOR_NAME_PATTERNS", - "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,smear", - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", - ",".join(CONTROL_TENSOR_NAME_PATTERNS), - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 -INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 -INT8_PER_ROW_SCALE_DTYPE = torch.float16 -INT8_CLIP_PERCENTILE = 99.99984 -INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 - -def tensor_nbytes(t: Tensor) -> int: - return int(t.numel()) * int(t.element_size()) - -def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: - if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): - return t.float().contiguous() - if t.dtype in {torch.float32, torch.bfloat16}: - passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") - return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() - return t - -def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: - t32 = t.float() - if t32.ndim == 2: - # Matrices get one scale per row, which usually tracks output-channel - # ranges much better than a single tensor-wide scale. - clip_abs = ( - torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) - if t32.numel() - else torch.empty((t32.shape[0],), dtype=torch.float32) - ) - clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) - scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) - q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() - return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() - - # Vectors / scalars use a simpler per-tensor scale. - clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 - scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) - q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() - return q, scale - -def quantize_state_dict_int8(state_dict: dict[str, Tensor]): - # Single supported clean-script export format: - # - per-row int8 for 2D float tensors - # - per-tensor int8 for other float tensors - # - exact passthrough for non-floats - # - passthrough for small float tensors, stored as fp16 to save bytes - quantized: dict[str, Tensor] = {} - scales: dict[str, Tensor] = {} - dtypes: dict[str, str] = {} - passthrough: dict[str, Tensor] = {} - passthrough_orig_dtypes: dict[str, str] = {} - qmeta: dict[str, dict[str, object]] = {} - stats = dict.fromkeys( - ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), - 0, - ) - - for name, tensor in state_dict.items(): - t = tensor.detach().to("cpu").contiguous() - stats["param_count"] += int(t.numel()) - stats["num_tensors"] += 1 - stats["baseline_tensor_bytes"] += tensor_nbytes(t) - - if not t.is_floating_point(): - stats["num_nonfloat_tensors"] += 1 - passthrough[name] = t - stats["int8_payload_bytes"] += tensor_nbytes(t) - continue - - # Small float tensors are cheap enough to keep directly. We still downcast - # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. - if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: - kept = keep_float_tensor(name, t, passthrough_orig_dtypes) - passthrough[name] = kept - stats["int8_payload_bytes"] += tensor_nbytes(kept) - continue - - stats["num_float_tensors"] += 1 - q, s = quantize_float_tensor(t) - if s.ndim > 0: - qmeta[name] = {"scheme": "per_row", "axis": 0} - quantized[name] = q - scales[name] = s - dtypes[name] = str(t.dtype).removeprefix("torch.") - stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) - - obj: dict[str, object] = { - "__quant_format__": "int8_clean_per_row_v1", - "quantized": quantized, - "scales": scales, - "dtypes": dtypes, - "passthrough": passthrough, - } - if qmeta: - obj["qmeta"] = qmeta - if passthrough_orig_dtypes: - obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes - return obj, stats - -def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: - out: dict[str, Tensor] = {} - qmeta = obj.get("qmeta", {}) - passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) - for name, q in obj["quantized"].items(): - dtype = getattr(torch, obj["dtypes"][name]) - s = obj["scales"][name] - if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: - s = s.to(dtype=torch.float32) - # Broadcast the saved row scale back across trailing dimensions. - out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() - else: - scale = float(s.item()) - out[name] = (q.float() * scale).to(dtype=dtype).contiguous() - for name, t in obj["passthrough"].items(): - # Restore small tensors, undoing the temporary fp16 storage cast if needed. - out_t = t.detach().to("cpu").contiguous() - orig_dtype = passthrough_orig_dtypes.get(name) - if isinstance(orig_dtype, str): - out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() - out[name] = out_t - return out - - -# ----------------------------- -# DATA LOADING -# ----------------------------- - -def load_data_shard(file: Path) -> Tensor: - header_bytes = 256 * np.dtype(" None: - self.file_idx = (self.file_idx + 1) % len(self.files) - self.tokens = load_data_shard(self.files[self.file_idx]) - self.pos = 0 - - def take(self, n: int) -> Tensor: - chunks: list[Tensor] = [] - remaining = n - while remaining > 0: - avail = self.tokens.numel() - self.pos - if avail <= 0: - self._advance_file() - continue - k = min(remaining, avail) - chunks.append(self.tokens[self.pos : self.pos + k]) - self.pos += k - remaining -= k - return chunks[0] if len(chunks) == 1 else torch.cat(chunks) - - -class DistributedTokenLoader: - # Each call consumes a contiguous chunk from the shared token stream, then slices out - # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. - def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): - self.rank = rank - self.world_size = world_size - self.device = device - self.stream = TokenStream(pattern) - - def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: - local_tokens = global_tokens // (self.world_size * grad_accum_steps) - per_rank_span = local_tokens + 1 - chunk = self.stream.take(per_rank_span * self.world_size) - start = self.rank * per_rank_span - local = chunk[start : start + per_rank_span].to(dtype=torch.int64) - x = local[:-1].reshape(-1, seq_len) - y = local[1:].reshape(-1, seq_len) - return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) - -# ----------------------------- -# TRANSFORMER MODULES -# ----------------------------- - -class RMSNorm(nn.Module): - def __init__(self, eps: float | None = None): - super().__init__() - self.eps = eps - - def forward(self, x: Tensor) -> Tensor: - return F.rms_norm(x, (x.size(-1),), eps=self.eps) - - -class CastedLinear(nn.Linear): - _qat_enabled: bool = False - - def forward(self, x: Tensor) -> Tensor: - w = self.weight.to(x.dtype) - if CastedLinear._qat_enabled and self.training and w.ndim == 2: - with torch.no_grad(): - w32 = self.weight.float() - row_max = w32.abs().amax(dim=1) - scale = (row_max / 31.0).clamp_min(1.0 / 31.0) - w_q = (torch.clamp(torch.round(w32 / scale[:, None]), -32, 31) * scale[:, None]).to(x.dtype) - w = w + (w_q - w).detach() - bias = self.bias.to(x.dtype) if self.bias is not None else None - return F.linear(x, w, bias) - - -def restore_low_dim_params_to_fp32(module: nn.Module) -> None: - # Keep small/control parameters in fp32 even when the model body runs in bf16. - with torch.no_grad(): - for name, param in module.named_parameters(): - if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: - param.data = param.data.float() - - -class Rotary(nn.Module): - # NTK-aware RoPE: auto-scales base frequency when seq_len exceeds train_seq_len. - def __init__(self, dim: int, base: float = 10000.0, train_seq_len: int = 1024): - super().__init__() - self.dim = dim - self.base = base - self.train_seq_len = train_seq_len - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self._seq_len_cached = 0 - self._cos_cached: Tensor | None = None - self._sin_cached: Tensor | None = None - - def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: - if ( - self._cos_cached is None - or self._sin_cached is None - or self._seq_len_cached != seq_len - or self._cos_cached.device != device - ): - if seq_len > self.train_seq_len: - scale = seq_len / self.train_seq_len - new_base = self.base * (scale ** (self.dim / (self.dim - 2))) - inv_freq = 1.0 / (new_base ** (torch.arange(0, self.dim, 2, dtype=torch.float32, device=device) / self.dim)) - else: - inv_freq = self.inv_freq.to(device) - t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) - freqs = torch.outer(t, inv_freq) - self._cos_cached = freqs.cos()[None, :, None, :] - self._sin_cached = freqs.sin()[None, :, None, :] - self._seq_len_cached = seq_len - return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) - - -def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: - half = x.size(-1) // 2 - x1, x2 = x[..., :half], x[..., half:] - return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) - - -class CausalSelfAttention(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - rope_base: float, - qk_gain_init: float, - ): - super().__init__() - if dim % num_heads != 0: - raise ValueError("model_dim must be divisible by num_heads") - if num_heads % num_kv_heads != 0: - raise ValueError("num_heads must be divisible by num_kv_heads") - self.num_heads = num_heads - self.num_kv_heads = num_kv_heads - self.head_dim = dim // num_heads - if self.head_dim % 2 != 0: - raise ValueError("head_dim must be even for RoPE") - kv_dim = self.num_kv_heads * self.head_dim - self.c_q = CastedLinear(dim, dim, bias=False) - self.c_k = CastedLinear(dim, kv_dim, bias=False) - self.c_v = CastedLinear(dim, kv_dim, bias=False) - self.proj = CastedLinear(dim, dim, bias=False) - self.proj._zero_init = True - self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) - self.rotary = Rotary(self.head_dim, base=rope_base, train_seq_len=1024) - self.use_xsa = False - - def _xsa_efficient(self, y: Tensor, v: Tensor) -> Tensor: - """Subtract self-value projection via GQA-aware reshape (no repeat_interleave).""" - B, T, H, D = y.shape - Hkv = v.size(-2) - group = H // Hkv - y_g = y.reshape(B, T, Hkv, group, D) - vn = F.normalize(v, dim=-1).unsqueeze(-2) - proj = (y_g * vn).sum(dim=-1, keepdim=True) * vn - return (y_g - proj).reshape(B, T, H, D) - - def forward(self, x: Tensor) -> Tensor: - bsz, seqlen, dim = x.shape - q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim) - k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) - v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) - q = F.rms_norm(q, (q.size(-1),)) - k = F.rms_norm(k, (k.size(-1),)) - cos, sin = self.rotary(seqlen, x.device, q.dtype) - q = apply_rotary_emb(q, cos, sin) - k = apply_rotary_emb(k, cos, sin) - q = q * self.q_gain.to(dtype=q.dtype)[None, None, :, None] - y = flash_attn_3_func(q, k, v, causal=True) - if self.use_xsa: - y = self._xsa_efficient(y, v) - y = y.reshape(bsz, seqlen, dim) - return self.proj(y) - - -class SmearGate(nn.Module): - def __init__(self, dim: int): - super().__init__() - self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) - - def forward(self, x: Tensor) -> Tensor: - g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] - x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) - return (1 - g) * x + g * x_prev - - -class BigramHashEmbedding(nn.Module): - def __init__(self, bigram_vocab_size: int, bigram_dim: int, model_dim: int): - super().__init__() - self.bigram_vocab_size = bigram_vocab_size - self.embed = nn.Embedding(bigram_vocab_size, bigram_dim) - nn.init.zeros_(self.embed.weight) - self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None - if self.proj is not None: - nn.init.zeros_(self.proj.weight) - self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) - - def bigram_hash(self, tokens: Tensor) -> Tensor: - t = tokens.to(torch.int32) - mod = self.bigram_vocab_size - 1 - out = torch.empty_like(t) - out[..., 0] = mod - out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod - return out.long() - - def forward(self, token_ids: Tensor) -> Tensor: - h = self.embed(self.bigram_hash(token_ids)) - if self.proj is not None: - h = self.proj(h) - return h * self.scale.to(dtype=h.dtype) - - -class MLP(nn.Module): - def __init__(self, dim: int, mlp_mult: int): - super().__init__() - hidden = int(mlp_mult * dim) - self.fc = CastedLinear(dim, hidden, bias=False) - self.proj = CastedLinear(hidden, dim, bias=False) - self.proj._zero_init = True - - def forward(self, x: Tensor) -> Tensor: - x = torch.relu(self.fc(x)) - return self.proj(x.square()) - - -class Block(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: int, - rope_base: float, - qk_gain_init: float, - ): - super().__init__() - self.attn_norm = RMSNorm() - self.mlp_norm = RMSNorm() - self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) - self.mlp = MLP(dim, mlp_mult) - self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) - - def forward(self, x: Tensor, x0: Tensor) -> Tensor: - mix = self.resid_mix.to(dtype=x.dtype) - x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 - attn_out = self.attn(self.attn_norm(x)) - x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out - x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) - return x - - -class GPT(nn.Module): - def __init__( - self, - vocab_size: int, - num_layers: int, - model_dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: int, - tie_embeddings: bool, - tied_embed_init_std: float, - logit_softcap: float, - rope_base: float, - qk_gain_init: float, - mtp_num_heads: int = 0, - mtp_loss_weight: float = 0.1, - bigram_vocab_size: int = 0, - bigram_dim: int = 128, - xsa_last_n: int = 0, - ): - super().__init__() - if logit_softcap <= 0.0: - raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") - self.tie_embeddings = tie_embeddings - self.tied_embed_init_std = tied_embed_init_std - self.logit_softcap = logit_softcap - self.mtp_num_heads = mtp_num_heads - self.mtp_loss_weight = mtp_loss_weight - self.tok_emb = nn.Embedding(vocab_size, model_dim) - self.bigram = BigramHashEmbedding(bigram_vocab_size, bigram_dim, model_dim) if bigram_vocab_size > 0 else None - self.smear = SmearGate(model_dim) - self.num_encoder_layers = num_layers // 2 - self.num_decoder_layers = num_layers - self.num_encoder_layers - self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) - self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) - self.blocks = nn.ModuleList( - [ - Block( - model_dim, - num_heads, - num_kv_heads, - mlp_mult, - rope_base, - qk_gain_init, - ) - for i in range(num_layers) - ] - ) - self.final_norm = RMSNorm() - self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) - if self.lm_head is not None: - self.lm_head._zero_init = True - self.mtp_heads = nn.ModuleList( - [CastedLinear(model_dim, vocab_size, bias=False) for _ in range(mtp_num_heads)] - ) - for head in self.mtp_heads: - head._zero_init = True - if xsa_last_n > 0: - for i in range(max(0, num_layers - xsa_last_n), num_layers): - self.blocks[i].attn.use_xsa = True - self._init_weights() - - def _init_weights(self) -> None: - if self.tie_embeddings: - nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) - num_layers = len(self.blocks) - for name, module in self.named_modules(): - if isinstance(module, nn.Linear): - if getattr(module, "_zero_init", False): - nn.init.zeros_(module.weight) - elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: - nn.init.orthogonal_(module.weight, gain=1.0) - if ".proj." in name or name.endswith(".proj"): - with torch.no_grad(): - module.weight.mul_(1.0 / math.sqrt(2 * num_layers)) - - def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: - x = self.tok_emb(input_ids) - if self.bigram is not None: - x = x + self.bigram(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - x = self.smear(x) - x0 = x - skips: list[Tensor] = [] - - for i in range(self.num_encoder_layers): - x = self.blocks[i](x, x0) - skips.append(x) - for i in range(self.num_decoder_layers): - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - x = self.blocks[self.num_encoder_layers + i](x, x0) - - x = self.final_norm(x) - x_flat = x.reshape(-1, x.size(-1)) - targets = target_ids.reshape(-1) - if self.tie_embeddings: - logits_proj = F.linear(x_flat, self.tok_emb.weight) - else: - if self.lm_head is None: - raise RuntimeError("lm_head is required when tie_embeddings=False") - logits_proj = self.lm_head(x_flat) - logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - main_loss = F.cross_entropy(logits.float(), targets, reduction="mean") - - if self.training and self.mtp_num_heads > 0 and self.mtp_loss_weight > 0.0: - _, seqlen, dim = x.shape - mtp_loss_sum = x.new_zeros(()) - mtp_loss_count = 0 - for k, mtp_head in enumerate(self.mtp_heads): - valid_t = seqlen - (k + 1) - if valid_t <= 0: - continue - mtp_hidden = x[:, :valid_t, :].reshape(-1, dim) - mtp_targets = target_ids[:, k + 1 :].reshape(-1) - mtp_logits_proj = mtp_head(mtp_hidden) - mtp_logits = self.logit_softcap * torch.tanh(mtp_logits_proj / self.logit_softcap) - mtp_loss_sum = mtp_loss_sum + F.cross_entropy(mtp_logits.float(), mtp_targets, reduction="mean") - mtp_loss_count += 1 - if mtp_loss_count > 0: - main_loss = main_loss + self.mtp_loss_weight * (mtp_loss_sum / mtp_loss_count) - - return main_loss - - def forward_logits(self, input_ids: Tensor) -> Tensor: - """Return logits (bsz, seq_len, vocab) without computing loss.""" - x = self.tok_emb(input_ids) - if self.bigram is not None: - x = x + self.bigram(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - x = self.smear(x) - x0 = x - skips: list[Tensor] = [] - for i in range(self.num_encoder_layers): - x = self.blocks[i](x, x0) - skips.append(x) - for i in range(self.num_decoder_layers): - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - x = self.blocks[self.num_encoder_layers + i](x, x0) - x = self.final_norm(x) - if self.tie_embeddings: - logits_proj = F.linear(x, self.tok_emb.weight) - else: - logits_proj = self.lm_head(x) - return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - - -# ----------------------------- -# SLIDING WINDOW EVALUATION -# ----------------------------- - -def eval_val_sliding( - args: Hyperparameters, - base_model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, - stride: int, - batch_seqs: int = 32, - eval_seq_len: int | None = None, -) -> tuple[float, float]: - """Sliding window evaluation: each token scored with maximum context.""" - seq_len = eval_seq_len or args.train_seq_len - total_tokens = val_tokens.numel() - 1 - - window_starts = [ws for ws in range(0, total_tokens, stride) - if min(ws + seq_len, total_tokens) - ws >= 1] - total_windows = len(window_starts) - - my_s = (total_windows * rank) // world_size - my_e = (total_windows * (rank + 1)) // world_size - my_windows = window_starts[my_s:my_e] - - loss_sum = torch.zeros((), device=device, dtype=torch.float64) - token_count = torch.zeros((), device=device, dtype=torch.float64) - byte_count = torch.zeros((), device=device, dtype=torch.float64) - - base_model.eval() - compiled_logits = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) - - with torch.inference_mode(): - for bi in range(0, len(my_windows), batch_seqs): - batch_ws = my_windows[bi:bi + batch_seqs] - bsz = len(batch_ws) - - x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) - y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) - wlens: list[int] = [] - - for i, ws in enumerate(batch_ws): - end = min(ws + seq_len, total_tokens) - wlen = end - ws - wlens.append(wlen) - chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) - x_batch[i, :wlen] = chunk[:-1] - y_batch[i, :wlen] = chunk[1:] - - with torch.autocast(device_type="cuda", dtype=torch.bfloat16): - logits = compiled_logits(x_batch) - - nll = F.cross_entropy( - logits.reshape(-1, logits.size(-1)).float(), - y_batch.reshape(-1), - reduction="none", - ).reshape(bsz, seq_len) - - for i, ws in enumerate(batch_ws): - wlen = wlens[i] - s = 0 if ws == 0 else max(wlen - stride, 0) - scored_nll = nll[i, s:wlen].to(torch.float64) - loss_sum += scored_nll.sum() - token_count += float(wlen - s) - tgt = y_batch[i, s:wlen] - prev = x_batch[i, s:wlen] - tb = base_bytes_lut[tgt].to(torch.float64) - tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) - byte_count += tb.sum() - - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) - - val_loss = (loss_sum / token_count).item() - bits_per_token = val_loss / math.log(2.0) - tokens_per_byte = token_count.item() / byte_count.item() - base_model.train() - return val_loss, bits_per_token * tokens_per_byte - - -# ----------------------------- -# INT6 MIXED QUANTIZATION (transplanted from working diagnostic scripts) -# ----------------------------- - -def _classify_param(name: str) -> str: - if "tok_emb" in name or "lm_head" in name: - return "embed" - if ".mlp." in name: - return "mlp" - if ".attn." in name or (".proj." in name and ".mlp." not in name): - return "attn" - return "other" - -def quantize_int6_per_row(t: Tensor) -> tuple[Tensor, Tensor]: - t32 = t.float() - if t32.ndim == 2: - row_max = t32.abs().amax(dim=1) - scale = (row_max / 31.0).clamp_min(1.0 / 31.0).to(torch.float16) - q = torch.clamp(torch.round(t32 / scale.float()[:, None]), -32, 31).to(torch.int8) - return q, scale - amax = t32.abs().max().item() - scale = torch.tensor(amax / 31.0 if amax > 0 else 1.0, dtype=torch.float16) - q = torch.clamp(torch.round(t32 / scale.float()), -32, 31).to(torch.int8) - return q, scale - -def mixed_quantize_int6(state_dict: dict[str, Tensor], int6_cats: set[str]): - num_layers_total = max( - (int(k.split(".")[1]) for k in state_dict if k.startswith("blocks.")), - default=0, - ) + 1 - late_k_layers = set(range(num_layers_total - 2, num_layers_total)) - - result: dict[str, Tensor] = {} - meta: dict[str, object] = {} - for name, tensor in state_dict.items(): - t = tensor.detach().cpu().contiguous() - cat = _classify_param(name) - if not t.is_floating_point() or t.numel() <= 65536: - result[name] = t.to(torch.float16) if t.is_floating_point() else t - meta[name] = "passthrough" - continue - if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): - result[name] = t.float() - meta[name] = "passthrough_ctrl" - continue - # tok_emb.weight falls through to int8 via "embed" category - if cat in int6_cats and t.ndim >= 1: - q, s = quantize_int6_per_row(t) - result[name + ".q"] = q - result[name + ".scale"] = s - meta[name] = {"type": "int6"} - else: - q, s = quantize_float_tensor(t) - result[name + ".q"] = q - result[name + ".scale"] = s - meta[name] = {"type": "int8"} - return result, meta - -def dequantize_mixed_int6(result: dict[str, Tensor], meta: dict[str, object], - template_sd: dict[str, Tensor]) -> dict[str, Tensor]: - out: dict[str, Tensor] = {} - for name, orig in template_sd.items(): - info = meta.get(name) - if info is None: - continue - orig_dtype = orig.dtype - if info in ("passthrough", "passthrough_ctrl", "passthrough_fp16"): - t = result[name] - if t.dtype == torch.float16 and orig_dtype in (torch.float32, torch.bfloat16): - t = t.to(orig_dtype) - out[name] = t - continue - q, s = result[name + ".q"], result[name + ".scale"] - if s.ndim > 0: - out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig_dtype) - else: - out[name] = (q.float() * float(s.item())).to(orig_dtype) - return out - - -# ----------------------------- -# TRAINING -# ----------------------------- - -def main() -> None: - global zeropower_via_newtonschulz5 - - code = Path(__file__).read_text(encoding="utf-8") - args = Hyperparameters() - zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) - - # ----------------------------- - # DISTRIBUTED + CUDA SETUP - # ----------------------------- - - distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ - rank = int(os.environ.get("RANK", "0")) - world_size = int(os.environ.get("WORLD_SIZE", "1")) - local_rank = int(os.environ.get("LOCAL_RANK", "0")) - if world_size <= 0: - raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") - if 8 % world_size != 0: - raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") - grad_accum_steps = 8 // world_size - grad_scale = 1.0 / grad_accum_steps - if not torch.cuda.is_available(): - raise RuntimeError("CUDA is required") - device = torch.device("cuda", local_rank) - torch.cuda.set_device(device) - if distributed: - dist.init_process_group(backend="nccl", device_id=device) - dist.barrier() - master_process = rank == 0 - - # Fast math knobs - torch.backends.cuda.matmul.allow_tf32 = True - torch.backends.cudnn.allow_tf32 = True - from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp - - enable_cudnn_sdp(False) - enable_flash_sdp(True) - enable_mem_efficient_sdp(False) - enable_math_sdp(False) - - logfile = None - if master_process: - os.makedirs("logs", exist_ok=True) - logfile = f"logs/{args.run_id}.txt" - print(logfile) - - def log0(msg: str, console: bool = True) -> None: - if not master_process: - return - if console: - print(msg) - if logfile is not None: - with open(logfile, "a", encoding="utf-8") as f: - print(msg, file=f) - - log0(code, console=False) - log0("=" * 100, console=False) - log0(f"Running Python {sys.version}", console=False) - log0(f"Running PyTorch {torch.__version__}", console=False) - log0( - subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, - console=False, - ) - log0("=" * 100, console=False) - - # ----------------------------- - # TOKENIZER + VALIDATION METRIC SETUP - # ----------------------------- - - random.seed(args.seed) - np.random.seed(args.seed) - torch.manual_seed(args.seed) - torch.cuda.manual_seed_all(args.seed) - - if not args.tokenizer_path.endswith(".model"): - raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") - sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) - if int(sp.vocab_size()) != args.vocab_size: - raise ValueError( - f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" - ) - dataset_dir = Path(args.data_path).resolve() - actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) - effective_eval_seq_len = args.eval_seq_len if args.eval_seq_len > 0 else args.train_seq_len - val_seq_len = max(args.train_seq_len, effective_eval_seq_len) - val_tokens = load_validation_tokens(args.val_files, val_seq_len) - base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( - sp, args.vocab_size, device - ) - log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") - log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") - log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") - - # ----------------------------- - # MODEL + OPTIMIZER SETUP - # ----------------------------- - - CastedLinear._qat_enabled = args.qat_enabled - - base_model = GPT( - vocab_size=args.vocab_size, - num_layers=args.num_layers, - model_dim=args.model_dim, - num_heads=args.num_heads, - num_kv_heads=args.num_kv_heads, - mlp_mult=args.mlp_mult, - tie_embeddings=args.tie_embeddings, - tied_embed_init_std=args.tied_embed_init_std, - logit_softcap=args.logit_softcap, - rope_base=args.rope_base, - qk_gain_init=args.qk_gain_init, - mtp_num_heads=args.mtp_num_heads, - mtp_loss_weight=args.mtp_loss_weight, - bigram_vocab_size=args.bigram_vocab_size, - bigram_dim=args.bigram_dim, - xsa_last_n=args.xsa_last_n, - ).to(device).bfloat16() - for module in base_model.modules(): - if isinstance(module, CastedLinear): - module.float() - restore_low_dim_params_to_fp32(base_model) - compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) - model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model - - # Optimizer split: - # - token embedding (Adam) uses EMBED_LR - # - untied lm_head (Adam) uses HEAD_LR - # - matrix params in transformer blocks use MATRIX_LR via Muon - # - vectors/scalars use SCALAR_LR via Adam - block_named_params = list(base_model.blocks.named_parameters()) - matrix_params = [ - p - for name, p in block_named_params - if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - if base_model.mtp_num_heads > 0: - matrix_params.extend([p for p in base_model.mtp_heads.parameters() if p.ndim == 2]) - scalar_params = [ - p - for name, p in block_named_params - if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - if base_model.skip_weights.numel() > 0: - scalar_params.append(base_model.skip_weights) - scalar_params.append(base_model.smear.gate) - if base_model.bigram is not None: - scalar_params.append(base_model.bigram.scale) - token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr - tok_params = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] - if base_model.bigram is not None: - tok_params.append({"params": [base_model.bigram.embed.weight], "lr": token_lr, "base_lr": token_lr}) - if base_model.bigram.proj is not None: - matrix_params.append(base_model.bigram.proj.weight) - optimizer_tok = torch.optim.AdamW( - tok_params, - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - weight_decay=args.adam_wd, - fused=True, - ) - optimizer_muon = Muon( - matrix_params, - lr=args.matrix_lr, - momentum=args.muon_momentum, - backend_steps=args.muon_backend_steps, - weight_decay=args.muon_wd, - ) - for group in optimizer_muon.param_groups: - group["base_lr"] = args.matrix_lr - optimizer_scalar = torch.optim.AdamW( - [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - weight_decay=args.adam_wd, - fused=True, - ) - optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] - if base_model.lm_head is not None: - optimizer_head = torch.optim.Adam( - [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizers.insert(1, optimizer_head) - - n_params = sum(p.numel() for p in base_model.parameters()) - mtp_params = sum(p.numel() for p in base_model.mtp_heads.parameters()) - log0(f"model_params:{n_params}") - log0(f"mtp_num_heads:{args.mtp_num_heads} mtp_loss_weight:{args.mtp_loss_weight} mtp_params:{mtp_params}") - log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") - log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") - log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") - log0( - f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " - f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " - f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" - ) - log0( - f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " - f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " - f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" - ) - log0(f"seed:{args.seed}") - - # ----------------------------- - # DATA LOADER & MODEL WARMUP - # ----------------------------- - - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - def zero_grad_all() -> None: - for opt in optimizers: - opt.zero_grad(set_to_none=True) - - max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None - - def lr_mul(step: int, elapsed_ms: float) -> float: - if args.warmdown_iters <= 0: - return 1.0 - if max_wallclock_ms is None: - warmdown_start = max(args.iterations - args.warmdown_iters, 0) - return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 - step_ms = elapsed_ms / max(step, 1) - warmdown_ms = args.warmdown_iters * step_ms - remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) - return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 - - # Warmup primes the compiled forward/backward/optimizer paths, then we restore the - # initial weights/optimizer state so measured training starts from the true init. - if args.warmup_steps > 0: - initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} - initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] - model.train() - for warmup_step in range(args.warmup_steps): - zero_grad_all() - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - warmup_loss = model(x, y) - (warmup_loss * grad_scale).backward() - for opt in optimizers: - opt.step() - zero_grad_all() - if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: - log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") - base_model.load_state_dict(initial_model_state, strict=True) - for opt, state in zip(optimizers, initial_optimizer_states, strict=True): - opt.load_state_dict(state) - zero_grad_all() - if distributed: - model.require_backward_grad_sync = True - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - # ----------------------------- - # MAIN TRAINING LOOP - # ----------------------------- - - swa_state: dict[str, Tensor] | None = None - swa_count = 0 - - ema_state: dict[str, Tensor] | None = None - if args.ema_enabled: - ema_state = {name: t.detach().float().clone() for name, t in base_model.state_dict().items()} - - training_time_ms = 0.0 - stop_after_step: int | None = None - torch.cuda.synchronize() - t0 = time.perf_counter() - - step = 0 - while True: - last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) - - should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) - if should_validate: - torch.cuda.synchronize() - training_time_ms += 1000.0 * (time.perf_counter() - t0) - val_loss, val_bpb = eval_val( - args, - model, - rank, - world_size, - device, - grad_accum_steps, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, - ) - log0( - f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " - f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" - ) - torch.cuda.synchronize() - t0 = time.perf_counter() - - if last_step: - if stop_after_step is not None and step < args.iterations: - log0( - f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " - f"step:{step}/{args.iterations}" - ) - break - - elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - scale = lr_mul(step, elapsed_ms) - zero_grad_all() - train_loss = torch.zeros((), device=device) - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - loss = model(x, y) - train_loss += loss.detach() - (loss * grad_scale).backward() - train_loss /= grad_accum_steps - - frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 - muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum - for group in optimizer_muon.param_groups: - group["momentum"] = muon_momentum - - for opt in optimizers: - for group in opt.param_groups: - group["lr"] = group["base_lr"] * scale - - if args.grad_clip_norm > 0: - torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) - for opt in optimizers: - opt.step() - zero_grad_all() - - step += 1 - approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - - if ema_state is not None: - d = args.ema_decay - with torch.no_grad(): - for name, t in base_model.state_dict().items(): - ema_state[name].mul_(d).add_(t.detach().float(), alpha=1.0 - d) - - if args.swa_enabled and not args.ema_enabled and scale < 0.5 and step % args.swa_every == 0: - if swa_state is None: - swa_state = {name: t.detach().float().clone() for name, t in base_model.state_dict().items()} - swa_count = 1 - log0(f"swa:start step:{step}") - else: - for name, t in base_model.state_dict().items(): - swa_state[name].add_(t.detach().float()) - swa_count += 1 - - should_log_train = ( - args.train_log_every > 0 - and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) - ) - if should_log_train: - log0( - f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " - f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" - ) - - # Needed to sync whether we've reached the wallclock cap. - reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms - if distributed and max_wallclock_ms is not None: - reached_cap_tensor = torch.tensor(int(reached_cap), device=device) - dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) - reached_cap = bool(reached_cap_tensor.item()) - if stop_after_step is None and reached_cap: - stop_after_step = step - - log0( - f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " - f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" - ) - - if ema_state is not None: - log0("ema:applying EMA weights") - avg_state = {name: t.to(dtype=base_model.state_dict()[name].dtype) - for name, t in ema_state.items()} - del ema_state - base_model.load_state_dict(avg_state, strict=True) - elif args.swa_enabled and swa_state is not None and swa_count > 1: - log0(f"swa:applying averaged {swa_count} checkpoints") - avg_state = {name: (t / swa_count).to(dtype=base_model.state_dict()[name].dtype) - for name, t in swa_state.items()} - del swa_state - base_model.load_state_dict(avg_state, strict=True) - - # ----------------------------- - # SERIALIZATION + ROUNDTRIP VALIDATION - # ----------------------------- - - full_state_dict = base_model.state_dict() - export_sd = {k: v for k, v in full_state_dict.items() if "mtp_heads" not in k} - excluded_mtp = sum(int(t.numel()) for k, t in full_state_dict.items() if "mtp_heads" in k) - if excluded_mtp > 0: - log0(f"export_excluding_mtp_params:{excluded_mtp}") - - if master_process: - torch.save(export_sd, "final_model.pt") - model_bytes = os.path.getsize("final_model.pt") - code_bytes = len(code.encode("utf-8")) - log0(f"Serialized model: {model_bytes} bytes") - log0(f"Code size: {code_bytes} bytes") - - sd_cpu = {k: v.detach().cpu() for k, v in export_sd.items()} - quant_result, quant_meta = mixed_quantize_int6(sd_cpu, {"mlp", "attn"}) - quant_buf = io.BytesIO() - torch.save({"w": quant_result, "m": quant_meta}, quant_buf) - quant_raw = quant_buf.getvalue() - quant_blob = zstandard.ZstdCompressor(level=22).compress(quant_raw) if _COMPRESSOR == "zstd" else zlib.compress(quant_raw, 9) - if master_process: - with open("final_model.int6.ptz", "wb") as f: - f.write(quant_blob) - quant_file_bytes = len(quant_blob) - code_bytes = len(code.encode("utf-8")) - log0(f"Serialized model int6+{_COMPRESSOR}: {quant_file_bytes} bytes") - log0(f"Total submission size int6+{_COMPRESSOR}: {quant_file_bytes + code_bytes} bytes") - - # Roundtrip: decompress + dequantize into fresh model + eval - if distributed: - dist.barrier() - with open("final_model.int6.ptz", "rb") as f: - quant_blob_disk = f.read() - quant_state = torch.load( - io.BytesIO(zstandard.ZstdDecompressor().decompress(quant_blob_disk) if _COMPRESSOR == "zstd" else zlib.decompress(quant_blob_disk)), - map_location="cpu", - ) - deq_state = dequantize_mixed_int6(quant_state["w"], quant_state["m"], sd_cpu) - - eval_model = GPT( - vocab_size=args.vocab_size, num_layers=args.num_layers, model_dim=args.model_dim, - num_heads=args.num_heads, num_kv_heads=args.num_kv_heads, mlp_mult=args.mlp_mult, - tie_embeddings=args.tie_embeddings, tied_embed_init_std=args.tied_embed_init_std, - logit_softcap=args.logit_softcap, rope_base=args.rope_base, qk_gain_init=args.qk_gain_init, - mtp_num_heads=0, mtp_loss_weight=0.0, - bigram_vocab_size=args.bigram_vocab_size, bigram_dim=args.bigram_dim, - xsa_last_n=args.xsa_last_n, - ).to(device).bfloat16() - for m in eval_model.modules(): - if isinstance(m, CastedLinear): - m.float() - restore_low_dim_params_to_fp32(eval_model) - eval_model.load_state_dict(deq_state, strict=True) - compiled_eval = torch.compile(eval_model, dynamic=False, fullgraph=True) - - # Standard non-overlapping eval (sanity check) - torch.cuda.synchronize() - t_qeval = time.perf_counter() - q_val_loss, q_val_bpb = eval_val( - args, compiled_eval, rank, world_size, device, grad_accum_steps, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - eval_seq_len=effective_eval_seq_len, - ) - torch.cuda.synchronize() - log0( - f"final_int6_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " - f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" - ) - log0(f"final_int6_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") - - # Sliding window eval (submission score) - sw_seq_len = effective_eval_seq_len - if args.eval_stride > 0 and args.eval_stride < sw_seq_len: - torch.cuda.synchronize() - t_slide = time.perf_counter() - sw_val_loss, sw_val_bpb = eval_val_sliding( - args, eval_model, rank, world_size, device, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - stride=args.eval_stride, - eval_seq_len=sw_seq_len, - ) - torch.cuda.synchronize() - log0( - f"final_int6_sliding_window val_loss:{sw_val_loss:.4f} val_bpb:{sw_val_bpb:.4f} " - f"stride:{args.eval_stride} eval_time:{1000.0 * (time.perf_counter() - t_slide):.0f}ms" - ) - log0(f"final_int6_sliding_window_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") - - # Second sliding window eval at stride=64 for submission comparison - if args.eval_stride != 64 and 64 < sw_seq_len: - torch.cuda.synchronize() - t_slide64 = time.perf_counter() - sw64_val_loss, sw64_val_bpb = eval_val_sliding( - args, eval_model, rank, world_size, device, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - stride=64, - eval_seq_len=sw_seq_len, - ) - torch.cuda.synchronize() - log0( - f"final_int6_sliding_window_s64 val_loss:{sw64_val_loss:.4f} val_bpb:{sw64_val_bpb:.4f} " - f"stride:64 eval_time:{1000.0 * (time.perf_counter() - t_slide64):.0f}ms" - ) - log0(f"final_int6_sliding_window_s64_exact val_loss:{sw64_val_loss:.8f} val_bpb:{sw64_val_bpb:.8f}") - - if distributed: - dist.destroy_process_group() - - -if __name__ == "__main__": - main() - -==================================================================================================== -Running Python 3.12.3 (main, Nov 6 2025, 13:44:16) [GCC 13.3.0] -Running PyTorch 2.9.1+cu128 -Fri Mar 20 22:27:07 2026 -+-----------------------------------------------------------------------------------------+ -| NVIDIA-SMI 580.126.09 Driver Version: 580.126.09 CUDA Version: 13.0 | -+-----------------------------------------+------------------------+----------------------+ -| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | -| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | -| | | MIG M. | -|=========================================+========================+======================| -| 0 NVIDIA H100 80GB HBM3 On | 00000000:18:00.0 Off | 0 | -| N/A 33C P0 121W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 1 NVIDIA H100 80GB HBM3 On | 00000000:2A:00.0 Off | 0 | -| N/A 37C P0 126W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 2 NVIDIA H100 80GB HBM3 On | 00000000:3A:00.0 Off | 0 | -| N/A 37C P0 121W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 3 NVIDIA H100 80GB HBM3 On | 00000000:5D:00.0 Off | 0 | -| N/A 30C P0 115W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 4 NVIDIA H100 80GB HBM3 On | 00000000:9A:00.0 Off | 0 | -| N/A 30C P0 117W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 5 NVIDIA H100 80GB HBM3 On | 00000000:AB:00.0 Off | 0 | -| N/A 37C P0 120W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 6 NVIDIA H100 80GB HBM3 On | 00000000:BA:00.0 Off | 0 | -| N/A 34C P0 118W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 7 NVIDIA H100 80GB HBM3 On | 00000000:DB:00.0 Off | 0 | -| N/A 31C P0 116W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ - -+-----------------------------------------------------------------------------------------+ -| Processes: | -| GPU GI CI PID Type Process name GPU Memory | -| ID ID Usage | -|=========================================================================================| -| 0 N/A N/A 106002 C /usr/local/bin/python 1512MiB | -| 1 N/A N/A 106003 C /usr/local/bin/python 1512MiB | -| 2 N/A N/A 106004 C /usr/local/bin/python 1512MiB | -| 3 N/A N/A 106005 C /usr/local/bin/python 1512MiB | -| 4 N/A N/A 106006 C /usr/local/bin/python 1512MiB | -| 5 N/A N/A 106007 C /usr/local/bin/python 1512MiB | -| 6 N/A N/A 106008 C /usr/local/bin/python 1512MiB | -| 7 N/A N/A 106009 C /usr/local/bin/python 1512MiB | -+-----------------------------------------------------------------------------------------+ - -==================================================================================================== -val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model -train_loader:dataset:fineweb10B_sp1024 train_shards:80 -val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 -model_params:26829913 -mtp_num_heads:0 mtp_loss_weight:0.2 mtp_params:0 -world_size:8 grad_accum_steps:1 -sdp_backends:cudnn=False flash=True mem_efficient=False math=False -attention_mode:gqa num_heads:8 num_kv_heads:4 -tie_embeddings:True embed_lr:0.035 head_lr:0.0 matrix_lr:0.025 scalar_lr:0.025 -train_batch_tokens:786432 train_seq_len:2048 iterations:9000 warmup_steps:20 max_wallclock_seconds:600.000 -seed:42 -warmup_step:1/20 -warmup_step:2/20 -warmup_step:3/20 -warmup_step:4/20 -warmup_step:5/20 -warmup_step:6/20 -warmup_step:7/20 -warmup_step:8/20 -warmup_step:9/20 -warmup_step:10/20 -warmup_step:11/20 -warmup_step:12/20 -warmup_step:13/20 -warmup_step:14/20 -warmup_step:15/20 -warmup_step:16/20 -warmup_step:17/20 -warmup_step:18/20 -warmup_step:19/20 -warmup_step:20/20 -step:1/9000 train_loss:6.9320 train_time:147ms step_avg:147.02ms -step:2/9000 train_loss:8.6575 train_time:249ms step_avg:124.57ms -step:3/9000 train_loss:7.9030 train_time:376ms step_avg:125.30ms -step:4/9000 train_loss:7.1364 train_time:479ms step_avg:119.79ms -step:5/9000 train_loss:6.9033 train_time:582ms step_avg:116.46ms -step:6/9000 train_loss:6.8939 train_time:681ms step_avg:113.47ms -step:7/9000 train_loss:6.7697 train_time:783ms step_avg:111.86ms -step:8/9000 train_loss:6.7490 train_time:895ms step_avg:111.85ms -step:9/9000 train_loss:6.4122 train_time:1003ms step_avg:111.46ms -step:10/9000 train_loss:6.1059 train_time:1108ms step_avg:110.79ms -step:200/9000 train_loss:2.3893 train_time:17190ms step_avg:85.95ms -step:400/9000 train_loss:2.4075 train_time:34248ms step_avg:85.62ms -step:600/9000 train_loss:2.3202 train_time:51202ms step_avg:85.34ms -step:800/9000 train_loss:2.2272 train_time:68226ms step_avg:85.28ms -step:1000/9000 train_loss:2.2662 train_time:85119ms step_avg:85.12ms -step:1200/9000 train_loss:2.3444 train_time:102169ms step_avg:85.14ms -step:1400/9000 train_loss:2.1804 train_time:119191ms step_avg:85.14ms -step:1600/9000 train_loss:2.0751 train_time:135994ms step_avg:85.00ms -step:1800/9000 train_loss:2.1527 train_time:153021ms step_avg:85.01ms -step:2000/9000 train_loss:2.0628 train_time:169826ms step_avg:84.91ms -step:2200/9000 train_loss:2.1402 train_time:186799ms step_avg:84.91ms -step:2400/9000 train_loss:2.0597 train_time:203600ms step_avg:84.83ms -step:2600/9000 train_loss:2.1052 train_time:220497ms step_avg:84.81ms -step:2800/9000 train_loss:2.1554 train_time:237407ms step_avg:84.79ms -step:3000/9000 train_loss:2.1583 train_time:254219ms step_avg:84.74ms -step:3200/9000 train_loss:2.1703 train_time:271155ms step_avg:84.74ms -step:3400/9000 train_loss:2.0206 train_time:287991ms step_avg:84.70ms -step:3600/9000 train_loss:2.0998 train_time:304878ms step_avg:84.69ms -step:3800/9000 train_loss:2.0797 train_time:321716ms step_avg:84.66ms -step:4000/9000 train_loss:1.9845 train_time:338684ms step_avg:84.67ms -step:4200/9000 train_loss:2.1687 train_time:355556ms step_avg:84.66ms -step:4400/9000 train_loss:2.0449 train_time:372344ms step_avg:84.62ms -step:4600/9000 train_loss:1.8520 train_time:389329ms step_avg:84.64ms -step:4800/9000 train_loss:2.4407 train_time:406140ms step_avg:84.61ms -step:5000/9000 train_loss:2.1133 train_time:423022ms step_avg:84.60ms -step:5200/9000 train_loss:2.0523 train_time:439818ms step_avg:84.58ms -step:5400/9000 train_loss:2.0586 train_time:456791ms step_avg:84.59ms -step:5600/9000 train_loss:1.9656 train_time:473716ms step_avg:84.59ms -step:5800/9000 train_loss:2.0096 train_time:490533ms step_avg:84.57ms -step:6000/9000 train_loss:1.9523 train_time:507435ms step_avg:84.57ms -step:6200/9000 train_loss:1.9603 train_time:524255ms step_avg:84.56ms -step:6400/9000 train_loss:2.0087 train_time:541288ms step_avg:84.58ms -step:6600/9000 train_loss:1.8530 train_time:558116ms step_avg:84.56ms -step:6800/9000 train_loss:2.0303 train_time:575121ms step_avg:84.58ms -step:7000/9000 train_loss:1.7987 train_time:592095ms step_avg:84.59ms -step:7094/9000 val_loss:1.9314 val_bpb:1.1439 train_time:600017ms step_avg:84.58ms -stopping_early: wallclock_cap train_time:600017ms step:7094/9000 -peak memory allocated: 20600 MiB reserved: 20660 MiB -ema:applying EMA weights -Serialized model: 105783807 bytes -Code size: 66133 bytes -Serialized model int6+zstd: 15679840 bytes -Total submission size int6+zstd: 15745973 bytes -final_int6_roundtrip val_loss:1.9429 val_bpb:1.1507 eval_time:30815ms -final_int6_roundtrip_exact val_loss:1.94285209 val_bpb:1.15066644 -final_int6_sliding_window val_loss:1.9056 val_bpb:1.1286 stride:64 eval_time:84494ms -final_int6_sliding_window_exact val_loss:1.90563076 val_bpb:1.12862486 diff --git a/records/track_10min_16mb/2026-03-20_Int6_MLP3x_SmearGate_BigramHash_MuonWD_SWA/README.md b/records/track_10min_16mb/2026-03-20_Int6_MLP3x_SmearGate_BigramHash_MuonWD_SWA/README.md deleted file mode 100644 index 1f71780175..0000000000 --- a/records/track_10min_16mb/2026-03-20_Int6_MLP3x_SmearGate_BigramHash_MuonWD_SWA/README.md +++ /dev/null @@ -1,77 +0,0 @@ -# Int6 MLP3x + SmearGate + BigramHash + OrthoInit + Muon WD + SWA - -## Score: mean val_bpb = 1.1458 (3 seeds: 1.1460, 1.1466, 1.1449) - -Trained on 8×H100 SXM in 600 seconds. 15.86MB artifact (int6+zstd-22). - -## Approach - -Seven techniques stacked on the baseline 9-layer, 512-dim GPT: - -### 1. Per-Row Int6 Quantization + zstd-22 Compression -MLP and attention weight matrices quantized to int6 ([-32, 31]) with per-row scaling. Tied embeddings remain in fp16 (quantization-sensitive). The last transformer layer's key projection is kept in fp16 to reduce the quantization penalty on late-layer attention. zstd at level 22 provides ~5% better compression than zlib-9 on int6 data. - -### 2. 3× MLP Expansion -MLP hidden dimension increased from 1024 (2×) to 1536 (3×), enabled by the byte savings from int6 quantization. This is the single largest contributor to the improvement. - -### 3. SmearGate -A learned gate blending each token's embedding with the previous token's embedding, providing lightweight bigram-level context at the embedding layer. Adds ~512 parameters. - -### 4. BigramHash Embedding -A 4096-bucket hash table (dim=128, projected to 512) mapping adjacent token pairs to learned embeddings via `(prev_token * 31 + curr_token) % 4096`. Adds ~524K parameters. Complements SmearGate with an additive bigram signal. - -### 5. Orthogonal Weight Initialization -All large weight matrices initialized with `orthogonal_(gain=1.0)`. Output projections scaled by `1/sqrt(2 * num_layers)` following muP conventions. Accelerates early convergence. - -### 6. Muon Optimizer with Weight Decay -Muon with decoupled weight decay WD=0.04 (swept from 0.01–0.05, optimal at 0.04). Momentum warmup from 0.92 to 0.99 over 1500 steps. AdamW WD=0.01 for embedding and scalar parameters. Weight decay regularizes magnitudes, directly improving int6 quantization quality. - -### 7. Stochastic Weight Averaging (SWA) -SWA every 50 steps over the last 50% of training (~30 checkpoints averaged). Produces smoother weight distributions that quantize better. Swept swa_every from 200 down to 25; optimal at 50. - -## Hyperparameters - -| Parameter | Value | -|-----------|-------| -| num_layers | 9 | -| model_dim | 512 | -| mlp_mult | 3.0 (hidden=1536) | -| train_seq_len | 2048 | -| train_batch_tokens | 786,432 | -| warmdown_iters | 3000 | -| matrix_lr | 0.02 | -| scalar_lr | 0.02 | -| tied_embed_lr | 0.03 | -| muon_momentum | 0.99 (warmup from 0.92 over 1500 steps) | -| muon_weight_decay | 0.04 | -| adamw_weight_decay | 0.01 | -| grad_clip_norm | 0.3 | -| eval_stride | 64 | -| swa_every | 50 | -| swa_start_frac | 0.5 | -| bigram_vocab_size | 4096 | -| bigram_dim | 128 | -| compressor | zstd (level 22) | - -## Key Metrics - -- **Mean val_bpb: 1.1458** (std: 0.0008) -- Pre-quant val_bpb: 1.1616 -- Quantization penalty: 0.016 bpb (int6 vs fp16) -- Training: 7,379 steps in 600s (81.3 ms/step) -- Model params: ~22M -- Artifact size: 15.86MB (int6+zstd-22) - -## Reproducibility - -Three independent training runs with different random seeds: - -| Seed | val_loss | val_bpb | -|------|----------|---------| -| 1337 | 1.93492 | 1.14597 | -| 42 | 1.93591 | 1.14656 | -| 7 | 1.93314 | 1.14492 | -| **Mean** | **1.93466** | **1.14582** | -| **Std** | **0.00139** | **0.00082** | - -Improvement over current SOTA (1.1748): **-0.0290 bpb / -0.0503 nats** (p < 0.001). diff --git a/records/track_10min_16mb/2026-03-20_Int6_MLP3x_SmearGate_BigramHash_MuonWD_SWA/submission.json b/records/track_10min_16mb/2026-03-20_Int6_MLP3x_SmearGate_BigramHash_MuonWD_SWA/submission.json deleted file mode 100644 index fc7c902599..0000000000 --- a/records/track_10min_16mb/2026-03-20_Int6_MLP3x_SmearGate_BigramHash_MuonWD_SWA/submission.json +++ /dev/null @@ -1,25 +0,0 @@ -{ - "author": "Raahil Shah", - "github_id": "raahilshah", - "name": "Int6 MLP3x + SmearGate + BigramHash + OrthoInit + Muon WD + SWA", - "blurb": "Per-row int6 quantization on MLP/attention weights with zstd-22 compression, enabling 3x MLP expansion (hidden=1536). SmearGate blends adjacent token embeddings via a learned gate. BigramHash embedding (4096 buckets, dim=128) captures token-pair context. Orthogonal weight initialization with muP output scaling. Muon optimizer with decoupled weight decay (WD=0.04) and momentum warmup (0.92->0.99 over 1500 steps). Stochastic Weight Averaging every 50 steps over the last 50% of training. Trained at seq_len=2048 with batch=786432, grad_clip=0.3, warmdown=3000. Sliding window evaluation at stride=64.", - "date": "2026-03-20T05:30:00Z", - "val_loss": 1.93465876, - "val_bpb": 1.14581692, - "val_loss_std": 0.00139, - "val_bpb_std": 0.00082, - "seeds": [1337, 42, 7], - "seed_results": { - "1337": {"val_loss": 1.93492097, "val_bpb": 1.14597222}, - "42": {"val_loss": 1.93591485, "val_bpb": 1.14656085}, - "7": {"val_loss": 1.93314046, "val_bpb": 1.14491770} - }, - "pre_quant_val_loss": 1.9613, - "pre_quant_val_bpb": 1.1616, - "step_stop": 7379, - "wallclock_seconds": 600.075, - "eval_time_seconds": 155.204, - "bytes_total": 15862650, - "bytes_model_int6_zstd": 15810407, - "bytes_code": 52243 -} diff --git a/records/track_10min_16mb/2026-03-20_Int6_MLP3x_SmearGate_BigramHash_MuonWD_SWA/train_gpt.py b/records/track_10min_16mb/2026-03-20_Int6_MLP3x_SmearGate_BigramHash_MuonWD_SWA/train_gpt.py deleted file mode 100644 index 49438aeea8..0000000000 --- a/records/track_10min_16mb/2026-03-20_Int6_MLP3x_SmearGate_BigramHash_MuonWD_SWA/train_gpt.py +++ /dev/null @@ -1,1218 +0,0 @@ -""" -The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. - -Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. -""" - -from __future__ import annotations - -import copy -import glob -import io -import math -import os -import random -import subprocess -import sys -import time -import uuid -import zlib -from pathlib import Path - -try: - import zstandard - _COMPRESSOR = "zstd" -except ImportError: - _COMPRESSOR = "zlib" - -import numpy as np -import sentencepiece as spm -import torch -import torch.distributed as dist -import torch.nn.functional as F -from torch import Tensor, nn -from torch.nn.parallel import DistributedDataParallel as DDP - -# ----------------------------- -# HYPERPARAMETERS -# ----------------------------- - -class Hyperparameters: - data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") - train_files = os.path.join(data_path, "fineweb_train_*.bin") - val_files = os.path.join(data_path, "fineweb_val_*.bin") - tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") - run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) - seed = int(os.environ.get("SEED", 1337)) - - val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) - val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 500)) - train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 100)) - - iterations = int(os.environ.get("ITERATIONS", 20000)) - warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3000)) - warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) - train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 786_432)) - train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) - max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) - qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) - - vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) - num_layers = int(os.environ.get("NUM_LAYERS", 9)) - num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) - model_dim = int(os.environ.get("MODEL_DIM", 512)) - num_heads = int(os.environ.get("NUM_HEADS", 8)) - mlp_mult = float(os.environ.get("MLP_MULT", 3.0)) - tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) - rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) - logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) - - embed_lr = float(os.environ.get("EMBED_LR", 0.6)) - head_lr = float(os.environ.get("HEAD_LR", 0.008)) - tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.03)) - tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) - matrix_lr = float(os.environ.get("MATRIX_LR", 0.02)) - scalar_lr = float(os.environ.get("SCALAR_LR", 0.02)) - muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) - muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) - muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) - muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) - beta1 = float(os.environ.get("BETA1", 0.9)) - beta2 = float(os.environ.get("BETA2", 0.95)) - adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) - grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) - weight_decay = float(os.environ.get("WEIGHT_DECAY", 0.01)) - - eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) - eval_batch_seqs = int(os.environ.get("EVAL_BATCH_SEQS", 32)) - - bigram_vocab_size = int(os.environ.get("BIGRAM_VOCAB_SIZE", 4096)) - bigram_dim = int(os.environ.get("BIGRAM_DIM", 128)) - - swa_enabled = bool(int(os.environ.get("SWA_ENABLED", "1"))) - swa_start_frac = float(os.environ.get("SWA_START_FRAC", 0.5)) - swa_every = int(os.environ.get("SWA_EVERY", 50)) - -# ----------------------------- -# MUON OPTIMIZER -# ----------------------------- - -def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: - a, b, c = (3.4445, -4.7750, 2.0315) - X = G.bfloat16() - X /= X.norm() + eps - transposed = G.size(0) > G.size(1) - if transposed: - X = X.T - for _ in range(steps): - A = X @ X.T - B = b * A + c * A @ A - X = a * X + B @ X - return X.T if transposed else X - - -class Muon(torch.optim.Optimizer): - def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True, weight_decay: float = 0.0): - super().__init__( - params, - dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov, weight_decay=weight_decay), - ) - - @torch.no_grad() - def step(self, closure=None): - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - distributed = dist.is_available() and dist.is_initialized() - world_size = dist.get_world_size() if distributed else 1 - rank = dist.get_rank() if distributed else 0 - - for group in self.param_groups: - params = group["params"] - if not params: - continue - lr = group["lr"] - momentum = group["momentum"] - backend_steps = group["backend_steps"] - nesterov = group["nesterov"] - - total_params = sum(int(p.numel()) for p in params) - updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) - - curr = 0 - for i, p in enumerate(params): - if i % world_size == rank and p.grad is not None: - g = p.grad - state = self.state[p] - if "momentum_buffer" not in state: - state["momentum_buffer"] = torch.zeros_like(g) - buf = state["momentum_buffer"] - buf.mul_(momentum).add_(g) - if nesterov: - g = g.add(buf, alpha=momentum) - g = zeropower_via_newtonschulz5(g, steps=backend_steps) - g *= max(1, g.size(0) / g.size(1)) ** 0.5 - updates_flat[curr : curr + p.numel()] = g.reshape(-1) - curr += p.numel() - - if distributed: - dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) - - wd = group.get("weight_decay", 0.0) - curr = 0 - for p in params: - g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) - if wd > 0: - p.data.mul_(1.0 - lr * wd) - p.add_(g, alpha=-lr) - curr += p.numel() - return loss - - -# ----------------------------- -# TOKENIZER-AGNOSTIC EVALUATION -# ----------------------------- - -def build_sentencepiece_luts( - sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device -) -> tuple[Tensor, Tensor, Tensor]: - sp_vocab_size = int(sp.vocab_size()) - table_size = max(sp_vocab_size, vocab_size) - base_bytes_np = np.zeros((table_size,), dtype=np.int16) - has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) - is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) - for token_id in range(sp_vocab_size): - if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): - continue - is_boundary_token_np[token_id] = False - if sp.is_byte(token_id): - base_bytes_np[token_id] = 1 - continue - piece = sp.id_to_piece(token_id) - if piece.startswith("\u2581"): - has_leading_space_np[token_id] = True - piece = piece[1:] - base_bytes_np[token_id] = len(piece.encode("utf-8")) - return ( - torch.tensor(base_bytes_np, dtype=torch.int16, device=device), - torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), - torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), - ) - - -def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: - files = [Path(p) for p in sorted(glob.glob(pattern))] - if not files: - raise FileNotFoundError(f"No files found for pattern: {pattern}") - tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() - usable = ((tokens.numel() - 1) // seq_len) * seq_len - if usable <= 0: - raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") - return tokens[: usable + 1] - - -def eval_val( - args: Hyperparameters, - model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - grad_accum_steps: int, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, -) -> tuple[float, float]: - local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) - if local_batch_tokens < args.train_seq_len: - raise ValueError( - "VAL_BATCH_SIZE must provide at least one sequence per rank; " - f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " - f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" - ) - local_batch_seqs = local_batch_tokens // args.train_seq_len - total_seqs = (val_tokens.numel() - 1) // args.train_seq_len - seq_start = (total_seqs * rank) // world_size - seq_end = (total_seqs * (rank + 1)) // world_size - val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) - val_token_count = torch.zeros((), device=device, dtype=torch.float64) - val_byte_count = torch.zeros((), device=device, dtype=torch.float64) - model.eval() - with torch.inference_mode(): - for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): - batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) - raw_start = batch_seq_start * args.train_seq_len - raw_end = batch_seq_end * args.train_seq_len + 1 - local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) - x = local[:-1].reshape(-1, args.train_seq_len) - y = local[1:].reshape(-1, args.train_seq_len) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - batch_loss = model(x, y).detach() - batch_token_count = float(y.numel()) - val_loss_sum += batch_loss.to(torch.float64) * batch_token_count - val_token_count += batch_token_count - prev_ids = x.reshape(-1) - tgt_ids = y.reshape(-1) - token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) - token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) - val_byte_count += token_bytes.to(torch.float64).sum() - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) - val_loss = val_loss_sum / val_token_count - bits_per_token = val_loss.item() / math.log(2.0) - tokens_per_byte = val_token_count.item() / val_byte_count.item() - model.train() - return float(val_loss.item()), float(bits_per_token * tokens_per_byte) - - -# ----------------------------- -# POST-TRAINING QUANTIZATION (INT8 legacy + INT6 mixed) -# ----------------------------- - -CONTROL_TENSOR_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "CONTROL_TENSOR_NAME_PATTERNS", - "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,smear,bigram.scale", - ).split(",") - if pattern -) -FP16_KEEP_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get("FP16_KEEP_NAME_PATTERNS", "tok_emb,blocks.8.attn.c_k").split(",") - if pattern -) -INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", - ",".join(CONTROL_TENSOR_NAME_PATTERNS), - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 -INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 -INT8_PER_ROW_SCALE_DTYPE = torch.float16 -INT8_CLIP_PERCENTILE = 99.99984 -INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 - -def tensor_nbytes(t: Tensor) -> int: - return int(t.numel()) * int(t.element_size()) - -def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: - t32 = t.float() - if t32.ndim == 2: - clip_abs = ( - torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) - if t32.numel() - else torch.empty((t32.shape[0],), dtype=torch.float32) - ) - clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) - scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) - q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() - return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() - clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 - scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) - q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() - return q, scale - - -def _classify_param(name: str) -> str: - if "tok_emb" in name or "lm_head" in name: - return "embed" - if ".mlp." in name: - return "mlp" - if ".attn." in name or (".proj." in name and ".mlp." not in name): - return "attn" - return "other" - -def quantize_int6_per_row(t: Tensor) -> tuple[Tensor, Tensor]: - t32 = t.float() - if t32.ndim == 2: - row_max = t32.abs().amax(dim=1) - scale = (row_max / 31.0).clamp_min(1e-12).to(torch.float16) - scale = scale.clamp_min(torch.finfo(torch.float16).tiny) - q = torch.clamp(torch.round(t32 / scale.float()[:, None]), -32, 31).to(torch.int8) - return q, scale - amax = t32.abs().max().item() - scale = torch.tensor(max(amax / 31.0, 1e-12), dtype=torch.float16) - q = torch.clamp(torch.round(t32 / scale.float()), -32, 31).to(torch.int8) - return q, scale - -def mixed_quantize_int6(state_dict: dict[str, Tensor], int6_cats: set[str]): - result: dict[str, Tensor] = {} - meta: dict[str, object] = {} - for name, tensor in state_dict.items(): - t = tensor.detach().cpu().contiguous() - cat = _classify_param(name) - if not t.is_floating_point() or t.numel() <= 65536: - result[name] = t.to(torch.float16) if t.is_floating_point() else t - meta[name] = "passthrough" - continue - if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): - result[name] = t.float() - meta[name] = "passthrough_ctrl" - continue - if any(pattern in name for pattern in FP16_KEEP_NAME_PATTERNS): - result[name] = t.to(dtype=torch.float16).contiguous() - meta[name] = "passthrough_fp16" - continue - if cat in int6_cats and t.ndim >= 1: - q, s = quantize_int6_per_row(t) - result[name + ".q"] = q - result[name + ".scale"] = s - meta[name] = {"type": "int6"} - else: - q, s = quantize_float_tensor(t) - result[name + ".q"] = q - result[name + ".scale"] = s - meta[name] = {"type": "int8"} - return result, meta - -def dequantize_mixed_int6(result: dict[str, Tensor], meta: dict[str, object], - template_sd: dict[str, Tensor]) -> dict[str, Tensor]: - out: dict[str, Tensor] = {} - for name, orig in template_sd.items(): - info = meta[name] - orig_dtype = orig.dtype - if info in ("passthrough", "passthrough_ctrl", "passthrough_fp16"): - t = result[name] - if t.dtype == torch.float16 and orig_dtype in (torch.float32, torch.bfloat16): - t = t.to(orig_dtype) - out[name] = t - continue - q, s = result[name + ".q"], result[name + ".scale"] - if s.ndim > 0: - out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig_dtype) - else: - out[name] = (q.float() * float(s.item())).to(orig_dtype) - return out - - -# ----------------------------- -# DATA LOADING -# ----------------------------- - -def load_data_shard(file: Path) -> Tensor: - header_bytes = 256 * np.dtype(" None: - self.file_idx = (self.file_idx + 1) % len(self.files) - self.tokens = load_data_shard(self.files[self.file_idx]) - self.pos = 0 - - def take(self, n: int) -> Tensor: - chunks: list[Tensor] = [] - remaining = n - while remaining > 0: - avail = self.tokens.numel() - self.pos - if avail <= 0: - self._advance_file() - continue - k = min(remaining, avail) - chunks.append(self.tokens[self.pos : self.pos + k]) - self.pos += k - remaining -= k - return chunks[0] if len(chunks) == 1 else torch.cat(chunks) - - -class DistributedTokenLoader: - def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): - self.rank = rank - self.world_size = world_size - self.device = device - self.stream = TokenStream(pattern) - - def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: - local_tokens = global_tokens // (self.world_size * grad_accum_steps) - per_rank_span = local_tokens + 1 - chunk = self.stream.take(per_rank_span * self.world_size) - start = self.rank * per_rank_span - local = chunk[start : start + per_rank_span].to(dtype=torch.int64) - x = local[:-1].reshape(-1, seq_len) - y = local[1:].reshape(-1, seq_len) - return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) - - -# ----------------------------- -# TRANSFORMER MODULES -# ----------------------------- - -class RMSNorm(nn.Module): - def __init__(self, eps: float | None = None): - super().__init__() - self.eps = eps - - def forward(self, x: Tensor) -> Tensor: - return F.rms_norm(x, (x.size(-1),), eps=self.eps) - - -class CastedLinear(nn.Linear): - def forward(self, x: Tensor) -> Tensor: - w = self.weight.to(x.dtype) - bias = self.bias.to(x.dtype) if self.bias is not None else None - return F.linear(x, w, bias) - - -def restore_low_dim_params_to_fp32(module: nn.Module) -> None: - with torch.no_grad(): - for name, param in module.named_parameters(): - if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: - param.data = param.data.float() - - -class Rotary(nn.Module): - def __init__(self, dim: int, base: float = 10000.0): - super().__init__() - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self._seq_len_cached = 0 - self._cos_cached: Tensor | None = None - self._sin_cached: Tensor | None = None - - def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: - if ( - self._cos_cached is None - or self._sin_cached is None - or self._seq_len_cached != seq_len - or self._cos_cached.device != device - ): - t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) - freqs = torch.outer(t, self.inv_freq.to(device)) - self._cos_cached = freqs.cos()[None, None, :, :] - self._sin_cached = freqs.sin()[None, None, :, :] - self._seq_len_cached = seq_len - return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) - - -def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: - half = x.size(-1) // 2 - x1, x2 = x[..., :half], x[..., half:] - return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) - - -class CausalSelfAttention(nn.Module): - def __init__(self, dim: int, num_heads: int, num_kv_heads: int, rope_base: float, qk_gain_init: float): - super().__init__() - if dim % num_heads != 0: - raise ValueError("model_dim must be divisible by num_heads") - if num_heads % num_kv_heads != 0: - raise ValueError("num_heads must be divisible by num_kv_heads") - self.num_heads = num_heads - self.num_kv_heads = num_kv_heads - self.head_dim = dim // num_heads - if self.head_dim % 2 != 0: - raise ValueError("head_dim must be even for RoPE") - kv_dim = self.num_kv_heads * self.head_dim - self.c_q = CastedLinear(dim, dim, bias=False) - self.c_k = CastedLinear(dim, kv_dim, bias=False) - self.c_v = CastedLinear(dim, kv_dim, bias=False) - self.proj = CastedLinear(dim, dim, bias=False) - self.proj._zero_init = True - self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) - self.rotary = Rotary(self.head_dim, base=rope_base) - - def forward(self, x: Tensor) -> Tensor: - bsz, seqlen, dim = x.shape - q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) - k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) - v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) - q = F.rms_norm(q, (q.size(-1),)) - k = F.rms_norm(k, (k.size(-1),)) - cos, sin = self.rotary(seqlen, x.device, q.dtype) - q = apply_rotary_emb(q, cos, sin) - k = apply_rotary_emb(k, cos, sin) - q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] - y = F.scaled_dot_product_attention( - q, k, v, attn_mask=None, is_causal=True, - enable_gqa=(self.num_kv_heads != self.num_heads), - ) - y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) - return self.proj(y) - - -class MLP(nn.Module): - def __init__(self, dim: int, mlp_mult: float): - super().__init__() - hidden = int(mlp_mult * dim) - self.fc = CastedLinear(dim, hidden, bias=False) - self.proj = CastedLinear(hidden, dim, bias=False) - self.proj._zero_init = True - - def forward(self, x: Tensor) -> Tensor: - x = torch.relu(self.fc(x)) - return self.proj(x.square()) - - -class SmearGate(nn.Module): - """Blend each token's embedding with the previous token's embedding.""" - def __init__(self, dim: int): - super().__init__() - self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) - - def forward(self, x: Tensor) -> Tensor: - g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] - x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) - return (1 - g) * x + g * x_prev - - -class BigramHashEmbedding(nn.Module): - """Hash consecutive token pairs into a learned embedding table.""" - def __init__(self, bigram_vocab_size: int, bigram_dim: int, model_dim: int): - super().__init__() - self.bigram_vocab_size = bigram_vocab_size - self.embed = nn.Embedding(bigram_vocab_size, bigram_dim) - nn.init.zeros_(self.embed.weight) - self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None - if self.proj is not None: - nn.init.zeros_(self.proj.weight) - self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) - - def bigram_hash(self, tokens: Tensor) -> Tensor: - t = tokens.to(torch.int32) - mod = self.bigram_vocab_size - 1 - out = torch.empty_like(t) - out[..., 0] = mod - out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod - return out.long() - - def forward(self, token_ids: Tensor) -> Tensor: - h = self.embed(self.bigram_hash(token_ids)) - if self.proj is not None: - h = self.proj(h) - return h * self.scale.to(dtype=h.dtype) - - -class Block(nn.Module): - def __init__(self, dim: int, num_heads: int, num_kv_heads: int, mlp_mult: float, rope_base: float, qk_gain_init: float): - super().__init__() - self.attn_norm = RMSNorm() - self.mlp_norm = RMSNorm() - self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) - self.mlp = MLP(dim, mlp_mult) - self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) - - def forward(self, x: Tensor, x0: Tensor) -> Tensor: - mix = self.resid_mix.to(dtype=x.dtype) - x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 - attn_out = self.attn(self.attn_norm(x)) - x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out - x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) - return x - - -class GPT(nn.Module): - def __init__( - self, - vocab_size: int, - num_layers: int, - model_dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: float, - tie_embeddings: bool, - tied_embed_init_std: float, - logit_softcap: float, - rope_base: float, - qk_gain_init: float, - bigram_vocab_size: int = 0, - bigram_dim: int = 128, - ): - super().__init__() - if logit_softcap <= 0.0: - raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") - self.tie_embeddings = tie_embeddings - self.tied_embed_init_std = tied_embed_init_std - self.logit_softcap = logit_softcap - self.tok_emb = nn.Embedding(vocab_size, model_dim) - self.bigram = BigramHashEmbedding(bigram_vocab_size, bigram_dim, model_dim) if bigram_vocab_size > 0 else None - self.num_encoder_layers = num_layers // 2 - self.num_decoder_layers = num_layers - self.num_encoder_layers - self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) - self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) - self.smear = SmearGate(model_dim) - self.blocks = nn.ModuleList( - [ - Block(model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init) - for _ in range(num_layers) - ] - ) - self.final_norm = RMSNorm() - self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) - if self.lm_head is not None: - self.lm_head._zero_init = True - self._init_weights() - - def _init_weights(self) -> None: - if self.tie_embeddings: - nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) - num_layers = len(self.blocks) - for name, module in self.named_modules(): - if isinstance(module, nn.Linear): - if getattr(module, "_zero_init", False): - nn.init.zeros_(module.weight) - elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: - nn.init.orthogonal_(module.weight, gain=1.0) - if ".proj." in name or name.endswith(".proj"): - with torch.no_grad(): - module.weight.mul_(1.0 / math.sqrt(2 * num_layers)) - - def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: - x = self.tok_emb(input_ids) - if self.bigram is not None: - x = x + self.bigram(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - x = self.smear(x) - x0 = x - skips: list[Tensor] = [] - for i in range(self.num_encoder_layers): - x = self.blocks[i](x, x0) - skips.append(x) - for i in range(self.num_decoder_layers): - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - x = self.blocks[self.num_encoder_layers + i](x, x0) - x = self.final_norm(x).reshape(-1, x.size(-1)) - targets = target_ids.reshape(-1) - if self.tie_embeddings: - logits_proj = F.linear(x, self.tok_emb.weight) - else: - if self.lm_head is None: - raise RuntimeError("lm_head is required when tie_embeddings=False") - logits_proj = self.lm_head(x) - logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - return F.cross_entropy(logits.float(), targets, reduction="mean") - - def forward_logits(self, input_ids: Tensor) -> Tensor: - x = self.tok_emb(input_ids) - if self.bigram is not None: - x = x + self.bigram(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - x = self.smear(x) - x0 = x - skips: list[Tensor] = [] - for i in range(self.num_encoder_layers): - x = self.blocks[i](x, x0) - skips.append(x) - for i in range(self.num_decoder_layers): - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - x = self.blocks[self.num_encoder_layers + i](x, x0) - x = self.final_norm(x) - if self.tie_embeddings: - logits_proj = F.linear(x, self.tok_emb.weight) - else: - logits_proj = self.lm_head(x) - return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - - -def eval_val_sliding( - args: Hyperparameters, - base_model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, - stride: int, - batch_seqs: int = 32, -) -> tuple[float, float]: - seq_len = args.train_seq_len - total_tokens = val_tokens.numel() - 1 - window_starts = [ws for ws in range(0, total_tokens, stride) - if min(ws + seq_len, total_tokens) - ws >= 1] - total_windows = len(window_starts) - my_s = (total_windows * rank) // world_size - my_e = (total_windows * (rank + 1)) // world_size - my_windows = window_starts[my_s:my_e] - - loss_sum = torch.zeros((), device=device, dtype=torch.float64) - token_count = torch.zeros((), device=device, dtype=torch.float64) - byte_count = torch.zeros((), device=device, dtype=torch.float64) - - base_model.eval() - with torch.inference_mode(): - for bi in range(0, len(my_windows), batch_seqs): - batch_ws = my_windows[bi:bi + batch_seqs] - bsz = len(batch_ws) - x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) - y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) - wlens: list[int] = [] - for i, ws in enumerate(batch_ws): - end = min(ws + seq_len, total_tokens) - wlen = end - ws - wlens.append(wlen) - chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) - x_batch[i, :wlen] = chunk[:-1] - y_batch[i, :wlen] = chunk[1:] - with torch.autocast(device_type="cuda", dtype=torch.bfloat16): - logits = base_model.forward_logits(x_batch) - nll = F.cross_entropy( - logits.reshape(-1, logits.size(-1)).float(), - y_batch.reshape(-1), - reduction="none", - ).reshape(bsz, seq_len) - for i, ws in enumerate(batch_ws): - wlen = wlens[i] - s = 0 if ws == 0 else max(wlen - stride, 0) - scored_nll = nll[i, s:wlen].to(torch.float64) - loss_sum += scored_nll.sum() - token_count += float(wlen - s) - tgt = y_batch[i, s:wlen] - prev = x_batch[i, s:wlen] - tb = base_bytes_lut[tgt].to(torch.float64) - tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) - byte_count += tb.sum() - if rank == 0 and (bi // batch_seqs) % 50 == 0: - done = min(bi + batch_seqs, len(my_windows)) - pct = done / len(my_windows) * 100 - running_bpb = 0.0 - if token_count.item() > 0: - rl = (loss_sum / token_count).item() - running_bpb = rl / math.log(2.0) * (token_count.item() / byte_count.item()) - print(f" sliding_eval [{pct:5.1f}%] {done}/{len(my_windows)} windows running_bpb={running_bpb:.6f}", flush=True) - - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) - - val_loss = (loss_sum / token_count).item() - bits_per_token = val_loss / math.log(2.0) - tokens_per_byte = token_count.item() / byte_count.item() - base_model.train() - return val_loss, bits_per_token * tokens_per_byte - - -# ----------------------------- -# TRAINING -# ----------------------------- - -def main() -> None: - global zeropower_via_newtonschulz5 - - code = Path(__file__).read_text(encoding="utf-8") - args = Hyperparameters() - zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) - - distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ - rank = int(os.environ.get("RANK", "0")) - world_size = int(os.environ.get("WORLD_SIZE", "1")) - local_rank = int(os.environ.get("LOCAL_RANK", "0")) - if world_size <= 0: - raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") - if 8 % world_size != 0: - raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") - grad_accum_steps = 8 // world_size - grad_scale = 1.0 / grad_accum_steps - if not torch.cuda.is_available(): - raise RuntimeError("CUDA is required") - device = torch.device("cuda", local_rank) - torch.cuda.set_device(device) - if distributed: - dist.init_process_group(backend="nccl", device_id=device) - dist.barrier() - master_process = rank == 0 - - torch.backends.cuda.matmul.allow_tf32 = True - torch.backends.cudnn.allow_tf32 = True - from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp - enable_cudnn_sdp(False) - enable_flash_sdp(True) - enable_mem_efficient_sdp(False) - enable_math_sdp(False) - - logfile = None - if master_process: - os.makedirs("logs", exist_ok=True) - logfile = f"logs/{args.run_id}.txt" - print(logfile) - - def log0(msg: str, console: bool = True) -> None: - if not master_process: - return - if console: - print(msg) - if logfile is not None: - with open(logfile, "a", encoding="utf-8") as f: - print(msg, file=f) - - log0(code, console=False) - log0("=" * 100, console=False) - log0(f"Running Python {sys.version}", console=False) - log0(f"Running PyTorch {torch.__version__}", console=False) - log0( - subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, - console=False, - ) - log0("=" * 100, console=False) - - random.seed(args.seed) - np.random.seed(args.seed) - torch.manual_seed(args.seed) - torch.cuda.manual_seed_all(args.seed) - - if not args.tokenizer_path.endswith(".model"): - raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") - sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) - if int(sp.vocab_size()) != args.vocab_size: - raise ValueError( - f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" - ) - dataset_dir = Path(args.data_path).resolve() - actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) - val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) - base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( - sp, args.vocab_size, device - ) - log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") - log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") - log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") - - # MODEL + OPTIMIZER SETUP - base_model = GPT( - vocab_size=args.vocab_size, - num_layers=args.num_layers, - model_dim=args.model_dim, - num_heads=args.num_heads, - num_kv_heads=args.num_kv_heads, - mlp_mult=args.mlp_mult, - tie_embeddings=args.tie_embeddings, - tied_embed_init_std=args.tied_embed_init_std, - logit_softcap=args.logit_softcap, - rope_base=args.rope_base, - qk_gain_init=args.qk_gain_init, - bigram_vocab_size=args.bigram_vocab_size, - bigram_dim=args.bigram_dim, - ).to(device).bfloat16() - for module in base_model.modules(): - if isinstance(module, CastedLinear): - module.float() - restore_low_dim_params_to_fp32(base_model) - compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) - model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model - - block_named_params = list(base_model.blocks.named_parameters()) - matrix_params = [ - p for name, p in block_named_params - if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - scalar_params = [ - p for name, p in block_named_params - if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - if base_model.skip_weights.numel() > 0: - scalar_params.append(base_model.skip_weights) - scalar_params.append(base_model.smear.gate) - if base_model.bigram is not None: - scalar_params.append(base_model.bigram.scale) - - token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr - tok_params = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] - if base_model.bigram is not None: - tok_params.append({"params": [base_model.bigram.embed.weight], "lr": token_lr, "base_lr": token_lr}) - if base_model.bigram.proj is not None: - matrix_params.append(base_model.bigram.proj.weight) - - optimizer_tok = torch.optim.AdamW( - tok_params, - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - weight_decay=args.weight_decay, - fused=True, - ) - optimizer_muon = Muon( - matrix_params, - lr=args.matrix_lr, - momentum=args.muon_momentum, - backend_steps=args.muon_backend_steps, - weight_decay=0.04, - ) - for group in optimizer_muon.param_groups: - group["base_lr"] = args.matrix_lr - optimizer_scalar = torch.optim.AdamW( - [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - weight_decay=args.weight_decay, - fused=True, - ) - optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] - if base_model.lm_head is not None: - optimizer_head = torch.optim.Adam( - [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizers.insert(1, optimizer_head) - - n_params = sum(p.numel() for p in base_model.parameters()) - log0(f"model_params:{n_params}") - log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") - log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") - log0( - f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " - f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" - ) - log0( - f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " - f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " - f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" - ) - log0(f"seed:{args.seed}") - - # DATA LOADER & MODEL WARMUP - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - def zero_grad_all() -> None: - for opt in optimizers: - opt.zero_grad(set_to_none=True) - - max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None - - def lr_mul(step: int, elapsed_ms: float) -> float: - if args.warmdown_iters <= 0: - return 1.0 - if max_wallclock_ms is None: - warmdown_start = max(args.iterations - args.warmdown_iters, 0) - return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 - step_ms = elapsed_ms / max(step, 1) - warmdown_ms = args.warmdown_iters * step_ms - remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) - return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 - - if args.warmup_steps > 0: - initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} - initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] - model.train() - for warmup_step in range(args.warmup_steps): - zero_grad_all() - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - warmup_loss = model(x, y) - (warmup_loss * grad_scale).backward() - for opt in optimizers: - opt.step() - zero_grad_all() - if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: - log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") - base_model.load_state_dict(initial_model_state, strict=True) - for opt, state in zip(optimizers, initial_optimizer_states, strict=True): - opt.load_state_dict(state) - zero_grad_all() - if distributed: - model.require_backward_grad_sync = True - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - # MAIN TRAINING LOOP - training_time_ms = 0.0 - stop_after_step: int | None = None - swa_state: dict[str, Tensor] | None = None - swa_count = 0 - torch.cuda.synchronize() - t0 = time.perf_counter() - - step = 0 - while True: - last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) - - should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) - if should_validate: - torch.cuda.synchronize() - training_time_ms += 1000.0 * (time.perf_counter() - t0) - val_loss, val_bpb = eval_val( - args, model, rank, world_size, device, grad_accum_steps, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - ) - log0( - f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " - f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" - ) - torch.cuda.synchronize() - t0 = time.perf_counter() - - if last_step: - if stop_after_step is not None and step < args.iterations: - log0( - f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " - f"step:{step}/{args.iterations}" - ) - break - - elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - scale = lr_mul(step, elapsed_ms) - zero_grad_all() - train_loss = torch.zeros((), device=device) - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - loss = model(x, y) - train_loss += loss.detach() - (loss * grad_scale).backward() - train_loss /= grad_accum_steps - - frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 - muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum - for group in optimizer_muon.param_groups: - group["momentum"] = muon_momentum - - for opt in optimizers: - for group in opt.param_groups: - group["lr"] = group["base_lr"] * scale - - if args.grad_clip_norm > 0: - torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) - for opt in optimizers: - opt.step() - zero_grad_all() - - step += 1 - approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - - # SWA: collect checkpoints during warmdown - if args.swa_enabled and scale < args.swa_start_frac and step % args.swa_every == 0: - if swa_state is None: - swa_state = {name: t.detach().cpu().clone() for name, t in base_model.state_dict().items()} - swa_count = 1 - log0(f"swa:start step:{step}") - else: - for name, t in base_model.state_dict().items(): - swa_state[name] += t.detach().cpu() - swa_count += 1 - - should_log_train = ( - args.train_log_every > 0 - and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) - ) - if should_log_train: - log0( - f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " - f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" - ) - - reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms - if distributed and max_wallclock_ms is not None: - reached_cap_tensor = torch.tensor(int(reached_cap), device=device) - dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) - reached_cap = bool(reached_cap_tensor.item()) - if stop_after_step is None and reached_cap: - stop_after_step = step - - log0( - f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " - f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" - ) - - # Apply SWA if collected - if args.swa_enabled and swa_state is not None and swa_count > 1: - log0(f"swa:applying averaged {swa_count} checkpoints") - current_state = base_model.state_dict() - avg_state = { - name: (tensor / swa_count).to(dtype=current_state[name].dtype) - for name, tensor in swa_state.items() - } - base_model.load_state_dict(avg_state, strict=True) - - # SERIALIZATION + ROUNDTRIP VALIDATION - if master_process: - torch.save(base_model.state_dict(), "final_model.pt") - model_bytes = os.path.getsize("final_model.pt") - code_bytes = len(code.encode("utf-8")) - log0(f"Serialized model: {model_bytes} bytes") - log0(f"Code size: {code_bytes} bytes") - log0(f"Total submission size: {model_bytes + code_bytes} bytes") - - # INT6 mixed quantization + zstd/zlib export - sd_cpu = {k: v.detach().cpu() for k, v in base_model.state_dict().items()} - quant_result, quant_meta = mixed_quantize_int6(sd_cpu, {"mlp", "attn"}) - quant_buf = io.BytesIO() - torch.save({"w": quant_result, "m": quant_meta}, quant_buf) - quant_raw = quant_buf.getvalue() - if _COMPRESSOR == "zstd": - quant_blob = zstandard.ZstdCompressor(level=22).compress(quant_raw) - else: - quant_blob = zlib.compress(quant_raw, 9) - if master_process: - with open("final_model.int8.ptz", "wb") as f: - f.write(quant_blob) - quant_file_bytes = os.path.getsize("final_model.int8.ptz") - code_bytes = len(code.encode("utf-8")) - log0(f"Serialized model int6+{_COMPRESSOR}: {quant_file_bytes} bytes") - log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") - - if distributed: - dist.barrier() - with open("final_model.int8.ptz", "rb") as f: - quant_blob_disk = f.read() - if _COMPRESSOR == "zstd": - decompressed = zstandard.ZstdDecompressor().decompress(quant_blob_disk) - else: - decompressed = zlib.decompress(quant_blob_disk) - quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") - deq_state = dequantize_mixed_int6(quant_state["w"], quant_state["m"], sd_cpu) - base_model.load_state_dict(deq_state, strict=True) - - # Sliding window eval on int6-roundtripped weights - torch.cuda.synchronize() - t_qeval = time.perf_counter() - if args.eval_stride > 0 and args.eval_stride < args.train_seq_len: - log0(f"final_eval_mode:sliding_window stride:{args.eval_stride} batch_seqs:{args.eval_batch_seqs}") - q_val_loss, q_val_bpb = eval_val_sliding( - args, base_model, rank, world_size, device, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - stride=args.eval_stride, batch_seqs=args.eval_batch_seqs, - ) - else: - log0("final_eval_mode:standard") - q_val_loss, q_val_bpb = eval_val( - args, model, rank, world_size, device, grad_accum_steps, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - ) - torch.cuda.synchronize() - log0( - f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " - f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" - ) - log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") - - if distributed: - dist.destroy_process_group() - - -if __name__ == "__main__": - main() diff --git a/records/track_10min_16mb/2026-03-20_Int6_MLP3x_SmearGate_BigramHash_MuonWD_SWA/train_seed1337.log b/records/track_10min_16mb/2026-03-20_Int6_MLP3x_SmearGate_BigramHash_MuonWD_SWA/train_seed1337.log deleted file mode 100644 index 5cc76640d4..0000000000 --- a/records/track_10min_16mb/2026-03-20_Int6_MLP3x_SmearGate_BigramHash_MuonWD_SWA/train_seed1337.log +++ /dev/null @@ -1,217 +0,0 @@ -logs/repro_seed1337.txt -val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model -train_loader:dataset:fineweb10B_sp1024 train_shards:80 -val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 -model_params:22368841 -world_size:8 grad_accum_steps:1 -attention_mode:gqa num_heads:8 num_kv_heads:4 -tie_embeddings:True embed_lr:0.03 matrix_lr:0.02 scalar_lr:0.02 -train_batch_tokens:786432 train_seq_len:2048 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 -seed:1337 -warmup_step:1/20 -warmup_step:2/20 -warmup_step:3/20 -warmup_step:4/20 -warmup_step:5/20 -warmup_step:6/20 -warmup_step:7/20 -warmup_step:8/20 -warmup_step:9/20 -warmup_step:10/20 -warmup_step:11/20 -warmup_step:12/20 -warmup_step:13/20 -warmup_step:14/20 -warmup_step:15/20 -warmup_step:16/20 -warmup_step:17/20 -warmup_step:18/20 -warmup_step:19/20 -warmup_step:20/20 -step:0/20000 val_loss:6.9341 val_bpb:4.1068 train_time:0ms step_avg:0.02ms -step:1/20000 train_loss:6.9348 train_time:141ms step_avg:140.86ms -step:2/20000 train_loss:8.1566 train_time:191ms step_avg:95.73ms -step:3/20000 train_loss:7.7064 train_time:270ms step_avg:90.08ms -step:4/20000 train_loss:6.9837 train_time:349ms step_avg:87.32ms -step:5/20000 train_loss:6.7600 train_time:429ms step_avg:85.70ms -step:6/20000 train_loss:6.5915 train_time:508ms step_avg:84.74ms -step:7/20000 train_loss:6.4894 train_time:588ms step_avg:83.96ms -step:8/20000 train_loss:6.4654 train_time:668ms step_avg:83.45ms -step:9/20000 train_loss:6.3310 train_time:747ms step_avg:83.00ms -step:10/20000 train_loss:6.0755 train_time:827ms step_avg:82.69ms -step:100/20000 train_loss:3.2043 train_time:8017ms step_avg:80.17ms -step:200/20000 train_loss:2.4054 train_time:16091ms step_avg:80.45ms -step:300/20000 train_loss:2.5684 train_time:24169ms step_avg:80.56ms -step:400/20000 train_loss:2.4415 train_time:32265ms step_avg:80.66ms -step:500/20000 train_loss:2.4149 train_time:40280ms step_avg:80.56ms -step:500/20000 val_loss:2.3768 val_bpb:1.4077 train_time:40311ms step_avg:80.62ms -step:600/20000 train_loss:2.3469 train_time:48365ms step_avg:80.61ms -step:700/20000 train_loss:2.3682 train_time:56464ms step_avg:80.66ms -step:800/20000 train_loss:2.2582 train_time:64564ms step_avg:80.70ms -step:900/20000 train_loss:2.1457 train_time:72674ms step_avg:80.75ms -step:1000/20000 train_loss:2.2893 train_time:80724ms step_avg:80.72ms -step:1000/20000 val_loss:2.2443 val_bpb:1.3292 train_time:80754ms step_avg:80.75ms -step:1100/20000 train_loss:2.3382 train_time:88850ms step_avg:80.77ms -step:1200/20000 train_loss:2.3689 train_time:96958ms step_avg:80.80ms -step:1300/20000 train_loss:2.1175 train_time:105083ms step_avg:80.83ms -step:1400/20000 train_loss:2.1998 train_time:113218ms step_avg:80.87ms -step:1500/20000 train_loss:2.2363 train_time:121286ms step_avg:80.86ms -step:1500/20000 val_loss:2.2003 val_bpb:1.3031 train_time:121316ms step_avg:80.88ms -step:1600/20000 train_loss:2.0913 train_time:129420ms step_avg:80.89ms -step:1700/20000 train_loss:2.1554 train_time:137554ms step_avg:80.91ms -step:1800/20000 train_loss:2.1699 train_time:145705ms step_avg:80.95ms -step:1900/20000 train_loss:2.1431 train_time:153783ms step_avg:80.94ms -step:2000/20000 train_loss:2.0824 train_time:161922ms step_avg:80.96ms -step:2000/20000 val_loss:2.1458 val_bpb:1.2709 train_time:161951ms step_avg:80.98ms -step:2100/20000 train_loss:2.0603 train_time:170068ms step_avg:80.98ms -step:2200/20000 train_loss:2.1501 train_time:178204ms step_avg:81.00ms -step:2300/20000 train_loss:2.1217 train_time:186356ms step_avg:81.02ms -step:2400/20000 train_loss:2.0800 train_time:194428ms step_avg:81.01ms -step:2500/20000 train_loss:2.1844 train_time:202574ms step_avg:81.03ms -step:2500/20000 val_loss:2.1186 val_bpb:1.2548 train_time:202604ms step_avg:81.04ms -step:2600/20000 train_loss:2.1216 train_time:210724ms step_avg:81.05ms -step:2700/20000 train_loss:2.1134 train_time:218878ms step_avg:81.07ms -step:2800/20000 train_loss:2.1697 train_time:227036ms step_avg:81.08ms -step:2900/20000 train_loss:2.0371 train_time:235109ms step_avg:81.07ms -step:3000/20000 train_loss:2.1729 train_time:243248ms step_avg:81.08ms -step:3000/20000 val_loss:2.1036 val_bpb:1.2458 train_time:243278ms step_avg:81.09ms -step:3100/20000 train_loss:2.0532 train_time:251404ms step_avg:81.10ms -step:3200/20000 train_loss:2.1856 train_time:259526ms step_avg:81.10ms -step:3300/20000 train_loss:2.0843 train_time:267598ms step_avg:81.09ms -step:3400/20000 train_loss:2.0318 train_time:275748ms step_avg:81.10ms -step:3500/20000 train_loss:2.1929 train_time:283889ms step_avg:81.11ms -step:3500/20000 val_loss:2.0947 val_bpb:1.2406 train_time:283918ms step_avg:81.12ms -step:3600/20000 train_loss:2.1124 train_time:292032ms step_avg:81.12ms -step:3700/20000 train_loss:2.1085 train_time:300166ms step_avg:81.13ms -step:3800/20000 train_loss:2.0913 train_time:308237ms step_avg:81.11ms -step:3900/20000 train_loss:2.0922 train_time:316389ms step_avg:81.13ms -step:4000/20000 train_loss:1.9969 train_time:324531ms step_avg:81.13ms -step:4000/20000 val_loss:2.0882 val_bpb:1.2368 train_time:324562ms step_avg:81.14ms -step:4100/20000 train_loss:2.0421 train_time:332679ms step_avg:81.14ms -step:4200/20000 train_loss:2.1781 train_time:340832ms step_avg:81.15ms -step:4300/20000 train_loss:2.0885 train_time:348910ms step_avg:81.14ms -step:4400/20000 train_loss:2.0679 train_time:357041ms step_avg:81.15ms -step:4500/20000 train_loss:2.1553 train_time:365175ms step_avg:81.15ms -step:4500/20000 val_loss:2.0779 val_bpb:1.2307 train_time:365205ms step_avg:81.16ms -step:4600/20000 train_loss:1.8732 train_time:373321ms step_avg:81.16ms -step:4700/20000 train_loss:2.2650 train_time:381397ms step_avg:81.15ms -step:4800/20000 train_loss:2.4581 train_time:389535ms step_avg:81.15ms -step:4900/20000 train_loss:2.0833 train_time:397690ms step_avg:81.16ms -step:5000/20000 train_loss:2.1380 train_time:405816ms step_avg:81.16ms -step:5000/20000 val_loss:2.0581 val_bpb:1.2190 train_time:405846ms step_avg:81.17ms -step:5100/20000 train_loss:2.1602 train_time:413970ms step_avg:81.17ms -step:5200/20000 train_loss:2.0790 train_time:422033ms step_avg:81.16ms -step:5300/20000 train_loss:2.0418 train_time:430163ms step_avg:81.16ms -step:5400/20000 train_loss:2.0823 train_time:438300ms step_avg:81.17ms -step:5500/20000 train_loss:2.0531 train_time:446439ms step_avg:81.17ms -step:5500/20000 val_loss:2.0397 val_bpb:1.2080 train_time:446470ms step_avg:81.18ms -step:5600/20000 train_loss:1.9954 train_time:454577ms step_avg:81.17ms -step:5700/20000 train_loss:2.0530 train_time:462635ms step_avg:81.16ms -step:5800/20000 train_loss:2.0411 train_time:470766ms step_avg:81.17ms -swa:start step:5900 -step:5900/20000 train_loss:1.9503 train_time:478911ms step_avg:81.17ms -step:6000/20000 train_loss:1.9795 train_time:487162ms step_avg:81.19ms -step:6000/20000 val_loss:2.0221 val_bpb:1.1976 train_time:487222ms step_avg:81.20ms -step:6100/20000 train_loss:1.9570 train_time:495309ms step_avg:81.20ms -step:6200/20000 train_loss:1.9956 train_time:503492ms step_avg:81.21ms -step:6300/20000 train_loss:1.9887 train_time:511672ms step_avg:81.22ms -step:6400/20000 train_loss:2.0452 train_time:519869ms step_avg:81.23ms -step:6500/20000 train_loss:2.1266 train_time:528043ms step_avg:81.24ms -step:6500/20000 val_loss:1.9989 val_bpb:1.1839 train_time:528099ms step_avg:81.25ms -step:6600/20000 train_loss:1.8915 train_time:536172ms step_avg:81.24ms -step:6700/20000 train_loss:1.9858 train_time:544370ms step_avg:81.25ms -step:6800/20000 train_loss:2.0731 train_time:552561ms step_avg:81.26ms -step:6900/20000 train_loss:1.8717 train_time:560768ms step_avg:81.27ms -step:7000/20000 train_loss:1.8351 train_time:568971ms step_avg:81.28ms -step:7000/20000 val_loss:1.9765 val_bpb:1.1706 train_time:569044ms step_avg:81.29ms -step:7100/20000 train_loss:1.9766 train_time:577112ms step_avg:81.28ms -step:7200/20000 train_loss:1.9235 train_time:585354ms step_avg:81.30ms -step:7300/20000 train_loss:2.0470 train_time:593523ms step_avg:81.30ms -step:7379/20000 val_loss:1.9613 val_bpb:1.1616 train_time:600075ms step_avg:81.32ms -stopping_early: wallclock_cap train_time:600075ms step:7379/20000 -peak memory allocated: 16965 MiB reserved: 17074 MiB -swa:applying averaged 30 checkpoints -Serialized model: 87413467 bytes -Code size: 52243 bytes -Total submission size: 87465710 bytes -Serialized model int6+zstd: 15810407 bytes -Total submission size int8+zlib: 15862650 bytes -final_eval_mode:sliding_window stride:64 batch_seqs:32 - sliding_eval [ 0.0%] 32/121136 windows running_bpb=1.208722 - sliding_eval [ 1.3%] 1632/121136 windows running_bpb=1.140131 - sliding_eval [ 2.7%] 3232/121136 windows running_bpb=1.142313 - sliding_eval [ 4.0%] 4832/121136 windows running_bpb=1.135811 - sliding_eval [ 5.3%] 6432/121136 windows running_bpb=1.148080 - sliding_eval [ 6.6%] 8032/121136 windows running_bpb=1.149270 - sliding_eval [ 8.0%] 9632/121136 windows running_bpb=1.150707 - sliding_eval [ 9.3%] 11232/121136 windows running_bpb=1.145960 - sliding_eval [ 10.6%] 12832/121136 windows running_bpb=1.143566 - sliding_eval [ 11.9%] 14432/121136 windows running_bpb=1.145048 - sliding_eval [ 13.2%] 16032/121136 windows running_bpb=1.153703 - sliding_eval [ 14.6%] 17632/121136 windows running_bpb=1.152061 - sliding_eval [ 15.9%] 19232/121136 windows running_bpb=1.153431 - sliding_eval [ 17.2%] 20832/121136 windows running_bpb=1.151819 - sliding_eval [ 18.5%] 22432/121136 windows running_bpb=1.150317 - sliding_eval [ 19.8%] 24032/121136 windows running_bpb=1.150648 - sliding_eval [ 21.2%] 25632/121136 windows running_bpb=1.152022 - sliding_eval [ 22.5%] 27232/121136 windows running_bpb=1.152565 - sliding_eval [ 23.8%] 28832/121136 windows running_bpb=1.158651 - sliding_eval [ 25.1%] 30432/121136 windows running_bpb=1.156082 - sliding_eval [ 26.4%] 32032/121136 windows running_bpb=1.156986 - sliding_eval [ 27.8%] 33632/121136 windows running_bpb=1.155627 - sliding_eval [ 29.1%] 35232/121136 windows running_bpb=1.155001 - sliding_eval [ 30.4%] 36832/121136 windows running_bpb=1.154637 - sliding_eval [ 31.7%] 38432/121136 windows running_bpb=1.155395 - sliding_eval [ 33.0%] 40032/121136 windows running_bpb=1.153042 - sliding_eval [ 34.4%] 41632/121136 windows running_bpb=1.152081 - sliding_eval [ 35.7%] 43232/121136 windows running_bpb=1.152401 - sliding_eval [ 37.0%] 44832/121136 windows running_bpb=1.151179 - sliding_eval [ 38.3%] 46432/121136 windows running_bpb=1.151070 - sliding_eval [ 39.7%] 48032/121136 windows running_bpb=1.150304 - sliding_eval [ 41.0%] 49632/121136 windows running_bpb=1.151501 - sliding_eval [ 42.3%] 51232/121136 windows running_bpb=1.152604 - sliding_eval [ 43.6%] 52832/121136 windows running_bpb=1.153122 - sliding_eval [ 44.9%] 54432/121136 windows running_bpb=1.152596 - sliding_eval [ 46.3%] 56032/121136 windows running_bpb=1.152950 - sliding_eval [ 47.6%] 57632/121136 windows running_bpb=1.152072 - sliding_eval [ 48.9%] 59232/121136 windows running_bpb=1.148163 - sliding_eval [ 50.2%] 60832/121136 windows running_bpb=1.148269 - sliding_eval [ 51.5%] 62432/121136 windows running_bpb=1.149229 - sliding_eval [ 52.9%] 64032/121136 windows running_bpb=1.149413 - sliding_eval [ 54.2%] 65632/121136 windows running_bpb=1.149267 - sliding_eval [ 55.5%] 67232/121136 windows running_bpb=1.148058 - sliding_eval [ 56.8%] 68832/121136 windows running_bpb=1.147801 - sliding_eval [ 58.1%] 70432/121136 windows running_bpb=1.147115 - sliding_eval [ 59.5%] 72032/121136 windows running_bpb=1.147155 - sliding_eval [ 60.8%] 73632/121136 windows running_bpb=1.147124 - sliding_eval [ 62.1%] 75232/121136 windows running_bpb=1.147277 - sliding_eval [ 63.4%] 76832/121136 windows running_bpb=1.146987 - sliding_eval [ 64.7%] 78432/121136 windows running_bpb=1.147571 - sliding_eval [ 66.1%] 80032/121136 windows running_bpb=1.147900 - sliding_eval [ 67.4%] 81632/121136 windows running_bpb=1.147616 - sliding_eval [ 68.7%] 83232/121136 windows running_bpb=1.148655 - sliding_eval [ 70.0%] 84832/121136 windows running_bpb=1.150558 - sliding_eval [ 71.4%] 86432/121136 windows running_bpb=1.149847 - sliding_eval [ 72.7%] 88032/121136 windows running_bpb=1.150566 - sliding_eval [ 74.0%] 89632/121136 windows running_bpb=1.150913 - sliding_eval [ 75.3%] 91232/121136 windows running_bpb=1.150908 - sliding_eval [ 76.6%] 92832/121136 windows running_bpb=1.150504 - sliding_eval [ 78.0%] 94432/121136 windows running_bpb=1.150744 - sliding_eval [ 79.3%] 96032/121136 windows running_bpb=1.150167 - sliding_eval [ 80.6%] 97632/121136 windows running_bpb=1.152978 - sliding_eval [ 81.9%] 99232/121136 windows running_bpb=1.152969 - sliding_eval [ 83.2%] 100832/121136 windows running_bpb=1.153008 - sliding_eval [ 84.6%] 102432/121136 windows running_bpb=1.152635 - sliding_eval [ 85.9%] 104032/121136 windows running_bpb=1.152163 - sliding_eval [ 87.2%] 105632/121136 windows running_bpb=1.151433 - sliding_eval [ 88.5%] 107232/121136 windows running_bpb=1.151416 - sliding_eval [ 89.8%] 108832/121136 windows running_bpb=1.152036 - sliding_eval [ 91.2%] 110432/121136 windows running_bpb=1.152069 - sliding_eval [ 92.5%] 112032/121136 windows running_bpb=1.152041 - sliding_eval [ 93.8%] 113632/121136 windows running_bpb=1.152486 - sliding_eval [ 95.1%] 115232/121136 windows running_bpb=1.152226 - sliding_eval [ 96.4%] 116832/121136 windows running_bpb=1.151816 - sliding_eval [ 97.8%] 118432/121136 windows running_bpb=1.152134 - sliding_eval [ 99.1%] 120032/121136 windows running_bpb=1.152203 -final_int8_zlib_roundtrip val_loss:1.9349 val_bpb:1.1460 eval_time:155204ms -final_int8_zlib_roundtrip_exact val_loss:1.93492097 val_bpb:1.14597222 diff --git a/records/track_10min_16mb/2026-03-20_Int6_MLP3x_SmearGate_BigramHash_MuonWD_SWA/train_seed42.log b/records/track_10min_16mb/2026-03-20_Int6_MLP3x_SmearGate_BigramHash_MuonWD_SWA/train_seed42.log deleted file mode 100644 index 0b1d423e1b..0000000000 --- a/records/track_10min_16mb/2026-03-20_Int6_MLP3x_SmearGate_BigramHash_MuonWD_SWA/train_seed42.log +++ /dev/null @@ -1,217 +0,0 @@ -logs/repro_seed42.txt -val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model -train_loader:dataset:fineweb10B_sp1024 train_shards:80 -val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 -model_params:22368841 -world_size:8 grad_accum_steps:1 -attention_mode:gqa num_heads:8 num_kv_heads:4 -tie_embeddings:True embed_lr:0.03 matrix_lr:0.02 scalar_lr:0.02 -train_batch_tokens:786432 train_seq_len:2048 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 -seed:42 -warmup_step:1/20 -warmup_step:2/20 -warmup_step:3/20 -warmup_step:4/20 -warmup_step:5/20 -warmup_step:6/20 -warmup_step:7/20 -warmup_step:8/20 -warmup_step:9/20 -warmup_step:10/20 -warmup_step:11/20 -warmup_step:12/20 -warmup_step:13/20 -warmup_step:14/20 -warmup_step:15/20 -warmup_step:16/20 -warmup_step:17/20 -warmup_step:18/20 -warmup_step:19/20 -warmup_step:20/20 -step:0/20000 val_loss:6.9299 val_bpb:4.1043 train_time:0ms step_avg:0.02ms -step:1/20000 train_loss:6.9310 train_time:141ms step_avg:141.09ms -step:2/20000 train_loss:8.0394 train_time:192ms step_avg:96.13ms -step:3/20000 train_loss:7.5866 train_time:271ms step_avg:90.31ms -step:4/20000 train_loss:6.9692 train_time:350ms step_avg:87.56ms -step:5/20000 train_loss:6.7855 train_time:430ms step_avg:85.92ms -step:6/20000 train_loss:6.7112 train_time:509ms step_avg:84.87ms -step:7/20000 train_loss:6.5691 train_time:589ms step_avg:84.08ms -step:8/20000 train_loss:6.4878 train_time:668ms step_avg:83.54ms -step:9/20000 train_loss:6.2613 train_time:749ms step_avg:83.17ms -step:10/20000 train_loss:6.0139 train_time:828ms step_avg:82.78ms -step:100/20000 train_loss:3.2185 train_time:8047ms step_avg:80.47ms -step:200/20000 train_loss:2.4217 train_time:16138ms step_avg:80.69ms -step:300/20000 train_loss:2.5674 train_time:24230ms step_avg:80.77ms -step:400/20000 train_loss:2.4369 train_time:32327ms step_avg:80.82ms -step:500/20000 train_loss:2.4143 train_time:40357ms step_avg:80.71ms -step:500/20000 val_loss:2.3763 val_bpb:1.4074 train_time:40389ms step_avg:80.78ms -step:600/20000 train_loss:2.3498 train_time:48470ms step_avg:80.78ms -step:700/20000 train_loss:2.3629 train_time:56586ms step_avg:80.84ms -step:800/20000 train_loss:2.2553 train_time:64717ms step_avg:80.90ms -step:900/20000 train_loss:2.1451 train_time:72831ms step_avg:80.92ms -step:1000/20000 train_loss:2.2893 train_time:80896ms step_avg:80.90ms -step:1000/20000 val_loss:2.2408 val_bpb:1.3271 train_time:80927ms step_avg:80.93ms -step:1100/20000 train_loss:2.3339 train_time:89028ms step_avg:80.93ms -step:1200/20000 train_loss:2.3653 train_time:97141ms step_avg:80.95ms -step:1300/20000 train_loss:2.1113 train_time:105265ms step_avg:80.97ms -step:1400/20000 train_loss:2.1968 train_time:113412ms step_avg:81.01ms -step:1500/20000 train_loss:2.2328 train_time:121473ms step_avg:80.98ms -step:1500/20000 val_loss:2.1971 val_bpb:1.3012 train_time:121505ms step_avg:81.00ms -step:1600/20000 train_loss:2.0913 train_time:129610ms step_avg:81.01ms -step:1700/20000 train_loss:2.1588 train_time:137750ms step_avg:81.03ms -step:1800/20000 train_loss:2.1768 train_time:145887ms step_avg:81.05ms -step:1900/20000 train_loss:2.1430 train_time:153960ms step_avg:81.03ms -step:2000/20000 train_loss:2.0824 train_time:162084ms step_avg:81.04ms -step:2000/20000 val_loss:2.1445 val_bpb:1.2701 train_time:162115ms step_avg:81.06ms -step:2100/20000 train_loss:2.0592 train_time:170212ms step_avg:81.05ms -step:2200/20000 train_loss:2.1532 train_time:178342ms step_avg:81.06ms -step:2300/20000 train_loss:2.1209 train_time:186459ms step_avg:81.07ms -step:2400/20000 train_loss:2.0806 train_time:194529ms step_avg:81.05ms -step:2500/20000 train_loss:2.1814 train_time:202658ms step_avg:81.06ms -step:2500/20000 val_loss:2.1180 val_bpb:1.2544 train_time:202687ms step_avg:81.07ms -step:2600/20000 train_loss:2.1218 train_time:210795ms step_avg:81.08ms -step:2700/20000 train_loss:2.1135 train_time:218932ms step_avg:81.09ms -step:2800/20000 train_loss:2.1658 train_time:227072ms step_avg:81.10ms -step:2900/20000 train_loss:2.0330 train_time:235138ms step_avg:81.08ms -step:3000/20000 train_loss:2.1732 train_time:243268ms step_avg:81.09ms -step:3000/20000 val_loss:2.1037 val_bpb:1.2459 train_time:243297ms step_avg:81.10ms -step:3100/20000 train_loss:2.0492 train_time:251415ms step_avg:81.10ms -step:3200/20000 train_loss:2.1852 train_time:259549ms step_avg:81.11ms -step:3300/20000 train_loss:2.0811 train_time:267631ms step_avg:81.10ms -step:3400/20000 train_loss:2.0312 train_time:275775ms step_avg:81.11ms -step:3500/20000 train_loss:2.1952 train_time:283914ms step_avg:81.12ms -step:3500/20000 val_loss:2.0950 val_bpb:1.2408 train_time:283944ms step_avg:81.13ms -step:3600/20000 train_loss:2.1113 train_time:292049ms step_avg:81.12ms -step:3700/20000 train_loss:2.1110 train_time:300182ms step_avg:81.13ms -step:3800/20000 train_loss:2.0916 train_time:308247ms step_avg:81.12ms -step:3900/20000 train_loss:2.0976 train_time:316370ms step_avg:81.12ms -step:4000/20000 train_loss:1.9964 train_time:324501ms step_avg:81.13ms -step:4000/20000 val_loss:2.0891 val_bpb:1.2373 train_time:324533ms step_avg:81.13ms -step:4100/20000 train_loss:2.0379 train_time:332622ms step_avg:81.13ms -step:4200/20000 train_loss:2.1806 train_time:340747ms step_avg:81.13ms -step:4300/20000 train_loss:2.0848 train_time:348819ms step_avg:81.12ms -step:4400/20000 train_loss:2.0663 train_time:356948ms step_avg:81.12ms -step:4500/20000 train_loss:2.1568 train_time:365076ms step_avg:81.13ms -step:4500/20000 val_loss:2.0785 val_bpb:1.2310 train_time:365107ms step_avg:81.13ms -step:4600/20000 train_loss:1.8748 train_time:373196ms step_avg:81.13ms -step:4700/20000 train_loss:2.2622 train_time:381255ms step_avg:81.12ms -step:4800/20000 train_loss:2.4646 train_time:389386ms step_avg:81.12ms -step:4900/20000 train_loss:2.0864 train_time:397537ms step_avg:81.13ms -step:5000/20000 train_loss:2.1410 train_time:405671ms step_avg:81.13ms -step:5000/20000 val_loss:2.0593 val_bpb:1.2196 train_time:405701ms step_avg:81.14ms -step:5100/20000 train_loss:2.1628 train_time:413803ms step_avg:81.14ms -step:5200/20000 train_loss:2.0759 train_time:421867ms step_avg:81.13ms -step:5300/20000 train_loss:2.0441 train_time:430005ms step_avg:81.13ms -step:5400/20000 train_loss:2.0850 train_time:438138ms step_avg:81.14ms -step:5500/20000 train_loss:2.0548 train_time:446264ms step_avg:81.14ms -step:5500/20000 val_loss:2.0408 val_bpb:1.2087 train_time:446294ms step_avg:81.14ms -step:5600/20000 train_loss:1.9948 train_time:454407ms step_avg:81.14ms -step:5700/20000 train_loss:2.0521 train_time:462470ms step_avg:81.14ms -step:5800/20000 train_loss:2.0472 train_time:470606ms step_avg:81.14ms -swa:start step:5900 -step:5900/20000 train_loss:1.9475 train_time:478734ms step_avg:81.14ms -step:6000/20000 train_loss:1.9827 train_time:486956ms step_avg:81.16ms -step:6000/20000 val_loss:2.0230 val_bpb:1.1981 train_time:487010ms step_avg:81.17ms -step:6100/20000 train_loss:1.9602 train_time:495057ms step_avg:81.16ms -step:6200/20000 train_loss:1.9951 train_time:503235ms step_avg:81.17ms -step:6300/20000 train_loss:1.9914 train_time:511399ms step_avg:81.17ms -step:6400/20000 train_loss:2.0448 train_time:519600ms step_avg:81.19ms -step:6500/20000 train_loss:2.1264 train_time:527774ms step_avg:81.20ms -step:6500/20000 val_loss:1.9999 val_bpb:1.1845 train_time:527827ms step_avg:81.20ms -step:6600/20000 train_loss:1.8904 train_time:535881ms step_avg:81.19ms -step:6700/20000 train_loss:1.9893 train_time:544058ms step_avg:81.20ms -step:6800/20000 train_loss:2.0702 train_time:552251ms step_avg:81.21ms -step:6900/20000 train_loss:1.8737 train_time:560415ms step_avg:81.22ms -step:7000/20000 train_loss:1.8403 train_time:568606ms step_avg:81.23ms -step:7000/20000 val_loss:1.9777 val_bpb:1.1713 train_time:568677ms step_avg:81.24ms -step:7100/20000 train_loss:1.9747 train_time:576746ms step_avg:81.23ms -step:7200/20000 train_loss:1.9284 train_time:584938ms step_avg:81.24ms -step:7300/20000 train_loss:2.0473 train_time:593105ms step_avg:81.25ms -step:7385/20000 val_loss:1.9624 val_bpb:1.1622 train_time:600099ms step_avg:81.26ms -stopping_early: wallclock_cap train_time:600099ms step:7385/20000 -peak memory allocated: 16962 MiB reserved: 17072 MiB -swa:applying averaged 30 checkpoints -Serialized model: 87413467 bytes -Code size: 52243 bytes -Total submission size: 87465710 bytes -Serialized model int6+zstd: 15865061 bytes -Total submission size int8+zlib: 15917304 bytes -final_eval_mode:sliding_window stride:64 batch_seqs:32 - sliding_eval [ 0.0%] 32/121136 windows running_bpb=1.217812 - sliding_eval [ 1.3%] 1632/121136 windows running_bpb=1.142244 - sliding_eval [ 2.7%] 3232/121136 windows running_bpb=1.143869 - sliding_eval [ 4.0%] 4832/121136 windows running_bpb=1.137510 - sliding_eval [ 5.3%] 6432/121136 windows running_bpb=1.149324 - sliding_eval [ 6.6%] 8032/121136 windows running_bpb=1.150680 - sliding_eval [ 8.0%] 9632/121136 windows running_bpb=1.152112 - sliding_eval [ 9.3%] 11232/121136 windows running_bpb=1.147263 - sliding_eval [ 10.6%] 12832/121136 windows running_bpb=1.144500 - sliding_eval [ 11.9%] 14432/121136 windows running_bpb=1.146151 - sliding_eval [ 13.2%] 16032/121136 windows running_bpb=1.154878 - sliding_eval [ 14.6%] 17632/121136 windows running_bpb=1.153234 - sliding_eval [ 15.9%] 19232/121136 windows running_bpb=1.154387 - sliding_eval [ 17.2%] 20832/121136 windows running_bpb=1.152614 - sliding_eval [ 18.5%] 22432/121136 windows running_bpb=1.151132 - sliding_eval [ 19.8%] 24032/121136 windows running_bpb=1.151450 - sliding_eval [ 21.2%] 25632/121136 windows running_bpb=1.152771 - sliding_eval [ 22.5%] 27232/121136 windows running_bpb=1.153223 - sliding_eval [ 23.8%] 28832/121136 windows running_bpb=1.159424 - sliding_eval [ 25.1%] 30432/121136 windows running_bpb=1.156860 - sliding_eval [ 26.4%] 32032/121136 windows running_bpb=1.157801 - sliding_eval [ 27.8%] 33632/121136 windows running_bpb=1.156445 - sliding_eval [ 29.1%] 35232/121136 windows running_bpb=1.155906 - sliding_eval [ 30.4%] 36832/121136 windows running_bpb=1.155529 - sliding_eval [ 31.7%] 38432/121136 windows running_bpb=1.156210 - sliding_eval [ 33.0%] 40032/121136 windows running_bpb=1.153835 - sliding_eval [ 34.4%] 41632/121136 windows running_bpb=1.152829 - sliding_eval [ 35.7%] 43232/121136 windows running_bpb=1.153183 - sliding_eval [ 37.0%] 44832/121136 windows running_bpb=1.151985 - sliding_eval [ 38.3%] 46432/121136 windows running_bpb=1.151803 - sliding_eval [ 39.7%] 48032/121136 windows running_bpb=1.151061 - sliding_eval [ 41.0%] 49632/121136 windows running_bpb=1.152265 - sliding_eval [ 42.3%] 51232/121136 windows running_bpb=1.153299 - sliding_eval [ 43.6%] 52832/121136 windows running_bpb=1.153825 - sliding_eval [ 44.9%] 54432/121136 windows running_bpb=1.153325 - sliding_eval [ 46.3%] 56032/121136 windows running_bpb=1.153701 - sliding_eval [ 47.6%] 57632/121136 windows running_bpb=1.152846 - sliding_eval [ 48.9%] 59232/121136 windows running_bpb=1.148994 - sliding_eval [ 50.2%] 60832/121136 windows running_bpb=1.149131 - sliding_eval [ 51.5%] 62432/121136 windows running_bpb=1.150059 - sliding_eval [ 52.9%] 64032/121136 windows running_bpb=1.150232 - sliding_eval [ 54.2%] 65632/121136 windows running_bpb=1.150070 - sliding_eval [ 55.5%] 67232/121136 windows running_bpb=1.148854 - sliding_eval [ 56.8%] 68832/121136 windows running_bpb=1.148560 - sliding_eval [ 58.1%] 70432/121136 windows running_bpb=1.147909 - sliding_eval [ 59.5%] 72032/121136 windows running_bpb=1.148036 - sliding_eval [ 60.8%] 73632/121136 windows running_bpb=1.148014 - sliding_eval [ 62.1%] 75232/121136 windows running_bpb=1.148155 - sliding_eval [ 63.4%] 76832/121136 windows running_bpb=1.147836 - sliding_eval [ 64.7%] 78432/121136 windows running_bpb=1.148439 - sliding_eval [ 66.1%] 80032/121136 windows running_bpb=1.148726 - sliding_eval [ 67.4%] 81632/121136 windows running_bpb=1.148454 - sliding_eval [ 68.7%] 83232/121136 windows running_bpb=1.149482 - sliding_eval [ 70.0%] 84832/121136 windows running_bpb=1.151398 - sliding_eval [ 71.4%] 86432/121136 windows running_bpb=1.150673 - sliding_eval [ 72.7%] 88032/121136 windows running_bpb=1.151403 - sliding_eval [ 74.0%] 89632/121136 windows running_bpb=1.151744 - sliding_eval [ 75.3%] 91232/121136 windows running_bpb=1.151730 - sliding_eval [ 76.6%] 92832/121136 windows running_bpb=1.151305 - sliding_eval [ 78.0%] 94432/121136 windows running_bpb=1.151514 - sliding_eval [ 79.3%] 96032/121136 windows running_bpb=1.150913 - sliding_eval [ 80.6%] 97632/121136 windows running_bpb=1.153700 - sliding_eval [ 81.9%] 99232/121136 windows running_bpb=1.153718 - sliding_eval [ 83.2%] 100832/121136 windows running_bpb=1.153765 - sliding_eval [ 84.6%] 102432/121136 windows running_bpb=1.153384 - sliding_eval [ 85.9%] 104032/121136 windows running_bpb=1.152914 - sliding_eval [ 87.2%] 105632/121136 windows running_bpb=1.152196 - sliding_eval [ 88.5%] 107232/121136 windows running_bpb=1.152180 - sliding_eval [ 89.8%] 108832/121136 windows running_bpb=1.152788 - sliding_eval [ 91.2%] 110432/121136 windows running_bpb=1.152807 - sliding_eval [ 92.5%] 112032/121136 windows running_bpb=1.152799 - sliding_eval [ 93.8%] 113632/121136 windows running_bpb=1.153251 - sliding_eval [ 95.1%] 115232/121136 windows running_bpb=1.153003 - sliding_eval [ 96.4%] 116832/121136 windows running_bpb=1.152614 - sliding_eval [ 97.8%] 118432/121136 windows running_bpb=1.152912 - sliding_eval [ 99.1%] 120032/121136 windows running_bpb=1.153003 -final_int8_zlib_roundtrip val_loss:1.9359 val_bpb:1.1466 eval_time:154937ms -final_int8_zlib_roundtrip_exact val_loss:1.93591485 val_bpb:1.14656085 diff --git a/records/track_10min_16mb/2026-03-20_Int6_MLP3x_SmearGate_BigramHash_MuonWD_SWA/train_seed7.log b/records/track_10min_16mb/2026-03-20_Int6_MLP3x_SmearGate_BigramHash_MuonWD_SWA/train_seed7.log deleted file mode 100644 index cea46043a1..0000000000 --- a/records/track_10min_16mb/2026-03-20_Int6_MLP3x_SmearGate_BigramHash_MuonWD_SWA/train_seed7.log +++ /dev/null @@ -1,217 +0,0 @@ -logs/repro_seed7.txt -val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model -train_loader:dataset:fineweb10B_sp1024 train_shards:80 -val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 -model_params:22368841 -world_size:8 grad_accum_steps:1 -attention_mode:gqa num_heads:8 num_kv_heads:4 -tie_embeddings:True embed_lr:0.03 matrix_lr:0.02 scalar_lr:0.02 -train_batch_tokens:786432 train_seq_len:2048 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 -seed:7 -warmup_step:1/20 -warmup_step:2/20 -warmup_step:3/20 -warmup_step:4/20 -warmup_step:5/20 -warmup_step:6/20 -warmup_step:7/20 -warmup_step:8/20 -warmup_step:9/20 -warmup_step:10/20 -warmup_step:11/20 -warmup_step:12/20 -warmup_step:13/20 -warmup_step:14/20 -warmup_step:15/20 -warmup_step:16/20 -warmup_step:17/20 -warmup_step:18/20 -warmup_step:19/20 -warmup_step:20/20 -step:0/20000 val_loss:6.9314 val_bpb:4.1051 train_time:0ms step_avg:0.02ms -step:1/20000 train_loss:6.9327 train_time:139ms step_avg:138.73ms -step:2/20000 train_loss:8.0944 train_time:189ms step_avg:94.62ms -step:3/20000 train_loss:7.6485 train_time:268ms step_avg:89.29ms -step:4/20000 train_loss:6.9720 train_time:348ms step_avg:87.04ms -step:5/20000 train_loss:6.8367 train_time:427ms step_avg:85.34ms -step:6/20000 train_loss:6.7016 train_time:506ms step_avg:84.36ms -step:7/20000 train_loss:6.5065 train_time:585ms step_avg:83.61ms -step:8/20000 train_loss:6.5046 train_time:665ms step_avg:83.08ms -step:9/20000 train_loss:6.3458 train_time:745ms step_avg:82.75ms -step:10/20000 train_loss:6.0822 train_time:824ms step_avg:82.41ms -step:100/20000 train_loss:3.2092 train_time:8024ms step_avg:80.24ms -step:200/20000 train_loss:2.4059 train_time:16105ms step_avg:80.53ms -step:300/20000 train_loss:2.5671 train_time:24186ms step_avg:80.62ms -step:400/20000 train_loss:2.4327 train_time:32257ms step_avg:80.64ms -step:500/20000 train_loss:2.4160 train_time:40268ms step_avg:80.54ms -step:500/20000 val_loss:2.3756 val_bpb:1.4069 train_time:40298ms step_avg:80.60ms -step:600/20000 train_loss:2.3475 train_time:48357ms step_avg:80.60ms -step:700/20000 train_loss:2.3648 train_time:56470ms step_avg:80.67ms -step:800/20000 train_loss:2.2575 train_time:64595ms step_avg:80.74ms -step:900/20000 train_loss:2.1461 train_time:72694ms step_avg:80.77ms -step:1000/20000 train_loss:2.2902 train_time:80736ms step_avg:80.74ms -step:1000/20000 val_loss:2.2431 val_bpb:1.3285 train_time:80767ms step_avg:80.77ms -step:1100/20000 train_loss:2.3372 train_time:88855ms step_avg:80.78ms -step:1200/20000 train_loss:2.3680 train_time:96961ms step_avg:80.80ms -step:1300/20000 train_loss:2.1164 train_time:105088ms step_avg:80.84ms -step:1400/20000 train_loss:2.1986 train_time:113201ms step_avg:80.86ms -step:1500/20000 train_loss:2.2387 train_time:121247ms step_avg:80.83ms -step:1500/20000 val_loss:2.1991 val_bpb:1.3024 train_time:121278ms step_avg:80.85ms -step:1600/20000 train_loss:2.0927 train_time:129348ms step_avg:80.84ms -step:1700/20000 train_loss:2.1590 train_time:137469ms step_avg:80.86ms -step:1800/20000 train_loss:2.1712 train_time:145593ms step_avg:80.89ms -step:1900/20000 train_loss:2.1385 train_time:153651ms step_avg:80.87ms -step:2000/20000 train_loss:2.0794 train_time:161771ms step_avg:80.89ms -step:2000/20000 val_loss:2.1462 val_bpb:1.2711 train_time:161800ms step_avg:80.90ms -step:2100/20000 train_loss:2.0627 train_time:169910ms step_avg:80.91ms -step:2200/20000 train_loss:2.1523 train_time:178030ms step_avg:80.92ms -step:2300/20000 train_loss:2.1207 train_time:186161ms step_avg:80.94ms -step:2400/20000 train_loss:2.0789 train_time:194213ms step_avg:80.92ms -step:2500/20000 train_loss:2.1794 train_time:202352ms step_avg:80.94ms -step:2500/20000 val_loss:2.1183 val_bpb:1.2546 train_time:202382ms step_avg:80.95ms -step:2600/20000 train_loss:2.1236 train_time:210483ms step_avg:80.96ms -step:2700/20000 train_loss:2.1158 train_time:218597ms step_avg:80.96ms -step:2800/20000 train_loss:2.1667 train_time:226733ms step_avg:80.98ms -step:2900/20000 train_loss:2.0343 train_time:234798ms step_avg:80.96ms -step:3000/20000 train_loss:2.1712 train_time:242909ms step_avg:80.97ms -step:3000/20000 val_loss:2.1029 val_bpb:1.2455 train_time:242938ms step_avg:80.98ms -step:3100/20000 train_loss:2.0485 train_time:251038ms step_avg:80.98ms -step:3200/20000 train_loss:2.1859 train_time:259153ms step_avg:80.99ms -step:3300/20000 train_loss:2.0814 train_time:267203ms step_avg:80.97ms -step:3400/20000 train_loss:2.0306 train_time:275332ms step_avg:80.98ms -step:3500/20000 train_loss:2.1933 train_time:283448ms step_avg:80.99ms -step:3500/20000 val_loss:2.0936 val_bpb:1.2399 train_time:283478ms step_avg:80.99ms -step:3600/20000 train_loss:2.1121 train_time:291568ms step_avg:80.99ms -step:3700/20000 train_loss:2.1086 train_time:299670ms step_avg:80.99ms -step:3800/20000 train_loss:2.0914 train_time:307736ms step_avg:80.98ms -step:3900/20000 train_loss:2.0955 train_time:315873ms step_avg:80.99ms -step:4000/20000 train_loss:1.9926 train_time:323983ms step_avg:81.00ms -step:4000/20000 val_loss:2.0874 val_bpb:1.2363 train_time:324012ms step_avg:81.00ms -step:4100/20000 train_loss:2.0390 train_time:332110ms step_avg:81.00ms -step:4200/20000 train_loss:2.1809 train_time:340221ms step_avg:81.01ms -step:4300/20000 train_loss:2.0871 train_time:348269ms step_avg:80.99ms -step:4400/20000 train_loss:2.0686 train_time:356381ms step_avg:81.00ms -step:4500/20000 train_loss:2.1547 train_time:364500ms step_avg:81.00ms -step:4500/20000 val_loss:2.0773 val_bpb:1.2303 train_time:364529ms step_avg:81.01ms -step:4600/20000 train_loss:1.8733 train_time:372623ms step_avg:81.01ms -step:4700/20000 train_loss:2.2637 train_time:380674ms step_avg:80.99ms -step:4800/20000 train_loss:2.4601 train_time:388792ms step_avg:81.00ms -step:4900/20000 train_loss:2.0823 train_time:396926ms step_avg:81.01ms -step:5000/20000 train_loss:2.1381 train_time:405050ms step_avg:81.01ms -step:5000/20000 val_loss:2.0581 val_bpb:1.2189 train_time:405079ms step_avg:81.02ms -step:5100/20000 train_loss:2.1576 train_time:413163ms step_avg:81.01ms -step:5200/20000 train_loss:2.0770 train_time:421218ms step_avg:81.00ms -step:5300/20000 train_loss:2.0447 train_time:429326ms step_avg:81.00ms -step:5400/20000 train_loss:2.0824 train_time:437472ms step_avg:81.01ms -step:5500/20000 train_loss:2.0543 train_time:445595ms step_avg:81.02ms -step:5500/20000 val_loss:2.0392 val_bpb:1.2078 train_time:445625ms step_avg:81.02ms -step:5600/20000 train_loss:1.9955 train_time:453729ms step_avg:81.02ms -step:5700/20000 train_loss:2.0511 train_time:461794ms step_avg:81.02ms -step:5800/20000 train_loss:2.0449 train_time:469913ms step_avg:81.02ms -step:5900/20000 train_loss:1.9491 train_time:478054ms step_avg:81.03ms -swa:start step:5950 -step:6000/20000 train_loss:1.9804 train_time:486249ms step_avg:81.04ms -step:6000/20000 val_loss:2.0216 val_bpb:1.1973 train_time:486308ms step_avg:81.05ms -step:6100/20000 train_loss:1.9605 train_time:494353ms step_avg:81.04ms -step:6200/20000 train_loss:1.9964 train_time:502535ms step_avg:81.05ms -step:6300/20000 train_loss:1.9883 train_time:510749ms step_avg:81.07ms -step:6400/20000 train_loss:2.0438 train_time:518937ms step_avg:81.08ms -step:6500/20000 train_loss:2.1242 train_time:527090ms step_avg:81.09ms -step:6500/20000 val_loss:1.9985 val_bpb:1.1836 train_time:527147ms step_avg:81.10ms -step:6600/20000 train_loss:1.8904 train_time:535211ms step_avg:81.09ms -step:6700/20000 train_loss:1.9885 train_time:543382ms step_avg:81.10ms -step:6800/20000 train_loss:2.0730 train_time:551575ms step_avg:81.11ms -step:6900/20000 train_loss:1.8732 train_time:559768ms step_avg:81.13ms -step:7000/20000 train_loss:1.8375 train_time:567934ms step_avg:81.13ms -step:7000/20000 val_loss:1.9763 val_bpb:1.1705 train_time:567990ms step_avg:81.14ms -step:7100/20000 train_loss:1.9738 train_time:576059ms step_avg:81.14ms -step:7200/20000 train_loss:1.9234 train_time:584245ms step_avg:81.15ms -step:7300/20000 train_loss:2.0473 train_time:592436ms step_avg:81.16ms -step:7393/20000 val_loss:1.9605 val_bpb:1.1611 train_time:600101ms step_avg:81.17ms -stopping_early: wallclock_cap train_time:600101ms step:7393/20000 -peak memory allocated: 16962 MiB reserved: 17072 MiB -swa:applying averaged 29 checkpoints -Serialized model: 87413467 bytes -Code size: 52243 bytes -Total submission size: 87465710 bytes -Serialized model int6+zstd: 15892664 bytes -Total submission size int8+zlib: 15944907 bytes -final_eval_mode:sliding_window stride:64 batch_seqs:32 - sliding_eval [ 0.0%] 32/121136 windows running_bpb=1.204187 - sliding_eval [ 1.3%] 1632/121136 windows running_bpb=1.140479 - sliding_eval [ 2.7%] 3232/121136 windows running_bpb=1.141456 - sliding_eval [ 4.0%] 4832/121136 windows running_bpb=1.134566 - sliding_eval [ 5.3%] 6432/121136 windows running_bpb=1.146982 - sliding_eval [ 6.6%] 8032/121136 windows running_bpb=1.148125 - sliding_eval [ 8.0%] 9632/121136 windows running_bpb=1.149646 - sliding_eval [ 9.3%] 11232/121136 windows running_bpb=1.145113 - sliding_eval [ 10.6%] 12832/121136 windows running_bpb=1.142509 - sliding_eval [ 11.9%] 14432/121136 windows running_bpb=1.144101 - sliding_eval [ 13.2%] 16032/121136 windows running_bpb=1.153069 - sliding_eval [ 14.6%] 17632/121136 windows running_bpb=1.151471 - sliding_eval [ 15.9%] 19232/121136 windows running_bpb=1.152853 - sliding_eval [ 17.2%] 20832/121136 windows running_bpb=1.151038 - sliding_eval [ 18.5%] 22432/121136 windows running_bpb=1.149419 - sliding_eval [ 19.8%] 24032/121136 windows running_bpb=1.149769 - sliding_eval [ 21.2%] 25632/121136 windows running_bpb=1.151145 - sliding_eval [ 22.5%] 27232/121136 windows running_bpb=1.151649 - sliding_eval [ 23.8%] 28832/121136 windows running_bpb=1.157772 - sliding_eval [ 25.1%] 30432/121136 windows running_bpb=1.155120 - sliding_eval [ 26.4%] 32032/121136 windows running_bpb=1.156148 - sliding_eval [ 27.8%] 33632/121136 windows running_bpb=1.154779 - sliding_eval [ 29.1%] 35232/121136 windows running_bpb=1.154098 - sliding_eval [ 30.4%] 36832/121136 windows running_bpb=1.153707 - sliding_eval [ 31.7%] 38432/121136 windows running_bpb=1.154262 - sliding_eval [ 33.0%] 40032/121136 windows running_bpb=1.151908 - sliding_eval [ 34.4%] 41632/121136 windows running_bpb=1.150870 - sliding_eval [ 35.7%] 43232/121136 windows running_bpb=1.151243 - sliding_eval [ 37.0%] 44832/121136 windows running_bpb=1.150119 - sliding_eval [ 38.3%] 46432/121136 windows running_bpb=1.150033 - sliding_eval [ 39.7%] 48032/121136 windows running_bpb=1.149298 - sliding_eval [ 41.0%] 49632/121136 windows running_bpb=1.150513 - sliding_eval [ 42.3%] 51232/121136 windows running_bpb=1.151542 - sliding_eval [ 43.6%] 52832/121136 windows running_bpb=1.152059 - sliding_eval [ 44.9%] 54432/121136 windows running_bpb=1.151549 - sliding_eval [ 46.3%] 56032/121136 windows running_bpb=1.151954 - sliding_eval [ 47.6%] 57632/121136 windows running_bpb=1.151101 - sliding_eval [ 48.9%] 59232/121136 windows running_bpb=1.147247 - sliding_eval [ 50.2%] 60832/121136 windows running_bpb=1.147351 - sliding_eval [ 51.5%] 62432/121136 windows running_bpb=1.148300 - sliding_eval [ 52.9%] 64032/121136 windows running_bpb=1.148482 - sliding_eval [ 54.2%] 65632/121136 windows running_bpb=1.148331 - sliding_eval [ 55.5%] 67232/121136 windows running_bpb=1.147109 - sliding_eval [ 56.8%] 68832/121136 windows running_bpb=1.146807 - sliding_eval [ 58.1%] 70432/121136 windows running_bpb=1.146093 - sliding_eval [ 59.5%] 72032/121136 windows running_bpb=1.146142 - sliding_eval [ 60.8%] 73632/121136 windows running_bpb=1.146091 - sliding_eval [ 62.1%] 75232/121136 windows running_bpb=1.146251 - sliding_eval [ 63.4%] 76832/121136 windows running_bpb=1.145998 - sliding_eval [ 64.7%] 78432/121136 windows running_bpb=1.146598 - sliding_eval [ 66.1%] 80032/121136 windows running_bpb=1.146887 - sliding_eval [ 67.4%] 81632/121136 windows running_bpb=1.146625 - sliding_eval [ 68.7%] 83232/121136 windows running_bpb=1.147686 - sliding_eval [ 70.0%] 84832/121136 windows running_bpb=1.149608 - sliding_eval [ 71.4%] 86432/121136 windows running_bpb=1.148926 - sliding_eval [ 72.7%] 88032/121136 windows running_bpb=1.149671 - sliding_eval [ 74.0%] 89632/121136 windows running_bpb=1.150045 - sliding_eval [ 75.3%] 91232/121136 windows running_bpb=1.150024 - sliding_eval [ 76.6%] 92832/121136 windows running_bpb=1.149607 - sliding_eval [ 78.0%] 94432/121136 windows running_bpb=1.149839 - sliding_eval [ 79.3%] 96032/121136 windows running_bpb=1.149256 - sliding_eval [ 80.6%] 97632/121136 windows running_bpb=1.152058 - sliding_eval [ 81.9%] 99232/121136 windows running_bpb=1.152077 - sliding_eval [ 83.2%] 100832/121136 windows running_bpb=1.152093 - sliding_eval [ 84.6%] 102432/121136 windows running_bpb=1.151727 - sliding_eval [ 85.9%] 104032/121136 windows running_bpb=1.151235 - sliding_eval [ 87.2%] 105632/121136 windows running_bpb=1.150486 - sliding_eval [ 88.5%] 107232/121136 windows running_bpb=1.150457 - sliding_eval [ 89.8%] 108832/121136 windows running_bpb=1.151088 - sliding_eval [ 91.2%] 110432/121136 windows running_bpb=1.151106 - sliding_eval [ 92.5%] 112032/121136 windows running_bpb=1.151078 - sliding_eval [ 93.8%] 113632/121136 windows running_bpb=1.151528 - sliding_eval [ 95.1%] 115232/121136 windows running_bpb=1.151293 - sliding_eval [ 96.4%] 116832/121136 windows running_bpb=1.150913 - sliding_eval [ 97.8%] 118432/121136 windows running_bpb=1.151248 - sliding_eval [ 99.1%] 120032/121136 windows running_bpb=1.151319 -final_int8_zlib_roundtrip val_loss:1.9331 val_bpb:1.1449 eval_time:155292ms -final_int8_zlib_roundtrip_exact val_loss:1.93314046 val_bpb:1.14491770 diff --git a/records/track_10min_16mb/2026-03-21_11L_XSA4_EMA_PartialRoPE_LateQAT_1.1248/README.md b/records/track_10min_16mb/2026-03-21_11L_XSA4_EMA_PartialRoPE_LateQAT_1.1248/README.md deleted file mode 100644 index a86bd2eaa4..0000000000 --- a/records/track_10min_16mb/2026-03-21_11L_XSA4_EMA_PartialRoPE_LateQAT_1.1248/README.md +++ /dev/null @@ -1,86 +0,0 @@ -## Record: 11L Partial RoPE + LN Scale + EMA + XSA4 (val_bpb: 1.1248) - -**val_bpb = 1.1248** (sliding window, stride=64) | **15.6 MB** artifact | 8xH100 SXM, 600s - -Previous: [PR #70](https://github.com/openai/parameter-golf/pull/70) (9L, 1.1659) → [PR #164](https://github.com/openai/parameter-golf/pull/164) (9L, 1.1524) → [PR #198](https://github.com/openai/parameter-golf/pull/198) (11L, 1.1318) → [PR #287](https://github.com/openai/parameter-golf/pull/287) (11L, 1.1271) → this - -### Changes from PR #287 - -| | [PR #287](https://github.com/openai/parameter-golf/pull/287) | This | -|---|---|---| -| val_bpb (sliding s64) | 1.1271 | **1.1248** | -| Partial RoPE | None (full 64d) | 16 of 64 dims | -| LN Scale | None | 1/sqrt(layer_idx+1) | -| Artifact | 15.5 MB | 15.6 MB | -| Everything else | Same | Same | - -### What's new - -1. **Partial RoPE (16 of 64 dims)**. Rotary position embeddings applied to only the first 16 of 64 head dimensions (25%). The remaining 48 dims attend without positional bias, allowing the model to learn position-invariant patterns. Zero new parameters. - -2. **LN Scale**. RMSNorm outputs are scaled by 1/sqrt(layer_idx+1), damping deeper layers' contributions. Stabilizes training and improves convergence in deep models. Zero new parameters. - -### Carried from PR #287 - -- 11 transformer layers with U-Net skip connections -- Exclusive Self Attention (XSA) on last 4 layers -- EMA weight averaging (decay=0.997, every step) -- Orthogonal + muP-scaled init on all large matrices -- 3x MLP (hidden=1536), relu² activation -- Int6 mixed quantization + zstd-22 (int6 on MLP+attention, int8 on embeddings) -- Weight decay 0.04 (Muon + AdamW) -- SmearGate (learned token blending gate, ~512 params) -- Bigram Hash Embedding (2048-bucket, dim=128, projected to 512) -- FlashAttention 3 (direct flash_attn_func calls) -- Sequence length 2048 with NTK-aware RoPE -- Muon optimizer, momentum 0.99 with warmup, warmdown 3000 iters, grad clip 0.3 - -### Configuration - -```bash -NUM_LAYERS=11 BIGRAM_VOCAB_SIZE=2048 XSA_LAST_N=4 \ -EMA_ENABLED=1 EMA_DECAY=0.997 SWA_ENABLED=0 \ -ROPE_DIMS=16 LN_SCALE=1 LATE_QAT=1 QAT_THRESHOLD=0.1 \ -MUON_WD=0.04 ADAM_WD=0.04 \ -MATRIX_LR=0.025 SCALAR_LR=0.025 TIED_EMBED_LR=0.035 \ -MUON_MOMENTUM=0.99 MUON_MOMENTUM_WARMUP_START=0.92 \ -MUON_MOMENTUM_WARMUP_STEPS=1500 WARMDOWN_ITERS=3000 \ -ITERATIONS=9000 MAX_WALLCLOCK_SECONDS=600 EVAL_STRIDE=64 \ -torchrun --standalone --nproc_per_node=8 train_gpt.py -``` - -### Key Metrics - -- 7,051 steps in 600s (85ms/step) -- ~5.5B train tokens (7,051 steps x 786,432 tokens/step) -- Peak memory: ~20,600 MiB per GPU - -| Metric | Value | -|--------|-------| -| Pre-quant val_bpb | 1.1418 | -| Int6 roundtrip val_bpb | 1.1485 | -| **Int6 sliding val_bpb (s64)** | **1.1248** | -| Compressed artifact (int6+zstd) | 15,544,691 bytes | -| Code size | 67,617 bytes | -| **Total submission size** | **15,612,308 bytes** | - -### Reproducibility - -| Seed | Steps | Sliding s64 | Artifact | -|------|-------|-------------|----------| -| **2025** | **7,051** | **1.1248** | **15,612,308** | -| 42 | 7,061 | 1.1250 | 15,528,666 | -| 1337 | 7,063 | 1.1253 | 15,639,340 | - -Mean val_bpb: **1.1250**. Submitted: seed 2025 (best). Inter-seed variance: 0.0005. - -### Included files - -- `train_gpt.py` — full training + quantization + evaluation script -- `train.log` — training log from best seed (2025) -- `train_seed2025.log`, `train_seed42.log`, `train_seed1337.log` — all seed logs -- `submission.json` — leaderboard metadata - -### Note on Late QAT - -The submitted code includes a Late QAT flag (`LATE_QAT=1`) intended to enable STE int6 fake-quantization in the final 4% of training. Post-submission analysis (credit: @152334H) revealed that `torch.compile` constant-folds the `CastedLinear._qat_enabled` class attribute at first trace, so the STE branch is dead-code-eliminated and never activates during training. Late QAT had no effect on the results. The score is driven entirely by Partial RoPE and LN Scale. diff --git a/records/track_10min_16mb/2026-03-21_11L_XSA4_EMA_PartialRoPE_LateQAT_1.1248/submission.json b/records/track_10min_16mb/2026-03-21_11L_XSA4_EMA_PartialRoPE_LateQAT_1.1248/submission.json deleted file mode 100644 index 0bc2d9f376..0000000000 --- a/records/track_10min_16mb/2026-03-21_11L_XSA4_EMA_PartialRoPE_LateQAT_1.1248/submission.json +++ /dev/null @@ -1,16 +0,0 @@ -{ - "author": "Jack Princz", - "github_id": "jfprincz", - "name": "Record: 11L Partial RoPE + LN Scale + EMA + XSA4", - "blurb": "11 layers with Partial RoPE (16 of 64 dims), LN Scale (1/sqrt(l+1)), EMA weight averaging (decay=0.997), Exclusive Self Attention (XSA) on last 4 layers. Int6 per-row on all MLP+attention weights, int8 tok_emb, zstd-22. Weight decay 0.04 (Muon+AdamW). OrthoInit + muP scaling. SmearGate + BigramHash(2048x128). FA3. Sliding window eval stride=64, seq=2048. Note: Late QAT flag is present in the code but inactive due to torch.compile constant-folding.", - "date": "2026-03-21T06:00:00Z", - "val_loss": 1.89924867, - "val_bpb": 1.12484502, - "pre_quant_val_loss": 1.9279, - "pre_quant_val_bpb": 1.1418, - "int6_zstd_val_loss": 1.93912126, - "int6_zstd_val_bpb": 1.14845684, - "bytes_total": 15612308, - "bytes_model_int6_zstd": 15544691, - "bytes_code": 67617 -} diff --git a/records/track_10min_16mb/2026-03-21_11L_XSA4_EMA_PartialRoPE_LateQAT_1.1248/train.log b/records/track_10min_16mb/2026-03-21_11L_XSA4_EMA_PartialRoPE_LateQAT_1.1248/train.log deleted file mode 100644 index 767fbca9d6..0000000000 --- a/records/track_10min_16mb/2026-03-21_11L_XSA4_EMA_PartialRoPE_LateQAT_1.1248/train.log +++ /dev/null @@ -1,1738 +0,0 @@ -""" -train_gpt_submit.py — Submission v2: wider MLP + STE int6 QAT + MTP + seq2048 + NTK RoPE + -fp16 embed + late-K passthrough + sliding window eval. -""" - -from __future__ import annotations - -import copy -import glob -import io -import math -import os -import random -import subprocess -import sys -import time -import uuid -import zlib -from pathlib import Path - -try: - import zstandard - _COMPRESSOR = "zstd" -except ImportError: - _COMPRESSOR = "zlib" - -import numpy as np -import sentencepiece as spm -import torch -import torch.distributed as dist -import torch.nn.functional as F -from torch import Tensor, nn -from torch.nn.parallel import DistributedDataParallel as DDP - -from flash_attn_interface import flash_attn_func as flash_attn_3_func - -# ----------------------------- -# HYPERPARAMETERS -# ----------------------------- -# Default Simple Baseline run: -# - 9 transformer blocks at width 512 -# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion -# - vocab size 1024, sequence length 1024, tied embeddings -# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap - -class Hyperparameters: - # Data paths are shard globs produced by the existing preprocessing pipeline. - data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") - train_files = os.path.join(data_path, "fineweb_train_*.bin") - val_files = os.path.join(data_path, "fineweb_val_*.bin") - tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") - run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) - seed = int(os.environ.get("SEED", 1337)) - - # Validation cadence and batch size. Validation always uses the full fineweb_val split. - val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) - val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) - train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) - - # Training length. - iterations = int(os.environ.get("ITERATIONS", 20000)) - warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) - warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) - train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 786_432)) - train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) - eval_seq_len = int(os.environ.get("EVAL_SEQ_LEN", 2048)) - max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) - qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) - - # Model shape. - vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) - num_layers = int(os.environ.get("NUM_LAYERS", 9)) - num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) - model_dim = int(os.environ.get("MODEL_DIM", 512)) - num_heads = int(os.environ.get("NUM_HEADS", 8)) - mlp_mult = float(os.environ.get("MLP_MULT", 3.0)) - tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) - rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) - logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) - - # Optimizer hyperparameters. - embed_lr = float(os.environ.get("EMBED_LR", 0.6)) - head_lr = float(os.environ.get("HEAD_LR", 0.008)) - tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) - tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) - matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) - scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) - muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) - muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) - muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) - muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) - beta1 = float(os.environ.get("BETA1", 0.9)) - beta2 = float(os.environ.get("BETA2", 0.95)) - adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) - grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) - eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) - mtp_num_heads = int(os.environ.get("MTP_NUM_HEADS", 0)) - mtp_loss_weight = float(os.environ.get("MTP_LOSS_WEIGHT", 0.2)) - muon_beta2 = float(os.environ.get("MUON_BETA2", 0.95)) - swa_enabled = bool(int(os.environ.get("SWA_ENABLED", "1"))) - swa_every = int(os.environ.get("SWA_EVERY", 200)) - muon_wd = float(os.environ.get("MUON_WD", 0.02)) - adam_wd = float(os.environ.get("ADAM_WD", 0.01)) - qat_enabled = bool(int(os.environ.get("QAT_ENABLED", "0"))) - xsa_last_n = int(os.environ.get("XSA_LAST_N", 0)) - ema_enabled = bool(int(os.environ.get("EMA_ENABLED", "0"))) - ema_decay = float(os.environ.get("EMA_DECAY", 0.997)) - rope_dims = int(os.environ.get("ROPE_DIMS", 0)) - ln_scale = bool(int(os.environ.get("LN_SCALE", "0"))) - late_qat = bool(int(os.environ.get("LATE_QAT", "0"))) - bigram_vocab_size = int(os.environ.get("BIGRAM_VOCAB_SIZE", 4096)) - bigram_dim = int(os.environ.get("BIGRAM_DIM", 128)) - -# ----------------------------- -# MUON OPTIMIZER -# ----------------------------- -# -# As borrowed from modded-nanogpt -# Background on Muon: https://kellerjordan.github.io/posts/muon/ - -def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: - a, b, c = (3.4445, -4.7750, 2.0315) - X = G.bfloat16() - X /= X.norm() + eps - transposed = G.size(0) > G.size(1) - if transposed: - X = X.T - for _ in range(steps): - A = X @ X.T - B = b * A + c * A @ A - X = a * X + B @ X - return X.T if transposed else X - - -class Muon(torch.optim.Optimizer): - def __init__(self, params, lr: float, momentum: float, backend_steps: int, - nesterov: bool = True, weight_decay: float = 0.0): - super().__init__( - params, - dict(lr=lr, momentum=momentum, backend_steps=backend_steps, - nesterov=nesterov, weight_decay=weight_decay), - ) - - @torch.no_grad() - def step(self, closure=None): - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - distributed = dist.is_available() and dist.is_initialized() - world_size = dist.get_world_size() if distributed else 1 - rank = dist.get_rank() if distributed else 0 - - for group in self.param_groups: - params = group["params"] - if not params: - continue - lr = group["lr"] - momentum = group["momentum"] - backend_steps = group["backend_steps"] - nesterov = group["nesterov"] - - total_params = sum(int(p.numel()) for p in params) - updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) - - curr = 0 - for i, p in enumerate(params): - if i % world_size == rank and p.grad is not None: - g = p.grad - state = self.state[p] - if "momentum_buffer" not in state: - state["momentum_buffer"] = torch.zeros_like(g) - buf = state["momentum_buffer"] - buf.mul_(momentum).add_(g) - if nesterov: - g = g.add(buf, alpha=momentum) - g = zeropower_via_newtonschulz5(g, steps=backend_steps) - g *= max(1, g.size(0) / g.size(1)) ** 0.5 - updates_flat[curr : curr + p.numel()] = g.reshape(-1) - curr += p.numel() - - if distributed: - dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) - - wd = group.get("weight_decay", 0.0) - curr = 0 - for p in params: - if wd > 0.0: - p.data.mul_(1.0 - lr * wd) - g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) - p.add_(g, alpha=-lr) - curr += p.numel() - - return loss - - -# ----------------------------- -# TOKENIZER-AGNOSTIC EVALUATION SETUP -# ----------------------------- -# -# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. -# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. -# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. -# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. - -def build_sentencepiece_luts( - sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device -) -> tuple[Tensor, Tensor, Tensor]: - sp_vocab_size = int(sp.vocab_size()) - table_size = max(sp_vocab_size, vocab_size) - base_bytes_np = np.zeros((table_size,), dtype=np.int16) - has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) - is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) - for token_id in range(sp_vocab_size): - if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): - continue - is_boundary_token_np[token_id] = False - if sp.is_byte(token_id): - base_bytes_np[token_id] = 1 - continue - piece = sp.id_to_piece(token_id) - if piece.startswith("▁"): - has_leading_space_np[token_id] = True - piece = piece[1:] - base_bytes_np[token_id] = len(piece.encode("utf-8")) - return ( - torch.tensor(base_bytes_np, dtype=torch.int16, device=device), - torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), - torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), - ) - - -def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: - files = [Path(p) for p in sorted(glob.glob(pattern))] - if not files: - raise FileNotFoundError(f"No files found for pattern: {pattern}") - # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. - tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() - usable = ((tokens.numel() - 1) // seq_len) * seq_len - if usable <= 0: - raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") - return tokens[: usable + 1] - - -def eval_val( - args: Hyperparameters, - model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - grad_accum_steps: int, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, - eval_seq_len: int | None = None, -) -> tuple[float, float]: - seq_len = eval_seq_len or args.train_seq_len - local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) - if local_batch_tokens < seq_len: - raise ValueError( - "VAL_BATCH_SIZE must provide at least one sequence per rank; " - f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " - f"GRAD_ACCUM_STEPS={grad_accum_steps}, seq_len={seq_len}" - ) - local_batch_seqs = local_batch_tokens // seq_len - total_seqs = (val_tokens.numel() - 1) // seq_len - seq_start = (total_seqs * rank) // world_size - seq_end = (total_seqs * (rank + 1)) // world_size - val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) - val_token_count = torch.zeros((), device=device, dtype=torch.float64) - val_byte_count = torch.zeros((), device=device, dtype=torch.float64) - - model.eval() - with torch.inference_mode(): - for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): - batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) - raw_start = batch_seq_start * seq_len - raw_end = batch_seq_end * seq_len + 1 - local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) - x = local[:-1].reshape(-1, seq_len) - y = local[1:].reshape(-1, seq_len) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - batch_loss = model(x, y).detach() - batch_token_count = float(y.numel()) - val_loss_sum += batch_loss.to(torch.float64) * batch_token_count - val_token_count += batch_token_count - prev_ids = x.reshape(-1) - tgt_ids = y.reshape(-1) - token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) - token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) - val_byte_count += token_bytes.to(torch.float64).sum() - - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) - - val_loss = val_loss_sum / val_token_count - bits_per_token = val_loss.item() / math.log(2.0) - tokens_per_byte = val_token_count.item() / val_byte_count.item() - model.train() - return float(val_loss.item()), float(bits_per_token * tokens_per_byte) - -# ----------------------------- -# POST-TRAINING QUANTIZATION -# ----------------------------- -# -# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. -# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. -# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. - -CONTROL_TENSOR_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "CONTROL_TENSOR_NAME_PATTERNS", - "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,smear", - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", - ",".join(CONTROL_TENSOR_NAME_PATTERNS), - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 -INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 -INT8_PER_ROW_SCALE_DTYPE = torch.float16 -INT8_CLIP_PERCENTILE = 99.99984 -INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 - -def tensor_nbytes(t: Tensor) -> int: - return int(t.numel()) * int(t.element_size()) - -def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: - if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): - return t.float().contiguous() - if t.dtype in {torch.float32, torch.bfloat16}: - passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") - return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() - return t - -def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: - t32 = t.float() - if t32.ndim == 2: - # Matrices get one scale per row, which usually tracks output-channel - # ranges much better than a single tensor-wide scale. - clip_abs = ( - torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) - if t32.numel() - else torch.empty((t32.shape[0],), dtype=torch.float32) - ) - clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) - scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) - q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() - return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() - - # Vectors / scalars use a simpler per-tensor scale. - clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 - scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) - q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() - return q, scale - -def quantize_state_dict_int8(state_dict: dict[str, Tensor]): - # Single supported clean-script export format: - # - per-row int8 for 2D float tensors - # - per-tensor int8 for other float tensors - # - exact passthrough for non-floats - # - passthrough for small float tensors, stored as fp16 to save bytes - quantized: dict[str, Tensor] = {} - scales: dict[str, Tensor] = {} - dtypes: dict[str, str] = {} - passthrough: dict[str, Tensor] = {} - passthrough_orig_dtypes: dict[str, str] = {} - qmeta: dict[str, dict[str, object]] = {} - stats = dict.fromkeys( - ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), - 0, - ) - - for name, tensor in state_dict.items(): - t = tensor.detach().to("cpu").contiguous() - stats["param_count"] += int(t.numel()) - stats["num_tensors"] += 1 - stats["baseline_tensor_bytes"] += tensor_nbytes(t) - - if not t.is_floating_point(): - stats["num_nonfloat_tensors"] += 1 - passthrough[name] = t - stats["int8_payload_bytes"] += tensor_nbytes(t) - continue - - # Small float tensors are cheap enough to keep directly. We still downcast - # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. - if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: - kept = keep_float_tensor(name, t, passthrough_orig_dtypes) - passthrough[name] = kept - stats["int8_payload_bytes"] += tensor_nbytes(kept) - continue - - stats["num_float_tensors"] += 1 - q, s = quantize_float_tensor(t) - if s.ndim > 0: - qmeta[name] = {"scheme": "per_row", "axis": 0} - quantized[name] = q - scales[name] = s - dtypes[name] = str(t.dtype).removeprefix("torch.") - stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) - - obj: dict[str, object] = { - "__quant_format__": "int8_clean_per_row_v1", - "quantized": quantized, - "scales": scales, - "dtypes": dtypes, - "passthrough": passthrough, - } - if qmeta: - obj["qmeta"] = qmeta - if passthrough_orig_dtypes: - obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes - return obj, stats - -def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: - out: dict[str, Tensor] = {} - qmeta = obj.get("qmeta", {}) - passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) - for name, q in obj["quantized"].items(): - dtype = getattr(torch, obj["dtypes"][name]) - s = obj["scales"][name] - if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: - s = s.to(dtype=torch.float32) - # Broadcast the saved row scale back across trailing dimensions. - out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() - else: - scale = float(s.item()) - out[name] = (q.float() * scale).to(dtype=dtype).contiguous() - for name, t in obj["passthrough"].items(): - # Restore small tensors, undoing the temporary fp16 storage cast if needed. - out_t = t.detach().to("cpu").contiguous() - orig_dtype = passthrough_orig_dtypes.get(name) - if isinstance(orig_dtype, str): - out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() - out[name] = out_t - return out - - -# ----------------------------- -# DATA LOADING -# ----------------------------- - -def load_data_shard(file: Path) -> Tensor: - header_bytes = 256 * np.dtype(" None: - self.file_idx = (self.file_idx + 1) % len(self.files) - self.tokens = load_data_shard(self.files[self.file_idx]) - self.pos = 0 - - def take(self, n: int) -> Tensor: - chunks: list[Tensor] = [] - remaining = n - while remaining > 0: - avail = self.tokens.numel() - self.pos - if avail <= 0: - self._advance_file() - continue - k = min(remaining, avail) - chunks.append(self.tokens[self.pos : self.pos + k]) - self.pos += k - remaining -= k - return chunks[0] if len(chunks) == 1 else torch.cat(chunks) - - -class DistributedTokenLoader: - # Each call consumes a contiguous chunk from the shared token stream, then slices out - # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. - def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): - self.rank = rank - self.world_size = world_size - self.device = device - self.stream = TokenStream(pattern) - - def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: - local_tokens = global_tokens // (self.world_size * grad_accum_steps) - per_rank_span = local_tokens + 1 - chunk = self.stream.take(per_rank_span * self.world_size) - start = self.rank * per_rank_span - local = chunk[start : start + per_rank_span].to(dtype=torch.int64) - x = local[:-1].reshape(-1, seq_len) - y = local[1:].reshape(-1, seq_len) - return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) - -# ----------------------------- -# TRANSFORMER MODULES -# ----------------------------- - -class RMSNorm(nn.Module): - def __init__(self, eps: float | None = None): - super().__init__() - self.eps = eps - - def forward(self, x: Tensor) -> Tensor: - return F.rms_norm(x, (x.size(-1),), eps=self.eps) - - -class CastedLinear(nn.Linear): - _qat_enabled: bool = False - - def forward(self, x: Tensor) -> Tensor: - w = self.weight.to(x.dtype) - if CastedLinear._qat_enabled and self.training and w.ndim == 2: - with torch.no_grad(): - w32 = self.weight.float() - row_max = w32.abs().amax(dim=1) - scale = (row_max / 31.0).clamp_min(1.0 / 31.0) - w_q = (torch.clamp(torch.round(w32 / scale[:, None]), -32, 31) * scale[:, None]).to(x.dtype) - w = w + (w_q - w).detach() - bias = self.bias.to(x.dtype) if self.bias is not None else None - return F.linear(x, w, bias) - - -def restore_low_dim_params_to_fp32(module: nn.Module) -> None: - # Keep small/control parameters in fp32 even when the model body runs in bf16. - with torch.no_grad(): - for name, param in module.named_parameters(): - if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: - param.data = param.data.float() - - -class Rotary(nn.Module): - # NTK-aware RoPE: auto-scales base frequency when seq_len exceeds train_seq_len. - def __init__(self, dim: int, base: float = 10000.0, train_seq_len: int = 1024, rope_dims: int = 0): - super().__init__() - self.rope_dims = rope_dims if rope_dims > 0 else dim - self.dim = dim - self.base = base - self.train_seq_len = train_seq_len - rd = self.rope_dims - inv_freq = 1.0 / (base ** (torch.arange(0, rd, 2, dtype=torch.float32) / rd)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self._seq_len_cached = 0 - self._cos_cached: Tensor | None = None - self._sin_cached: Tensor | None = None - - def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: - if ( - self._cos_cached is None - or self._sin_cached is None - or self._seq_len_cached != seq_len - or self._cos_cached.device != device - ): - rd = self.rope_dims - if seq_len > self.train_seq_len: - scale = seq_len / self.train_seq_len - new_base = self.base * (scale ** (rd / (rd - 2))) - inv_freq = 1.0 / (new_base ** (torch.arange(0, rd, 2, dtype=torch.float32, device=device) / rd)) - else: - inv_freq = self.inv_freq.to(device) - t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) - freqs = torch.outer(t, inv_freq) - self._cos_cached = freqs.cos()[None, :, None, :] - self._sin_cached = freqs.sin()[None, :, None, :] - self._seq_len_cached = seq_len - return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) - - -def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: - rd = cos.size(-1) * 2 - if rd < x.size(-1): - x_rope, x_pass = x[..., :rd], x[..., rd:] - half = rd // 2 - x1, x2 = x_rope[..., :half], x_rope[..., half:] - x_rot = torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) - return torch.cat((x_rot, x_pass), dim=-1) - half = x.size(-1) // 2 - x1, x2 = x[..., :half], x[..., half:] - return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) - - -class CausalSelfAttention(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - rope_base: float, - qk_gain_init: float, - rope_dims: int = 0, - ): - super().__init__() - if dim % num_heads != 0: - raise ValueError("model_dim must be divisible by num_heads") - if num_heads % num_kv_heads != 0: - raise ValueError("num_heads must be divisible by num_kv_heads") - self.num_heads = num_heads - self.num_kv_heads = num_kv_heads - self.head_dim = dim // num_heads - if self.head_dim % 2 != 0: - raise ValueError("head_dim must be even for RoPE") - kv_dim = self.num_kv_heads * self.head_dim - self.c_q = CastedLinear(dim, dim, bias=False) - self.c_k = CastedLinear(dim, kv_dim, bias=False) - self.c_v = CastedLinear(dim, kv_dim, bias=False) - self.proj = CastedLinear(dim, dim, bias=False) - self.proj._zero_init = True - self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) - self.rope_dims = rope_dims - self.rotary = Rotary(self.head_dim, base=rope_base, train_seq_len=1024, rope_dims=rope_dims) - self.use_xsa = False - - def _xsa_efficient(self, y: Tensor, v: Tensor) -> Tensor: - """Subtract self-value projection via GQA-aware reshape (no repeat_interleave).""" - B, T, H, D = y.shape - Hkv = v.size(-2) - group = H // Hkv - y_g = y.reshape(B, T, Hkv, group, D) - vn = F.normalize(v, dim=-1).unsqueeze(-2) - proj = (y_g * vn).sum(dim=-1, keepdim=True) * vn - return (y_g - proj).reshape(B, T, H, D) - - def forward(self, x: Tensor) -> Tensor: - bsz, seqlen, dim = x.shape - q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim) - k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) - v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) - q = F.rms_norm(q, (q.size(-1),)) - k = F.rms_norm(k, (k.size(-1),)) - cos, sin = self.rotary(seqlen, x.device, q.dtype) - q = apply_rotary_emb(q, cos, sin) - k = apply_rotary_emb(k, cos, sin) - q = q * self.q_gain.to(dtype=q.dtype)[None, None, :, None] - y = flash_attn_3_func(q, k, v, causal=True) - if self.use_xsa: - y = self._xsa_efficient(y, v) - y = y.reshape(bsz, seqlen, dim) - return self.proj(y) - - -class SmearGate(nn.Module): - def __init__(self, dim: int): - super().__init__() - self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) - - def forward(self, x: Tensor) -> Tensor: - g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] - x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) - return (1 - g) * x + g * x_prev - - -class BigramHashEmbedding(nn.Module): - def __init__(self, bigram_vocab_size: int, bigram_dim: int, model_dim: int): - super().__init__() - self.bigram_vocab_size = bigram_vocab_size - self.embed = nn.Embedding(bigram_vocab_size, bigram_dim) - nn.init.zeros_(self.embed.weight) - self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None - if self.proj is not None: - nn.init.zeros_(self.proj.weight) - self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) - - def bigram_hash(self, tokens: Tensor) -> Tensor: - t = tokens.to(torch.int32) - mod = self.bigram_vocab_size - 1 - out = torch.empty_like(t) - out[..., 0] = mod - out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod - return out.long() - - def forward(self, token_ids: Tensor) -> Tensor: - h = self.embed(self.bigram_hash(token_ids)) - if self.proj is not None: - h = self.proj(h) - return h * self.scale.to(dtype=h.dtype) - - -class MLP(nn.Module): - def __init__(self, dim: int, mlp_mult: int): - super().__init__() - hidden = int(mlp_mult * dim) - self.fc = CastedLinear(dim, hidden, bias=False) - self.proj = CastedLinear(hidden, dim, bias=False) - self.proj._zero_init = True - - def forward(self, x: Tensor) -> Tensor: - x = torch.relu(self.fc(x)) - return self.proj(x.square()) - - -class Block(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: int, - rope_base: float, - qk_gain_init: float, - rope_dims: int = 0, - layer_idx: int = 0, - ln_scale: bool = False, - ): - super().__init__() - self.attn_norm = RMSNorm() - self.mlp_norm = RMSNorm() - self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init, rope_dims=rope_dims) - self.mlp = MLP(dim, mlp_mult) - self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) - self.ln_scale_factor = 1.0 / math.sqrt(layer_idx + 1) if ln_scale else 1.0 - - def forward(self, x: Tensor, x0: Tensor) -> Tensor: - mix = self.resid_mix.to(dtype=x.dtype) - x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 - s = self.ln_scale_factor - attn_out = self.attn(self.attn_norm(x) * s) - x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out - x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x) * s) - return x - - -class GPT(nn.Module): - def __init__( - self, - vocab_size: int, - num_layers: int, - model_dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: int, - tie_embeddings: bool, - tied_embed_init_std: float, - logit_softcap: float, - rope_base: float, - qk_gain_init: float, - mtp_num_heads: int = 0, - mtp_loss_weight: float = 0.1, - bigram_vocab_size: int = 0, - bigram_dim: int = 128, - xsa_last_n: int = 0, - rope_dims: int = 0, - ln_scale: bool = False, - ): - super().__init__() - if logit_softcap <= 0.0: - raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") - self.tie_embeddings = tie_embeddings - self.tied_embed_init_std = tied_embed_init_std - self.logit_softcap = logit_softcap - self.mtp_num_heads = mtp_num_heads - self.mtp_loss_weight = mtp_loss_weight - self.tok_emb = nn.Embedding(vocab_size, model_dim) - self.bigram = BigramHashEmbedding(bigram_vocab_size, bigram_dim, model_dim) if bigram_vocab_size > 0 else None - self.smear = SmearGate(model_dim) - self.num_encoder_layers = num_layers // 2 - self.num_decoder_layers = num_layers - self.num_encoder_layers - self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) - self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) - self.blocks = nn.ModuleList( - [ - Block( - model_dim, - num_heads, - num_kv_heads, - mlp_mult, - rope_base, - qk_gain_init, - rope_dims=rope_dims, - layer_idx=i, - ln_scale=ln_scale, - ) - for i in range(num_layers) - ] - ) - self.final_norm = RMSNorm() - self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) - if self.lm_head is not None: - self.lm_head._zero_init = True - self.mtp_heads = nn.ModuleList( - [CastedLinear(model_dim, vocab_size, bias=False) for _ in range(mtp_num_heads)] - ) - for head in self.mtp_heads: - head._zero_init = True - if xsa_last_n > 0: - for i in range(max(0, num_layers - xsa_last_n), num_layers): - self.blocks[i].attn.use_xsa = True - self._init_weights() - - def _init_weights(self) -> None: - if self.tie_embeddings: - nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) - num_layers = len(self.blocks) - for name, module in self.named_modules(): - if isinstance(module, nn.Linear): - if getattr(module, "_zero_init", False): - nn.init.zeros_(module.weight) - elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: - nn.init.orthogonal_(module.weight, gain=1.0) - if ".proj." in name or name.endswith(".proj"): - with torch.no_grad(): - module.weight.mul_(1.0 / math.sqrt(2 * num_layers)) - - def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: - x = self.tok_emb(input_ids) - if self.bigram is not None: - x = x + self.bigram(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - x = self.smear(x) - x0 = x - skips: list[Tensor] = [] - - for i in range(self.num_encoder_layers): - x = self.blocks[i](x, x0) - skips.append(x) - for i in range(self.num_decoder_layers): - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - x = self.blocks[self.num_encoder_layers + i](x, x0) - - x = self.final_norm(x) - x_flat = x.reshape(-1, x.size(-1)) - targets = target_ids.reshape(-1) - if self.tie_embeddings: - logits_proj = F.linear(x_flat, self.tok_emb.weight) - else: - if self.lm_head is None: - raise RuntimeError("lm_head is required when tie_embeddings=False") - logits_proj = self.lm_head(x_flat) - logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - main_loss = F.cross_entropy(logits.float(), targets, reduction="mean") - - if self.training and self.mtp_num_heads > 0 and self.mtp_loss_weight > 0.0: - _, seqlen, dim = x.shape - mtp_loss_sum = x.new_zeros(()) - mtp_loss_count = 0 - for k, mtp_head in enumerate(self.mtp_heads): - valid_t = seqlen - (k + 1) - if valid_t <= 0: - continue - mtp_hidden = x[:, :valid_t, :].reshape(-1, dim) - mtp_targets = target_ids[:, k + 1 :].reshape(-1) - mtp_logits_proj = mtp_head(mtp_hidden) - mtp_logits = self.logit_softcap * torch.tanh(mtp_logits_proj / self.logit_softcap) - mtp_loss_sum = mtp_loss_sum + F.cross_entropy(mtp_logits.float(), mtp_targets, reduction="mean") - mtp_loss_count += 1 - if mtp_loss_count > 0: - main_loss = main_loss + self.mtp_loss_weight * (mtp_loss_sum / mtp_loss_count) - - return main_loss - - def forward_logits(self, input_ids: Tensor) -> Tensor: - """Return logits (bsz, seq_len, vocab) without computing loss.""" - x = self.tok_emb(input_ids) - if self.bigram is not None: - x = x + self.bigram(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - x = self.smear(x) - x0 = x - skips: list[Tensor] = [] - for i in range(self.num_encoder_layers): - x = self.blocks[i](x, x0) - skips.append(x) - for i in range(self.num_decoder_layers): - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - x = self.blocks[self.num_encoder_layers + i](x, x0) - x = self.final_norm(x) - if self.tie_embeddings: - logits_proj = F.linear(x, self.tok_emb.weight) - else: - logits_proj = self.lm_head(x) - return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - - -# ----------------------------- -# SLIDING WINDOW EVALUATION -# ----------------------------- - -def eval_val_sliding( - args: Hyperparameters, - base_model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, - stride: int, - batch_seqs: int = 32, - eval_seq_len: int | None = None, -) -> tuple[float, float]: - """Sliding window evaluation: each token scored with maximum context.""" - seq_len = eval_seq_len or args.train_seq_len - total_tokens = val_tokens.numel() - 1 - - window_starts = [ws for ws in range(0, total_tokens, stride) - if min(ws + seq_len, total_tokens) - ws >= 1] - total_windows = len(window_starts) - - my_s = (total_windows * rank) // world_size - my_e = (total_windows * (rank + 1)) // world_size - my_windows = window_starts[my_s:my_e] - - loss_sum = torch.zeros((), device=device, dtype=torch.float64) - token_count = torch.zeros((), device=device, dtype=torch.float64) - byte_count = torch.zeros((), device=device, dtype=torch.float64) - - base_model.eval() - compiled_logits = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) - - with torch.inference_mode(): - for bi in range(0, len(my_windows), batch_seqs): - batch_ws = my_windows[bi:bi + batch_seqs] - bsz = len(batch_ws) - - x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) - y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) - wlens: list[int] = [] - - for i, ws in enumerate(batch_ws): - end = min(ws + seq_len, total_tokens) - wlen = end - ws - wlens.append(wlen) - chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) - x_batch[i, :wlen] = chunk[:-1] - y_batch[i, :wlen] = chunk[1:] - - with torch.autocast(device_type="cuda", dtype=torch.bfloat16): - logits = compiled_logits(x_batch) - - nll = F.cross_entropy( - logits.reshape(-1, logits.size(-1)).float(), - y_batch.reshape(-1), - reduction="none", - ).reshape(bsz, seq_len) - - for i, ws in enumerate(batch_ws): - wlen = wlens[i] - s = 0 if ws == 0 else max(wlen - stride, 0) - scored_nll = nll[i, s:wlen].to(torch.float64) - loss_sum += scored_nll.sum() - token_count += float(wlen - s) - tgt = y_batch[i, s:wlen] - prev = x_batch[i, s:wlen] - tb = base_bytes_lut[tgt].to(torch.float64) - tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) - byte_count += tb.sum() - - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) - - val_loss = (loss_sum / token_count).item() - bits_per_token = val_loss / math.log(2.0) - tokens_per_byte = token_count.item() / byte_count.item() - base_model.train() - return val_loss, bits_per_token * tokens_per_byte - - -# ----------------------------- -# INT6 MIXED QUANTIZATION (transplanted from working diagnostic scripts) -# ----------------------------- - -def _classify_param(name: str) -> str: - if "tok_emb" in name or "lm_head" in name: - return "embed" - if ".mlp." in name: - return "mlp" - if ".attn." in name or (".proj." in name and ".mlp." not in name): - return "attn" - return "other" - -def quantize_int6_per_row(t: Tensor) -> tuple[Tensor, Tensor]: - t32 = t.float() - if t32.ndim == 2: - row_max = t32.abs().amax(dim=1) - scale = (row_max / 31.0).clamp_min(1.0 / 31.0).to(torch.float16) - q = torch.clamp(torch.round(t32 / scale.float()[:, None]), -32, 31).to(torch.int8) - return q, scale - amax = t32.abs().max().item() - scale = torch.tensor(amax / 31.0 if amax > 0 else 1.0, dtype=torch.float16) - q = torch.clamp(torch.round(t32 / scale.float()), -32, 31).to(torch.int8) - return q, scale - -def mixed_quantize_int6(state_dict: dict[str, Tensor], int6_cats: set[str]): - num_layers_total = max( - (int(k.split(".")[1]) for k in state_dict if k.startswith("blocks.")), - default=0, - ) + 1 - late_k_layers = set(range(num_layers_total - 2, num_layers_total)) - - result: dict[str, Tensor] = {} - meta: dict[str, object] = {} - for name, tensor in state_dict.items(): - t = tensor.detach().cpu().contiguous() - cat = _classify_param(name) - if not t.is_floating_point() or t.numel() <= 65536: - result[name] = t.to(torch.float16) if t.is_floating_point() else t - meta[name] = "passthrough" - continue - if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): - result[name] = t.float() - meta[name] = "passthrough_ctrl" - continue - # tok_emb.weight falls through to int8 via "embed" category - if cat in int6_cats and t.ndim >= 1: - q, s = quantize_int6_per_row(t) - result[name + ".q"] = q - result[name + ".scale"] = s - meta[name] = {"type": "int6"} - else: - q, s = quantize_float_tensor(t) - result[name + ".q"] = q - result[name + ".scale"] = s - meta[name] = {"type": "int8"} - return result, meta - -def dequantize_mixed_int6(result: dict[str, Tensor], meta: dict[str, object], - template_sd: dict[str, Tensor]) -> dict[str, Tensor]: - out: dict[str, Tensor] = {} - for name, orig in template_sd.items(): - info = meta.get(name) - if info is None: - continue - orig_dtype = orig.dtype - if info in ("passthrough", "passthrough_ctrl", "passthrough_fp16"): - t = result[name] - if t.dtype == torch.float16 and orig_dtype in (torch.float32, torch.bfloat16): - t = t.to(orig_dtype) - out[name] = t - continue - q, s = result[name + ".q"], result[name + ".scale"] - if s.ndim > 0: - out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig_dtype) - else: - out[name] = (q.float() * float(s.item())).to(orig_dtype) - return out - - -# ----------------------------- -# TRAINING -# ----------------------------- - -def main() -> None: - global zeropower_via_newtonschulz5 - - code = Path(__file__).read_text(encoding="utf-8") - args = Hyperparameters() - zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) - - # ----------------------------- - # DISTRIBUTED + CUDA SETUP - # ----------------------------- - - distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ - rank = int(os.environ.get("RANK", "0")) - world_size = int(os.environ.get("WORLD_SIZE", "1")) - local_rank = int(os.environ.get("LOCAL_RANK", "0")) - if world_size <= 0: - raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") - if 8 % world_size != 0: - raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") - grad_accum_steps = 8 // world_size - grad_scale = 1.0 / grad_accum_steps - if not torch.cuda.is_available(): - raise RuntimeError("CUDA is required") - device = torch.device("cuda", local_rank) - torch.cuda.set_device(device) - if distributed: - dist.init_process_group(backend="nccl", device_id=device) - dist.barrier() - master_process = rank == 0 - - # Fast math knobs - torch.backends.cuda.matmul.allow_tf32 = True - torch.backends.cudnn.allow_tf32 = True - from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp - - enable_cudnn_sdp(False) - enable_flash_sdp(True) - enable_mem_efficient_sdp(False) - enable_math_sdp(False) - - logfile = None - if master_process: - os.makedirs("logs", exist_ok=True) - logfile = f"logs/{args.run_id}.txt" - print(logfile) - - def log0(msg: str, console: bool = True) -> None: - if not master_process: - return - if console: - print(msg) - if logfile is not None: - with open(logfile, "a", encoding="utf-8") as f: - print(msg, file=f) - - log0(code, console=False) - log0("=" * 100, console=False) - log0(f"Running Python {sys.version}", console=False) - log0(f"Running PyTorch {torch.__version__}", console=False) - log0( - subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, - console=False, - ) - log0("=" * 100, console=False) - - # ----------------------------- - # TOKENIZER + VALIDATION METRIC SETUP - # ----------------------------- - - random.seed(args.seed) - np.random.seed(args.seed) - torch.manual_seed(args.seed) - torch.cuda.manual_seed_all(args.seed) - - if not args.tokenizer_path.endswith(".model"): - raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") - sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) - if int(sp.vocab_size()) != args.vocab_size: - raise ValueError( - f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" - ) - dataset_dir = Path(args.data_path).resolve() - actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) - effective_eval_seq_len = args.eval_seq_len if args.eval_seq_len > 0 else args.train_seq_len - val_seq_len = max(args.train_seq_len, effective_eval_seq_len) - val_tokens = load_validation_tokens(args.val_files, val_seq_len) - base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( - sp, args.vocab_size, device - ) - log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") - log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") - log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") - - # ----------------------------- - # MODEL + OPTIMIZER SETUP - # ----------------------------- - - CastedLinear._qat_enabled = args.qat_enabled - - base_model = GPT( - vocab_size=args.vocab_size, - num_layers=args.num_layers, - model_dim=args.model_dim, - num_heads=args.num_heads, - num_kv_heads=args.num_kv_heads, - mlp_mult=args.mlp_mult, - tie_embeddings=args.tie_embeddings, - tied_embed_init_std=args.tied_embed_init_std, - logit_softcap=args.logit_softcap, - rope_base=args.rope_base, - qk_gain_init=args.qk_gain_init, - mtp_num_heads=args.mtp_num_heads, - mtp_loss_weight=args.mtp_loss_weight, - bigram_vocab_size=args.bigram_vocab_size, - bigram_dim=args.bigram_dim, - xsa_last_n=args.xsa_last_n, - rope_dims=args.rope_dims, - ln_scale=args.ln_scale, - ).to(device).bfloat16() - for module in base_model.modules(): - if isinstance(module, CastedLinear): - module.float() - restore_low_dim_params_to_fp32(base_model) - compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) - model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model - - # Optimizer split: - # - token embedding (Adam) uses EMBED_LR - # - untied lm_head (Adam) uses HEAD_LR - # - matrix params in transformer blocks use MATRIX_LR via Muon - # - vectors/scalars use SCALAR_LR via Adam - block_named_params = list(base_model.blocks.named_parameters()) - matrix_params = [ - p - for name, p in block_named_params - if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - if base_model.mtp_num_heads > 0: - matrix_params.extend([p for p in base_model.mtp_heads.parameters() if p.ndim == 2]) - scalar_params = [ - p - for name, p in block_named_params - if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - if base_model.skip_weights.numel() > 0: - scalar_params.append(base_model.skip_weights) - scalar_params.append(base_model.smear.gate) - if base_model.bigram is not None: - scalar_params.append(base_model.bigram.scale) - token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr - tok_params = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] - if base_model.bigram is not None: - tok_params.append({"params": [base_model.bigram.embed.weight], "lr": token_lr, "base_lr": token_lr}) - if base_model.bigram.proj is not None: - matrix_params.append(base_model.bigram.proj.weight) - optimizer_tok = torch.optim.AdamW( - tok_params, - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - weight_decay=args.adam_wd, - fused=True, - ) - optimizer_muon = Muon( - matrix_params, - lr=args.matrix_lr, - momentum=args.muon_momentum, - backend_steps=args.muon_backend_steps, - weight_decay=args.muon_wd, - ) - for group in optimizer_muon.param_groups: - group["base_lr"] = args.matrix_lr - optimizer_scalar = torch.optim.AdamW( - [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - weight_decay=args.adam_wd, - fused=True, - ) - optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] - if base_model.lm_head is not None: - optimizer_head = torch.optim.Adam( - [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizers.insert(1, optimizer_head) - - n_params = sum(p.numel() for p in base_model.parameters()) - mtp_params = sum(p.numel() for p in base_model.mtp_heads.parameters()) - log0(f"model_params:{n_params}") - log0(f"mtp_num_heads:{args.mtp_num_heads} mtp_loss_weight:{args.mtp_loss_weight} mtp_params:{mtp_params}") - log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") - log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") - log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") - log0( - f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " - f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " - f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" - ) - log0( - f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " - f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " - f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" - ) - log0(f"seed:{args.seed}") - - # ----------------------------- - # DATA LOADER & MODEL WARMUP - # ----------------------------- - - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - def zero_grad_all() -> None: - for opt in optimizers: - opt.zero_grad(set_to_none=True) - - max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None - - def lr_mul(step: int, elapsed_ms: float) -> float: - if args.warmdown_iters <= 0: - return 1.0 - if max_wallclock_ms is None: - warmdown_start = max(args.iterations - args.warmdown_iters, 0) - return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 - step_ms = elapsed_ms / max(step, 1) - warmdown_ms = args.warmdown_iters * step_ms - remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) - return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 - - # Warmup primes the compiled forward/backward/optimizer paths, then we restore the - # initial weights/optimizer state so measured training starts from the true init. - if args.warmup_steps > 0: - initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} - initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] - model.train() - for warmup_step in range(args.warmup_steps): - zero_grad_all() - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - warmup_loss = model(x, y) - (warmup_loss * grad_scale).backward() - for opt in optimizers: - opt.step() - zero_grad_all() - if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: - log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") - base_model.load_state_dict(initial_model_state, strict=True) - for opt, state in zip(optimizers, initial_optimizer_states, strict=True): - opt.load_state_dict(state) - zero_grad_all() - if distributed: - model.require_backward_grad_sync = True - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - # ----------------------------- - # MAIN TRAINING LOOP - # ----------------------------- - - swa_state: dict[str, Tensor] | None = None - swa_count = 0 - - ema_state: dict[str, Tensor] | None = None - if args.ema_enabled: - ema_state = {name: t.detach().float().clone() for name, t in base_model.state_dict().items()} - - training_time_ms = 0.0 - stop_after_step: int | None = None - torch.cuda.synchronize() - t0 = time.perf_counter() - - step = 0 - while True: - last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) - - should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) - if should_validate: - torch.cuda.synchronize() - training_time_ms += 1000.0 * (time.perf_counter() - t0) - val_loss, val_bpb = eval_val( - args, - model, - rank, - world_size, - device, - grad_accum_steps, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, - ) - log0( - f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " - f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" - ) - torch.cuda.synchronize() - t0 = time.perf_counter() - - if last_step: - if stop_after_step is not None and step < args.iterations: - log0( - f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " - f"step:{step}/{args.iterations}" - ) - break - - elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - scale = lr_mul(step, elapsed_ms) - qat_threshold = float(os.environ.get("QAT_THRESHOLD", "0.1")) - if args.late_qat and scale < qat_threshold and not CastedLinear._qat_enabled: - CastedLinear._qat_enabled = True - log0(f"late_qat:enabled step:{step} scale:{scale:.4f}") - zero_grad_all() - train_loss = torch.zeros((), device=device) - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - loss = model(x, y) - train_loss += loss.detach() - (loss * grad_scale).backward() - train_loss /= grad_accum_steps - - frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 - muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum - for group in optimizer_muon.param_groups: - group["momentum"] = muon_momentum - - for opt in optimizers: - for group in opt.param_groups: - group["lr"] = group["base_lr"] * scale - - if args.grad_clip_norm > 0: - torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) - for opt in optimizers: - opt.step() - zero_grad_all() - - step += 1 - approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - - if ema_state is not None: - d = args.ema_decay - with torch.no_grad(): - for name, t in base_model.state_dict().items(): - ema_state[name].mul_(d).add_(t.detach().float(), alpha=1.0 - d) - - if args.swa_enabled and not args.ema_enabled and scale < 0.5 and step % args.swa_every == 0: - if swa_state is None: - swa_state = {name: t.detach().float().clone() for name, t in base_model.state_dict().items()} - swa_count = 1 - log0(f"swa:start step:{step}") - else: - for name, t in base_model.state_dict().items(): - swa_state[name].add_(t.detach().float()) - swa_count += 1 - - should_log_train = ( - args.train_log_every > 0 - and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) - ) - if should_log_train: - log0( - f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " - f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" - ) - - # Needed to sync whether we've reached the wallclock cap. - reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms - if distributed and max_wallclock_ms is not None: - reached_cap_tensor = torch.tensor(int(reached_cap), device=device) - dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) - reached_cap = bool(reached_cap_tensor.item()) - if stop_after_step is None and reached_cap: - stop_after_step = step - - log0( - f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " - f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" - ) - - if ema_state is not None: - log0("ema:applying EMA weights") - avg_state = {name: t.to(dtype=base_model.state_dict()[name].dtype) - for name, t in ema_state.items()} - del ema_state - base_model.load_state_dict(avg_state, strict=True) - elif args.swa_enabled and swa_state is not None and swa_count > 1: - log0(f"swa:applying averaged {swa_count} checkpoints") - avg_state = {name: (t / swa_count).to(dtype=base_model.state_dict()[name].dtype) - for name, t in swa_state.items()} - del swa_state - base_model.load_state_dict(avg_state, strict=True) - - # ----------------------------- - # SERIALIZATION + ROUNDTRIP VALIDATION - # ----------------------------- - - full_state_dict = base_model.state_dict() - export_sd = {k: v for k, v in full_state_dict.items() if "mtp_heads" not in k} - excluded_mtp = sum(int(t.numel()) for k, t in full_state_dict.items() if "mtp_heads" in k) - if excluded_mtp > 0: - log0(f"export_excluding_mtp_params:{excluded_mtp}") - - if master_process: - torch.save(export_sd, "final_model.pt") - model_bytes = os.path.getsize("final_model.pt") - code_bytes = len(code.encode("utf-8")) - log0(f"Serialized model: {model_bytes} bytes") - log0(f"Code size: {code_bytes} bytes") - - sd_cpu = {k: v.detach().cpu() for k, v in export_sd.items()} - quant_result, quant_meta = mixed_quantize_int6(sd_cpu, {"mlp", "attn"}) - quant_buf = io.BytesIO() - torch.save({"w": quant_result, "m": quant_meta}, quant_buf) - quant_raw = quant_buf.getvalue() - quant_blob = zstandard.ZstdCompressor(level=22).compress(quant_raw) if _COMPRESSOR == "zstd" else zlib.compress(quant_raw, 9) - if master_process: - with open("final_model.int6.ptz", "wb") as f: - f.write(quant_blob) - quant_file_bytes = len(quant_blob) - code_bytes = len(code.encode("utf-8")) - log0(f"Serialized model int6+{_COMPRESSOR}: {quant_file_bytes} bytes") - log0(f"Total submission size int6+{_COMPRESSOR}: {quant_file_bytes + code_bytes} bytes") - - # Roundtrip: decompress + dequantize into fresh model + eval - if distributed: - dist.barrier() - with open("final_model.int6.ptz", "rb") as f: - quant_blob_disk = f.read() - quant_state = torch.load( - io.BytesIO(zstandard.ZstdDecompressor().decompress(quant_blob_disk) if _COMPRESSOR == "zstd" else zlib.decompress(quant_blob_disk)), - map_location="cpu", - ) - deq_state = dequantize_mixed_int6(quant_state["w"], quant_state["m"], sd_cpu) - - eval_model = GPT( - vocab_size=args.vocab_size, num_layers=args.num_layers, model_dim=args.model_dim, - num_heads=args.num_heads, num_kv_heads=args.num_kv_heads, mlp_mult=args.mlp_mult, - tie_embeddings=args.tie_embeddings, tied_embed_init_std=args.tied_embed_init_std, - logit_softcap=args.logit_softcap, rope_base=args.rope_base, qk_gain_init=args.qk_gain_init, - mtp_num_heads=0, mtp_loss_weight=0.0, - bigram_vocab_size=args.bigram_vocab_size, bigram_dim=args.bigram_dim, - xsa_last_n=args.xsa_last_n, - rope_dims=args.rope_dims, - ln_scale=args.ln_scale, - ).to(device).bfloat16() - for m in eval_model.modules(): - if isinstance(m, CastedLinear): - m.float() - restore_low_dim_params_to_fp32(eval_model) - eval_model.load_state_dict(deq_state, strict=True) - compiled_eval = torch.compile(eval_model, dynamic=False, fullgraph=True) - - # Standard non-overlapping eval (sanity check) - torch.cuda.synchronize() - t_qeval = time.perf_counter() - q_val_loss, q_val_bpb = eval_val( - args, compiled_eval, rank, world_size, device, grad_accum_steps, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - eval_seq_len=effective_eval_seq_len, - ) - torch.cuda.synchronize() - log0( - f"final_int6_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " - f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" - ) - log0(f"final_int6_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") - - # Sliding window eval (submission score) - sw_seq_len = effective_eval_seq_len - if args.eval_stride > 0 and args.eval_stride < sw_seq_len: - torch.cuda.synchronize() - t_slide = time.perf_counter() - sw_val_loss, sw_val_bpb = eval_val_sliding( - args, eval_model, rank, world_size, device, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - stride=args.eval_stride, - eval_seq_len=sw_seq_len, - ) - torch.cuda.synchronize() - log0( - f"final_int6_sliding_window val_loss:{sw_val_loss:.4f} val_bpb:{sw_val_bpb:.4f} " - f"stride:{args.eval_stride} eval_time:{1000.0 * (time.perf_counter() - t_slide):.0f}ms" - ) - log0(f"final_int6_sliding_window_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") - - # Second sliding window eval at stride=64 for submission comparison - if args.eval_stride != 64 and 64 < sw_seq_len: - torch.cuda.synchronize() - t_slide64 = time.perf_counter() - sw64_val_loss, sw64_val_bpb = eval_val_sliding( - args, eval_model, rank, world_size, device, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - stride=64, - eval_seq_len=sw_seq_len, - ) - torch.cuda.synchronize() - log0( - f"final_int6_sliding_window_s64 val_loss:{sw64_val_loss:.4f} val_bpb:{sw64_val_bpb:.4f} " - f"stride:64 eval_time:{1000.0 * (time.perf_counter() - t_slide64):.0f}ms" - ) - log0(f"final_int6_sliding_window_s64_exact val_loss:{sw64_val_loss:.8f} val_bpb:{sw64_val_bpb:.8f}") - - if distributed: - dist.destroy_process_group() - - -if __name__ == "__main__": - main() - -==================================================================================================== -Running Python 3.12.3 (main, Nov 6 2025, 13:44:16) [GCC 13.3.0] -Running PyTorch 2.9.1+cu128 -Sat Mar 21 05:30:01 2026 -+-----------------------------------------------------------------------------------------+ -| NVIDIA-SMI 580.126.09 Driver Version: 580.126.09 CUDA Version: 13.0 | -+-----------------------------------------+------------------------+----------------------+ -| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | -| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | -| | | MIG M. | -|=========================================+========================+======================| -| 0 NVIDIA H100 80GB HBM3 On | 00000000:18:00.0 Off | 0 | -| N/A 37C P0 118W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 1 NVIDIA H100 80GB HBM3 On | 00000000:2A:00.0 Off | 0 | -| N/A 42C P0 131W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 2 NVIDIA H100 80GB HBM3 On | 00000000:3A:00.0 Off | 0 | -| N/A 43C P0 123W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 3 NVIDIA H100 80GB HBM3 On | 00000000:5D:00.0 Off | 0 | -| N/A 37C P0 118W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 4 NVIDIA H100 80GB HBM3 On | 00000000:9A:00.0 Off | 0 | -| N/A 36C P0 118W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 5 NVIDIA H100 80GB HBM3 On | 00000000:AB:00.0 Off | 0 | -| N/A 42C P0 122W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 6 NVIDIA H100 80GB HBM3 On | 00000000:BA:00.0 Off | 0 | -| N/A 39C P0 118W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 7 NVIDIA H100 80GB HBM3 On | 00000000:DB:00.0 Off | 0 | -| N/A 35C P0 119W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ - -+-----------------------------------------------------------------------------------------+ -| Processes: | -| GPU GI CI PID Type Process name GPU Memory | -| ID ID Usage | -|=========================================================================================| -| 0 N/A N/A 184004 C /usr/local/bin/python 1512MiB | -| 1 N/A N/A 184005 C /usr/local/bin/python 1512MiB | -| 2 N/A N/A 184006 C /usr/local/bin/python 1512MiB | -| 3 N/A N/A 184007 C /usr/local/bin/python 1512MiB | -| 4 N/A N/A 184008 C /usr/local/bin/python 1512MiB | -| 5 N/A N/A 184009 C /usr/local/bin/python 1512MiB | -| 6 N/A N/A 184010 C /usr/local/bin/python 1512MiB | -| 7 N/A N/A 184011 C /usr/local/bin/python 1512MiB | -+-----------------------------------------------------------------------------------------+ - -==================================================================================================== -val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model -train_loader:dataset:fineweb10B_sp1024 train_shards:80 -val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 -model_params:26829913 -mtp_num_heads:0 mtp_loss_weight:0.2 mtp_params:0 -world_size:8 grad_accum_steps:1 -sdp_backends:cudnn=False flash=True mem_efficient=False math=False -attention_mode:gqa num_heads:8 num_kv_heads:4 -tie_embeddings:True embed_lr:0.035 head_lr:0.0 matrix_lr:0.025 scalar_lr:0.025 -train_batch_tokens:786432 train_seq_len:2048 iterations:9000 warmup_steps:20 max_wallclock_seconds:600.000 -seed:2025 -warmup_step:1/20 -warmup_step:2/20 -warmup_step:3/20 -warmup_step:4/20 -warmup_step:5/20 -warmup_step:6/20 -warmup_step:7/20 -warmup_step:8/20 -warmup_step:9/20 -warmup_step:10/20 -warmup_step:11/20 -warmup_step:12/20 -warmup_step:13/20 -warmup_step:14/20 -warmup_step:15/20 -warmup_step:16/20 -warmup_step:17/20 -warmup_step:18/20 -warmup_step:19/20 -warmup_step:20/20 -step:1/9000 train_loss:6.9287 train_time:147ms step_avg:146.68ms -step:2/9000 train_loss:8.6093 train_time:238ms step_avg:118.88ms -step:3/9000 train_loss:7.8135 train_time:348ms step_avg:115.93ms -step:4/9000 train_loss:7.2171 train_time:465ms step_avg:116.33ms -step:5/9000 train_loss:7.0157 train_time:583ms step_avg:116.59ms -step:6/9000 train_loss:6.9238 train_time:696ms step_avg:115.95ms -step:7/9000 train_loss:6.8594 train_time:816ms step_avg:116.61ms -step:8/9000 train_loss:6.8059 train_time:936ms step_avg:117.00ms -step:9/9000 train_loss:6.4555 train_time:1050ms step_avg:116.65ms -step:10/9000 train_loss:6.1152 train_time:1165ms step_avg:116.52ms -step:200/9000 train_loss:2.4421 train_time:17238ms step_avg:86.19ms -step:400/9000 train_loss:2.4489 train_time:34347ms step_avg:85.87ms -step:600/9000 train_loss:2.3426 train_time:51245ms step_avg:85.41ms -step:800/9000 train_loss:2.2484 train_time:68253ms step_avg:85.32ms -step:1000/9000 train_loss:2.2816 train_time:85175ms step_avg:85.17ms -step:1200/9000 train_loss:2.3513 train_time:102167ms step_avg:85.14ms -step:1400/9000 train_loss:2.1864 train_time:119175ms step_avg:85.13ms -step:1600/9000 train_loss:2.0718 train_time:136086ms step_avg:85.05ms -step:1800/9000 train_loss:2.1513 train_time:153111ms step_avg:85.06ms -step:2000/9000 train_loss:2.0619 train_time:170065ms step_avg:85.03ms -step:2200/9000 train_loss:2.1277 train_time:187099ms step_avg:85.04ms -step:2400/9000 train_loss:2.0597 train_time:204035ms step_avg:85.01ms -step:2600/9000 train_loss:2.1060 train_time:221047ms step_avg:85.02ms -step:2800/9000 train_loss:2.1497 train_time:238048ms step_avg:85.02ms -step:3000/9000 train_loss:2.1569 train_time:254965ms step_avg:84.99ms -step:3200/9000 train_loss:2.1658 train_time:272023ms step_avg:85.01ms -step:3400/9000 train_loss:2.0182 train_time:288932ms step_avg:84.98ms -step:3600/9000 train_loss:2.0979 train_time:306055ms step_avg:85.02ms -step:3800/9000 train_loss:2.0741 train_time:323059ms step_avg:85.02ms -step:4000/9000 train_loss:1.9803 train_time:340194ms step_avg:85.05ms -step:4200/9000 train_loss:2.1582 train_time:357294ms step_avg:85.07ms -step:4400/9000 train_loss:2.0428 train_time:374435ms step_avg:85.10ms -step:4600/9000 train_loss:1.8465 train_time:391578ms step_avg:85.13ms -step:4800/9000 train_loss:2.4338 train_time:408766ms step_avg:85.16ms -step:5000/9000 train_loss:2.1121 train_time:425810ms step_avg:85.16ms -step:5200/9000 train_loss:2.0455 train_time:442747ms step_avg:85.14ms -step:5400/9000 train_loss:2.0548 train_time:459747ms step_avg:85.14ms -step:5600/9000 train_loss:1.9599 train_time:476731ms step_avg:85.13ms -step:5800/9000 train_loss:2.0013 train_time:493652ms step_avg:85.11ms -step:6000/9000 train_loss:1.9434 train_time:510668ms step_avg:85.11ms -step:6200/9000 train_loss:1.9595 train_time:527590ms step_avg:85.10ms -step:6400/9000 train_loss:2.0020 train_time:544640ms step_avg:85.10ms -step:6600/9000 train_loss:1.8443 train_time:561567ms step_avg:85.09ms -late_qat:enabled step:6751 scale:0.0998 -step:6800/9000 train_loss:2.0259 train_time:578667ms step_avg:85.10ms -step:7000/9000 train_loss:1.7919 train_time:595662ms step_avg:85.09ms -step:7051/9000 val_loss:1.9279 val_bpb:1.1418 train_time:599986ms step_avg:85.09ms -stopping_early: wallclock_cap train_time:599986ms step:7051/9000 -peak memory allocated: 20596 MiB reserved: 20652 MiB -ema:applying EMA weights -Serialized model: 105783807 bytes -Code size: 67617 bytes -Serialized model int6+zstd: 15544691 bytes -Total submission size int6+zstd: 15612308 bytes -final_int6_roundtrip val_loss:1.9391 val_bpb:1.1485 eval_time:5890ms -final_int6_roundtrip_exact val_loss:1.93912126 val_bpb:1.14845684 -final_int6_sliding_window val_loss:1.8992 val_bpb:1.1248 stride:64 eval_time:72659ms -final_int6_sliding_window_exact val_loss:1.89924867 val_bpb:1.12484502 diff --git a/records/track_10min_16mb/2026-03-21_11L_XSA4_EMA_PartialRoPE_LateQAT_1.1248/train_gpt.py b/records/track_10min_16mb/2026-03-21_11L_XSA4_EMA_PartialRoPE_LateQAT_1.1248/train_gpt.py deleted file mode 100644 index 754eab1f1f..0000000000 --- a/records/track_10min_16mb/2026-03-21_11L_XSA4_EMA_PartialRoPE_LateQAT_1.1248/train_gpt.py +++ /dev/null @@ -1,1588 +0,0 @@ -""" -train_gpt_submit.py — Submission v2: wider MLP + STE int6 QAT + MTP + seq2048 + NTK RoPE + -fp16 embed + late-K passthrough + sliding window eval. -""" - -from __future__ import annotations - -import copy -import glob -import io -import math -import os -import random -import subprocess -import sys -import time -import uuid -import zlib -from pathlib import Path - -try: - import zstandard - _COMPRESSOR = "zstd" -except ImportError: - _COMPRESSOR = "zlib" - -import numpy as np -import sentencepiece as spm -import torch -import torch.distributed as dist -import torch.nn.functional as F -from torch import Tensor, nn -from torch.nn.parallel import DistributedDataParallel as DDP - -from flash_attn_interface import flash_attn_func as flash_attn_3_func - -# ----------------------------- -# HYPERPARAMETERS -# ----------------------------- -# Default Simple Baseline run: -# - 9 transformer blocks at width 512 -# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion -# - vocab size 1024, sequence length 1024, tied embeddings -# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap - -class Hyperparameters: - # Data paths are shard globs produced by the existing preprocessing pipeline. - data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") - train_files = os.path.join(data_path, "fineweb_train_*.bin") - val_files = os.path.join(data_path, "fineweb_val_*.bin") - tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") - run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) - seed = int(os.environ.get("SEED", 1337)) - - # Validation cadence and batch size. Validation always uses the full fineweb_val split. - val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) - val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) - train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) - - # Training length. - iterations = int(os.environ.get("ITERATIONS", 20000)) - warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) - warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) - train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 786_432)) - train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) - eval_seq_len = int(os.environ.get("EVAL_SEQ_LEN", 2048)) - max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) - qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) - - # Model shape. - vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) - num_layers = int(os.environ.get("NUM_LAYERS", 9)) - num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) - model_dim = int(os.environ.get("MODEL_DIM", 512)) - num_heads = int(os.environ.get("NUM_HEADS", 8)) - mlp_mult = float(os.environ.get("MLP_MULT", 3.0)) - tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) - rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) - logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) - - # Optimizer hyperparameters. - embed_lr = float(os.environ.get("EMBED_LR", 0.6)) - head_lr = float(os.environ.get("HEAD_LR", 0.008)) - tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) - tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) - matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) - scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) - muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) - muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) - muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) - muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) - beta1 = float(os.environ.get("BETA1", 0.9)) - beta2 = float(os.environ.get("BETA2", 0.95)) - adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) - grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) - eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) - mtp_num_heads = int(os.environ.get("MTP_NUM_HEADS", 0)) - mtp_loss_weight = float(os.environ.get("MTP_LOSS_WEIGHT", 0.2)) - muon_beta2 = float(os.environ.get("MUON_BETA2", 0.95)) - swa_enabled = bool(int(os.environ.get("SWA_ENABLED", "1"))) - swa_every = int(os.environ.get("SWA_EVERY", 200)) - muon_wd = float(os.environ.get("MUON_WD", 0.02)) - adam_wd = float(os.environ.get("ADAM_WD", 0.01)) - qat_enabled = bool(int(os.environ.get("QAT_ENABLED", "0"))) - xsa_last_n = int(os.environ.get("XSA_LAST_N", 0)) - ema_enabled = bool(int(os.environ.get("EMA_ENABLED", "0"))) - ema_decay = float(os.environ.get("EMA_DECAY", 0.997)) - rope_dims = int(os.environ.get("ROPE_DIMS", 0)) - ln_scale = bool(int(os.environ.get("LN_SCALE", "0"))) - late_qat = bool(int(os.environ.get("LATE_QAT", "0"))) - bigram_vocab_size = int(os.environ.get("BIGRAM_VOCAB_SIZE", 4096)) - bigram_dim = int(os.environ.get("BIGRAM_DIM", 128)) - -# ----------------------------- -# MUON OPTIMIZER -# ----------------------------- -# -# As borrowed from modded-nanogpt -# Background on Muon: https://kellerjordan.github.io/posts/muon/ - -def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: - a, b, c = (3.4445, -4.7750, 2.0315) - X = G.bfloat16() - X /= X.norm() + eps - transposed = G.size(0) > G.size(1) - if transposed: - X = X.T - for _ in range(steps): - A = X @ X.T - B = b * A + c * A @ A - X = a * X + B @ X - return X.T if transposed else X - - -class Muon(torch.optim.Optimizer): - def __init__(self, params, lr: float, momentum: float, backend_steps: int, - nesterov: bool = True, weight_decay: float = 0.0): - super().__init__( - params, - dict(lr=lr, momentum=momentum, backend_steps=backend_steps, - nesterov=nesterov, weight_decay=weight_decay), - ) - - @torch.no_grad() - def step(self, closure=None): - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - distributed = dist.is_available() and dist.is_initialized() - world_size = dist.get_world_size() if distributed else 1 - rank = dist.get_rank() if distributed else 0 - - for group in self.param_groups: - params = group["params"] - if not params: - continue - lr = group["lr"] - momentum = group["momentum"] - backend_steps = group["backend_steps"] - nesterov = group["nesterov"] - - total_params = sum(int(p.numel()) for p in params) - updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) - - curr = 0 - for i, p in enumerate(params): - if i % world_size == rank and p.grad is not None: - g = p.grad - state = self.state[p] - if "momentum_buffer" not in state: - state["momentum_buffer"] = torch.zeros_like(g) - buf = state["momentum_buffer"] - buf.mul_(momentum).add_(g) - if nesterov: - g = g.add(buf, alpha=momentum) - g = zeropower_via_newtonschulz5(g, steps=backend_steps) - g *= max(1, g.size(0) / g.size(1)) ** 0.5 - updates_flat[curr : curr + p.numel()] = g.reshape(-1) - curr += p.numel() - - if distributed: - dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) - - wd = group.get("weight_decay", 0.0) - curr = 0 - for p in params: - if wd > 0.0: - p.data.mul_(1.0 - lr * wd) - g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) - p.add_(g, alpha=-lr) - curr += p.numel() - - return loss - - -# ----------------------------- -# TOKENIZER-AGNOSTIC EVALUATION SETUP -# ----------------------------- -# -# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. -# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. -# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. -# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. - -def build_sentencepiece_luts( - sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device -) -> tuple[Tensor, Tensor, Tensor]: - sp_vocab_size = int(sp.vocab_size()) - table_size = max(sp_vocab_size, vocab_size) - base_bytes_np = np.zeros((table_size,), dtype=np.int16) - has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) - is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) - for token_id in range(sp_vocab_size): - if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): - continue - is_boundary_token_np[token_id] = False - if sp.is_byte(token_id): - base_bytes_np[token_id] = 1 - continue - piece = sp.id_to_piece(token_id) - if piece.startswith("▁"): - has_leading_space_np[token_id] = True - piece = piece[1:] - base_bytes_np[token_id] = len(piece.encode("utf-8")) - return ( - torch.tensor(base_bytes_np, dtype=torch.int16, device=device), - torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), - torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), - ) - - -def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: - files = [Path(p) for p in sorted(glob.glob(pattern))] - if not files: - raise FileNotFoundError(f"No files found for pattern: {pattern}") - # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. - tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() - usable = ((tokens.numel() - 1) // seq_len) * seq_len - if usable <= 0: - raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") - return tokens[: usable + 1] - - -def eval_val( - args: Hyperparameters, - model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - grad_accum_steps: int, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, - eval_seq_len: int | None = None, -) -> tuple[float, float]: - seq_len = eval_seq_len or args.train_seq_len - local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) - if local_batch_tokens < seq_len: - raise ValueError( - "VAL_BATCH_SIZE must provide at least one sequence per rank; " - f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " - f"GRAD_ACCUM_STEPS={grad_accum_steps}, seq_len={seq_len}" - ) - local_batch_seqs = local_batch_tokens // seq_len - total_seqs = (val_tokens.numel() - 1) // seq_len - seq_start = (total_seqs * rank) // world_size - seq_end = (total_seqs * (rank + 1)) // world_size - val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) - val_token_count = torch.zeros((), device=device, dtype=torch.float64) - val_byte_count = torch.zeros((), device=device, dtype=torch.float64) - - model.eval() - with torch.inference_mode(): - for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): - batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) - raw_start = batch_seq_start * seq_len - raw_end = batch_seq_end * seq_len + 1 - local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) - x = local[:-1].reshape(-1, seq_len) - y = local[1:].reshape(-1, seq_len) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - batch_loss = model(x, y).detach() - batch_token_count = float(y.numel()) - val_loss_sum += batch_loss.to(torch.float64) * batch_token_count - val_token_count += batch_token_count - prev_ids = x.reshape(-1) - tgt_ids = y.reshape(-1) - token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) - token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) - val_byte_count += token_bytes.to(torch.float64).sum() - - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) - - val_loss = val_loss_sum / val_token_count - bits_per_token = val_loss.item() / math.log(2.0) - tokens_per_byte = val_token_count.item() / val_byte_count.item() - model.train() - return float(val_loss.item()), float(bits_per_token * tokens_per_byte) - -# ----------------------------- -# POST-TRAINING QUANTIZATION -# ----------------------------- -# -# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. -# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. -# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. - -CONTROL_TENSOR_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "CONTROL_TENSOR_NAME_PATTERNS", - "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,smear", - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", - ",".join(CONTROL_TENSOR_NAME_PATTERNS), - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 -INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 -INT8_PER_ROW_SCALE_DTYPE = torch.float16 -INT8_CLIP_PERCENTILE = 99.99984 -INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 - -def tensor_nbytes(t: Tensor) -> int: - return int(t.numel()) * int(t.element_size()) - -def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: - if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): - return t.float().contiguous() - if t.dtype in {torch.float32, torch.bfloat16}: - passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") - return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() - return t - -def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: - t32 = t.float() - if t32.ndim == 2: - # Matrices get one scale per row, which usually tracks output-channel - # ranges much better than a single tensor-wide scale. - clip_abs = ( - torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) - if t32.numel() - else torch.empty((t32.shape[0],), dtype=torch.float32) - ) - clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) - scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) - q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() - return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() - - # Vectors / scalars use a simpler per-tensor scale. - clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 - scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) - q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() - return q, scale - -def quantize_state_dict_int8(state_dict: dict[str, Tensor]): - # Single supported clean-script export format: - # - per-row int8 for 2D float tensors - # - per-tensor int8 for other float tensors - # - exact passthrough for non-floats - # - passthrough for small float tensors, stored as fp16 to save bytes - quantized: dict[str, Tensor] = {} - scales: dict[str, Tensor] = {} - dtypes: dict[str, str] = {} - passthrough: dict[str, Tensor] = {} - passthrough_orig_dtypes: dict[str, str] = {} - qmeta: dict[str, dict[str, object]] = {} - stats = dict.fromkeys( - ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), - 0, - ) - - for name, tensor in state_dict.items(): - t = tensor.detach().to("cpu").contiguous() - stats["param_count"] += int(t.numel()) - stats["num_tensors"] += 1 - stats["baseline_tensor_bytes"] += tensor_nbytes(t) - - if not t.is_floating_point(): - stats["num_nonfloat_tensors"] += 1 - passthrough[name] = t - stats["int8_payload_bytes"] += tensor_nbytes(t) - continue - - # Small float tensors are cheap enough to keep directly. We still downcast - # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. - if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: - kept = keep_float_tensor(name, t, passthrough_orig_dtypes) - passthrough[name] = kept - stats["int8_payload_bytes"] += tensor_nbytes(kept) - continue - - stats["num_float_tensors"] += 1 - q, s = quantize_float_tensor(t) - if s.ndim > 0: - qmeta[name] = {"scheme": "per_row", "axis": 0} - quantized[name] = q - scales[name] = s - dtypes[name] = str(t.dtype).removeprefix("torch.") - stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) - - obj: dict[str, object] = { - "__quant_format__": "int8_clean_per_row_v1", - "quantized": quantized, - "scales": scales, - "dtypes": dtypes, - "passthrough": passthrough, - } - if qmeta: - obj["qmeta"] = qmeta - if passthrough_orig_dtypes: - obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes - return obj, stats - -def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: - out: dict[str, Tensor] = {} - qmeta = obj.get("qmeta", {}) - passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) - for name, q in obj["quantized"].items(): - dtype = getattr(torch, obj["dtypes"][name]) - s = obj["scales"][name] - if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: - s = s.to(dtype=torch.float32) - # Broadcast the saved row scale back across trailing dimensions. - out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() - else: - scale = float(s.item()) - out[name] = (q.float() * scale).to(dtype=dtype).contiguous() - for name, t in obj["passthrough"].items(): - # Restore small tensors, undoing the temporary fp16 storage cast if needed. - out_t = t.detach().to("cpu").contiguous() - orig_dtype = passthrough_orig_dtypes.get(name) - if isinstance(orig_dtype, str): - out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() - out[name] = out_t - return out - - -# ----------------------------- -# DATA LOADING -# ----------------------------- - -def load_data_shard(file: Path) -> Tensor: - header_bytes = 256 * np.dtype(" None: - self.file_idx = (self.file_idx + 1) % len(self.files) - self.tokens = load_data_shard(self.files[self.file_idx]) - self.pos = 0 - - def take(self, n: int) -> Tensor: - chunks: list[Tensor] = [] - remaining = n - while remaining > 0: - avail = self.tokens.numel() - self.pos - if avail <= 0: - self._advance_file() - continue - k = min(remaining, avail) - chunks.append(self.tokens[self.pos : self.pos + k]) - self.pos += k - remaining -= k - return chunks[0] if len(chunks) == 1 else torch.cat(chunks) - - -class DistributedTokenLoader: - # Each call consumes a contiguous chunk from the shared token stream, then slices out - # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. - def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): - self.rank = rank - self.world_size = world_size - self.device = device - self.stream = TokenStream(pattern) - - def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: - local_tokens = global_tokens // (self.world_size * grad_accum_steps) - per_rank_span = local_tokens + 1 - chunk = self.stream.take(per_rank_span * self.world_size) - start = self.rank * per_rank_span - local = chunk[start : start + per_rank_span].to(dtype=torch.int64) - x = local[:-1].reshape(-1, seq_len) - y = local[1:].reshape(-1, seq_len) - return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) - -# ----------------------------- -# TRANSFORMER MODULES -# ----------------------------- - -class RMSNorm(nn.Module): - def __init__(self, eps: float | None = None): - super().__init__() - self.eps = eps - - def forward(self, x: Tensor) -> Tensor: - return F.rms_norm(x, (x.size(-1),), eps=self.eps) - - -class CastedLinear(nn.Linear): - _qat_enabled: bool = False - - def forward(self, x: Tensor) -> Tensor: - w = self.weight.to(x.dtype) - if CastedLinear._qat_enabled and self.training and w.ndim == 2: - with torch.no_grad(): - w32 = self.weight.float() - row_max = w32.abs().amax(dim=1) - scale = (row_max / 31.0).clamp_min(1.0 / 31.0) - w_q = (torch.clamp(torch.round(w32 / scale[:, None]), -32, 31) * scale[:, None]).to(x.dtype) - w = w + (w_q - w).detach() - bias = self.bias.to(x.dtype) if self.bias is not None else None - return F.linear(x, w, bias) - - -def restore_low_dim_params_to_fp32(module: nn.Module) -> None: - # Keep small/control parameters in fp32 even when the model body runs in bf16. - with torch.no_grad(): - for name, param in module.named_parameters(): - if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: - param.data = param.data.float() - - -class Rotary(nn.Module): - # NTK-aware RoPE: auto-scales base frequency when seq_len exceeds train_seq_len. - def __init__(self, dim: int, base: float = 10000.0, train_seq_len: int = 1024, rope_dims: int = 0): - super().__init__() - self.rope_dims = rope_dims if rope_dims > 0 else dim - self.dim = dim - self.base = base - self.train_seq_len = train_seq_len - rd = self.rope_dims - inv_freq = 1.0 / (base ** (torch.arange(0, rd, 2, dtype=torch.float32) / rd)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self._seq_len_cached = 0 - self._cos_cached: Tensor | None = None - self._sin_cached: Tensor | None = None - - def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: - if ( - self._cos_cached is None - or self._sin_cached is None - or self._seq_len_cached != seq_len - or self._cos_cached.device != device - ): - rd = self.rope_dims - if seq_len > self.train_seq_len: - scale = seq_len / self.train_seq_len - new_base = self.base * (scale ** (rd / (rd - 2))) - inv_freq = 1.0 / (new_base ** (torch.arange(0, rd, 2, dtype=torch.float32, device=device) / rd)) - else: - inv_freq = self.inv_freq.to(device) - t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) - freqs = torch.outer(t, inv_freq) - self._cos_cached = freqs.cos()[None, :, None, :] - self._sin_cached = freqs.sin()[None, :, None, :] - self._seq_len_cached = seq_len - return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) - - -def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: - rd = cos.size(-1) * 2 - if rd < x.size(-1): - x_rope, x_pass = x[..., :rd], x[..., rd:] - half = rd // 2 - x1, x2 = x_rope[..., :half], x_rope[..., half:] - x_rot = torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) - return torch.cat((x_rot, x_pass), dim=-1) - half = x.size(-1) // 2 - x1, x2 = x[..., :half], x[..., half:] - return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) - - -class CausalSelfAttention(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - rope_base: float, - qk_gain_init: float, - rope_dims: int = 0, - ): - super().__init__() - if dim % num_heads != 0: - raise ValueError("model_dim must be divisible by num_heads") - if num_heads % num_kv_heads != 0: - raise ValueError("num_heads must be divisible by num_kv_heads") - self.num_heads = num_heads - self.num_kv_heads = num_kv_heads - self.head_dim = dim // num_heads - if self.head_dim % 2 != 0: - raise ValueError("head_dim must be even for RoPE") - kv_dim = self.num_kv_heads * self.head_dim - self.c_q = CastedLinear(dim, dim, bias=False) - self.c_k = CastedLinear(dim, kv_dim, bias=False) - self.c_v = CastedLinear(dim, kv_dim, bias=False) - self.proj = CastedLinear(dim, dim, bias=False) - self.proj._zero_init = True - self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) - self.rope_dims = rope_dims - self.rotary = Rotary(self.head_dim, base=rope_base, train_seq_len=1024, rope_dims=rope_dims) - self.use_xsa = False - - def _xsa_efficient(self, y: Tensor, v: Tensor) -> Tensor: - """Subtract self-value projection via GQA-aware reshape (no repeat_interleave).""" - B, T, H, D = y.shape - Hkv = v.size(-2) - group = H // Hkv - y_g = y.reshape(B, T, Hkv, group, D) - vn = F.normalize(v, dim=-1).unsqueeze(-2) - proj = (y_g * vn).sum(dim=-1, keepdim=True) * vn - return (y_g - proj).reshape(B, T, H, D) - - def forward(self, x: Tensor) -> Tensor: - bsz, seqlen, dim = x.shape - q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim) - k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) - v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) - q = F.rms_norm(q, (q.size(-1),)) - k = F.rms_norm(k, (k.size(-1),)) - cos, sin = self.rotary(seqlen, x.device, q.dtype) - q = apply_rotary_emb(q, cos, sin) - k = apply_rotary_emb(k, cos, sin) - q = q * self.q_gain.to(dtype=q.dtype)[None, None, :, None] - y = flash_attn_3_func(q, k, v, causal=True) - if self.use_xsa: - y = self._xsa_efficient(y, v) - y = y.reshape(bsz, seqlen, dim) - return self.proj(y) - - -class SmearGate(nn.Module): - def __init__(self, dim: int): - super().__init__() - self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) - - def forward(self, x: Tensor) -> Tensor: - g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] - x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) - return (1 - g) * x + g * x_prev - - -class BigramHashEmbedding(nn.Module): - def __init__(self, bigram_vocab_size: int, bigram_dim: int, model_dim: int): - super().__init__() - self.bigram_vocab_size = bigram_vocab_size - self.embed = nn.Embedding(bigram_vocab_size, bigram_dim) - nn.init.zeros_(self.embed.weight) - self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None - if self.proj is not None: - nn.init.zeros_(self.proj.weight) - self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) - - def bigram_hash(self, tokens: Tensor) -> Tensor: - t = tokens.to(torch.int32) - mod = self.bigram_vocab_size - 1 - out = torch.empty_like(t) - out[..., 0] = mod - out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod - return out.long() - - def forward(self, token_ids: Tensor) -> Tensor: - h = self.embed(self.bigram_hash(token_ids)) - if self.proj is not None: - h = self.proj(h) - return h * self.scale.to(dtype=h.dtype) - - -class MLP(nn.Module): - def __init__(self, dim: int, mlp_mult: int): - super().__init__() - hidden = int(mlp_mult * dim) - self.fc = CastedLinear(dim, hidden, bias=False) - self.proj = CastedLinear(hidden, dim, bias=False) - self.proj._zero_init = True - - def forward(self, x: Tensor) -> Tensor: - x = torch.relu(self.fc(x)) - return self.proj(x.square()) - - -class Block(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: int, - rope_base: float, - qk_gain_init: float, - rope_dims: int = 0, - layer_idx: int = 0, - ln_scale: bool = False, - ): - super().__init__() - self.attn_norm = RMSNorm() - self.mlp_norm = RMSNorm() - self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init, rope_dims=rope_dims) - self.mlp = MLP(dim, mlp_mult) - self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) - self.ln_scale_factor = 1.0 / math.sqrt(layer_idx + 1) if ln_scale else 1.0 - - def forward(self, x: Tensor, x0: Tensor) -> Tensor: - mix = self.resid_mix.to(dtype=x.dtype) - x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 - s = self.ln_scale_factor - attn_out = self.attn(self.attn_norm(x) * s) - x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out - x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x) * s) - return x - - -class GPT(nn.Module): - def __init__( - self, - vocab_size: int, - num_layers: int, - model_dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: int, - tie_embeddings: bool, - tied_embed_init_std: float, - logit_softcap: float, - rope_base: float, - qk_gain_init: float, - mtp_num_heads: int = 0, - mtp_loss_weight: float = 0.1, - bigram_vocab_size: int = 0, - bigram_dim: int = 128, - xsa_last_n: int = 0, - rope_dims: int = 0, - ln_scale: bool = False, - ): - super().__init__() - if logit_softcap <= 0.0: - raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") - self.tie_embeddings = tie_embeddings - self.tied_embed_init_std = tied_embed_init_std - self.logit_softcap = logit_softcap - self.mtp_num_heads = mtp_num_heads - self.mtp_loss_weight = mtp_loss_weight - self.tok_emb = nn.Embedding(vocab_size, model_dim) - self.bigram = BigramHashEmbedding(bigram_vocab_size, bigram_dim, model_dim) if bigram_vocab_size > 0 else None - self.smear = SmearGate(model_dim) - self.num_encoder_layers = num_layers // 2 - self.num_decoder_layers = num_layers - self.num_encoder_layers - self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) - self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) - self.blocks = nn.ModuleList( - [ - Block( - model_dim, - num_heads, - num_kv_heads, - mlp_mult, - rope_base, - qk_gain_init, - rope_dims=rope_dims, - layer_idx=i, - ln_scale=ln_scale, - ) - for i in range(num_layers) - ] - ) - self.final_norm = RMSNorm() - self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) - if self.lm_head is not None: - self.lm_head._zero_init = True - self.mtp_heads = nn.ModuleList( - [CastedLinear(model_dim, vocab_size, bias=False) for _ in range(mtp_num_heads)] - ) - for head in self.mtp_heads: - head._zero_init = True - if xsa_last_n > 0: - for i in range(max(0, num_layers - xsa_last_n), num_layers): - self.blocks[i].attn.use_xsa = True - self._init_weights() - - def _init_weights(self) -> None: - if self.tie_embeddings: - nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) - num_layers = len(self.blocks) - for name, module in self.named_modules(): - if isinstance(module, nn.Linear): - if getattr(module, "_zero_init", False): - nn.init.zeros_(module.weight) - elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: - nn.init.orthogonal_(module.weight, gain=1.0) - if ".proj." in name or name.endswith(".proj"): - with torch.no_grad(): - module.weight.mul_(1.0 / math.sqrt(2 * num_layers)) - - def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: - x = self.tok_emb(input_ids) - if self.bigram is not None: - x = x + self.bigram(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - x = self.smear(x) - x0 = x - skips: list[Tensor] = [] - - for i in range(self.num_encoder_layers): - x = self.blocks[i](x, x0) - skips.append(x) - for i in range(self.num_decoder_layers): - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - x = self.blocks[self.num_encoder_layers + i](x, x0) - - x = self.final_norm(x) - x_flat = x.reshape(-1, x.size(-1)) - targets = target_ids.reshape(-1) - if self.tie_embeddings: - logits_proj = F.linear(x_flat, self.tok_emb.weight) - else: - if self.lm_head is None: - raise RuntimeError("lm_head is required when tie_embeddings=False") - logits_proj = self.lm_head(x_flat) - logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - main_loss = F.cross_entropy(logits.float(), targets, reduction="mean") - - if self.training and self.mtp_num_heads > 0 and self.mtp_loss_weight > 0.0: - _, seqlen, dim = x.shape - mtp_loss_sum = x.new_zeros(()) - mtp_loss_count = 0 - for k, mtp_head in enumerate(self.mtp_heads): - valid_t = seqlen - (k + 1) - if valid_t <= 0: - continue - mtp_hidden = x[:, :valid_t, :].reshape(-1, dim) - mtp_targets = target_ids[:, k + 1 :].reshape(-1) - mtp_logits_proj = mtp_head(mtp_hidden) - mtp_logits = self.logit_softcap * torch.tanh(mtp_logits_proj / self.logit_softcap) - mtp_loss_sum = mtp_loss_sum + F.cross_entropy(mtp_logits.float(), mtp_targets, reduction="mean") - mtp_loss_count += 1 - if mtp_loss_count > 0: - main_loss = main_loss + self.mtp_loss_weight * (mtp_loss_sum / mtp_loss_count) - - return main_loss - - def forward_logits(self, input_ids: Tensor) -> Tensor: - """Return logits (bsz, seq_len, vocab) without computing loss.""" - x = self.tok_emb(input_ids) - if self.bigram is not None: - x = x + self.bigram(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - x = self.smear(x) - x0 = x - skips: list[Tensor] = [] - for i in range(self.num_encoder_layers): - x = self.blocks[i](x, x0) - skips.append(x) - for i in range(self.num_decoder_layers): - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - x = self.blocks[self.num_encoder_layers + i](x, x0) - x = self.final_norm(x) - if self.tie_embeddings: - logits_proj = F.linear(x, self.tok_emb.weight) - else: - logits_proj = self.lm_head(x) - return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - - -# ----------------------------- -# SLIDING WINDOW EVALUATION -# ----------------------------- - -def eval_val_sliding( - args: Hyperparameters, - base_model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, - stride: int, - batch_seqs: int = 32, - eval_seq_len: int | None = None, -) -> tuple[float, float]: - """Sliding window evaluation: each token scored with maximum context.""" - seq_len = eval_seq_len or args.train_seq_len - total_tokens = val_tokens.numel() - 1 - - window_starts = [ws for ws in range(0, total_tokens, stride) - if min(ws + seq_len, total_tokens) - ws >= 1] - total_windows = len(window_starts) - - my_s = (total_windows * rank) // world_size - my_e = (total_windows * (rank + 1)) // world_size - my_windows = window_starts[my_s:my_e] - - loss_sum = torch.zeros((), device=device, dtype=torch.float64) - token_count = torch.zeros((), device=device, dtype=torch.float64) - byte_count = torch.zeros((), device=device, dtype=torch.float64) - - base_model.eval() - compiled_logits = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) - - with torch.inference_mode(): - for bi in range(0, len(my_windows), batch_seqs): - batch_ws = my_windows[bi:bi + batch_seqs] - bsz = len(batch_ws) - - x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) - y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) - wlens: list[int] = [] - - for i, ws in enumerate(batch_ws): - end = min(ws + seq_len, total_tokens) - wlen = end - ws - wlens.append(wlen) - chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) - x_batch[i, :wlen] = chunk[:-1] - y_batch[i, :wlen] = chunk[1:] - - with torch.autocast(device_type="cuda", dtype=torch.bfloat16): - logits = compiled_logits(x_batch) - - nll = F.cross_entropy( - logits.reshape(-1, logits.size(-1)).float(), - y_batch.reshape(-1), - reduction="none", - ).reshape(bsz, seq_len) - - for i, ws in enumerate(batch_ws): - wlen = wlens[i] - s = 0 if ws == 0 else max(wlen - stride, 0) - scored_nll = nll[i, s:wlen].to(torch.float64) - loss_sum += scored_nll.sum() - token_count += float(wlen - s) - tgt = y_batch[i, s:wlen] - prev = x_batch[i, s:wlen] - tb = base_bytes_lut[tgt].to(torch.float64) - tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) - byte_count += tb.sum() - - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) - - val_loss = (loss_sum / token_count).item() - bits_per_token = val_loss / math.log(2.0) - tokens_per_byte = token_count.item() / byte_count.item() - base_model.train() - return val_loss, bits_per_token * tokens_per_byte - - -# ----------------------------- -# INT6 MIXED QUANTIZATION (transplanted from working diagnostic scripts) -# ----------------------------- - -def _classify_param(name: str) -> str: - if "tok_emb" in name or "lm_head" in name: - return "embed" - if ".mlp." in name: - return "mlp" - if ".attn." in name or (".proj." in name and ".mlp." not in name): - return "attn" - return "other" - -def quantize_int6_per_row(t: Tensor) -> tuple[Tensor, Tensor]: - t32 = t.float() - if t32.ndim == 2: - row_max = t32.abs().amax(dim=1) - scale = (row_max / 31.0).clamp_min(1.0 / 31.0).to(torch.float16) - q = torch.clamp(torch.round(t32 / scale.float()[:, None]), -32, 31).to(torch.int8) - return q, scale - amax = t32.abs().max().item() - scale = torch.tensor(amax / 31.0 if amax > 0 else 1.0, dtype=torch.float16) - q = torch.clamp(torch.round(t32 / scale.float()), -32, 31).to(torch.int8) - return q, scale - -def mixed_quantize_int6(state_dict: dict[str, Tensor], int6_cats: set[str]): - num_layers_total = max( - (int(k.split(".")[1]) for k in state_dict if k.startswith("blocks.")), - default=0, - ) + 1 - late_k_layers = set(range(num_layers_total - 2, num_layers_total)) - - result: dict[str, Tensor] = {} - meta: dict[str, object] = {} - for name, tensor in state_dict.items(): - t = tensor.detach().cpu().contiguous() - cat = _classify_param(name) - if not t.is_floating_point() or t.numel() <= 65536: - result[name] = t.to(torch.float16) if t.is_floating_point() else t - meta[name] = "passthrough" - continue - if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): - result[name] = t.float() - meta[name] = "passthrough_ctrl" - continue - # tok_emb.weight falls through to int8 via "embed" category - if cat in int6_cats and t.ndim >= 1: - q, s = quantize_int6_per_row(t) - result[name + ".q"] = q - result[name + ".scale"] = s - meta[name] = {"type": "int6"} - else: - q, s = quantize_float_tensor(t) - result[name + ".q"] = q - result[name + ".scale"] = s - meta[name] = {"type": "int8"} - return result, meta - -def dequantize_mixed_int6(result: dict[str, Tensor], meta: dict[str, object], - template_sd: dict[str, Tensor]) -> dict[str, Tensor]: - out: dict[str, Tensor] = {} - for name, orig in template_sd.items(): - info = meta.get(name) - if info is None: - continue - orig_dtype = orig.dtype - if info in ("passthrough", "passthrough_ctrl", "passthrough_fp16"): - t = result[name] - if t.dtype == torch.float16 and orig_dtype in (torch.float32, torch.bfloat16): - t = t.to(orig_dtype) - out[name] = t - continue - q, s = result[name + ".q"], result[name + ".scale"] - if s.ndim > 0: - out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig_dtype) - else: - out[name] = (q.float() * float(s.item())).to(orig_dtype) - return out - - -# ----------------------------- -# TRAINING -# ----------------------------- - -def main() -> None: - global zeropower_via_newtonschulz5 - - code = Path(__file__).read_text(encoding="utf-8") - args = Hyperparameters() - zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) - - # ----------------------------- - # DISTRIBUTED + CUDA SETUP - # ----------------------------- - - distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ - rank = int(os.environ.get("RANK", "0")) - world_size = int(os.environ.get("WORLD_SIZE", "1")) - local_rank = int(os.environ.get("LOCAL_RANK", "0")) - if world_size <= 0: - raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") - if 8 % world_size != 0: - raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") - grad_accum_steps = 8 // world_size - grad_scale = 1.0 / grad_accum_steps - if not torch.cuda.is_available(): - raise RuntimeError("CUDA is required") - device = torch.device("cuda", local_rank) - torch.cuda.set_device(device) - if distributed: - dist.init_process_group(backend="nccl", device_id=device) - dist.barrier() - master_process = rank == 0 - - # Fast math knobs - torch.backends.cuda.matmul.allow_tf32 = True - torch.backends.cudnn.allow_tf32 = True - from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp - - enable_cudnn_sdp(False) - enable_flash_sdp(True) - enable_mem_efficient_sdp(False) - enable_math_sdp(False) - - logfile = None - if master_process: - os.makedirs("logs", exist_ok=True) - logfile = f"logs/{args.run_id}.txt" - print(logfile) - - def log0(msg: str, console: bool = True) -> None: - if not master_process: - return - if console: - print(msg) - if logfile is not None: - with open(logfile, "a", encoding="utf-8") as f: - print(msg, file=f) - - log0(code, console=False) - log0("=" * 100, console=False) - log0(f"Running Python {sys.version}", console=False) - log0(f"Running PyTorch {torch.__version__}", console=False) - log0( - subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, - console=False, - ) - log0("=" * 100, console=False) - - # ----------------------------- - # TOKENIZER + VALIDATION METRIC SETUP - # ----------------------------- - - random.seed(args.seed) - np.random.seed(args.seed) - torch.manual_seed(args.seed) - torch.cuda.manual_seed_all(args.seed) - - if not args.tokenizer_path.endswith(".model"): - raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") - sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) - if int(sp.vocab_size()) != args.vocab_size: - raise ValueError( - f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" - ) - dataset_dir = Path(args.data_path).resolve() - actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) - effective_eval_seq_len = args.eval_seq_len if args.eval_seq_len > 0 else args.train_seq_len - val_seq_len = max(args.train_seq_len, effective_eval_seq_len) - val_tokens = load_validation_tokens(args.val_files, val_seq_len) - base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( - sp, args.vocab_size, device - ) - log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") - log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") - log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") - - # ----------------------------- - # MODEL + OPTIMIZER SETUP - # ----------------------------- - - CastedLinear._qat_enabled = args.qat_enabled - - base_model = GPT( - vocab_size=args.vocab_size, - num_layers=args.num_layers, - model_dim=args.model_dim, - num_heads=args.num_heads, - num_kv_heads=args.num_kv_heads, - mlp_mult=args.mlp_mult, - tie_embeddings=args.tie_embeddings, - tied_embed_init_std=args.tied_embed_init_std, - logit_softcap=args.logit_softcap, - rope_base=args.rope_base, - qk_gain_init=args.qk_gain_init, - mtp_num_heads=args.mtp_num_heads, - mtp_loss_weight=args.mtp_loss_weight, - bigram_vocab_size=args.bigram_vocab_size, - bigram_dim=args.bigram_dim, - xsa_last_n=args.xsa_last_n, - rope_dims=args.rope_dims, - ln_scale=args.ln_scale, - ).to(device).bfloat16() - for module in base_model.modules(): - if isinstance(module, CastedLinear): - module.float() - restore_low_dim_params_to_fp32(base_model) - compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) - model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model - - # Optimizer split: - # - token embedding (Adam) uses EMBED_LR - # - untied lm_head (Adam) uses HEAD_LR - # - matrix params in transformer blocks use MATRIX_LR via Muon - # - vectors/scalars use SCALAR_LR via Adam - block_named_params = list(base_model.blocks.named_parameters()) - matrix_params = [ - p - for name, p in block_named_params - if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - if base_model.mtp_num_heads > 0: - matrix_params.extend([p for p in base_model.mtp_heads.parameters() if p.ndim == 2]) - scalar_params = [ - p - for name, p in block_named_params - if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - if base_model.skip_weights.numel() > 0: - scalar_params.append(base_model.skip_weights) - scalar_params.append(base_model.smear.gate) - if base_model.bigram is not None: - scalar_params.append(base_model.bigram.scale) - token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr - tok_params = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] - if base_model.bigram is not None: - tok_params.append({"params": [base_model.bigram.embed.weight], "lr": token_lr, "base_lr": token_lr}) - if base_model.bigram.proj is not None: - matrix_params.append(base_model.bigram.proj.weight) - optimizer_tok = torch.optim.AdamW( - tok_params, - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - weight_decay=args.adam_wd, - fused=True, - ) - optimizer_muon = Muon( - matrix_params, - lr=args.matrix_lr, - momentum=args.muon_momentum, - backend_steps=args.muon_backend_steps, - weight_decay=args.muon_wd, - ) - for group in optimizer_muon.param_groups: - group["base_lr"] = args.matrix_lr - optimizer_scalar = torch.optim.AdamW( - [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - weight_decay=args.adam_wd, - fused=True, - ) - optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] - if base_model.lm_head is not None: - optimizer_head = torch.optim.Adam( - [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizers.insert(1, optimizer_head) - - n_params = sum(p.numel() for p in base_model.parameters()) - mtp_params = sum(p.numel() for p in base_model.mtp_heads.parameters()) - log0(f"model_params:{n_params}") - log0(f"mtp_num_heads:{args.mtp_num_heads} mtp_loss_weight:{args.mtp_loss_weight} mtp_params:{mtp_params}") - log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") - log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") - log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") - log0( - f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " - f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " - f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" - ) - log0( - f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " - f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " - f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" - ) - log0(f"seed:{args.seed}") - - # ----------------------------- - # DATA LOADER & MODEL WARMUP - # ----------------------------- - - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - def zero_grad_all() -> None: - for opt in optimizers: - opt.zero_grad(set_to_none=True) - - max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None - - def lr_mul(step: int, elapsed_ms: float) -> float: - if args.warmdown_iters <= 0: - return 1.0 - if max_wallclock_ms is None: - warmdown_start = max(args.iterations - args.warmdown_iters, 0) - return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 - step_ms = elapsed_ms / max(step, 1) - warmdown_ms = args.warmdown_iters * step_ms - remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) - return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 - - # Warmup primes the compiled forward/backward/optimizer paths, then we restore the - # initial weights/optimizer state so measured training starts from the true init. - if args.warmup_steps > 0: - initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} - initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] - model.train() - for warmup_step in range(args.warmup_steps): - zero_grad_all() - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - warmup_loss = model(x, y) - (warmup_loss * grad_scale).backward() - for opt in optimizers: - opt.step() - zero_grad_all() - if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: - log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") - base_model.load_state_dict(initial_model_state, strict=True) - for opt, state in zip(optimizers, initial_optimizer_states, strict=True): - opt.load_state_dict(state) - zero_grad_all() - if distributed: - model.require_backward_grad_sync = True - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - # ----------------------------- - # MAIN TRAINING LOOP - # ----------------------------- - - swa_state: dict[str, Tensor] | None = None - swa_count = 0 - - ema_state: dict[str, Tensor] | None = None - if args.ema_enabled: - ema_state = {name: t.detach().float().clone() for name, t in base_model.state_dict().items()} - - training_time_ms = 0.0 - stop_after_step: int | None = None - torch.cuda.synchronize() - t0 = time.perf_counter() - - step = 0 - while True: - last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) - - should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) - if should_validate: - torch.cuda.synchronize() - training_time_ms += 1000.0 * (time.perf_counter() - t0) - val_loss, val_bpb = eval_val( - args, - model, - rank, - world_size, - device, - grad_accum_steps, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, - ) - log0( - f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " - f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" - ) - torch.cuda.synchronize() - t0 = time.perf_counter() - - if last_step: - if stop_after_step is not None and step < args.iterations: - log0( - f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " - f"step:{step}/{args.iterations}" - ) - break - - elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - scale = lr_mul(step, elapsed_ms) - qat_threshold = float(os.environ.get("QAT_THRESHOLD", "0.1")) - if args.late_qat and scale < qat_threshold and not CastedLinear._qat_enabled: - CastedLinear._qat_enabled = True - log0(f"late_qat:enabled step:{step} scale:{scale:.4f}") - zero_grad_all() - train_loss = torch.zeros((), device=device) - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - loss = model(x, y) - train_loss += loss.detach() - (loss * grad_scale).backward() - train_loss /= grad_accum_steps - - frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 - muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum - for group in optimizer_muon.param_groups: - group["momentum"] = muon_momentum - - for opt in optimizers: - for group in opt.param_groups: - group["lr"] = group["base_lr"] * scale - - if args.grad_clip_norm > 0: - torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) - for opt in optimizers: - opt.step() - zero_grad_all() - - step += 1 - approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - - if ema_state is not None: - d = args.ema_decay - with torch.no_grad(): - for name, t in base_model.state_dict().items(): - ema_state[name].mul_(d).add_(t.detach().float(), alpha=1.0 - d) - - if args.swa_enabled and not args.ema_enabled and scale < 0.5 and step % args.swa_every == 0: - if swa_state is None: - swa_state = {name: t.detach().float().clone() for name, t in base_model.state_dict().items()} - swa_count = 1 - log0(f"swa:start step:{step}") - else: - for name, t in base_model.state_dict().items(): - swa_state[name].add_(t.detach().float()) - swa_count += 1 - - should_log_train = ( - args.train_log_every > 0 - and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) - ) - if should_log_train: - log0( - f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " - f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" - ) - - # Needed to sync whether we've reached the wallclock cap. - reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms - if distributed and max_wallclock_ms is not None: - reached_cap_tensor = torch.tensor(int(reached_cap), device=device) - dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) - reached_cap = bool(reached_cap_tensor.item()) - if stop_after_step is None and reached_cap: - stop_after_step = step - - log0( - f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " - f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" - ) - - if ema_state is not None: - log0("ema:applying EMA weights") - avg_state = {name: t.to(dtype=base_model.state_dict()[name].dtype) - for name, t in ema_state.items()} - del ema_state - base_model.load_state_dict(avg_state, strict=True) - elif args.swa_enabled and swa_state is not None and swa_count > 1: - log0(f"swa:applying averaged {swa_count} checkpoints") - avg_state = {name: (t / swa_count).to(dtype=base_model.state_dict()[name].dtype) - for name, t in swa_state.items()} - del swa_state - base_model.load_state_dict(avg_state, strict=True) - - # ----------------------------- - # SERIALIZATION + ROUNDTRIP VALIDATION - # ----------------------------- - - full_state_dict = base_model.state_dict() - export_sd = {k: v for k, v in full_state_dict.items() if "mtp_heads" not in k} - excluded_mtp = sum(int(t.numel()) for k, t in full_state_dict.items() if "mtp_heads" in k) - if excluded_mtp > 0: - log0(f"export_excluding_mtp_params:{excluded_mtp}") - - if master_process: - torch.save(export_sd, "final_model.pt") - model_bytes = os.path.getsize("final_model.pt") - code_bytes = len(code.encode("utf-8")) - log0(f"Serialized model: {model_bytes} bytes") - log0(f"Code size: {code_bytes} bytes") - - sd_cpu = {k: v.detach().cpu() for k, v in export_sd.items()} - quant_result, quant_meta = mixed_quantize_int6(sd_cpu, {"mlp", "attn"}) - quant_buf = io.BytesIO() - torch.save({"w": quant_result, "m": quant_meta}, quant_buf) - quant_raw = quant_buf.getvalue() - quant_blob = zstandard.ZstdCompressor(level=22).compress(quant_raw) if _COMPRESSOR == "zstd" else zlib.compress(quant_raw, 9) - if master_process: - with open("final_model.int6.ptz", "wb") as f: - f.write(quant_blob) - quant_file_bytes = len(quant_blob) - code_bytes = len(code.encode("utf-8")) - log0(f"Serialized model int6+{_COMPRESSOR}: {quant_file_bytes} bytes") - log0(f"Total submission size int6+{_COMPRESSOR}: {quant_file_bytes + code_bytes} bytes") - - # Roundtrip: decompress + dequantize into fresh model + eval - if distributed: - dist.barrier() - with open("final_model.int6.ptz", "rb") as f: - quant_blob_disk = f.read() - quant_state = torch.load( - io.BytesIO(zstandard.ZstdDecompressor().decompress(quant_blob_disk) if _COMPRESSOR == "zstd" else zlib.decompress(quant_blob_disk)), - map_location="cpu", - ) - deq_state = dequantize_mixed_int6(quant_state["w"], quant_state["m"], sd_cpu) - - eval_model = GPT( - vocab_size=args.vocab_size, num_layers=args.num_layers, model_dim=args.model_dim, - num_heads=args.num_heads, num_kv_heads=args.num_kv_heads, mlp_mult=args.mlp_mult, - tie_embeddings=args.tie_embeddings, tied_embed_init_std=args.tied_embed_init_std, - logit_softcap=args.logit_softcap, rope_base=args.rope_base, qk_gain_init=args.qk_gain_init, - mtp_num_heads=0, mtp_loss_weight=0.0, - bigram_vocab_size=args.bigram_vocab_size, bigram_dim=args.bigram_dim, - xsa_last_n=args.xsa_last_n, - rope_dims=args.rope_dims, - ln_scale=args.ln_scale, - ).to(device).bfloat16() - for m in eval_model.modules(): - if isinstance(m, CastedLinear): - m.float() - restore_low_dim_params_to_fp32(eval_model) - eval_model.load_state_dict(deq_state, strict=True) - compiled_eval = torch.compile(eval_model, dynamic=False, fullgraph=True) - - # Standard non-overlapping eval (sanity check) - torch.cuda.synchronize() - t_qeval = time.perf_counter() - q_val_loss, q_val_bpb = eval_val( - args, compiled_eval, rank, world_size, device, grad_accum_steps, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - eval_seq_len=effective_eval_seq_len, - ) - torch.cuda.synchronize() - log0( - f"final_int6_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " - f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" - ) - log0(f"final_int6_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") - - # Sliding window eval (submission score) - sw_seq_len = effective_eval_seq_len - if args.eval_stride > 0 and args.eval_stride < sw_seq_len: - torch.cuda.synchronize() - t_slide = time.perf_counter() - sw_val_loss, sw_val_bpb = eval_val_sliding( - args, eval_model, rank, world_size, device, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - stride=args.eval_stride, - eval_seq_len=sw_seq_len, - ) - torch.cuda.synchronize() - log0( - f"final_int6_sliding_window val_loss:{sw_val_loss:.4f} val_bpb:{sw_val_bpb:.4f} " - f"stride:{args.eval_stride} eval_time:{1000.0 * (time.perf_counter() - t_slide):.0f}ms" - ) - log0(f"final_int6_sliding_window_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") - - # Second sliding window eval at stride=64 for submission comparison - if args.eval_stride != 64 and 64 < sw_seq_len: - torch.cuda.synchronize() - t_slide64 = time.perf_counter() - sw64_val_loss, sw64_val_bpb = eval_val_sliding( - args, eval_model, rank, world_size, device, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - stride=64, - eval_seq_len=sw_seq_len, - ) - torch.cuda.synchronize() - log0( - f"final_int6_sliding_window_s64 val_loss:{sw64_val_loss:.4f} val_bpb:{sw64_val_bpb:.4f} " - f"stride:64 eval_time:{1000.0 * (time.perf_counter() - t_slide64):.0f}ms" - ) - log0(f"final_int6_sliding_window_s64_exact val_loss:{sw64_val_loss:.8f} val_bpb:{sw64_val_bpb:.8f}") - - if distributed: - dist.destroy_process_group() - - -if __name__ == "__main__": - main() diff --git a/records/track_10min_16mb/2026-03-21_11L_XSA4_EMA_PartialRoPE_LateQAT_1.1248/train_seed1337.log b/records/track_10min_16mb/2026-03-21_11L_XSA4_EMA_PartialRoPE_LateQAT_1.1248/train_seed1337.log deleted file mode 100644 index b1cd8c7cd4..0000000000 --- a/records/track_10min_16mb/2026-03-21_11L_XSA4_EMA_PartialRoPE_LateQAT_1.1248/train_seed1337.log +++ /dev/null @@ -1,1738 +0,0 @@ -""" -train_gpt_submit.py — Submission v2: wider MLP + STE int6 QAT + MTP + seq2048 + NTK RoPE + -fp16 embed + late-K passthrough + sliding window eval. -""" - -from __future__ import annotations - -import copy -import glob -import io -import math -import os -import random -import subprocess -import sys -import time -import uuid -import zlib -from pathlib import Path - -try: - import zstandard - _COMPRESSOR = "zstd" -except ImportError: - _COMPRESSOR = "zlib" - -import numpy as np -import sentencepiece as spm -import torch -import torch.distributed as dist -import torch.nn.functional as F -from torch import Tensor, nn -from torch.nn.parallel import DistributedDataParallel as DDP - -from flash_attn_interface import flash_attn_func as flash_attn_3_func - -# ----------------------------- -# HYPERPARAMETERS -# ----------------------------- -# Default Simple Baseline run: -# - 9 transformer blocks at width 512 -# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion -# - vocab size 1024, sequence length 1024, tied embeddings -# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap - -class Hyperparameters: - # Data paths are shard globs produced by the existing preprocessing pipeline. - data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") - train_files = os.path.join(data_path, "fineweb_train_*.bin") - val_files = os.path.join(data_path, "fineweb_val_*.bin") - tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") - run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) - seed = int(os.environ.get("SEED", 1337)) - - # Validation cadence and batch size. Validation always uses the full fineweb_val split. - val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) - val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) - train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) - - # Training length. - iterations = int(os.environ.get("ITERATIONS", 20000)) - warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) - warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) - train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 786_432)) - train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) - eval_seq_len = int(os.environ.get("EVAL_SEQ_LEN", 2048)) - max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) - qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) - - # Model shape. - vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) - num_layers = int(os.environ.get("NUM_LAYERS", 9)) - num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) - model_dim = int(os.environ.get("MODEL_DIM", 512)) - num_heads = int(os.environ.get("NUM_HEADS", 8)) - mlp_mult = float(os.environ.get("MLP_MULT", 3.0)) - tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) - rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) - logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) - - # Optimizer hyperparameters. - embed_lr = float(os.environ.get("EMBED_LR", 0.6)) - head_lr = float(os.environ.get("HEAD_LR", 0.008)) - tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) - tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) - matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) - scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) - muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) - muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) - muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) - muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) - beta1 = float(os.environ.get("BETA1", 0.9)) - beta2 = float(os.environ.get("BETA2", 0.95)) - adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) - grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) - eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) - mtp_num_heads = int(os.environ.get("MTP_NUM_HEADS", 0)) - mtp_loss_weight = float(os.environ.get("MTP_LOSS_WEIGHT", 0.2)) - muon_beta2 = float(os.environ.get("MUON_BETA2", 0.95)) - swa_enabled = bool(int(os.environ.get("SWA_ENABLED", "1"))) - swa_every = int(os.environ.get("SWA_EVERY", 200)) - muon_wd = float(os.environ.get("MUON_WD", 0.02)) - adam_wd = float(os.environ.get("ADAM_WD", 0.01)) - qat_enabled = bool(int(os.environ.get("QAT_ENABLED", "0"))) - xsa_last_n = int(os.environ.get("XSA_LAST_N", 0)) - ema_enabled = bool(int(os.environ.get("EMA_ENABLED", "0"))) - ema_decay = float(os.environ.get("EMA_DECAY", 0.997)) - rope_dims = int(os.environ.get("ROPE_DIMS", 0)) - ln_scale = bool(int(os.environ.get("LN_SCALE", "0"))) - late_qat = bool(int(os.environ.get("LATE_QAT", "0"))) - bigram_vocab_size = int(os.environ.get("BIGRAM_VOCAB_SIZE", 4096)) - bigram_dim = int(os.environ.get("BIGRAM_DIM", 128)) - -# ----------------------------- -# MUON OPTIMIZER -# ----------------------------- -# -# As borrowed from modded-nanogpt -# Background on Muon: https://kellerjordan.github.io/posts/muon/ - -def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: - a, b, c = (3.4445, -4.7750, 2.0315) - X = G.bfloat16() - X /= X.norm() + eps - transposed = G.size(0) > G.size(1) - if transposed: - X = X.T - for _ in range(steps): - A = X @ X.T - B = b * A + c * A @ A - X = a * X + B @ X - return X.T if transposed else X - - -class Muon(torch.optim.Optimizer): - def __init__(self, params, lr: float, momentum: float, backend_steps: int, - nesterov: bool = True, weight_decay: float = 0.0): - super().__init__( - params, - dict(lr=lr, momentum=momentum, backend_steps=backend_steps, - nesterov=nesterov, weight_decay=weight_decay), - ) - - @torch.no_grad() - def step(self, closure=None): - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - distributed = dist.is_available() and dist.is_initialized() - world_size = dist.get_world_size() if distributed else 1 - rank = dist.get_rank() if distributed else 0 - - for group in self.param_groups: - params = group["params"] - if not params: - continue - lr = group["lr"] - momentum = group["momentum"] - backend_steps = group["backend_steps"] - nesterov = group["nesterov"] - - total_params = sum(int(p.numel()) for p in params) - updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) - - curr = 0 - for i, p in enumerate(params): - if i % world_size == rank and p.grad is not None: - g = p.grad - state = self.state[p] - if "momentum_buffer" not in state: - state["momentum_buffer"] = torch.zeros_like(g) - buf = state["momentum_buffer"] - buf.mul_(momentum).add_(g) - if nesterov: - g = g.add(buf, alpha=momentum) - g = zeropower_via_newtonschulz5(g, steps=backend_steps) - g *= max(1, g.size(0) / g.size(1)) ** 0.5 - updates_flat[curr : curr + p.numel()] = g.reshape(-1) - curr += p.numel() - - if distributed: - dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) - - wd = group.get("weight_decay", 0.0) - curr = 0 - for p in params: - if wd > 0.0: - p.data.mul_(1.0 - lr * wd) - g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) - p.add_(g, alpha=-lr) - curr += p.numel() - - return loss - - -# ----------------------------- -# TOKENIZER-AGNOSTIC EVALUATION SETUP -# ----------------------------- -# -# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. -# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. -# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. -# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. - -def build_sentencepiece_luts( - sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device -) -> tuple[Tensor, Tensor, Tensor]: - sp_vocab_size = int(sp.vocab_size()) - table_size = max(sp_vocab_size, vocab_size) - base_bytes_np = np.zeros((table_size,), dtype=np.int16) - has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) - is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) - for token_id in range(sp_vocab_size): - if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): - continue - is_boundary_token_np[token_id] = False - if sp.is_byte(token_id): - base_bytes_np[token_id] = 1 - continue - piece = sp.id_to_piece(token_id) - if piece.startswith("▁"): - has_leading_space_np[token_id] = True - piece = piece[1:] - base_bytes_np[token_id] = len(piece.encode("utf-8")) - return ( - torch.tensor(base_bytes_np, dtype=torch.int16, device=device), - torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), - torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), - ) - - -def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: - files = [Path(p) for p in sorted(glob.glob(pattern))] - if not files: - raise FileNotFoundError(f"No files found for pattern: {pattern}") - # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. - tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() - usable = ((tokens.numel() - 1) // seq_len) * seq_len - if usable <= 0: - raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") - return tokens[: usable + 1] - - -def eval_val( - args: Hyperparameters, - model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - grad_accum_steps: int, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, - eval_seq_len: int | None = None, -) -> tuple[float, float]: - seq_len = eval_seq_len or args.train_seq_len - local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) - if local_batch_tokens < seq_len: - raise ValueError( - "VAL_BATCH_SIZE must provide at least one sequence per rank; " - f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " - f"GRAD_ACCUM_STEPS={grad_accum_steps}, seq_len={seq_len}" - ) - local_batch_seqs = local_batch_tokens // seq_len - total_seqs = (val_tokens.numel() - 1) // seq_len - seq_start = (total_seqs * rank) // world_size - seq_end = (total_seqs * (rank + 1)) // world_size - val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) - val_token_count = torch.zeros((), device=device, dtype=torch.float64) - val_byte_count = torch.zeros((), device=device, dtype=torch.float64) - - model.eval() - with torch.inference_mode(): - for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): - batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) - raw_start = batch_seq_start * seq_len - raw_end = batch_seq_end * seq_len + 1 - local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) - x = local[:-1].reshape(-1, seq_len) - y = local[1:].reshape(-1, seq_len) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - batch_loss = model(x, y).detach() - batch_token_count = float(y.numel()) - val_loss_sum += batch_loss.to(torch.float64) * batch_token_count - val_token_count += batch_token_count - prev_ids = x.reshape(-1) - tgt_ids = y.reshape(-1) - token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) - token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) - val_byte_count += token_bytes.to(torch.float64).sum() - - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) - - val_loss = val_loss_sum / val_token_count - bits_per_token = val_loss.item() / math.log(2.0) - tokens_per_byte = val_token_count.item() / val_byte_count.item() - model.train() - return float(val_loss.item()), float(bits_per_token * tokens_per_byte) - -# ----------------------------- -# POST-TRAINING QUANTIZATION -# ----------------------------- -# -# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. -# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. -# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. - -CONTROL_TENSOR_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "CONTROL_TENSOR_NAME_PATTERNS", - "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,smear", - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", - ",".join(CONTROL_TENSOR_NAME_PATTERNS), - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 -INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 -INT8_PER_ROW_SCALE_DTYPE = torch.float16 -INT8_CLIP_PERCENTILE = 99.99984 -INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 - -def tensor_nbytes(t: Tensor) -> int: - return int(t.numel()) * int(t.element_size()) - -def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: - if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): - return t.float().contiguous() - if t.dtype in {torch.float32, torch.bfloat16}: - passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") - return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() - return t - -def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: - t32 = t.float() - if t32.ndim == 2: - # Matrices get one scale per row, which usually tracks output-channel - # ranges much better than a single tensor-wide scale. - clip_abs = ( - torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) - if t32.numel() - else torch.empty((t32.shape[0],), dtype=torch.float32) - ) - clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) - scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) - q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() - return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() - - # Vectors / scalars use a simpler per-tensor scale. - clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 - scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) - q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() - return q, scale - -def quantize_state_dict_int8(state_dict: dict[str, Tensor]): - # Single supported clean-script export format: - # - per-row int8 for 2D float tensors - # - per-tensor int8 for other float tensors - # - exact passthrough for non-floats - # - passthrough for small float tensors, stored as fp16 to save bytes - quantized: dict[str, Tensor] = {} - scales: dict[str, Tensor] = {} - dtypes: dict[str, str] = {} - passthrough: dict[str, Tensor] = {} - passthrough_orig_dtypes: dict[str, str] = {} - qmeta: dict[str, dict[str, object]] = {} - stats = dict.fromkeys( - ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), - 0, - ) - - for name, tensor in state_dict.items(): - t = tensor.detach().to("cpu").contiguous() - stats["param_count"] += int(t.numel()) - stats["num_tensors"] += 1 - stats["baseline_tensor_bytes"] += tensor_nbytes(t) - - if not t.is_floating_point(): - stats["num_nonfloat_tensors"] += 1 - passthrough[name] = t - stats["int8_payload_bytes"] += tensor_nbytes(t) - continue - - # Small float tensors are cheap enough to keep directly. We still downcast - # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. - if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: - kept = keep_float_tensor(name, t, passthrough_orig_dtypes) - passthrough[name] = kept - stats["int8_payload_bytes"] += tensor_nbytes(kept) - continue - - stats["num_float_tensors"] += 1 - q, s = quantize_float_tensor(t) - if s.ndim > 0: - qmeta[name] = {"scheme": "per_row", "axis": 0} - quantized[name] = q - scales[name] = s - dtypes[name] = str(t.dtype).removeprefix("torch.") - stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) - - obj: dict[str, object] = { - "__quant_format__": "int8_clean_per_row_v1", - "quantized": quantized, - "scales": scales, - "dtypes": dtypes, - "passthrough": passthrough, - } - if qmeta: - obj["qmeta"] = qmeta - if passthrough_orig_dtypes: - obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes - return obj, stats - -def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: - out: dict[str, Tensor] = {} - qmeta = obj.get("qmeta", {}) - passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) - for name, q in obj["quantized"].items(): - dtype = getattr(torch, obj["dtypes"][name]) - s = obj["scales"][name] - if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: - s = s.to(dtype=torch.float32) - # Broadcast the saved row scale back across trailing dimensions. - out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() - else: - scale = float(s.item()) - out[name] = (q.float() * scale).to(dtype=dtype).contiguous() - for name, t in obj["passthrough"].items(): - # Restore small tensors, undoing the temporary fp16 storage cast if needed. - out_t = t.detach().to("cpu").contiguous() - orig_dtype = passthrough_orig_dtypes.get(name) - if isinstance(orig_dtype, str): - out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() - out[name] = out_t - return out - - -# ----------------------------- -# DATA LOADING -# ----------------------------- - -def load_data_shard(file: Path) -> Tensor: - header_bytes = 256 * np.dtype(" None: - self.file_idx = (self.file_idx + 1) % len(self.files) - self.tokens = load_data_shard(self.files[self.file_idx]) - self.pos = 0 - - def take(self, n: int) -> Tensor: - chunks: list[Tensor] = [] - remaining = n - while remaining > 0: - avail = self.tokens.numel() - self.pos - if avail <= 0: - self._advance_file() - continue - k = min(remaining, avail) - chunks.append(self.tokens[self.pos : self.pos + k]) - self.pos += k - remaining -= k - return chunks[0] if len(chunks) == 1 else torch.cat(chunks) - - -class DistributedTokenLoader: - # Each call consumes a contiguous chunk from the shared token stream, then slices out - # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. - def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): - self.rank = rank - self.world_size = world_size - self.device = device - self.stream = TokenStream(pattern) - - def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: - local_tokens = global_tokens // (self.world_size * grad_accum_steps) - per_rank_span = local_tokens + 1 - chunk = self.stream.take(per_rank_span * self.world_size) - start = self.rank * per_rank_span - local = chunk[start : start + per_rank_span].to(dtype=torch.int64) - x = local[:-1].reshape(-1, seq_len) - y = local[1:].reshape(-1, seq_len) - return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) - -# ----------------------------- -# TRANSFORMER MODULES -# ----------------------------- - -class RMSNorm(nn.Module): - def __init__(self, eps: float | None = None): - super().__init__() - self.eps = eps - - def forward(self, x: Tensor) -> Tensor: - return F.rms_norm(x, (x.size(-1),), eps=self.eps) - - -class CastedLinear(nn.Linear): - _qat_enabled: bool = False - - def forward(self, x: Tensor) -> Tensor: - w = self.weight.to(x.dtype) - if CastedLinear._qat_enabled and self.training and w.ndim == 2: - with torch.no_grad(): - w32 = self.weight.float() - row_max = w32.abs().amax(dim=1) - scale = (row_max / 31.0).clamp_min(1.0 / 31.0) - w_q = (torch.clamp(torch.round(w32 / scale[:, None]), -32, 31) * scale[:, None]).to(x.dtype) - w = w + (w_q - w).detach() - bias = self.bias.to(x.dtype) if self.bias is not None else None - return F.linear(x, w, bias) - - -def restore_low_dim_params_to_fp32(module: nn.Module) -> None: - # Keep small/control parameters in fp32 even when the model body runs in bf16. - with torch.no_grad(): - for name, param in module.named_parameters(): - if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: - param.data = param.data.float() - - -class Rotary(nn.Module): - # NTK-aware RoPE: auto-scales base frequency when seq_len exceeds train_seq_len. - def __init__(self, dim: int, base: float = 10000.0, train_seq_len: int = 1024, rope_dims: int = 0): - super().__init__() - self.rope_dims = rope_dims if rope_dims > 0 else dim - self.dim = dim - self.base = base - self.train_seq_len = train_seq_len - rd = self.rope_dims - inv_freq = 1.0 / (base ** (torch.arange(0, rd, 2, dtype=torch.float32) / rd)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self._seq_len_cached = 0 - self._cos_cached: Tensor | None = None - self._sin_cached: Tensor | None = None - - def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: - if ( - self._cos_cached is None - or self._sin_cached is None - or self._seq_len_cached != seq_len - or self._cos_cached.device != device - ): - rd = self.rope_dims - if seq_len > self.train_seq_len: - scale = seq_len / self.train_seq_len - new_base = self.base * (scale ** (rd / (rd - 2))) - inv_freq = 1.0 / (new_base ** (torch.arange(0, rd, 2, dtype=torch.float32, device=device) / rd)) - else: - inv_freq = self.inv_freq.to(device) - t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) - freqs = torch.outer(t, inv_freq) - self._cos_cached = freqs.cos()[None, :, None, :] - self._sin_cached = freqs.sin()[None, :, None, :] - self._seq_len_cached = seq_len - return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) - - -def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: - rd = cos.size(-1) * 2 - if rd < x.size(-1): - x_rope, x_pass = x[..., :rd], x[..., rd:] - half = rd // 2 - x1, x2 = x_rope[..., :half], x_rope[..., half:] - x_rot = torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) - return torch.cat((x_rot, x_pass), dim=-1) - half = x.size(-1) // 2 - x1, x2 = x[..., :half], x[..., half:] - return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) - - -class CausalSelfAttention(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - rope_base: float, - qk_gain_init: float, - rope_dims: int = 0, - ): - super().__init__() - if dim % num_heads != 0: - raise ValueError("model_dim must be divisible by num_heads") - if num_heads % num_kv_heads != 0: - raise ValueError("num_heads must be divisible by num_kv_heads") - self.num_heads = num_heads - self.num_kv_heads = num_kv_heads - self.head_dim = dim // num_heads - if self.head_dim % 2 != 0: - raise ValueError("head_dim must be even for RoPE") - kv_dim = self.num_kv_heads * self.head_dim - self.c_q = CastedLinear(dim, dim, bias=False) - self.c_k = CastedLinear(dim, kv_dim, bias=False) - self.c_v = CastedLinear(dim, kv_dim, bias=False) - self.proj = CastedLinear(dim, dim, bias=False) - self.proj._zero_init = True - self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) - self.rope_dims = rope_dims - self.rotary = Rotary(self.head_dim, base=rope_base, train_seq_len=1024, rope_dims=rope_dims) - self.use_xsa = False - - def _xsa_efficient(self, y: Tensor, v: Tensor) -> Tensor: - """Subtract self-value projection via GQA-aware reshape (no repeat_interleave).""" - B, T, H, D = y.shape - Hkv = v.size(-2) - group = H // Hkv - y_g = y.reshape(B, T, Hkv, group, D) - vn = F.normalize(v, dim=-1).unsqueeze(-2) - proj = (y_g * vn).sum(dim=-1, keepdim=True) * vn - return (y_g - proj).reshape(B, T, H, D) - - def forward(self, x: Tensor) -> Tensor: - bsz, seqlen, dim = x.shape - q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim) - k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) - v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) - q = F.rms_norm(q, (q.size(-1),)) - k = F.rms_norm(k, (k.size(-1),)) - cos, sin = self.rotary(seqlen, x.device, q.dtype) - q = apply_rotary_emb(q, cos, sin) - k = apply_rotary_emb(k, cos, sin) - q = q * self.q_gain.to(dtype=q.dtype)[None, None, :, None] - y = flash_attn_3_func(q, k, v, causal=True) - if self.use_xsa: - y = self._xsa_efficient(y, v) - y = y.reshape(bsz, seqlen, dim) - return self.proj(y) - - -class SmearGate(nn.Module): - def __init__(self, dim: int): - super().__init__() - self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) - - def forward(self, x: Tensor) -> Tensor: - g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] - x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) - return (1 - g) * x + g * x_prev - - -class BigramHashEmbedding(nn.Module): - def __init__(self, bigram_vocab_size: int, bigram_dim: int, model_dim: int): - super().__init__() - self.bigram_vocab_size = bigram_vocab_size - self.embed = nn.Embedding(bigram_vocab_size, bigram_dim) - nn.init.zeros_(self.embed.weight) - self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None - if self.proj is not None: - nn.init.zeros_(self.proj.weight) - self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) - - def bigram_hash(self, tokens: Tensor) -> Tensor: - t = tokens.to(torch.int32) - mod = self.bigram_vocab_size - 1 - out = torch.empty_like(t) - out[..., 0] = mod - out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod - return out.long() - - def forward(self, token_ids: Tensor) -> Tensor: - h = self.embed(self.bigram_hash(token_ids)) - if self.proj is not None: - h = self.proj(h) - return h * self.scale.to(dtype=h.dtype) - - -class MLP(nn.Module): - def __init__(self, dim: int, mlp_mult: int): - super().__init__() - hidden = int(mlp_mult * dim) - self.fc = CastedLinear(dim, hidden, bias=False) - self.proj = CastedLinear(hidden, dim, bias=False) - self.proj._zero_init = True - - def forward(self, x: Tensor) -> Tensor: - x = torch.relu(self.fc(x)) - return self.proj(x.square()) - - -class Block(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: int, - rope_base: float, - qk_gain_init: float, - rope_dims: int = 0, - layer_idx: int = 0, - ln_scale: bool = False, - ): - super().__init__() - self.attn_norm = RMSNorm() - self.mlp_norm = RMSNorm() - self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init, rope_dims=rope_dims) - self.mlp = MLP(dim, mlp_mult) - self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) - self.ln_scale_factor = 1.0 / math.sqrt(layer_idx + 1) if ln_scale else 1.0 - - def forward(self, x: Tensor, x0: Tensor) -> Tensor: - mix = self.resid_mix.to(dtype=x.dtype) - x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 - s = self.ln_scale_factor - attn_out = self.attn(self.attn_norm(x) * s) - x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out - x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x) * s) - return x - - -class GPT(nn.Module): - def __init__( - self, - vocab_size: int, - num_layers: int, - model_dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: int, - tie_embeddings: bool, - tied_embed_init_std: float, - logit_softcap: float, - rope_base: float, - qk_gain_init: float, - mtp_num_heads: int = 0, - mtp_loss_weight: float = 0.1, - bigram_vocab_size: int = 0, - bigram_dim: int = 128, - xsa_last_n: int = 0, - rope_dims: int = 0, - ln_scale: bool = False, - ): - super().__init__() - if logit_softcap <= 0.0: - raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") - self.tie_embeddings = tie_embeddings - self.tied_embed_init_std = tied_embed_init_std - self.logit_softcap = logit_softcap - self.mtp_num_heads = mtp_num_heads - self.mtp_loss_weight = mtp_loss_weight - self.tok_emb = nn.Embedding(vocab_size, model_dim) - self.bigram = BigramHashEmbedding(bigram_vocab_size, bigram_dim, model_dim) if bigram_vocab_size > 0 else None - self.smear = SmearGate(model_dim) - self.num_encoder_layers = num_layers // 2 - self.num_decoder_layers = num_layers - self.num_encoder_layers - self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) - self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) - self.blocks = nn.ModuleList( - [ - Block( - model_dim, - num_heads, - num_kv_heads, - mlp_mult, - rope_base, - qk_gain_init, - rope_dims=rope_dims, - layer_idx=i, - ln_scale=ln_scale, - ) - for i in range(num_layers) - ] - ) - self.final_norm = RMSNorm() - self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) - if self.lm_head is not None: - self.lm_head._zero_init = True - self.mtp_heads = nn.ModuleList( - [CastedLinear(model_dim, vocab_size, bias=False) for _ in range(mtp_num_heads)] - ) - for head in self.mtp_heads: - head._zero_init = True - if xsa_last_n > 0: - for i in range(max(0, num_layers - xsa_last_n), num_layers): - self.blocks[i].attn.use_xsa = True - self._init_weights() - - def _init_weights(self) -> None: - if self.tie_embeddings: - nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) - num_layers = len(self.blocks) - for name, module in self.named_modules(): - if isinstance(module, nn.Linear): - if getattr(module, "_zero_init", False): - nn.init.zeros_(module.weight) - elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: - nn.init.orthogonal_(module.weight, gain=1.0) - if ".proj." in name or name.endswith(".proj"): - with torch.no_grad(): - module.weight.mul_(1.0 / math.sqrt(2 * num_layers)) - - def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: - x = self.tok_emb(input_ids) - if self.bigram is not None: - x = x + self.bigram(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - x = self.smear(x) - x0 = x - skips: list[Tensor] = [] - - for i in range(self.num_encoder_layers): - x = self.blocks[i](x, x0) - skips.append(x) - for i in range(self.num_decoder_layers): - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - x = self.blocks[self.num_encoder_layers + i](x, x0) - - x = self.final_norm(x) - x_flat = x.reshape(-1, x.size(-1)) - targets = target_ids.reshape(-1) - if self.tie_embeddings: - logits_proj = F.linear(x_flat, self.tok_emb.weight) - else: - if self.lm_head is None: - raise RuntimeError("lm_head is required when tie_embeddings=False") - logits_proj = self.lm_head(x_flat) - logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - main_loss = F.cross_entropy(logits.float(), targets, reduction="mean") - - if self.training and self.mtp_num_heads > 0 and self.mtp_loss_weight > 0.0: - _, seqlen, dim = x.shape - mtp_loss_sum = x.new_zeros(()) - mtp_loss_count = 0 - for k, mtp_head in enumerate(self.mtp_heads): - valid_t = seqlen - (k + 1) - if valid_t <= 0: - continue - mtp_hidden = x[:, :valid_t, :].reshape(-1, dim) - mtp_targets = target_ids[:, k + 1 :].reshape(-1) - mtp_logits_proj = mtp_head(mtp_hidden) - mtp_logits = self.logit_softcap * torch.tanh(mtp_logits_proj / self.logit_softcap) - mtp_loss_sum = mtp_loss_sum + F.cross_entropy(mtp_logits.float(), mtp_targets, reduction="mean") - mtp_loss_count += 1 - if mtp_loss_count > 0: - main_loss = main_loss + self.mtp_loss_weight * (mtp_loss_sum / mtp_loss_count) - - return main_loss - - def forward_logits(self, input_ids: Tensor) -> Tensor: - """Return logits (bsz, seq_len, vocab) without computing loss.""" - x = self.tok_emb(input_ids) - if self.bigram is not None: - x = x + self.bigram(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - x = self.smear(x) - x0 = x - skips: list[Tensor] = [] - for i in range(self.num_encoder_layers): - x = self.blocks[i](x, x0) - skips.append(x) - for i in range(self.num_decoder_layers): - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - x = self.blocks[self.num_encoder_layers + i](x, x0) - x = self.final_norm(x) - if self.tie_embeddings: - logits_proj = F.linear(x, self.tok_emb.weight) - else: - logits_proj = self.lm_head(x) - return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - - -# ----------------------------- -# SLIDING WINDOW EVALUATION -# ----------------------------- - -def eval_val_sliding( - args: Hyperparameters, - base_model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, - stride: int, - batch_seqs: int = 32, - eval_seq_len: int | None = None, -) -> tuple[float, float]: - """Sliding window evaluation: each token scored with maximum context.""" - seq_len = eval_seq_len or args.train_seq_len - total_tokens = val_tokens.numel() - 1 - - window_starts = [ws for ws in range(0, total_tokens, stride) - if min(ws + seq_len, total_tokens) - ws >= 1] - total_windows = len(window_starts) - - my_s = (total_windows * rank) // world_size - my_e = (total_windows * (rank + 1)) // world_size - my_windows = window_starts[my_s:my_e] - - loss_sum = torch.zeros((), device=device, dtype=torch.float64) - token_count = torch.zeros((), device=device, dtype=torch.float64) - byte_count = torch.zeros((), device=device, dtype=torch.float64) - - base_model.eval() - compiled_logits = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) - - with torch.inference_mode(): - for bi in range(0, len(my_windows), batch_seqs): - batch_ws = my_windows[bi:bi + batch_seqs] - bsz = len(batch_ws) - - x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) - y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) - wlens: list[int] = [] - - for i, ws in enumerate(batch_ws): - end = min(ws + seq_len, total_tokens) - wlen = end - ws - wlens.append(wlen) - chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) - x_batch[i, :wlen] = chunk[:-1] - y_batch[i, :wlen] = chunk[1:] - - with torch.autocast(device_type="cuda", dtype=torch.bfloat16): - logits = compiled_logits(x_batch) - - nll = F.cross_entropy( - logits.reshape(-1, logits.size(-1)).float(), - y_batch.reshape(-1), - reduction="none", - ).reshape(bsz, seq_len) - - for i, ws in enumerate(batch_ws): - wlen = wlens[i] - s = 0 if ws == 0 else max(wlen - stride, 0) - scored_nll = nll[i, s:wlen].to(torch.float64) - loss_sum += scored_nll.sum() - token_count += float(wlen - s) - tgt = y_batch[i, s:wlen] - prev = x_batch[i, s:wlen] - tb = base_bytes_lut[tgt].to(torch.float64) - tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) - byte_count += tb.sum() - - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) - - val_loss = (loss_sum / token_count).item() - bits_per_token = val_loss / math.log(2.0) - tokens_per_byte = token_count.item() / byte_count.item() - base_model.train() - return val_loss, bits_per_token * tokens_per_byte - - -# ----------------------------- -# INT6 MIXED QUANTIZATION (transplanted from working diagnostic scripts) -# ----------------------------- - -def _classify_param(name: str) -> str: - if "tok_emb" in name or "lm_head" in name: - return "embed" - if ".mlp." in name: - return "mlp" - if ".attn." in name or (".proj." in name and ".mlp." not in name): - return "attn" - return "other" - -def quantize_int6_per_row(t: Tensor) -> tuple[Tensor, Tensor]: - t32 = t.float() - if t32.ndim == 2: - row_max = t32.abs().amax(dim=1) - scale = (row_max / 31.0).clamp_min(1.0 / 31.0).to(torch.float16) - q = torch.clamp(torch.round(t32 / scale.float()[:, None]), -32, 31).to(torch.int8) - return q, scale - amax = t32.abs().max().item() - scale = torch.tensor(amax / 31.0 if amax > 0 else 1.0, dtype=torch.float16) - q = torch.clamp(torch.round(t32 / scale.float()), -32, 31).to(torch.int8) - return q, scale - -def mixed_quantize_int6(state_dict: dict[str, Tensor], int6_cats: set[str]): - num_layers_total = max( - (int(k.split(".")[1]) for k in state_dict if k.startswith("blocks.")), - default=0, - ) + 1 - late_k_layers = set(range(num_layers_total - 2, num_layers_total)) - - result: dict[str, Tensor] = {} - meta: dict[str, object] = {} - for name, tensor in state_dict.items(): - t = tensor.detach().cpu().contiguous() - cat = _classify_param(name) - if not t.is_floating_point() or t.numel() <= 65536: - result[name] = t.to(torch.float16) if t.is_floating_point() else t - meta[name] = "passthrough" - continue - if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): - result[name] = t.float() - meta[name] = "passthrough_ctrl" - continue - # tok_emb.weight falls through to int8 via "embed" category - if cat in int6_cats and t.ndim >= 1: - q, s = quantize_int6_per_row(t) - result[name + ".q"] = q - result[name + ".scale"] = s - meta[name] = {"type": "int6"} - else: - q, s = quantize_float_tensor(t) - result[name + ".q"] = q - result[name + ".scale"] = s - meta[name] = {"type": "int8"} - return result, meta - -def dequantize_mixed_int6(result: dict[str, Tensor], meta: dict[str, object], - template_sd: dict[str, Tensor]) -> dict[str, Tensor]: - out: dict[str, Tensor] = {} - for name, orig in template_sd.items(): - info = meta.get(name) - if info is None: - continue - orig_dtype = orig.dtype - if info in ("passthrough", "passthrough_ctrl", "passthrough_fp16"): - t = result[name] - if t.dtype == torch.float16 and orig_dtype in (torch.float32, torch.bfloat16): - t = t.to(orig_dtype) - out[name] = t - continue - q, s = result[name + ".q"], result[name + ".scale"] - if s.ndim > 0: - out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig_dtype) - else: - out[name] = (q.float() * float(s.item())).to(orig_dtype) - return out - - -# ----------------------------- -# TRAINING -# ----------------------------- - -def main() -> None: - global zeropower_via_newtonschulz5 - - code = Path(__file__).read_text(encoding="utf-8") - args = Hyperparameters() - zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) - - # ----------------------------- - # DISTRIBUTED + CUDA SETUP - # ----------------------------- - - distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ - rank = int(os.environ.get("RANK", "0")) - world_size = int(os.environ.get("WORLD_SIZE", "1")) - local_rank = int(os.environ.get("LOCAL_RANK", "0")) - if world_size <= 0: - raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") - if 8 % world_size != 0: - raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") - grad_accum_steps = 8 // world_size - grad_scale = 1.0 / grad_accum_steps - if not torch.cuda.is_available(): - raise RuntimeError("CUDA is required") - device = torch.device("cuda", local_rank) - torch.cuda.set_device(device) - if distributed: - dist.init_process_group(backend="nccl", device_id=device) - dist.barrier() - master_process = rank == 0 - - # Fast math knobs - torch.backends.cuda.matmul.allow_tf32 = True - torch.backends.cudnn.allow_tf32 = True - from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp - - enable_cudnn_sdp(False) - enable_flash_sdp(True) - enable_mem_efficient_sdp(False) - enable_math_sdp(False) - - logfile = None - if master_process: - os.makedirs("logs", exist_ok=True) - logfile = f"logs/{args.run_id}.txt" - print(logfile) - - def log0(msg: str, console: bool = True) -> None: - if not master_process: - return - if console: - print(msg) - if logfile is not None: - with open(logfile, "a", encoding="utf-8") as f: - print(msg, file=f) - - log0(code, console=False) - log0("=" * 100, console=False) - log0(f"Running Python {sys.version}", console=False) - log0(f"Running PyTorch {torch.__version__}", console=False) - log0( - subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, - console=False, - ) - log0("=" * 100, console=False) - - # ----------------------------- - # TOKENIZER + VALIDATION METRIC SETUP - # ----------------------------- - - random.seed(args.seed) - np.random.seed(args.seed) - torch.manual_seed(args.seed) - torch.cuda.manual_seed_all(args.seed) - - if not args.tokenizer_path.endswith(".model"): - raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") - sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) - if int(sp.vocab_size()) != args.vocab_size: - raise ValueError( - f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" - ) - dataset_dir = Path(args.data_path).resolve() - actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) - effective_eval_seq_len = args.eval_seq_len if args.eval_seq_len > 0 else args.train_seq_len - val_seq_len = max(args.train_seq_len, effective_eval_seq_len) - val_tokens = load_validation_tokens(args.val_files, val_seq_len) - base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( - sp, args.vocab_size, device - ) - log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") - log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") - log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") - - # ----------------------------- - # MODEL + OPTIMIZER SETUP - # ----------------------------- - - CastedLinear._qat_enabled = args.qat_enabled - - base_model = GPT( - vocab_size=args.vocab_size, - num_layers=args.num_layers, - model_dim=args.model_dim, - num_heads=args.num_heads, - num_kv_heads=args.num_kv_heads, - mlp_mult=args.mlp_mult, - tie_embeddings=args.tie_embeddings, - tied_embed_init_std=args.tied_embed_init_std, - logit_softcap=args.logit_softcap, - rope_base=args.rope_base, - qk_gain_init=args.qk_gain_init, - mtp_num_heads=args.mtp_num_heads, - mtp_loss_weight=args.mtp_loss_weight, - bigram_vocab_size=args.bigram_vocab_size, - bigram_dim=args.bigram_dim, - xsa_last_n=args.xsa_last_n, - rope_dims=args.rope_dims, - ln_scale=args.ln_scale, - ).to(device).bfloat16() - for module in base_model.modules(): - if isinstance(module, CastedLinear): - module.float() - restore_low_dim_params_to_fp32(base_model) - compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) - model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model - - # Optimizer split: - # - token embedding (Adam) uses EMBED_LR - # - untied lm_head (Adam) uses HEAD_LR - # - matrix params in transformer blocks use MATRIX_LR via Muon - # - vectors/scalars use SCALAR_LR via Adam - block_named_params = list(base_model.blocks.named_parameters()) - matrix_params = [ - p - for name, p in block_named_params - if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - if base_model.mtp_num_heads > 0: - matrix_params.extend([p for p in base_model.mtp_heads.parameters() if p.ndim == 2]) - scalar_params = [ - p - for name, p in block_named_params - if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - if base_model.skip_weights.numel() > 0: - scalar_params.append(base_model.skip_weights) - scalar_params.append(base_model.smear.gate) - if base_model.bigram is not None: - scalar_params.append(base_model.bigram.scale) - token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr - tok_params = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] - if base_model.bigram is not None: - tok_params.append({"params": [base_model.bigram.embed.weight], "lr": token_lr, "base_lr": token_lr}) - if base_model.bigram.proj is not None: - matrix_params.append(base_model.bigram.proj.weight) - optimizer_tok = torch.optim.AdamW( - tok_params, - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - weight_decay=args.adam_wd, - fused=True, - ) - optimizer_muon = Muon( - matrix_params, - lr=args.matrix_lr, - momentum=args.muon_momentum, - backend_steps=args.muon_backend_steps, - weight_decay=args.muon_wd, - ) - for group in optimizer_muon.param_groups: - group["base_lr"] = args.matrix_lr - optimizer_scalar = torch.optim.AdamW( - [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - weight_decay=args.adam_wd, - fused=True, - ) - optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] - if base_model.lm_head is not None: - optimizer_head = torch.optim.Adam( - [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizers.insert(1, optimizer_head) - - n_params = sum(p.numel() for p in base_model.parameters()) - mtp_params = sum(p.numel() for p in base_model.mtp_heads.parameters()) - log0(f"model_params:{n_params}") - log0(f"mtp_num_heads:{args.mtp_num_heads} mtp_loss_weight:{args.mtp_loss_weight} mtp_params:{mtp_params}") - log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") - log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") - log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") - log0( - f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " - f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " - f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" - ) - log0( - f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " - f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " - f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" - ) - log0(f"seed:{args.seed}") - - # ----------------------------- - # DATA LOADER & MODEL WARMUP - # ----------------------------- - - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - def zero_grad_all() -> None: - for opt in optimizers: - opt.zero_grad(set_to_none=True) - - max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None - - def lr_mul(step: int, elapsed_ms: float) -> float: - if args.warmdown_iters <= 0: - return 1.0 - if max_wallclock_ms is None: - warmdown_start = max(args.iterations - args.warmdown_iters, 0) - return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 - step_ms = elapsed_ms / max(step, 1) - warmdown_ms = args.warmdown_iters * step_ms - remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) - return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 - - # Warmup primes the compiled forward/backward/optimizer paths, then we restore the - # initial weights/optimizer state so measured training starts from the true init. - if args.warmup_steps > 0: - initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} - initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] - model.train() - for warmup_step in range(args.warmup_steps): - zero_grad_all() - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - warmup_loss = model(x, y) - (warmup_loss * grad_scale).backward() - for opt in optimizers: - opt.step() - zero_grad_all() - if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: - log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") - base_model.load_state_dict(initial_model_state, strict=True) - for opt, state in zip(optimizers, initial_optimizer_states, strict=True): - opt.load_state_dict(state) - zero_grad_all() - if distributed: - model.require_backward_grad_sync = True - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - # ----------------------------- - # MAIN TRAINING LOOP - # ----------------------------- - - swa_state: dict[str, Tensor] | None = None - swa_count = 0 - - ema_state: dict[str, Tensor] | None = None - if args.ema_enabled: - ema_state = {name: t.detach().float().clone() for name, t in base_model.state_dict().items()} - - training_time_ms = 0.0 - stop_after_step: int | None = None - torch.cuda.synchronize() - t0 = time.perf_counter() - - step = 0 - while True: - last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) - - should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) - if should_validate: - torch.cuda.synchronize() - training_time_ms += 1000.0 * (time.perf_counter() - t0) - val_loss, val_bpb = eval_val( - args, - model, - rank, - world_size, - device, - grad_accum_steps, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, - ) - log0( - f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " - f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" - ) - torch.cuda.synchronize() - t0 = time.perf_counter() - - if last_step: - if stop_after_step is not None and step < args.iterations: - log0( - f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " - f"step:{step}/{args.iterations}" - ) - break - - elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - scale = lr_mul(step, elapsed_ms) - qat_threshold = float(os.environ.get("QAT_THRESHOLD", "0.1")) - if args.late_qat and scale < qat_threshold and not CastedLinear._qat_enabled: - CastedLinear._qat_enabled = True - log0(f"late_qat:enabled step:{step} scale:{scale:.4f}") - zero_grad_all() - train_loss = torch.zeros((), device=device) - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - loss = model(x, y) - train_loss += loss.detach() - (loss * grad_scale).backward() - train_loss /= grad_accum_steps - - frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 - muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum - for group in optimizer_muon.param_groups: - group["momentum"] = muon_momentum - - for opt in optimizers: - for group in opt.param_groups: - group["lr"] = group["base_lr"] * scale - - if args.grad_clip_norm > 0: - torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) - for opt in optimizers: - opt.step() - zero_grad_all() - - step += 1 - approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - - if ema_state is not None: - d = args.ema_decay - with torch.no_grad(): - for name, t in base_model.state_dict().items(): - ema_state[name].mul_(d).add_(t.detach().float(), alpha=1.0 - d) - - if args.swa_enabled and not args.ema_enabled and scale < 0.5 and step % args.swa_every == 0: - if swa_state is None: - swa_state = {name: t.detach().float().clone() for name, t in base_model.state_dict().items()} - swa_count = 1 - log0(f"swa:start step:{step}") - else: - for name, t in base_model.state_dict().items(): - swa_state[name].add_(t.detach().float()) - swa_count += 1 - - should_log_train = ( - args.train_log_every > 0 - and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) - ) - if should_log_train: - log0( - f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " - f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" - ) - - # Needed to sync whether we've reached the wallclock cap. - reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms - if distributed and max_wallclock_ms is not None: - reached_cap_tensor = torch.tensor(int(reached_cap), device=device) - dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) - reached_cap = bool(reached_cap_tensor.item()) - if stop_after_step is None and reached_cap: - stop_after_step = step - - log0( - f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " - f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" - ) - - if ema_state is not None: - log0("ema:applying EMA weights") - avg_state = {name: t.to(dtype=base_model.state_dict()[name].dtype) - for name, t in ema_state.items()} - del ema_state - base_model.load_state_dict(avg_state, strict=True) - elif args.swa_enabled and swa_state is not None and swa_count > 1: - log0(f"swa:applying averaged {swa_count} checkpoints") - avg_state = {name: (t / swa_count).to(dtype=base_model.state_dict()[name].dtype) - for name, t in swa_state.items()} - del swa_state - base_model.load_state_dict(avg_state, strict=True) - - # ----------------------------- - # SERIALIZATION + ROUNDTRIP VALIDATION - # ----------------------------- - - full_state_dict = base_model.state_dict() - export_sd = {k: v for k, v in full_state_dict.items() if "mtp_heads" not in k} - excluded_mtp = sum(int(t.numel()) for k, t in full_state_dict.items() if "mtp_heads" in k) - if excluded_mtp > 0: - log0(f"export_excluding_mtp_params:{excluded_mtp}") - - if master_process: - torch.save(export_sd, "final_model.pt") - model_bytes = os.path.getsize("final_model.pt") - code_bytes = len(code.encode("utf-8")) - log0(f"Serialized model: {model_bytes} bytes") - log0(f"Code size: {code_bytes} bytes") - - sd_cpu = {k: v.detach().cpu() for k, v in export_sd.items()} - quant_result, quant_meta = mixed_quantize_int6(sd_cpu, {"mlp", "attn"}) - quant_buf = io.BytesIO() - torch.save({"w": quant_result, "m": quant_meta}, quant_buf) - quant_raw = quant_buf.getvalue() - quant_blob = zstandard.ZstdCompressor(level=22).compress(quant_raw) if _COMPRESSOR == "zstd" else zlib.compress(quant_raw, 9) - if master_process: - with open("final_model.int6.ptz", "wb") as f: - f.write(quant_blob) - quant_file_bytes = len(quant_blob) - code_bytes = len(code.encode("utf-8")) - log0(f"Serialized model int6+{_COMPRESSOR}: {quant_file_bytes} bytes") - log0(f"Total submission size int6+{_COMPRESSOR}: {quant_file_bytes + code_bytes} bytes") - - # Roundtrip: decompress + dequantize into fresh model + eval - if distributed: - dist.barrier() - with open("final_model.int6.ptz", "rb") as f: - quant_blob_disk = f.read() - quant_state = torch.load( - io.BytesIO(zstandard.ZstdDecompressor().decompress(quant_blob_disk) if _COMPRESSOR == "zstd" else zlib.decompress(quant_blob_disk)), - map_location="cpu", - ) - deq_state = dequantize_mixed_int6(quant_state["w"], quant_state["m"], sd_cpu) - - eval_model = GPT( - vocab_size=args.vocab_size, num_layers=args.num_layers, model_dim=args.model_dim, - num_heads=args.num_heads, num_kv_heads=args.num_kv_heads, mlp_mult=args.mlp_mult, - tie_embeddings=args.tie_embeddings, tied_embed_init_std=args.tied_embed_init_std, - logit_softcap=args.logit_softcap, rope_base=args.rope_base, qk_gain_init=args.qk_gain_init, - mtp_num_heads=0, mtp_loss_weight=0.0, - bigram_vocab_size=args.bigram_vocab_size, bigram_dim=args.bigram_dim, - xsa_last_n=args.xsa_last_n, - rope_dims=args.rope_dims, - ln_scale=args.ln_scale, - ).to(device).bfloat16() - for m in eval_model.modules(): - if isinstance(m, CastedLinear): - m.float() - restore_low_dim_params_to_fp32(eval_model) - eval_model.load_state_dict(deq_state, strict=True) - compiled_eval = torch.compile(eval_model, dynamic=False, fullgraph=True) - - # Standard non-overlapping eval (sanity check) - torch.cuda.synchronize() - t_qeval = time.perf_counter() - q_val_loss, q_val_bpb = eval_val( - args, compiled_eval, rank, world_size, device, grad_accum_steps, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - eval_seq_len=effective_eval_seq_len, - ) - torch.cuda.synchronize() - log0( - f"final_int6_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " - f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" - ) - log0(f"final_int6_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") - - # Sliding window eval (submission score) - sw_seq_len = effective_eval_seq_len - if args.eval_stride > 0 and args.eval_stride < sw_seq_len: - torch.cuda.synchronize() - t_slide = time.perf_counter() - sw_val_loss, sw_val_bpb = eval_val_sliding( - args, eval_model, rank, world_size, device, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - stride=args.eval_stride, - eval_seq_len=sw_seq_len, - ) - torch.cuda.synchronize() - log0( - f"final_int6_sliding_window val_loss:{sw_val_loss:.4f} val_bpb:{sw_val_bpb:.4f} " - f"stride:{args.eval_stride} eval_time:{1000.0 * (time.perf_counter() - t_slide):.0f}ms" - ) - log0(f"final_int6_sliding_window_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") - - # Second sliding window eval at stride=64 for submission comparison - if args.eval_stride != 64 and 64 < sw_seq_len: - torch.cuda.synchronize() - t_slide64 = time.perf_counter() - sw64_val_loss, sw64_val_bpb = eval_val_sliding( - args, eval_model, rank, world_size, device, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - stride=64, - eval_seq_len=sw_seq_len, - ) - torch.cuda.synchronize() - log0( - f"final_int6_sliding_window_s64 val_loss:{sw64_val_loss:.4f} val_bpb:{sw64_val_bpb:.4f} " - f"stride:64 eval_time:{1000.0 * (time.perf_counter() - t_slide64):.0f}ms" - ) - log0(f"final_int6_sliding_window_s64_exact val_loss:{sw64_val_loss:.8f} val_bpb:{sw64_val_bpb:.8f}") - - if distributed: - dist.destroy_process_group() - - -if __name__ == "__main__": - main() - -==================================================================================================== -Running Python 3.12.3 (main, Nov 6 2025, 13:44:16) [GCC 13.3.0] -Running PyTorch 2.9.1+cu128 -Sat Mar 21 05:17:26 2026 -+-----------------------------------------------------------------------------------------+ -| NVIDIA-SMI 580.126.09 Driver Version: 580.126.09 CUDA Version: 13.0 | -+-----------------------------------------+------------------------+----------------------+ -| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | -| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | -| | | MIG M. | -|=========================================+========================+======================| -| 0 NVIDIA H100 80GB HBM3 On | 00000000:18:00.0 Off | 0 | -| N/A 31C P0 117W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 1 NVIDIA H100 80GB HBM3 On | 00000000:2A:00.0 Off | 0 | -| N/A 33C P0 124W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 2 NVIDIA H100 80GB HBM3 On | 00000000:3A:00.0 Off | 0 | -| N/A 32C P0 117W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 3 NVIDIA H100 80GB HBM3 On | 00000000:5D:00.0 Off | 0 | -| N/A 32C P0 116W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 4 NVIDIA H100 80GB HBM3 On | 00000000:9A:00.0 Off | 0 | -| N/A 30C P0 115W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 5 NVIDIA H100 80GB HBM3 On | 00000000:AB:00.0 Off | 0 | -| N/A 31C P0 115W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 6 NVIDIA H100 80GB HBM3 On | 00000000:BA:00.0 Off | 0 | -| N/A 29C P0 116W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 7 NVIDIA H100 80GB HBM3 On | 00000000:DB:00.0 Off | 0 | -| N/A 29C P0 116W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ - -+-----------------------------------------------------------------------------------------+ -| Processes: | -| GPU GI CI PID Type Process name GPU Memory | -| ID ID Usage | -|=========================================================================================| -| 0 N/A N/A 182911 C /usr/local/bin/python 1512MiB | -| 1 N/A N/A 182912 C /usr/local/bin/python 1512MiB | -| 2 N/A N/A 182913 C /usr/local/bin/python 1512MiB | -| 3 N/A N/A 182914 C /usr/local/bin/python 1512MiB | -| 4 N/A N/A 182915 C /usr/local/bin/python 1512MiB | -| 5 N/A N/A 182916 C /usr/local/bin/python 1512MiB | -| 6 N/A N/A 182917 C /usr/local/bin/python 1512MiB | -| 7 N/A N/A 182918 C /usr/local/bin/python 1512MiB | -+-----------------------------------------------------------------------------------------+ - -==================================================================================================== -val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model -train_loader:dataset:fineweb10B_sp1024 train_shards:80 -val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 -model_params:26829913 -mtp_num_heads:0 mtp_loss_weight:0.2 mtp_params:0 -world_size:8 grad_accum_steps:1 -sdp_backends:cudnn=False flash=True mem_efficient=False math=False -attention_mode:gqa num_heads:8 num_kv_heads:4 -tie_embeddings:True embed_lr:0.035 head_lr:0.0 matrix_lr:0.025 scalar_lr:0.025 -train_batch_tokens:786432 train_seq_len:2048 iterations:9000 warmup_steps:20 max_wallclock_seconds:600.000 -seed:1337 -warmup_step:1/20 -warmup_step:2/20 -warmup_step:3/20 -warmup_step:4/20 -warmup_step:5/20 -warmup_step:6/20 -warmup_step:7/20 -warmup_step:8/20 -warmup_step:9/20 -warmup_step:10/20 -warmup_step:11/20 -warmup_step:12/20 -warmup_step:13/20 -warmup_step:14/20 -warmup_step:15/20 -warmup_step:16/20 -warmup_step:17/20 -warmup_step:18/20 -warmup_step:19/20 -warmup_step:20/20 -step:1/9000 train_loss:6.9326 train_time:135ms step_avg:135.16ms -step:2/9000 train_loss:8.7164 train_time:211ms step_avg:105.38ms -step:3/9000 train_loss:7.9408 train_time:308ms step_avg:102.62ms -step:4/9000 train_loss:7.2228 train_time:400ms step_avg:99.94ms -step:5/9000 train_loss:6.9711 train_time:495ms step_avg:99.07ms -step:6/9000 train_loss:6.8209 train_time:585ms step_avg:97.55ms -step:7/9000 train_loss:6.7680 train_time:676ms step_avg:96.59ms -step:8/9000 train_loss:6.7224 train_time:783ms step_avg:97.83ms -step:9/9000 train_loss:6.3933 train_time:889ms step_avg:98.77ms -step:10/9000 train_loss:6.0724 train_time:980ms step_avg:98.02ms -step:200/9000 train_loss:2.4445 train_time:17039ms step_avg:85.20ms -step:400/9000 train_loss:2.4500 train_time:34038ms step_avg:85.10ms -step:600/9000 train_loss:2.3526 train_time:50912ms step_avg:84.85ms -step:800/9000 train_loss:2.2510 train_time:67865ms step_avg:84.83ms -step:1000/9000 train_loss:2.2865 train_time:84779ms step_avg:84.78ms -step:1200/9000 train_loss:2.3586 train_time:101770ms step_avg:84.81ms -step:1400/9000 train_loss:2.1866 train_time:118767ms step_avg:84.83ms -step:1600/9000 train_loss:2.0768 train_time:135727ms step_avg:84.83ms -step:1800/9000 train_loss:2.1621 train_time:152699ms step_avg:84.83ms -step:2000/9000 train_loss:2.0636 train_time:169606ms step_avg:84.80ms -step:2200/9000 train_loss:2.1335 train_time:186602ms step_avg:84.82ms -step:2400/9000 train_loss:2.0642 train_time:203517ms step_avg:84.80ms -step:2600/9000 train_loss:2.1083 train_time:220530ms step_avg:84.82ms -step:2800/9000 train_loss:2.1540 train_time:237515ms step_avg:84.83ms -step:3000/9000 train_loss:2.1609 train_time:254412ms step_avg:84.80ms -step:3200/9000 train_loss:2.1718 train_time:271409ms step_avg:84.82ms -step:3400/9000 train_loss:2.0208 train_time:288336ms step_avg:84.80ms -step:3600/9000 train_loss:2.0990 train_time:305419ms step_avg:84.84ms -step:3800/9000 train_loss:2.0744 train_time:322377ms step_avg:84.84ms -step:4000/9000 train_loss:1.9836 train_time:339605ms step_avg:84.90ms -step:4200/9000 train_loss:2.1648 train_time:356672ms step_avg:84.92ms -step:4400/9000 train_loss:2.0403 train_time:373641ms step_avg:84.92ms -step:4600/9000 train_loss:1.8473 train_time:390750ms step_avg:84.95ms -step:4800/9000 train_loss:2.4305 train_time:407747ms step_avg:84.95ms -step:5000/9000 train_loss:2.1102 train_time:424750ms step_avg:84.95ms -step:5200/9000 train_loss:2.0490 train_time:441670ms step_avg:84.94ms -step:5400/9000 train_loss:2.0527 train_time:458669ms step_avg:84.94ms -step:5600/9000 train_loss:1.9608 train_time:475703ms step_avg:84.95ms -step:5800/9000 train_loss:2.0024 train_time:492612ms step_avg:84.93ms -step:6000/9000 train_loss:1.9437 train_time:509621ms step_avg:84.94ms -step:6200/9000 train_loss:1.9595 train_time:526574ms step_avg:84.93ms -step:6400/9000 train_loss:2.0013 train_time:543613ms step_avg:84.94ms -step:6600/9000 train_loss:1.8487 train_time:560569ms step_avg:84.93ms -late_qat:enabled step:6764 scale:0.0997 -step:6800/9000 train_loss:2.0261 train_time:577613ms step_avg:84.94ms -step:7000/9000 train_loss:1.7927 train_time:594628ms step_avg:84.95ms -step:7063/9000 val_loss:1.9284 val_bpb:1.1421 train_time:600099ms step_avg:84.96ms -stopping_early: wallclock_cap train_time:600099ms step:7063/9000 -peak memory allocated: 20596 MiB reserved: 20652 MiB -ema:applying EMA weights -Serialized model: 105783807 bytes -Code size: 67617 bytes -Serialized model int6+zstd: 15571723 bytes -Total submission size int6+zstd: 15639340 bytes -final_int6_roundtrip val_loss:1.9395 val_bpb:1.1487 eval_time:5857ms -final_int6_roundtrip_exact val_loss:1.93952169 val_bpb:1.14869399 -final_int6_sliding_window val_loss:1.9000 val_bpb:1.1253 stride:64 eval_time:72610ms -final_int6_sliding_window_exact val_loss:1.89996173 val_bpb:1.12526734 diff --git a/records/track_10min_16mb/2026-03-21_11L_XSA4_EMA_PartialRoPE_LateQAT_1.1248/train_seed2025.log b/records/track_10min_16mb/2026-03-21_11L_XSA4_EMA_PartialRoPE_LateQAT_1.1248/train_seed2025.log deleted file mode 100644 index 767fbca9d6..0000000000 --- a/records/track_10min_16mb/2026-03-21_11L_XSA4_EMA_PartialRoPE_LateQAT_1.1248/train_seed2025.log +++ /dev/null @@ -1,1738 +0,0 @@ -""" -train_gpt_submit.py — Submission v2: wider MLP + STE int6 QAT + MTP + seq2048 + NTK RoPE + -fp16 embed + late-K passthrough + sliding window eval. -""" - -from __future__ import annotations - -import copy -import glob -import io -import math -import os -import random -import subprocess -import sys -import time -import uuid -import zlib -from pathlib import Path - -try: - import zstandard - _COMPRESSOR = "zstd" -except ImportError: - _COMPRESSOR = "zlib" - -import numpy as np -import sentencepiece as spm -import torch -import torch.distributed as dist -import torch.nn.functional as F -from torch import Tensor, nn -from torch.nn.parallel import DistributedDataParallel as DDP - -from flash_attn_interface import flash_attn_func as flash_attn_3_func - -# ----------------------------- -# HYPERPARAMETERS -# ----------------------------- -# Default Simple Baseline run: -# - 9 transformer blocks at width 512 -# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion -# - vocab size 1024, sequence length 1024, tied embeddings -# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap - -class Hyperparameters: - # Data paths are shard globs produced by the existing preprocessing pipeline. - data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") - train_files = os.path.join(data_path, "fineweb_train_*.bin") - val_files = os.path.join(data_path, "fineweb_val_*.bin") - tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") - run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) - seed = int(os.environ.get("SEED", 1337)) - - # Validation cadence and batch size. Validation always uses the full fineweb_val split. - val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) - val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) - train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) - - # Training length. - iterations = int(os.environ.get("ITERATIONS", 20000)) - warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) - warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) - train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 786_432)) - train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) - eval_seq_len = int(os.environ.get("EVAL_SEQ_LEN", 2048)) - max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) - qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) - - # Model shape. - vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) - num_layers = int(os.environ.get("NUM_LAYERS", 9)) - num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) - model_dim = int(os.environ.get("MODEL_DIM", 512)) - num_heads = int(os.environ.get("NUM_HEADS", 8)) - mlp_mult = float(os.environ.get("MLP_MULT", 3.0)) - tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) - rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) - logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) - - # Optimizer hyperparameters. - embed_lr = float(os.environ.get("EMBED_LR", 0.6)) - head_lr = float(os.environ.get("HEAD_LR", 0.008)) - tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) - tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) - matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) - scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) - muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) - muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) - muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) - muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) - beta1 = float(os.environ.get("BETA1", 0.9)) - beta2 = float(os.environ.get("BETA2", 0.95)) - adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) - grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) - eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) - mtp_num_heads = int(os.environ.get("MTP_NUM_HEADS", 0)) - mtp_loss_weight = float(os.environ.get("MTP_LOSS_WEIGHT", 0.2)) - muon_beta2 = float(os.environ.get("MUON_BETA2", 0.95)) - swa_enabled = bool(int(os.environ.get("SWA_ENABLED", "1"))) - swa_every = int(os.environ.get("SWA_EVERY", 200)) - muon_wd = float(os.environ.get("MUON_WD", 0.02)) - adam_wd = float(os.environ.get("ADAM_WD", 0.01)) - qat_enabled = bool(int(os.environ.get("QAT_ENABLED", "0"))) - xsa_last_n = int(os.environ.get("XSA_LAST_N", 0)) - ema_enabled = bool(int(os.environ.get("EMA_ENABLED", "0"))) - ema_decay = float(os.environ.get("EMA_DECAY", 0.997)) - rope_dims = int(os.environ.get("ROPE_DIMS", 0)) - ln_scale = bool(int(os.environ.get("LN_SCALE", "0"))) - late_qat = bool(int(os.environ.get("LATE_QAT", "0"))) - bigram_vocab_size = int(os.environ.get("BIGRAM_VOCAB_SIZE", 4096)) - bigram_dim = int(os.environ.get("BIGRAM_DIM", 128)) - -# ----------------------------- -# MUON OPTIMIZER -# ----------------------------- -# -# As borrowed from modded-nanogpt -# Background on Muon: https://kellerjordan.github.io/posts/muon/ - -def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: - a, b, c = (3.4445, -4.7750, 2.0315) - X = G.bfloat16() - X /= X.norm() + eps - transposed = G.size(0) > G.size(1) - if transposed: - X = X.T - for _ in range(steps): - A = X @ X.T - B = b * A + c * A @ A - X = a * X + B @ X - return X.T if transposed else X - - -class Muon(torch.optim.Optimizer): - def __init__(self, params, lr: float, momentum: float, backend_steps: int, - nesterov: bool = True, weight_decay: float = 0.0): - super().__init__( - params, - dict(lr=lr, momentum=momentum, backend_steps=backend_steps, - nesterov=nesterov, weight_decay=weight_decay), - ) - - @torch.no_grad() - def step(self, closure=None): - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - distributed = dist.is_available() and dist.is_initialized() - world_size = dist.get_world_size() if distributed else 1 - rank = dist.get_rank() if distributed else 0 - - for group in self.param_groups: - params = group["params"] - if not params: - continue - lr = group["lr"] - momentum = group["momentum"] - backend_steps = group["backend_steps"] - nesterov = group["nesterov"] - - total_params = sum(int(p.numel()) for p in params) - updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) - - curr = 0 - for i, p in enumerate(params): - if i % world_size == rank and p.grad is not None: - g = p.grad - state = self.state[p] - if "momentum_buffer" not in state: - state["momentum_buffer"] = torch.zeros_like(g) - buf = state["momentum_buffer"] - buf.mul_(momentum).add_(g) - if nesterov: - g = g.add(buf, alpha=momentum) - g = zeropower_via_newtonschulz5(g, steps=backend_steps) - g *= max(1, g.size(0) / g.size(1)) ** 0.5 - updates_flat[curr : curr + p.numel()] = g.reshape(-1) - curr += p.numel() - - if distributed: - dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) - - wd = group.get("weight_decay", 0.0) - curr = 0 - for p in params: - if wd > 0.0: - p.data.mul_(1.0 - lr * wd) - g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) - p.add_(g, alpha=-lr) - curr += p.numel() - - return loss - - -# ----------------------------- -# TOKENIZER-AGNOSTIC EVALUATION SETUP -# ----------------------------- -# -# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. -# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. -# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. -# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. - -def build_sentencepiece_luts( - sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device -) -> tuple[Tensor, Tensor, Tensor]: - sp_vocab_size = int(sp.vocab_size()) - table_size = max(sp_vocab_size, vocab_size) - base_bytes_np = np.zeros((table_size,), dtype=np.int16) - has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) - is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) - for token_id in range(sp_vocab_size): - if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): - continue - is_boundary_token_np[token_id] = False - if sp.is_byte(token_id): - base_bytes_np[token_id] = 1 - continue - piece = sp.id_to_piece(token_id) - if piece.startswith("▁"): - has_leading_space_np[token_id] = True - piece = piece[1:] - base_bytes_np[token_id] = len(piece.encode("utf-8")) - return ( - torch.tensor(base_bytes_np, dtype=torch.int16, device=device), - torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), - torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), - ) - - -def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: - files = [Path(p) for p in sorted(glob.glob(pattern))] - if not files: - raise FileNotFoundError(f"No files found for pattern: {pattern}") - # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. - tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() - usable = ((tokens.numel() - 1) // seq_len) * seq_len - if usable <= 0: - raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") - return tokens[: usable + 1] - - -def eval_val( - args: Hyperparameters, - model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - grad_accum_steps: int, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, - eval_seq_len: int | None = None, -) -> tuple[float, float]: - seq_len = eval_seq_len or args.train_seq_len - local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) - if local_batch_tokens < seq_len: - raise ValueError( - "VAL_BATCH_SIZE must provide at least one sequence per rank; " - f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " - f"GRAD_ACCUM_STEPS={grad_accum_steps}, seq_len={seq_len}" - ) - local_batch_seqs = local_batch_tokens // seq_len - total_seqs = (val_tokens.numel() - 1) // seq_len - seq_start = (total_seqs * rank) // world_size - seq_end = (total_seqs * (rank + 1)) // world_size - val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) - val_token_count = torch.zeros((), device=device, dtype=torch.float64) - val_byte_count = torch.zeros((), device=device, dtype=torch.float64) - - model.eval() - with torch.inference_mode(): - for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): - batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) - raw_start = batch_seq_start * seq_len - raw_end = batch_seq_end * seq_len + 1 - local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) - x = local[:-1].reshape(-1, seq_len) - y = local[1:].reshape(-1, seq_len) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - batch_loss = model(x, y).detach() - batch_token_count = float(y.numel()) - val_loss_sum += batch_loss.to(torch.float64) * batch_token_count - val_token_count += batch_token_count - prev_ids = x.reshape(-1) - tgt_ids = y.reshape(-1) - token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) - token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) - val_byte_count += token_bytes.to(torch.float64).sum() - - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) - - val_loss = val_loss_sum / val_token_count - bits_per_token = val_loss.item() / math.log(2.0) - tokens_per_byte = val_token_count.item() / val_byte_count.item() - model.train() - return float(val_loss.item()), float(bits_per_token * tokens_per_byte) - -# ----------------------------- -# POST-TRAINING QUANTIZATION -# ----------------------------- -# -# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. -# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. -# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. - -CONTROL_TENSOR_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "CONTROL_TENSOR_NAME_PATTERNS", - "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,smear", - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", - ",".join(CONTROL_TENSOR_NAME_PATTERNS), - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 -INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 -INT8_PER_ROW_SCALE_DTYPE = torch.float16 -INT8_CLIP_PERCENTILE = 99.99984 -INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 - -def tensor_nbytes(t: Tensor) -> int: - return int(t.numel()) * int(t.element_size()) - -def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: - if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): - return t.float().contiguous() - if t.dtype in {torch.float32, torch.bfloat16}: - passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") - return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() - return t - -def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: - t32 = t.float() - if t32.ndim == 2: - # Matrices get one scale per row, which usually tracks output-channel - # ranges much better than a single tensor-wide scale. - clip_abs = ( - torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) - if t32.numel() - else torch.empty((t32.shape[0],), dtype=torch.float32) - ) - clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) - scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) - q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() - return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() - - # Vectors / scalars use a simpler per-tensor scale. - clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 - scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) - q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() - return q, scale - -def quantize_state_dict_int8(state_dict: dict[str, Tensor]): - # Single supported clean-script export format: - # - per-row int8 for 2D float tensors - # - per-tensor int8 for other float tensors - # - exact passthrough for non-floats - # - passthrough for small float tensors, stored as fp16 to save bytes - quantized: dict[str, Tensor] = {} - scales: dict[str, Tensor] = {} - dtypes: dict[str, str] = {} - passthrough: dict[str, Tensor] = {} - passthrough_orig_dtypes: dict[str, str] = {} - qmeta: dict[str, dict[str, object]] = {} - stats = dict.fromkeys( - ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), - 0, - ) - - for name, tensor in state_dict.items(): - t = tensor.detach().to("cpu").contiguous() - stats["param_count"] += int(t.numel()) - stats["num_tensors"] += 1 - stats["baseline_tensor_bytes"] += tensor_nbytes(t) - - if not t.is_floating_point(): - stats["num_nonfloat_tensors"] += 1 - passthrough[name] = t - stats["int8_payload_bytes"] += tensor_nbytes(t) - continue - - # Small float tensors are cheap enough to keep directly. We still downcast - # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. - if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: - kept = keep_float_tensor(name, t, passthrough_orig_dtypes) - passthrough[name] = kept - stats["int8_payload_bytes"] += tensor_nbytes(kept) - continue - - stats["num_float_tensors"] += 1 - q, s = quantize_float_tensor(t) - if s.ndim > 0: - qmeta[name] = {"scheme": "per_row", "axis": 0} - quantized[name] = q - scales[name] = s - dtypes[name] = str(t.dtype).removeprefix("torch.") - stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) - - obj: dict[str, object] = { - "__quant_format__": "int8_clean_per_row_v1", - "quantized": quantized, - "scales": scales, - "dtypes": dtypes, - "passthrough": passthrough, - } - if qmeta: - obj["qmeta"] = qmeta - if passthrough_orig_dtypes: - obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes - return obj, stats - -def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: - out: dict[str, Tensor] = {} - qmeta = obj.get("qmeta", {}) - passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) - for name, q in obj["quantized"].items(): - dtype = getattr(torch, obj["dtypes"][name]) - s = obj["scales"][name] - if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: - s = s.to(dtype=torch.float32) - # Broadcast the saved row scale back across trailing dimensions. - out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() - else: - scale = float(s.item()) - out[name] = (q.float() * scale).to(dtype=dtype).contiguous() - for name, t in obj["passthrough"].items(): - # Restore small tensors, undoing the temporary fp16 storage cast if needed. - out_t = t.detach().to("cpu").contiguous() - orig_dtype = passthrough_orig_dtypes.get(name) - if isinstance(orig_dtype, str): - out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() - out[name] = out_t - return out - - -# ----------------------------- -# DATA LOADING -# ----------------------------- - -def load_data_shard(file: Path) -> Tensor: - header_bytes = 256 * np.dtype(" None: - self.file_idx = (self.file_idx + 1) % len(self.files) - self.tokens = load_data_shard(self.files[self.file_idx]) - self.pos = 0 - - def take(self, n: int) -> Tensor: - chunks: list[Tensor] = [] - remaining = n - while remaining > 0: - avail = self.tokens.numel() - self.pos - if avail <= 0: - self._advance_file() - continue - k = min(remaining, avail) - chunks.append(self.tokens[self.pos : self.pos + k]) - self.pos += k - remaining -= k - return chunks[0] if len(chunks) == 1 else torch.cat(chunks) - - -class DistributedTokenLoader: - # Each call consumes a contiguous chunk from the shared token stream, then slices out - # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. - def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): - self.rank = rank - self.world_size = world_size - self.device = device - self.stream = TokenStream(pattern) - - def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: - local_tokens = global_tokens // (self.world_size * grad_accum_steps) - per_rank_span = local_tokens + 1 - chunk = self.stream.take(per_rank_span * self.world_size) - start = self.rank * per_rank_span - local = chunk[start : start + per_rank_span].to(dtype=torch.int64) - x = local[:-1].reshape(-1, seq_len) - y = local[1:].reshape(-1, seq_len) - return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) - -# ----------------------------- -# TRANSFORMER MODULES -# ----------------------------- - -class RMSNorm(nn.Module): - def __init__(self, eps: float | None = None): - super().__init__() - self.eps = eps - - def forward(self, x: Tensor) -> Tensor: - return F.rms_norm(x, (x.size(-1),), eps=self.eps) - - -class CastedLinear(nn.Linear): - _qat_enabled: bool = False - - def forward(self, x: Tensor) -> Tensor: - w = self.weight.to(x.dtype) - if CastedLinear._qat_enabled and self.training and w.ndim == 2: - with torch.no_grad(): - w32 = self.weight.float() - row_max = w32.abs().amax(dim=1) - scale = (row_max / 31.0).clamp_min(1.0 / 31.0) - w_q = (torch.clamp(torch.round(w32 / scale[:, None]), -32, 31) * scale[:, None]).to(x.dtype) - w = w + (w_q - w).detach() - bias = self.bias.to(x.dtype) if self.bias is not None else None - return F.linear(x, w, bias) - - -def restore_low_dim_params_to_fp32(module: nn.Module) -> None: - # Keep small/control parameters in fp32 even when the model body runs in bf16. - with torch.no_grad(): - for name, param in module.named_parameters(): - if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: - param.data = param.data.float() - - -class Rotary(nn.Module): - # NTK-aware RoPE: auto-scales base frequency when seq_len exceeds train_seq_len. - def __init__(self, dim: int, base: float = 10000.0, train_seq_len: int = 1024, rope_dims: int = 0): - super().__init__() - self.rope_dims = rope_dims if rope_dims > 0 else dim - self.dim = dim - self.base = base - self.train_seq_len = train_seq_len - rd = self.rope_dims - inv_freq = 1.0 / (base ** (torch.arange(0, rd, 2, dtype=torch.float32) / rd)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self._seq_len_cached = 0 - self._cos_cached: Tensor | None = None - self._sin_cached: Tensor | None = None - - def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: - if ( - self._cos_cached is None - or self._sin_cached is None - or self._seq_len_cached != seq_len - or self._cos_cached.device != device - ): - rd = self.rope_dims - if seq_len > self.train_seq_len: - scale = seq_len / self.train_seq_len - new_base = self.base * (scale ** (rd / (rd - 2))) - inv_freq = 1.0 / (new_base ** (torch.arange(0, rd, 2, dtype=torch.float32, device=device) / rd)) - else: - inv_freq = self.inv_freq.to(device) - t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) - freqs = torch.outer(t, inv_freq) - self._cos_cached = freqs.cos()[None, :, None, :] - self._sin_cached = freqs.sin()[None, :, None, :] - self._seq_len_cached = seq_len - return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) - - -def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: - rd = cos.size(-1) * 2 - if rd < x.size(-1): - x_rope, x_pass = x[..., :rd], x[..., rd:] - half = rd // 2 - x1, x2 = x_rope[..., :half], x_rope[..., half:] - x_rot = torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) - return torch.cat((x_rot, x_pass), dim=-1) - half = x.size(-1) // 2 - x1, x2 = x[..., :half], x[..., half:] - return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) - - -class CausalSelfAttention(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - rope_base: float, - qk_gain_init: float, - rope_dims: int = 0, - ): - super().__init__() - if dim % num_heads != 0: - raise ValueError("model_dim must be divisible by num_heads") - if num_heads % num_kv_heads != 0: - raise ValueError("num_heads must be divisible by num_kv_heads") - self.num_heads = num_heads - self.num_kv_heads = num_kv_heads - self.head_dim = dim // num_heads - if self.head_dim % 2 != 0: - raise ValueError("head_dim must be even for RoPE") - kv_dim = self.num_kv_heads * self.head_dim - self.c_q = CastedLinear(dim, dim, bias=False) - self.c_k = CastedLinear(dim, kv_dim, bias=False) - self.c_v = CastedLinear(dim, kv_dim, bias=False) - self.proj = CastedLinear(dim, dim, bias=False) - self.proj._zero_init = True - self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) - self.rope_dims = rope_dims - self.rotary = Rotary(self.head_dim, base=rope_base, train_seq_len=1024, rope_dims=rope_dims) - self.use_xsa = False - - def _xsa_efficient(self, y: Tensor, v: Tensor) -> Tensor: - """Subtract self-value projection via GQA-aware reshape (no repeat_interleave).""" - B, T, H, D = y.shape - Hkv = v.size(-2) - group = H // Hkv - y_g = y.reshape(B, T, Hkv, group, D) - vn = F.normalize(v, dim=-1).unsqueeze(-2) - proj = (y_g * vn).sum(dim=-1, keepdim=True) * vn - return (y_g - proj).reshape(B, T, H, D) - - def forward(self, x: Tensor) -> Tensor: - bsz, seqlen, dim = x.shape - q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim) - k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) - v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) - q = F.rms_norm(q, (q.size(-1),)) - k = F.rms_norm(k, (k.size(-1),)) - cos, sin = self.rotary(seqlen, x.device, q.dtype) - q = apply_rotary_emb(q, cos, sin) - k = apply_rotary_emb(k, cos, sin) - q = q * self.q_gain.to(dtype=q.dtype)[None, None, :, None] - y = flash_attn_3_func(q, k, v, causal=True) - if self.use_xsa: - y = self._xsa_efficient(y, v) - y = y.reshape(bsz, seqlen, dim) - return self.proj(y) - - -class SmearGate(nn.Module): - def __init__(self, dim: int): - super().__init__() - self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) - - def forward(self, x: Tensor) -> Tensor: - g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] - x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) - return (1 - g) * x + g * x_prev - - -class BigramHashEmbedding(nn.Module): - def __init__(self, bigram_vocab_size: int, bigram_dim: int, model_dim: int): - super().__init__() - self.bigram_vocab_size = bigram_vocab_size - self.embed = nn.Embedding(bigram_vocab_size, bigram_dim) - nn.init.zeros_(self.embed.weight) - self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None - if self.proj is not None: - nn.init.zeros_(self.proj.weight) - self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) - - def bigram_hash(self, tokens: Tensor) -> Tensor: - t = tokens.to(torch.int32) - mod = self.bigram_vocab_size - 1 - out = torch.empty_like(t) - out[..., 0] = mod - out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod - return out.long() - - def forward(self, token_ids: Tensor) -> Tensor: - h = self.embed(self.bigram_hash(token_ids)) - if self.proj is not None: - h = self.proj(h) - return h * self.scale.to(dtype=h.dtype) - - -class MLP(nn.Module): - def __init__(self, dim: int, mlp_mult: int): - super().__init__() - hidden = int(mlp_mult * dim) - self.fc = CastedLinear(dim, hidden, bias=False) - self.proj = CastedLinear(hidden, dim, bias=False) - self.proj._zero_init = True - - def forward(self, x: Tensor) -> Tensor: - x = torch.relu(self.fc(x)) - return self.proj(x.square()) - - -class Block(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: int, - rope_base: float, - qk_gain_init: float, - rope_dims: int = 0, - layer_idx: int = 0, - ln_scale: bool = False, - ): - super().__init__() - self.attn_norm = RMSNorm() - self.mlp_norm = RMSNorm() - self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init, rope_dims=rope_dims) - self.mlp = MLP(dim, mlp_mult) - self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) - self.ln_scale_factor = 1.0 / math.sqrt(layer_idx + 1) if ln_scale else 1.0 - - def forward(self, x: Tensor, x0: Tensor) -> Tensor: - mix = self.resid_mix.to(dtype=x.dtype) - x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 - s = self.ln_scale_factor - attn_out = self.attn(self.attn_norm(x) * s) - x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out - x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x) * s) - return x - - -class GPT(nn.Module): - def __init__( - self, - vocab_size: int, - num_layers: int, - model_dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: int, - tie_embeddings: bool, - tied_embed_init_std: float, - logit_softcap: float, - rope_base: float, - qk_gain_init: float, - mtp_num_heads: int = 0, - mtp_loss_weight: float = 0.1, - bigram_vocab_size: int = 0, - bigram_dim: int = 128, - xsa_last_n: int = 0, - rope_dims: int = 0, - ln_scale: bool = False, - ): - super().__init__() - if logit_softcap <= 0.0: - raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") - self.tie_embeddings = tie_embeddings - self.tied_embed_init_std = tied_embed_init_std - self.logit_softcap = logit_softcap - self.mtp_num_heads = mtp_num_heads - self.mtp_loss_weight = mtp_loss_weight - self.tok_emb = nn.Embedding(vocab_size, model_dim) - self.bigram = BigramHashEmbedding(bigram_vocab_size, bigram_dim, model_dim) if bigram_vocab_size > 0 else None - self.smear = SmearGate(model_dim) - self.num_encoder_layers = num_layers // 2 - self.num_decoder_layers = num_layers - self.num_encoder_layers - self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) - self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) - self.blocks = nn.ModuleList( - [ - Block( - model_dim, - num_heads, - num_kv_heads, - mlp_mult, - rope_base, - qk_gain_init, - rope_dims=rope_dims, - layer_idx=i, - ln_scale=ln_scale, - ) - for i in range(num_layers) - ] - ) - self.final_norm = RMSNorm() - self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) - if self.lm_head is not None: - self.lm_head._zero_init = True - self.mtp_heads = nn.ModuleList( - [CastedLinear(model_dim, vocab_size, bias=False) for _ in range(mtp_num_heads)] - ) - for head in self.mtp_heads: - head._zero_init = True - if xsa_last_n > 0: - for i in range(max(0, num_layers - xsa_last_n), num_layers): - self.blocks[i].attn.use_xsa = True - self._init_weights() - - def _init_weights(self) -> None: - if self.tie_embeddings: - nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) - num_layers = len(self.blocks) - for name, module in self.named_modules(): - if isinstance(module, nn.Linear): - if getattr(module, "_zero_init", False): - nn.init.zeros_(module.weight) - elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: - nn.init.orthogonal_(module.weight, gain=1.0) - if ".proj." in name or name.endswith(".proj"): - with torch.no_grad(): - module.weight.mul_(1.0 / math.sqrt(2 * num_layers)) - - def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: - x = self.tok_emb(input_ids) - if self.bigram is not None: - x = x + self.bigram(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - x = self.smear(x) - x0 = x - skips: list[Tensor] = [] - - for i in range(self.num_encoder_layers): - x = self.blocks[i](x, x0) - skips.append(x) - for i in range(self.num_decoder_layers): - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - x = self.blocks[self.num_encoder_layers + i](x, x0) - - x = self.final_norm(x) - x_flat = x.reshape(-1, x.size(-1)) - targets = target_ids.reshape(-1) - if self.tie_embeddings: - logits_proj = F.linear(x_flat, self.tok_emb.weight) - else: - if self.lm_head is None: - raise RuntimeError("lm_head is required when tie_embeddings=False") - logits_proj = self.lm_head(x_flat) - logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - main_loss = F.cross_entropy(logits.float(), targets, reduction="mean") - - if self.training and self.mtp_num_heads > 0 and self.mtp_loss_weight > 0.0: - _, seqlen, dim = x.shape - mtp_loss_sum = x.new_zeros(()) - mtp_loss_count = 0 - for k, mtp_head in enumerate(self.mtp_heads): - valid_t = seqlen - (k + 1) - if valid_t <= 0: - continue - mtp_hidden = x[:, :valid_t, :].reshape(-1, dim) - mtp_targets = target_ids[:, k + 1 :].reshape(-1) - mtp_logits_proj = mtp_head(mtp_hidden) - mtp_logits = self.logit_softcap * torch.tanh(mtp_logits_proj / self.logit_softcap) - mtp_loss_sum = mtp_loss_sum + F.cross_entropy(mtp_logits.float(), mtp_targets, reduction="mean") - mtp_loss_count += 1 - if mtp_loss_count > 0: - main_loss = main_loss + self.mtp_loss_weight * (mtp_loss_sum / mtp_loss_count) - - return main_loss - - def forward_logits(self, input_ids: Tensor) -> Tensor: - """Return logits (bsz, seq_len, vocab) without computing loss.""" - x = self.tok_emb(input_ids) - if self.bigram is not None: - x = x + self.bigram(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - x = self.smear(x) - x0 = x - skips: list[Tensor] = [] - for i in range(self.num_encoder_layers): - x = self.blocks[i](x, x0) - skips.append(x) - for i in range(self.num_decoder_layers): - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - x = self.blocks[self.num_encoder_layers + i](x, x0) - x = self.final_norm(x) - if self.tie_embeddings: - logits_proj = F.linear(x, self.tok_emb.weight) - else: - logits_proj = self.lm_head(x) - return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - - -# ----------------------------- -# SLIDING WINDOW EVALUATION -# ----------------------------- - -def eval_val_sliding( - args: Hyperparameters, - base_model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, - stride: int, - batch_seqs: int = 32, - eval_seq_len: int | None = None, -) -> tuple[float, float]: - """Sliding window evaluation: each token scored with maximum context.""" - seq_len = eval_seq_len or args.train_seq_len - total_tokens = val_tokens.numel() - 1 - - window_starts = [ws for ws in range(0, total_tokens, stride) - if min(ws + seq_len, total_tokens) - ws >= 1] - total_windows = len(window_starts) - - my_s = (total_windows * rank) // world_size - my_e = (total_windows * (rank + 1)) // world_size - my_windows = window_starts[my_s:my_e] - - loss_sum = torch.zeros((), device=device, dtype=torch.float64) - token_count = torch.zeros((), device=device, dtype=torch.float64) - byte_count = torch.zeros((), device=device, dtype=torch.float64) - - base_model.eval() - compiled_logits = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) - - with torch.inference_mode(): - for bi in range(0, len(my_windows), batch_seqs): - batch_ws = my_windows[bi:bi + batch_seqs] - bsz = len(batch_ws) - - x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) - y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) - wlens: list[int] = [] - - for i, ws in enumerate(batch_ws): - end = min(ws + seq_len, total_tokens) - wlen = end - ws - wlens.append(wlen) - chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) - x_batch[i, :wlen] = chunk[:-1] - y_batch[i, :wlen] = chunk[1:] - - with torch.autocast(device_type="cuda", dtype=torch.bfloat16): - logits = compiled_logits(x_batch) - - nll = F.cross_entropy( - logits.reshape(-1, logits.size(-1)).float(), - y_batch.reshape(-1), - reduction="none", - ).reshape(bsz, seq_len) - - for i, ws in enumerate(batch_ws): - wlen = wlens[i] - s = 0 if ws == 0 else max(wlen - stride, 0) - scored_nll = nll[i, s:wlen].to(torch.float64) - loss_sum += scored_nll.sum() - token_count += float(wlen - s) - tgt = y_batch[i, s:wlen] - prev = x_batch[i, s:wlen] - tb = base_bytes_lut[tgt].to(torch.float64) - tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) - byte_count += tb.sum() - - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) - - val_loss = (loss_sum / token_count).item() - bits_per_token = val_loss / math.log(2.0) - tokens_per_byte = token_count.item() / byte_count.item() - base_model.train() - return val_loss, bits_per_token * tokens_per_byte - - -# ----------------------------- -# INT6 MIXED QUANTIZATION (transplanted from working diagnostic scripts) -# ----------------------------- - -def _classify_param(name: str) -> str: - if "tok_emb" in name or "lm_head" in name: - return "embed" - if ".mlp." in name: - return "mlp" - if ".attn." in name or (".proj." in name and ".mlp." not in name): - return "attn" - return "other" - -def quantize_int6_per_row(t: Tensor) -> tuple[Tensor, Tensor]: - t32 = t.float() - if t32.ndim == 2: - row_max = t32.abs().amax(dim=1) - scale = (row_max / 31.0).clamp_min(1.0 / 31.0).to(torch.float16) - q = torch.clamp(torch.round(t32 / scale.float()[:, None]), -32, 31).to(torch.int8) - return q, scale - amax = t32.abs().max().item() - scale = torch.tensor(amax / 31.0 if amax > 0 else 1.0, dtype=torch.float16) - q = torch.clamp(torch.round(t32 / scale.float()), -32, 31).to(torch.int8) - return q, scale - -def mixed_quantize_int6(state_dict: dict[str, Tensor], int6_cats: set[str]): - num_layers_total = max( - (int(k.split(".")[1]) for k in state_dict if k.startswith("blocks.")), - default=0, - ) + 1 - late_k_layers = set(range(num_layers_total - 2, num_layers_total)) - - result: dict[str, Tensor] = {} - meta: dict[str, object] = {} - for name, tensor in state_dict.items(): - t = tensor.detach().cpu().contiguous() - cat = _classify_param(name) - if not t.is_floating_point() or t.numel() <= 65536: - result[name] = t.to(torch.float16) if t.is_floating_point() else t - meta[name] = "passthrough" - continue - if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): - result[name] = t.float() - meta[name] = "passthrough_ctrl" - continue - # tok_emb.weight falls through to int8 via "embed" category - if cat in int6_cats and t.ndim >= 1: - q, s = quantize_int6_per_row(t) - result[name + ".q"] = q - result[name + ".scale"] = s - meta[name] = {"type": "int6"} - else: - q, s = quantize_float_tensor(t) - result[name + ".q"] = q - result[name + ".scale"] = s - meta[name] = {"type": "int8"} - return result, meta - -def dequantize_mixed_int6(result: dict[str, Tensor], meta: dict[str, object], - template_sd: dict[str, Tensor]) -> dict[str, Tensor]: - out: dict[str, Tensor] = {} - for name, orig in template_sd.items(): - info = meta.get(name) - if info is None: - continue - orig_dtype = orig.dtype - if info in ("passthrough", "passthrough_ctrl", "passthrough_fp16"): - t = result[name] - if t.dtype == torch.float16 and orig_dtype in (torch.float32, torch.bfloat16): - t = t.to(orig_dtype) - out[name] = t - continue - q, s = result[name + ".q"], result[name + ".scale"] - if s.ndim > 0: - out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig_dtype) - else: - out[name] = (q.float() * float(s.item())).to(orig_dtype) - return out - - -# ----------------------------- -# TRAINING -# ----------------------------- - -def main() -> None: - global zeropower_via_newtonschulz5 - - code = Path(__file__).read_text(encoding="utf-8") - args = Hyperparameters() - zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) - - # ----------------------------- - # DISTRIBUTED + CUDA SETUP - # ----------------------------- - - distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ - rank = int(os.environ.get("RANK", "0")) - world_size = int(os.environ.get("WORLD_SIZE", "1")) - local_rank = int(os.environ.get("LOCAL_RANK", "0")) - if world_size <= 0: - raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") - if 8 % world_size != 0: - raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") - grad_accum_steps = 8 // world_size - grad_scale = 1.0 / grad_accum_steps - if not torch.cuda.is_available(): - raise RuntimeError("CUDA is required") - device = torch.device("cuda", local_rank) - torch.cuda.set_device(device) - if distributed: - dist.init_process_group(backend="nccl", device_id=device) - dist.barrier() - master_process = rank == 0 - - # Fast math knobs - torch.backends.cuda.matmul.allow_tf32 = True - torch.backends.cudnn.allow_tf32 = True - from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp - - enable_cudnn_sdp(False) - enable_flash_sdp(True) - enable_mem_efficient_sdp(False) - enable_math_sdp(False) - - logfile = None - if master_process: - os.makedirs("logs", exist_ok=True) - logfile = f"logs/{args.run_id}.txt" - print(logfile) - - def log0(msg: str, console: bool = True) -> None: - if not master_process: - return - if console: - print(msg) - if logfile is not None: - with open(logfile, "a", encoding="utf-8") as f: - print(msg, file=f) - - log0(code, console=False) - log0("=" * 100, console=False) - log0(f"Running Python {sys.version}", console=False) - log0(f"Running PyTorch {torch.__version__}", console=False) - log0( - subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, - console=False, - ) - log0("=" * 100, console=False) - - # ----------------------------- - # TOKENIZER + VALIDATION METRIC SETUP - # ----------------------------- - - random.seed(args.seed) - np.random.seed(args.seed) - torch.manual_seed(args.seed) - torch.cuda.manual_seed_all(args.seed) - - if not args.tokenizer_path.endswith(".model"): - raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") - sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) - if int(sp.vocab_size()) != args.vocab_size: - raise ValueError( - f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" - ) - dataset_dir = Path(args.data_path).resolve() - actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) - effective_eval_seq_len = args.eval_seq_len if args.eval_seq_len > 0 else args.train_seq_len - val_seq_len = max(args.train_seq_len, effective_eval_seq_len) - val_tokens = load_validation_tokens(args.val_files, val_seq_len) - base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( - sp, args.vocab_size, device - ) - log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") - log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") - log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") - - # ----------------------------- - # MODEL + OPTIMIZER SETUP - # ----------------------------- - - CastedLinear._qat_enabled = args.qat_enabled - - base_model = GPT( - vocab_size=args.vocab_size, - num_layers=args.num_layers, - model_dim=args.model_dim, - num_heads=args.num_heads, - num_kv_heads=args.num_kv_heads, - mlp_mult=args.mlp_mult, - tie_embeddings=args.tie_embeddings, - tied_embed_init_std=args.tied_embed_init_std, - logit_softcap=args.logit_softcap, - rope_base=args.rope_base, - qk_gain_init=args.qk_gain_init, - mtp_num_heads=args.mtp_num_heads, - mtp_loss_weight=args.mtp_loss_weight, - bigram_vocab_size=args.bigram_vocab_size, - bigram_dim=args.bigram_dim, - xsa_last_n=args.xsa_last_n, - rope_dims=args.rope_dims, - ln_scale=args.ln_scale, - ).to(device).bfloat16() - for module in base_model.modules(): - if isinstance(module, CastedLinear): - module.float() - restore_low_dim_params_to_fp32(base_model) - compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) - model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model - - # Optimizer split: - # - token embedding (Adam) uses EMBED_LR - # - untied lm_head (Adam) uses HEAD_LR - # - matrix params in transformer blocks use MATRIX_LR via Muon - # - vectors/scalars use SCALAR_LR via Adam - block_named_params = list(base_model.blocks.named_parameters()) - matrix_params = [ - p - for name, p in block_named_params - if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - if base_model.mtp_num_heads > 0: - matrix_params.extend([p for p in base_model.mtp_heads.parameters() if p.ndim == 2]) - scalar_params = [ - p - for name, p in block_named_params - if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - if base_model.skip_weights.numel() > 0: - scalar_params.append(base_model.skip_weights) - scalar_params.append(base_model.smear.gate) - if base_model.bigram is not None: - scalar_params.append(base_model.bigram.scale) - token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr - tok_params = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] - if base_model.bigram is not None: - tok_params.append({"params": [base_model.bigram.embed.weight], "lr": token_lr, "base_lr": token_lr}) - if base_model.bigram.proj is not None: - matrix_params.append(base_model.bigram.proj.weight) - optimizer_tok = torch.optim.AdamW( - tok_params, - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - weight_decay=args.adam_wd, - fused=True, - ) - optimizer_muon = Muon( - matrix_params, - lr=args.matrix_lr, - momentum=args.muon_momentum, - backend_steps=args.muon_backend_steps, - weight_decay=args.muon_wd, - ) - for group in optimizer_muon.param_groups: - group["base_lr"] = args.matrix_lr - optimizer_scalar = torch.optim.AdamW( - [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - weight_decay=args.adam_wd, - fused=True, - ) - optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] - if base_model.lm_head is not None: - optimizer_head = torch.optim.Adam( - [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizers.insert(1, optimizer_head) - - n_params = sum(p.numel() for p in base_model.parameters()) - mtp_params = sum(p.numel() for p in base_model.mtp_heads.parameters()) - log0(f"model_params:{n_params}") - log0(f"mtp_num_heads:{args.mtp_num_heads} mtp_loss_weight:{args.mtp_loss_weight} mtp_params:{mtp_params}") - log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") - log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") - log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") - log0( - f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " - f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " - f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" - ) - log0( - f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " - f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " - f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" - ) - log0(f"seed:{args.seed}") - - # ----------------------------- - # DATA LOADER & MODEL WARMUP - # ----------------------------- - - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - def zero_grad_all() -> None: - for opt in optimizers: - opt.zero_grad(set_to_none=True) - - max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None - - def lr_mul(step: int, elapsed_ms: float) -> float: - if args.warmdown_iters <= 0: - return 1.0 - if max_wallclock_ms is None: - warmdown_start = max(args.iterations - args.warmdown_iters, 0) - return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 - step_ms = elapsed_ms / max(step, 1) - warmdown_ms = args.warmdown_iters * step_ms - remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) - return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 - - # Warmup primes the compiled forward/backward/optimizer paths, then we restore the - # initial weights/optimizer state so measured training starts from the true init. - if args.warmup_steps > 0: - initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} - initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] - model.train() - for warmup_step in range(args.warmup_steps): - zero_grad_all() - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - warmup_loss = model(x, y) - (warmup_loss * grad_scale).backward() - for opt in optimizers: - opt.step() - zero_grad_all() - if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: - log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") - base_model.load_state_dict(initial_model_state, strict=True) - for opt, state in zip(optimizers, initial_optimizer_states, strict=True): - opt.load_state_dict(state) - zero_grad_all() - if distributed: - model.require_backward_grad_sync = True - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - # ----------------------------- - # MAIN TRAINING LOOP - # ----------------------------- - - swa_state: dict[str, Tensor] | None = None - swa_count = 0 - - ema_state: dict[str, Tensor] | None = None - if args.ema_enabled: - ema_state = {name: t.detach().float().clone() for name, t in base_model.state_dict().items()} - - training_time_ms = 0.0 - stop_after_step: int | None = None - torch.cuda.synchronize() - t0 = time.perf_counter() - - step = 0 - while True: - last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) - - should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) - if should_validate: - torch.cuda.synchronize() - training_time_ms += 1000.0 * (time.perf_counter() - t0) - val_loss, val_bpb = eval_val( - args, - model, - rank, - world_size, - device, - grad_accum_steps, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, - ) - log0( - f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " - f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" - ) - torch.cuda.synchronize() - t0 = time.perf_counter() - - if last_step: - if stop_after_step is not None and step < args.iterations: - log0( - f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " - f"step:{step}/{args.iterations}" - ) - break - - elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - scale = lr_mul(step, elapsed_ms) - qat_threshold = float(os.environ.get("QAT_THRESHOLD", "0.1")) - if args.late_qat and scale < qat_threshold and not CastedLinear._qat_enabled: - CastedLinear._qat_enabled = True - log0(f"late_qat:enabled step:{step} scale:{scale:.4f}") - zero_grad_all() - train_loss = torch.zeros((), device=device) - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - loss = model(x, y) - train_loss += loss.detach() - (loss * grad_scale).backward() - train_loss /= grad_accum_steps - - frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 - muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum - for group in optimizer_muon.param_groups: - group["momentum"] = muon_momentum - - for opt in optimizers: - for group in opt.param_groups: - group["lr"] = group["base_lr"] * scale - - if args.grad_clip_norm > 0: - torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) - for opt in optimizers: - opt.step() - zero_grad_all() - - step += 1 - approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - - if ema_state is not None: - d = args.ema_decay - with torch.no_grad(): - for name, t in base_model.state_dict().items(): - ema_state[name].mul_(d).add_(t.detach().float(), alpha=1.0 - d) - - if args.swa_enabled and not args.ema_enabled and scale < 0.5 and step % args.swa_every == 0: - if swa_state is None: - swa_state = {name: t.detach().float().clone() for name, t in base_model.state_dict().items()} - swa_count = 1 - log0(f"swa:start step:{step}") - else: - for name, t in base_model.state_dict().items(): - swa_state[name].add_(t.detach().float()) - swa_count += 1 - - should_log_train = ( - args.train_log_every > 0 - and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) - ) - if should_log_train: - log0( - f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " - f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" - ) - - # Needed to sync whether we've reached the wallclock cap. - reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms - if distributed and max_wallclock_ms is not None: - reached_cap_tensor = torch.tensor(int(reached_cap), device=device) - dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) - reached_cap = bool(reached_cap_tensor.item()) - if stop_after_step is None and reached_cap: - stop_after_step = step - - log0( - f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " - f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" - ) - - if ema_state is not None: - log0("ema:applying EMA weights") - avg_state = {name: t.to(dtype=base_model.state_dict()[name].dtype) - for name, t in ema_state.items()} - del ema_state - base_model.load_state_dict(avg_state, strict=True) - elif args.swa_enabled and swa_state is not None and swa_count > 1: - log0(f"swa:applying averaged {swa_count} checkpoints") - avg_state = {name: (t / swa_count).to(dtype=base_model.state_dict()[name].dtype) - for name, t in swa_state.items()} - del swa_state - base_model.load_state_dict(avg_state, strict=True) - - # ----------------------------- - # SERIALIZATION + ROUNDTRIP VALIDATION - # ----------------------------- - - full_state_dict = base_model.state_dict() - export_sd = {k: v for k, v in full_state_dict.items() if "mtp_heads" not in k} - excluded_mtp = sum(int(t.numel()) for k, t in full_state_dict.items() if "mtp_heads" in k) - if excluded_mtp > 0: - log0(f"export_excluding_mtp_params:{excluded_mtp}") - - if master_process: - torch.save(export_sd, "final_model.pt") - model_bytes = os.path.getsize("final_model.pt") - code_bytes = len(code.encode("utf-8")) - log0(f"Serialized model: {model_bytes} bytes") - log0(f"Code size: {code_bytes} bytes") - - sd_cpu = {k: v.detach().cpu() for k, v in export_sd.items()} - quant_result, quant_meta = mixed_quantize_int6(sd_cpu, {"mlp", "attn"}) - quant_buf = io.BytesIO() - torch.save({"w": quant_result, "m": quant_meta}, quant_buf) - quant_raw = quant_buf.getvalue() - quant_blob = zstandard.ZstdCompressor(level=22).compress(quant_raw) if _COMPRESSOR == "zstd" else zlib.compress(quant_raw, 9) - if master_process: - with open("final_model.int6.ptz", "wb") as f: - f.write(quant_blob) - quant_file_bytes = len(quant_blob) - code_bytes = len(code.encode("utf-8")) - log0(f"Serialized model int6+{_COMPRESSOR}: {quant_file_bytes} bytes") - log0(f"Total submission size int6+{_COMPRESSOR}: {quant_file_bytes + code_bytes} bytes") - - # Roundtrip: decompress + dequantize into fresh model + eval - if distributed: - dist.barrier() - with open("final_model.int6.ptz", "rb") as f: - quant_blob_disk = f.read() - quant_state = torch.load( - io.BytesIO(zstandard.ZstdDecompressor().decompress(quant_blob_disk) if _COMPRESSOR == "zstd" else zlib.decompress(quant_blob_disk)), - map_location="cpu", - ) - deq_state = dequantize_mixed_int6(quant_state["w"], quant_state["m"], sd_cpu) - - eval_model = GPT( - vocab_size=args.vocab_size, num_layers=args.num_layers, model_dim=args.model_dim, - num_heads=args.num_heads, num_kv_heads=args.num_kv_heads, mlp_mult=args.mlp_mult, - tie_embeddings=args.tie_embeddings, tied_embed_init_std=args.tied_embed_init_std, - logit_softcap=args.logit_softcap, rope_base=args.rope_base, qk_gain_init=args.qk_gain_init, - mtp_num_heads=0, mtp_loss_weight=0.0, - bigram_vocab_size=args.bigram_vocab_size, bigram_dim=args.bigram_dim, - xsa_last_n=args.xsa_last_n, - rope_dims=args.rope_dims, - ln_scale=args.ln_scale, - ).to(device).bfloat16() - for m in eval_model.modules(): - if isinstance(m, CastedLinear): - m.float() - restore_low_dim_params_to_fp32(eval_model) - eval_model.load_state_dict(deq_state, strict=True) - compiled_eval = torch.compile(eval_model, dynamic=False, fullgraph=True) - - # Standard non-overlapping eval (sanity check) - torch.cuda.synchronize() - t_qeval = time.perf_counter() - q_val_loss, q_val_bpb = eval_val( - args, compiled_eval, rank, world_size, device, grad_accum_steps, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - eval_seq_len=effective_eval_seq_len, - ) - torch.cuda.synchronize() - log0( - f"final_int6_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " - f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" - ) - log0(f"final_int6_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") - - # Sliding window eval (submission score) - sw_seq_len = effective_eval_seq_len - if args.eval_stride > 0 and args.eval_stride < sw_seq_len: - torch.cuda.synchronize() - t_slide = time.perf_counter() - sw_val_loss, sw_val_bpb = eval_val_sliding( - args, eval_model, rank, world_size, device, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - stride=args.eval_stride, - eval_seq_len=sw_seq_len, - ) - torch.cuda.synchronize() - log0( - f"final_int6_sliding_window val_loss:{sw_val_loss:.4f} val_bpb:{sw_val_bpb:.4f} " - f"stride:{args.eval_stride} eval_time:{1000.0 * (time.perf_counter() - t_slide):.0f}ms" - ) - log0(f"final_int6_sliding_window_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") - - # Second sliding window eval at stride=64 for submission comparison - if args.eval_stride != 64 and 64 < sw_seq_len: - torch.cuda.synchronize() - t_slide64 = time.perf_counter() - sw64_val_loss, sw64_val_bpb = eval_val_sliding( - args, eval_model, rank, world_size, device, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - stride=64, - eval_seq_len=sw_seq_len, - ) - torch.cuda.synchronize() - log0( - f"final_int6_sliding_window_s64 val_loss:{sw64_val_loss:.4f} val_bpb:{sw64_val_bpb:.4f} " - f"stride:64 eval_time:{1000.0 * (time.perf_counter() - t_slide64):.0f}ms" - ) - log0(f"final_int6_sliding_window_s64_exact val_loss:{sw64_val_loss:.8f} val_bpb:{sw64_val_bpb:.8f}") - - if distributed: - dist.destroy_process_group() - - -if __name__ == "__main__": - main() - -==================================================================================================== -Running Python 3.12.3 (main, Nov 6 2025, 13:44:16) [GCC 13.3.0] -Running PyTorch 2.9.1+cu128 -Sat Mar 21 05:30:01 2026 -+-----------------------------------------------------------------------------------------+ -| NVIDIA-SMI 580.126.09 Driver Version: 580.126.09 CUDA Version: 13.0 | -+-----------------------------------------+------------------------+----------------------+ -| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | -| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | -| | | MIG M. | -|=========================================+========================+======================| -| 0 NVIDIA H100 80GB HBM3 On | 00000000:18:00.0 Off | 0 | -| N/A 37C P0 118W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 1 NVIDIA H100 80GB HBM3 On | 00000000:2A:00.0 Off | 0 | -| N/A 42C P0 131W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 2 NVIDIA H100 80GB HBM3 On | 00000000:3A:00.0 Off | 0 | -| N/A 43C P0 123W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 3 NVIDIA H100 80GB HBM3 On | 00000000:5D:00.0 Off | 0 | -| N/A 37C P0 118W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 4 NVIDIA H100 80GB HBM3 On | 00000000:9A:00.0 Off | 0 | -| N/A 36C P0 118W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 5 NVIDIA H100 80GB HBM3 On | 00000000:AB:00.0 Off | 0 | -| N/A 42C P0 122W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 6 NVIDIA H100 80GB HBM3 On | 00000000:BA:00.0 Off | 0 | -| N/A 39C P0 118W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 7 NVIDIA H100 80GB HBM3 On | 00000000:DB:00.0 Off | 0 | -| N/A 35C P0 119W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ - -+-----------------------------------------------------------------------------------------+ -| Processes: | -| GPU GI CI PID Type Process name GPU Memory | -| ID ID Usage | -|=========================================================================================| -| 0 N/A N/A 184004 C /usr/local/bin/python 1512MiB | -| 1 N/A N/A 184005 C /usr/local/bin/python 1512MiB | -| 2 N/A N/A 184006 C /usr/local/bin/python 1512MiB | -| 3 N/A N/A 184007 C /usr/local/bin/python 1512MiB | -| 4 N/A N/A 184008 C /usr/local/bin/python 1512MiB | -| 5 N/A N/A 184009 C /usr/local/bin/python 1512MiB | -| 6 N/A N/A 184010 C /usr/local/bin/python 1512MiB | -| 7 N/A N/A 184011 C /usr/local/bin/python 1512MiB | -+-----------------------------------------------------------------------------------------+ - -==================================================================================================== -val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model -train_loader:dataset:fineweb10B_sp1024 train_shards:80 -val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 -model_params:26829913 -mtp_num_heads:0 mtp_loss_weight:0.2 mtp_params:0 -world_size:8 grad_accum_steps:1 -sdp_backends:cudnn=False flash=True mem_efficient=False math=False -attention_mode:gqa num_heads:8 num_kv_heads:4 -tie_embeddings:True embed_lr:0.035 head_lr:0.0 matrix_lr:0.025 scalar_lr:0.025 -train_batch_tokens:786432 train_seq_len:2048 iterations:9000 warmup_steps:20 max_wallclock_seconds:600.000 -seed:2025 -warmup_step:1/20 -warmup_step:2/20 -warmup_step:3/20 -warmup_step:4/20 -warmup_step:5/20 -warmup_step:6/20 -warmup_step:7/20 -warmup_step:8/20 -warmup_step:9/20 -warmup_step:10/20 -warmup_step:11/20 -warmup_step:12/20 -warmup_step:13/20 -warmup_step:14/20 -warmup_step:15/20 -warmup_step:16/20 -warmup_step:17/20 -warmup_step:18/20 -warmup_step:19/20 -warmup_step:20/20 -step:1/9000 train_loss:6.9287 train_time:147ms step_avg:146.68ms -step:2/9000 train_loss:8.6093 train_time:238ms step_avg:118.88ms -step:3/9000 train_loss:7.8135 train_time:348ms step_avg:115.93ms -step:4/9000 train_loss:7.2171 train_time:465ms step_avg:116.33ms -step:5/9000 train_loss:7.0157 train_time:583ms step_avg:116.59ms -step:6/9000 train_loss:6.9238 train_time:696ms step_avg:115.95ms -step:7/9000 train_loss:6.8594 train_time:816ms step_avg:116.61ms -step:8/9000 train_loss:6.8059 train_time:936ms step_avg:117.00ms -step:9/9000 train_loss:6.4555 train_time:1050ms step_avg:116.65ms -step:10/9000 train_loss:6.1152 train_time:1165ms step_avg:116.52ms -step:200/9000 train_loss:2.4421 train_time:17238ms step_avg:86.19ms -step:400/9000 train_loss:2.4489 train_time:34347ms step_avg:85.87ms -step:600/9000 train_loss:2.3426 train_time:51245ms step_avg:85.41ms -step:800/9000 train_loss:2.2484 train_time:68253ms step_avg:85.32ms -step:1000/9000 train_loss:2.2816 train_time:85175ms step_avg:85.17ms -step:1200/9000 train_loss:2.3513 train_time:102167ms step_avg:85.14ms -step:1400/9000 train_loss:2.1864 train_time:119175ms step_avg:85.13ms -step:1600/9000 train_loss:2.0718 train_time:136086ms step_avg:85.05ms -step:1800/9000 train_loss:2.1513 train_time:153111ms step_avg:85.06ms -step:2000/9000 train_loss:2.0619 train_time:170065ms step_avg:85.03ms -step:2200/9000 train_loss:2.1277 train_time:187099ms step_avg:85.04ms -step:2400/9000 train_loss:2.0597 train_time:204035ms step_avg:85.01ms -step:2600/9000 train_loss:2.1060 train_time:221047ms step_avg:85.02ms -step:2800/9000 train_loss:2.1497 train_time:238048ms step_avg:85.02ms -step:3000/9000 train_loss:2.1569 train_time:254965ms step_avg:84.99ms -step:3200/9000 train_loss:2.1658 train_time:272023ms step_avg:85.01ms -step:3400/9000 train_loss:2.0182 train_time:288932ms step_avg:84.98ms -step:3600/9000 train_loss:2.0979 train_time:306055ms step_avg:85.02ms -step:3800/9000 train_loss:2.0741 train_time:323059ms step_avg:85.02ms -step:4000/9000 train_loss:1.9803 train_time:340194ms step_avg:85.05ms -step:4200/9000 train_loss:2.1582 train_time:357294ms step_avg:85.07ms -step:4400/9000 train_loss:2.0428 train_time:374435ms step_avg:85.10ms -step:4600/9000 train_loss:1.8465 train_time:391578ms step_avg:85.13ms -step:4800/9000 train_loss:2.4338 train_time:408766ms step_avg:85.16ms -step:5000/9000 train_loss:2.1121 train_time:425810ms step_avg:85.16ms -step:5200/9000 train_loss:2.0455 train_time:442747ms step_avg:85.14ms -step:5400/9000 train_loss:2.0548 train_time:459747ms step_avg:85.14ms -step:5600/9000 train_loss:1.9599 train_time:476731ms step_avg:85.13ms -step:5800/9000 train_loss:2.0013 train_time:493652ms step_avg:85.11ms -step:6000/9000 train_loss:1.9434 train_time:510668ms step_avg:85.11ms -step:6200/9000 train_loss:1.9595 train_time:527590ms step_avg:85.10ms -step:6400/9000 train_loss:2.0020 train_time:544640ms step_avg:85.10ms -step:6600/9000 train_loss:1.8443 train_time:561567ms step_avg:85.09ms -late_qat:enabled step:6751 scale:0.0998 -step:6800/9000 train_loss:2.0259 train_time:578667ms step_avg:85.10ms -step:7000/9000 train_loss:1.7919 train_time:595662ms step_avg:85.09ms -step:7051/9000 val_loss:1.9279 val_bpb:1.1418 train_time:599986ms step_avg:85.09ms -stopping_early: wallclock_cap train_time:599986ms step:7051/9000 -peak memory allocated: 20596 MiB reserved: 20652 MiB -ema:applying EMA weights -Serialized model: 105783807 bytes -Code size: 67617 bytes -Serialized model int6+zstd: 15544691 bytes -Total submission size int6+zstd: 15612308 bytes -final_int6_roundtrip val_loss:1.9391 val_bpb:1.1485 eval_time:5890ms -final_int6_roundtrip_exact val_loss:1.93912126 val_bpb:1.14845684 -final_int6_sliding_window val_loss:1.8992 val_bpb:1.1248 stride:64 eval_time:72659ms -final_int6_sliding_window_exact val_loss:1.89924867 val_bpb:1.12484502 diff --git a/records/track_10min_16mb/2026-03-21_11L_XSA4_EMA_PartialRoPE_LateQAT_1.1248/train_seed42.log b/records/track_10min_16mb/2026-03-21_11L_XSA4_EMA_PartialRoPE_LateQAT_1.1248/train_seed42.log deleted file mode 100644 index c24c9ca961..0000000000 --- a/records/track_10min_16mb/2026-03-21_11L_XSA4_EMA_PartialRoPE_LateQAT_1.1248/train_seed42.log +++ /dev/null @@ -1,1738 +0,0 @@ -""" -train_gpt_submit.py — Submission v2: wider MLP + STE int6 QAT + MTP + seq2048 + NTK RoPE + -fp16 embed + late-K passthrough + sliding window eval. -""" - -from __future__ import annotations - -import copy -import glob -import io -import math -import os -import random -import subprocess -import sys -import time -import uuid -import zlib -from pathlib import Path - -try: - import zstandard - _COMPRESSOR = "zstd" -except ImportError: - _COMPRESSOR = "zlib" - -import numpy as np -import sentencepiece as spm -import torch -import torch.distributed as dist -import torch.nn.functional as F -from torch import Tensor, nn -from torch.nn.parallel import DistributedDataParallel as DDP - -from flash_attn_interface import flash_attn_func as flash_attn_3_func - -# ----------------------------- -# HYPERPARAMETERS -# ----------------------------- -# Default Simple Baseline run: -# - 9 transformer blocks at width 512 -# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion -# - vocab size 1024, sequence length 1024, tied embeddings -# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap - -class Hyperparameters: - # Data paths are shard globs produced by the existing preprocessing pipeline. - data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") - train_files = os.path.join(data_path, "fineweb_train_*.bin") - val_files = os.path.join(data_path, "fineweb_val_*.bin") - tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") - run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) - seed = int(os.environ.get("SEED", 1337)) - - # Validation cadence and batch size. Validation always uses the full fineweb_val split. - val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) - val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) - train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) - - # Training length. - iterations = int(os.environ.get("ITERATIONS", 20000)) - warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) - warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) - train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 786_432)) - train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) - eval_seq_len = int(os.environ.get("EVAL_SEQ_LEN", 2048)) - max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) - qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) - - # Model shape. - vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) - num_layers = int(os.environ.get("NUM_LAYERS", 9)) - num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) - model_dim = int(os.environ.get("MODEL_DIM", 512)) - num_heads = int(os.environ.get("NUM_HEADS", 8)) - mlp_mult = float(os.environ.get("MLP_MULT", 3.0)) - tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) - rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) - logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) - - # Optimizer hyperparameters. - embed_lr = float(os.environ.get("EMBED_LR", 0.6)) - head_lr = float(os.environ.get("HEAD_LR", 0.008)) - tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) - tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) - matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) - scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) - muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) - muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) - muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) - muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) - beta1 = float(os.environ.get("BETA1", 0.9)) - beta2 = float(os.environ.get("BETA2", 0.95)) - adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) - grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) - eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) - mtp_num_heads = int(os.environ.get("MTP_NUM_HEADS", 0)) - mtp_loss_weight = float(os.environ.get("MTP_LOSS_WEIGHT", 0.2)) - muon_beta2 = float(os.environ.get("MUON_BETA2", 0.95)) - swa_enabled = bool(int(os.environ.get("SWA_ENABLED", "1"))) - swa_every = int(os.environ.get("SWA_EVERY", 200)) - muon_wd = float(os.environ.get("MUON_WD", 0.02)) - adam_wd = float(os.environ.get("ADAM_WD", 0.01)) - qat_enabled = bool(int(os.environ.get("QAT_ENABLED", "0"))) - xsa_last_n = int(os.environ.get("XSA_LAST_N", 0)) - ema_enabled = bool(int(os.environ.get("EMA_ENABLED", "0"))) - ema_decay = float(os.environ.get("EMA_DECAY", 0.997)) - rope_dims = int(os.environ.get("ROPE_DIMS", 0)) - ln_scale = bool(int(os.environ.get("LN_SCALE", "0"))) - late_qat = bool(int(os.environ.get("LATE_QAT", "0"))) - bigram_vocab_size = int(os.environ.get("BIGRAM_VOCAB_SIZE", 4096)) - bigram_dim = int(os.environ.get("BIGRAM_DIM", 128)) - -# ----------------------------- -# MUON OPTIMIZER -# ----------------------------- -# -# As borrowed from modded-nanogpt -# Background on Muon: https://kellerjordan.github.io/posts/muon/ - -def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: - a, b, c = (3.4445, -4.7750, 2.0315) - X = G.bfloat16() - X /= X.norm() + eps - transposed = G.size(0) > G.size(1) - if transposed: - X = X.T - for _ in range(steps): - A = X @ X.T - B = b * A + c * A @ A - X = a * X + B @ X - return X.T if transposed else X - - -class Muon(torch.optim.Optimizer): - def __init__(self, params, lr: float, momentum: float, backend_steps: int, - nesterov: bool = True, weight_decay: float = 0.0): - super().__init__( - params, - dict(lr=lr, momentum=momentum, backend_steps=backend_steps, - nesterov=nesterov, weight_decay=weight_decay), - ) - - @torch.no_grad() - def step(self, closure=None): - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - distributed = dist.is_available() and dist.is_initialized() - world_size = dist.get_world_size() if distributed else 1 - rank = dist.get_rank() if distributed else 0 - - for group in self.param_groups: - params = group["params"] - if not params: - continue - lr = group["lr"] - momentum = group["momentum"] - backend_steps = group["backend_steps"] - nesterov = group["nesterov"] - - total_params = sum(int(p.numel()) for p in params) - updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) - - curr = 0 - for i, p in enumerate(params): - if i % world_size == rank and p.grad is not None: - g = p.grad - state = self.state[p] - if "momentum_buffer" not in state: - state["momentum_buffer"] = torch.zeros_like(g) - buf = state["momentum_buffer"] - buf.mul_(momentum).add_(g) - if nesterov: - g = g.add(buf, alpha=momentum) - g = zeropower_via_newtonschulz5(g, steps=backend_steps) - g *= max(1, g.size(0) / g.size(1)) ** 0.5 - updates_flat[curr : curr + p.numel()] = g.reshape(-1) - curr += p.numel() - - if distributed: - dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) - - wd = group.get("weight_decay", 0.0) - curr = 0 - for p in params: - if wd > 0.0: - p.data.mul_(1.0 - lr * wd) - g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) - p.add_(g, alpha=-lr) - curr += p.numel() - - return loss - - -# ----------------------------- -# TOKENIZER-AGNOSTIC EVALUATION SETUP -# ----------------------------- -# -# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. -# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. -# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. -# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. - -def build_sentencepiece_luts( - sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device -) -> tuple[Tensor, Tensor, Tensor]: - sp_vocab_size = int(sp.vocab_size()) - table_size = max(sp_vocab_size, vocab_size) - base_bytes_np = np.zeros((table_size,), dtype=np.int16) - has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) - is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) - for token_id in range(sp_vocab_size): - if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): - continue - is_boundary_token_np[token_id] = False - if sp.is_byte(token_id): - base_bytes_np[token_id] = 1 - continue - piece = sp.id_to_piece(token_id) - if piece.startswith("▁"): - has_leading_space_np[token_id] = True - piece = piece[1:] - base_bytes_np[token_id] = len(piece.encode("utf-8")) - return ( - torch.tensor(base_bytes_np, dtype=torch.int16, device=device), - torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), - torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), - ) - - -def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: - files = [Path(p) for p in sorted(glob.glob(pattern))] - if not files: - raise FileNotFoundError(f"No files found for pattern: {pattern}") - # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. - tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() - usable = ((tokens.numel() - 1) // seq_len) * seq_len - if usable <= 0: - raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") - return tokens[: usable + 1] - - -def eval_val( - args: Hyperparameters, - model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - grad_accum_steps: int, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, - eval_seq_len: int | None = None, -) -> tuple[float, float]: - seq_len = eval_seq_len or args.train_seq_len - local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) - if local_batch_tokens < seq_len: - raise ValueError( - "VAL_BATCH_SIZE must provide at least one sequence per rank; " - f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " - f"GRAD_ACCUM_STEPS={grad_accum_steps}, seq_len={seq_len}" - ) - local_batch_seqs = local_batch_tokens // seq_len - total_seqs = (val_tokens.numel() - 1) // seq_len - seq_start = (total_seqs * rank) // world_size - seq_end = (total_seqs * (rank + 1)) // world_size - val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) - val_token_count = torch.zeros((), device=device, dtype=torch.float64) - val_byte_count = torch.zeros((), device=device, dtype=torch.float64) - - model.eval() - with torch.inference_mode(): - for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): - batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) - raw_start = batch_seq_start * seq_len - raw_end = batch_seq_end * seq_len + 1 - local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) - x = local[:-1].reshape(-1, seq_len) - y = local[1:].reshape(-1, seq_len) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - batch_loss = model(x, y).detach() - batch_token_count = float(y.numel()) - val_loss_sum += batch_loss.to(torch.float64) * batch_token_count - val_token_count += batch_token_count - prev_ids = x.reshape(-1) - tgt_ids = y.reshape(-1) - token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) - token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) - val_byte_count += token_bytes.to(torch.float64).sum() - - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) - - val_loss = val_loss_sum / val_token_count - bits_per_token = val_loss.item() / math.log(2.0) - tokens_per_byte = val_token_count.item() / val_byte_count.item() - model.train() - return float(val_loss.item()), float(bits_per_token * tokens_per_byte) - -# ----------------------------- -# POST-TRAINING QUANTIZATION -# ----------------------------- -# -# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. -# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. -# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. - -CONTROL_TENSOR_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "CONTROL_TENSOR_NAME_PATTERNS", - "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,smear", - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", - ",".join(CONTROL_TENSOR_NAME_PATTERNS), - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 -INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 -INT8_PER_ROW_SCALE_DTYPE = torch.float16 -INT8_CLIP_PERCENTILE = 99.99984 -INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 - -def tensor_nbytes(t: Tensor) -> int: - return int(t.numel()) * int(t.element_size()) - -def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: - if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): - return t.float().contiguous() - if t.dtype in {torch.float32, torch.bfloat16}: - passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") - return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() - return t - -def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: - t32 = t.float() - if t32.ndim == 2: - # Matrices get one scale per row, which usually tracks output-channel - # ranges much better than a single tensor-wide scale. - clip_abs = ( - torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) - if t32.numel() - else torch.empty((t32.shape[0],), dtype=torch.float32) - ) - clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) - scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) - q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() - return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() - - # Vectors / scalars use a simpler per-tensor scale. - clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 - scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) - q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() - return q, scale - -def quantize_state_dict_int8(state_dict: dict[str, Tensor]): - # Single supported clean-script export format: - # - per-row int8 for 2D float tensors - # - per-tensor int8 for other float tensors - # - exact passthrough for non-floats - # - passthrough for small float tensors, stored as fp16 to save bytes - quantized: dict[str, Tensor] = {} - scales: dict[str, Tensor] = {} - dtypes: dict[str, str] = {} - passthrough: dict[str, Tensor] = {} - passthrough_orig_dtypes: dict[str, str] = {} - qmeta: dict[str, dict[str, object]] = {} - stats = dict.fromkeys( - ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), - 0, - ) - - for name, tensor in state_dict.items(): - t = tensor.detach().to("cpu").contiguous() - stats["param_count"] += int(t.numel()) - stats["num_tensors"] += 1 - stats["baseline_tensor_bytes"] += tensor_nbytes(t) - - if not t.is_floating_point(): - stats["num_nonfloat_tensors"] += 1 - passthrough[name] = t - stats["int8_payload_bytes"] += tensor_nbytes(t) - continue - - # Small float tensors are cheap enough to keep directly. We still downcast - # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. - if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: - kept = keep_float_tensor(name, t, passthrough_orig_dtypes) - passthrough[name] = kept - stats["int8_payload_bytes"] += tensor_nbytes(kept) - continue - - stats["num_float_tensors"] += 1 - q, s = quantize_float_tensor(t) - if s.ndim > 0: - qmeta[name] = {"scheme": "per_row", "axis": 0} - quantized[name] = q - scales[name] = s - dtypes[name] = str(t.dtype).removeprefix("torch.") - stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) - - obj: dict[str, object] = { - "__quant_format__": "int8_clean_per_row_v1", - "quantized": quantized, - "scales": scales, - "dtypes": dtypes, - "passthrough": passthrough, - } - if qmeta: - obj["qmeta"] = qmeta - if passthrough_orig_dtypes: - obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes - return obj, stats - -def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: - out: dict[str, Tensor] = {} - qmeta = obj.get("qmeta", {}) - passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) - for name, q in obj["quantized"].items(): - dtype = getattr(torch, obj["dtypes"][name]) - s = obj["scales"][name] - if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: - s = s.to(dtype=torch.float32) - # Broadcast the saved row scale back across trailing dimensions. - out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() - else: - scale = float(s.item()) - out[name] = (q.float() * scale).to(dtype=dtype).contiguous() - for name, t in obj["passthrough"].items(): - # Restore small tensors, undoing the temporary fp16 storage cast if needed. - out_t = t.detach().to("cpu").contiguous() - orig_dtype = passthrough_orig_dtypes.get(name) - if isinstance(orig_dtype, str): - out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() - out[name] = out_t - return out - - -# ----------------------------- -# DATA LOADING -# ----------------------------- - -def load_data_shard(file: Path) -> Tensor: - header_bytes = 256 * np.dtype(" None: - self.file_idx = (self.file_idx + 1) % len(self.files) - self.tokens = load_data_shard(self.files[self.file_idx]) - self.pos = 0 - - def take(self, n: int) -> Tensor: - chunks: list[Tensor] = [] - remaining = n - while remaining > 0: - avail = self.tokens.numel() - self.pos - if avail <= 0: - self._advance_file() - continue - k = min(remaining, avail) - chunks.append(self.tokens[self.pos : self.pos + k]) - self.pos += k - remaining -= k - return chunks[0] if len(chunks) == 1 else torch.cat(chunks) - - -class DistributedTokenLoader: - # Each call consumes a contiguous chunk from the shared token stream, then slices out - # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. - def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): - self.rank = rank - self.world_size = world_size - self.device = device - self.stream = TokenStream(pattern) - - def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: - local_tokens = global_tokens // (self.world_size * grad_accum_steps) - per_rank_span = local_tokens + 1 - chunk = self.stream.take(per_rank_span * self.world_size) - start = self.rank * per_rank_span - local = chunk[start : start + per_rank_span].to(dtype=torch.int64) - x = local[:-1].reshape(-1, seq_len) - y = local[1:].reshape(-1, seq_len) - return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) - -# ----------------------------- -# TRANSFORMER MODULES -# ----------------------------- - -class RMSNorm(nn.Module): - def __init__(self, eps: float | None = None): - super().__init__() - self.eps = eps - - def forward(self, x: Tensor) -> Tensor: - return F.rms_norm(x, (x.size(-1),), eps=self.eps) - - -class CastedLinear(nn.Linear): - _qat_enabled: bool = False - - def forward(self, x: Tensor) -> Tensor: - w = self.weight.to(x.dtype) - if CastedLinear._qat_enabled and self.training and w.ndim == 2: - with torch.no_grad(): - w32 = self.weight.float() - row_max = w32.abs().amax(dim=1) - scale = (row_max / 31.0).clamp_min(1.0 / 31.0) - w_q = (torch.clamp(torch.round(w32 / scale[:, None]), -32, 31) * scale[:, None]).to(x.dtype) - w = w + (w_q - w).detach() - bias = self.bias.to(x.dtype) if self.bias is not None else None - return F.linear(x, w, bias) - - -def restore_low_dim_params_to_fp32(module: nn.Module) -> None: - # Keep small/control parameters in fp32 even when the model body runs in bf16. - with torch.no_grad(): - for name, param in module.named_parameters(): - if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: - param.data = param.data.float() - - -class Rotary(nn.Module): - # NTK-aware RoPE: auto-scales base frequency when seq_len exceeds train_seq_len. - def __init__(self, dim: int, base: float = 10000.0, train_seq_len: int = 1024, rope_dims: int = 0): - super().__init__() - self.rope_dims = rope_dims if rope_dims > 0 else dim - self.dim = dim - self.base = base - self.train_seq_len = train_seq_len - rd = self.rope_dims - inv_freq = 1.0 / (base ** (torch.arange(0, rd, 2, dtype=torch.float32) / rd)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self._seq_len_cached = 0 - self._cos_cached: Tensor | None = None - self._sin_cached: Tensor | None = None - - def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: - if ( - self._cos_cached is None - or self._sin_cached is None - or self._seq_len_cached != seq_len - or self._cos_cached.device != device - ): - rd = self.rope_dims - if seq_len > self.train_seq_len: - scale = seq_len / self.train_seq_len - new_base = self.base * (scale ** (rd / (rd - 2))) - inv_freq = 1.0 / (new_base ** (torch.arange(0, rd, 2, dtype=torch.float32, device=device) / rd)) - else: - inv_freq = self.inv_freq.to(device) - t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) - freqs = torch.outer(t, inv_freq) - self._cos_cached = freqs.cos()[None, :, None, :] - self._sin_cached = freqs.sin()[None, :, None, :] - self._seq_len_cached = seq_len - return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) - - -def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: - rd = cos.size(-1) * 2 - if rd < x.size(-1): - x_rope, x_pass = x[..., :rd], x[..., rd:] - half = rd // 2 - x1, x2 = x_rope[..., :half], x_rope[..., half:] - x_rot = torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) - return torch.cat((x_rot, x_pass), dim=-1) - half = x.size(-1) // 2 - x1, x2 = x[..., :half], x[..., half:] - return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) - - -class CausalSelfAttention(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - rope_base: float, - qk_gain_init: float, - rope_dims: int = 0, - ): - super().__init__() - if dim % num_heads != 0: - raise ValueError("model_dim must be divisible by num_heads") - if num_heads % num_kv_heads != 0: - raise ValueError("num_heads must be divisible by num_kv_heads") - self.num_heads = num_heads - self.num_kv_heads = num_kv_heads - self.head_dim = dim // num_heads - if self.head_dim % 2 != 0: - raise ValueError("head_dim must be even for RoPE") - kv_dim = self.num_kv_heads * self.head_dim - self.c_q = CastedLinear(dim, dim, bias=False) - self.c_k = CastedLinear(dim, kv_dim, bias=False) - self.c_v = CastedLinear(dim, kv_dim, bias=False) - self.proj = CastedLinear(dim, dim, bias=False) - self.proj._zero_init = True - self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) - self.rope_dims = rope_dims - self.rotary = Rotary(self.head_dim, base=rope_base, train_seq_len=1024, rope_dims=rope_dims) - self.use_xsa = False - - def _xsa_efficient(self, y: Tensor, v: Tensor) -> Tensor: - """Subtract self-value projection via GQA-aware reshape (no repeat_interleave).""" - B, T, H, D = y.shape - Hkv = v.size(-2) - group = H // Hkv - y_g = y.reshape(B, T, Hkv, group, D) - vn = F.normalize(v, dim=-1).unsqueeze(-2) - proj = (y_g * vn).sum(dim=-1, keepdim=True) * vn - return (y_g - proj).reshape(B, T, H, D) - - def forward(self, x: Tensor) -> Tensor: - bsz, seqlen, dim = x.shape - q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim) - k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) - v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) - q = F.rms_norm(q, (q.size(-1),)) - k = F.rms_norm(k, (k.size(-1),)) - cos, sin = self.rotary(seqlen, x.device, q.dtype) - q = apply_rotary_emb(q, cos, sin) - k = apply_rotary_emb(k, cos, sin) - q = q * self.q_gain.to(dtype=q.dtype)[None, None, :, None] - y = flash_attn_3_func(q, k, v, causal=True) - if self.use_xsa: - y = self._xsa_efficient(y, v) - y = y.reshape(bsz, seqlen, dim) - return self.proj(y) - - -class SmearGate(nn.Module): - def __init__(self, dim: int): - super().__init__() - self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) - - def forward(self, x: Tensor) -> Tensor: - g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] - x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) - return (1 - g) * x + g * x_prev - - -class BigramHashEmbedding(nn.Module): - def __init__(self, bigram_vocab_size: int, bigram_dim: int, model_dim: int): - super().__init__() - self.bigram_vocab_size = bigram_vocab_size - self.embed = nn.Embedding(bigram_vocab_size, bigram_dim) - nn.init.zeros_(self.embed.weight) - self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None - if self.proj is not None: - nn.init.zeros_(self.proj.weight) - self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) - - def bigram_hash(self, tokens: Tensor) -> Tensor: - t = tokens.to(torch.int32) - mod = self.bigram_vocab_size - 1 - out = torch.empty_like(t) - out[..., 0] = mod - out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod - return out.long() - - def forward(self, token_ids: Tensor) -> Tensor: - h = self.embed(self.bigram_hash(token_ids)) - if self.proj is not None: - h = self.proj(h) - return h * self.scale.to(dtype=h.dtype) - - -class MLP(nn.Module): - def __init__(self, dim: int, mlp_mult: int): - super().__init__() - hidden = int(mlp_mult * dim) - self.fc = CastedLinear(dim, hidden, bias=False) - self.proj = CastedLinear(hidden, dim, bias=False) - self.proj._zero_init = True - - def forward(self, x: Tensor) -> Tensor: - x = torch.relu(self.fc(x)) - return self.proj(x.square()) - - -class Block(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: int, - rope_base: float, - qk_gain_init: float, - rope_dims: int = 0, - layer_idx: int = 0, - ln_scale: bool = False, - ): - super().__init__() - self.attn_norm = RMSNorm() - self.mlp_norm = RMSNorm() - self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init, rope_dims=rope_dims) - self.mlp = MLP(dim, mlp_mult) - self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) - self.ln_scale_factor = 1.0 / math.sqrt(layer_idx + 1) if ln_scale else 1.0 - - def forward(self, x: Tensor, x0: Tensor) -> Tensor: - mix = self.resid_mix.to(dtype=x.dtype) - x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 - s = self.ln_scale_factor - attn_out = self.attn(self.attn_norm(x) * s) - x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out - x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x) * s) - return x - - -class GPT(nn.Module): - def __init__( - self, - vocab_size: int, - num_layers: int, - model_dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: int, - tie_embeddings: bool, - tied_embed_init_std: float, - logit_softcap: float, - rope_base: float, - qk_gain_init: float, - mtp_num_heads: int = 0, - mtp_loss_weight: float = 0.1, - bigram_vocab_size: int = 0, - bigram_dim: int = 128, - xsa_last_n: int = 0, - rope_dims: int = 0, - ln_scale: bool = False, - ): - super().__init__() - if logit_softcap <= 0.0: - raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") - self.tie_embeddings = tie_embeddings - self.tied_embed_init_std = tied_embed_init_std - self.logit_softcap = logit_softcap - self.mtp_num_heads = mtp_num_heads - self.mtp_loss_weight = mtp_loss_weight - self.tok_emb = nn.Embedding(vocab_size, model_dim) - self.bigram = BigramHashEmbedding(bigram_vocab_size, bigram_dim, model_dim) if bigram_vocab_size > 0 else None - self.smear = SmearGate(model_dim) - self.num_encoder_layers = num_layers // 2 - self.num_decoder_layers = num_layers - self.num_encoder_layers - self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) - self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) - self.blocks = nn.ModuleList( - [ - Block( - model_dim, - num_heads, - num_kv_heads, - mlp_mult, - rope_base, - qk_gain_init, - rope_dims=rope_dims, - layer_idx=i, - ln_scale=ln_scale, - ) - for i in range(num_layers) - ] - ) - self.final_norm = RMSNorm() - self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) - if self.lm_head is not None: - self.lm_head._zero_init = True - self.mtp_heads = nn.ModuleList( - [CastedLinear(model_dim, vocab_size, bias=False) for _ in range(mtp_num_heads)] - ) - for head in self.mtp_heads: - head._zero_init = True - if xsa_last_n > 0: - for i in range(max(0, num_layers - xsa_last_n), num_layers): - self.blocks[i].attn.use_xsa = True - self._init_weights() - - def _init_weights(self) -> None: - if self.tie_embeddings: - nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) - num_layers = len(self.blocks) - for name, module in self.named_modules(): - if isinstance(module, nn.Linear): - if getattr(module, "_zero_init", False): - nn.init.zeros_(module.weight) - elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: - nn.init.orthogonal_(module.weight, gain=1.0) - if ".proj." in name or name.endswith(".proj"): - with torch.no_grad(): - module.weight.mul_(1.0 / math.sqrt(2 * num_layers)) - - def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: - x = self.tok_emb(input_ids) - if self.bigram is not None: - x = x + self.bigram(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - x = self.smear(x) - x0 = x - skips: list[Tensor] = [] - - for i in range(self.num_encoder_layers): - x = self.blocks[i](x, x0) - skips.append(x) - for i in range(self.num_decoder_layers): - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - x = self.blocks[self.num_encoder_layers + i](x, x0) - - x = self.final_norm(x) - x_flat = x.reshape(-1, x.size(-1)) - targets = target_ids.reshape(-1) - if self.tie_embeddings: - logits_proj = F.linear(x_flat, self.tok_emb.weight) - else: - if self.lm_head is None: - raise RuntimeError("lm_head is required when tie_embeddings=False") - logits_proj = self.lm_head(x_flat) - logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - main_loss = F.cross_entropy(logits.float(), targets, reduction="mean") - - if self.training and self.mtp_num_heads > 0 and self.mtp_loss_weight > 0.0: - _, seqlen, dim = x.shape - mtp_loss_sum = x.new_zeros(()) - mtp_loss_count = 0 - for k, mtp_head in enumerate(self.mtp_heads): - valid_t = seqlen - (k + 1) - if valid_t <= 0: - continue - mtp_hidden = x[:, :valid_t, :].reshape(-1, dim) - mtp_targets = target_ids[:, k + 1 :].reshape(-1) - mtp_logits_proj = mtp_head(mtp_hidden) - mtp_logits = self.logit_softcap * torch.tanh(mtp_logits_proj / self.logit_softcap) - mtp_loss_sum = mtp_loss_sum + F.cross_entropy(mtp_logits.float(), mtp_targets, reduction="mean") - mtp_loss_count += 1 - if mtp_loss_count > 0: - main_loss = main_loss + self.mtp_loss_weight * (mtp_loss_sum / mtp_loss_count) - - return main_loss - - def forward_logits(self, input_ids: Tensor) -> Tensor: - """Return logits (bsz, seq_len, vocab) without computing loss.""" - x = self.tok_emb(input_ids) - if self.bigram is not None: - x = x + self.bigram(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - x = self.smear(x) - x0 = x - skips: list[Tensor] = [] - for i in range(self.num_encoder_layers): - x = self.blocks[i](x, x0) - skips.append(x) - for i in range(self.num_decoder_layers): - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - x = self.blocks[self.num_encoder_layers + i](x, x0) - x = self.final_norm(x) - if self.tie_embeddings: - logits_proj = F.linear(x, self.tok_emb.weight) - else: - logits_proj = self.lm_head(x) - return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - - -# ----------------------------- -# SLIDING WINDOW EVALUATION -# ----------------------------- - -def eval_val_sliding( - args: Hyperparameters, - base_model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, - stride: int, - batch_seqs: int = 32, - eval_seq_len: int | None = None, -) -> tuple[float, float]: - """Sliding window evaluation: each token scored with maximum context.""" - seq_len = eval_seq_len or args.train_seq_len - total_tokens = val_tokens.numel() - 1 - - window_starts = [ws for ws in range(0, total_tokens, stride) - if min(ws + seq_len, total_tokens) - ws >= 1] - total_windows = len(window_starts) - - my_s = (total_windows * rank) // world_size - my_e = (total_windows * (rank + 1)) // world_size - my_windows = window_starts[my_s:my_e] - - loss_sum = torch.zeros((), device=device, dtype=torch.float64) - token_count = torch.zeros((), device=device, dtype=torch.float64) - byte_count = torch.zeros((), device=device, dtype=torch.float64) - - base_model.eval() - compiled_logits = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) - - with torch.inference_mode(): - for bi in range(0, len(my_windows), batch_seqs): - batch_ws = my_windows[bi:bi + batch_seqs] - bsz = len(batch_ws) - - x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) - y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) - wlens: list[int] = [] - - for i, ws in enumerate(batch_ws): - end = min(ws + seq_len, total_tokens) - wlen = end - ws - wlens.append(wlen) - chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) - x_batch[i, :wlen] = chunk[:-1] - y_batch[i, :wlen] = chunk[1:] - - with torch.autocast(device_type="cuda", dtype=torch.bfloat16): - logits = compiled_logits(x_batch) - - nll = F.cross_entropy( - logits.reshape(-1, logits.size(-1)).float(), - y_batch.reshape(-1), - reduction="none", - ).reshape(bsz, seq_len) - - for i, ws in enumerate(batch_ws): - wlen = wlens[i] - s = 0 if ws == 0 else max(wlen - stride, 0) - scored_nll = nll[i, s:wlen].to(torch.float64) - loss_sum += scored_nll.sum() - token_count += float(wlen - s) - tgt = y_batch[i, s:wlen] - prev = x_batch[i, s:wlen] - tb = base_bytes_lut[tgt].to(torch.float64) - tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) - byte_count += tb.sum() - - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) - - val_loss = (loss_sum / token_count).item() - bits_per_token = val_loss / math.log(2.0) - tokens_per_byte = token_count.item() / byte_count.item() - base_model.train() - return val_loss, bits_per_token * tokens_per_byte - - -# ----------------------------- -# INT6 MIXED QUANTIZATION (transplanted from working diagnostic scripts) -# ----------------------------- - -def _classify_param(name: str) -> str: - if "tok_emb" in name or "lm_head" in name: - return "embed" - if ".mlp." in name: - return "mlp" - if ".attn." in name or (".proj." in name and ".mlp." not in name): - return "attn" - return "other" - -def quantize_int6_per_row(t: Tensor) -> tuple[Tensor, Tensor]: - t32 = t.float() - if t32.ndim == 2: - row_max = t32.abs().amax(dim=1) - scale = (row_max / 31.0).clamp_min(1.0 / 31.0).to(torch.float16) - q = torch.clamp(torch.round(t32 / scale.float()[:, None]), -32, 31).to(torch.int8) - return q, scale - amax = t32.abs().max().item() - scale = torch.tensor(amax / 31.0 if amax > 0 else 1.0, dtype=torch.float16) - q = torch.clamp(torch.round(t32 / scale.float()), -32, 31).to(torch.int8) - return q, scale - -def mixed_quantize_int6(state_dict: dict[str, Tensor], int6_cats: set[str]): - num_layers_total = max( - (int(k.split(".")[1]) for k in state_dict if k.startswith("blocks.")), - default=0, - ) + 1 - late_k_layers = set(range(num_layers_total - 2, num_layers_total)) - - result: dict[str, Tensor] = {} - meta: dict[str, object] = {} - for name, tensor in state_dict.items(): - t = tensor.detach().cpu().contiguous() - cat = _classify_param(name) - if not t.is_floating_point() or t.numel() <= 65536: - result[name] = t.to(torch.float16) if t.is_floating_point() else t - meta[name] = "passthrough" - continue - if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): - result[name] = t.float() - meta[name] = "passthrough_ctrl" - continue - # tok_emb.weight falls through to int8 via "embed" category - if cat in int6_cats and t.ndim >= 1: - q, s = quantize_int6_per_row(t) - result[name + ".q"] = q - result[name + ".scale"] = s - meta[name] = {"type": "int6"} - else: - q, s = quantize_float_tensor(t) - result[name + ".q"] = q - result[name + ".scale"] = s - meta[name] = {"type": "int8"} - return result, meta - -def dequantize_mixed_int6(result: dict[str, Tensor], meta: dict[str, object], - template_sd: dict[str, Tensor]) -> dict[str, Tensor]: - out: dict[str, Tensor] = {} - for name, orig in template_sd.items(): - info = meta.get(name) - if info is None: - continue - orig_dtype = orig.dtype - if info in ("passthrough", "passthrough_ctrl", "passthrough_fp16"): - t = result[name] - if t.dtype == torch.float16 and orig_dtype in (torch.float32, torch.bfloat16): - t = t.to(orig_dtype) - out[name] = t - continue - q, s = result[name + ".q"], result[name + ".scale"] - if s.ndim > 0: - out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig_dtype) - else: - out[name] = (q.float() * float(s.item())).to(orig_dtype) - return out - - -# ----------------------------- -# TRAINING -# ----------------------------- - -def main() -> None: - global zeropower_via_newtonschulz5 - - code = Path(__file__).read_text(encoding="utf-8") - args = Hyperparameters() - zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) - - # ----------------------------- - # DISTRIBUTED + CUDA SETUP - # ----------------------------- - - distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ - rank = int(os.environ.get("RANK", "0")) - world_size = int(os.environ.get("WORLD_SIZE", "1")) - local_rank = int(os.environ.get("LOCAL_RANK", "0")) - if world_size <= 0: - raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") - if 8 % world_size != 0: - raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") - grad_accum_steps = 8 // world_size - grad_scale = 1.0 / grad_accum_steps - if not torch.cuda.is_available(): - raise RuntimeError("CUDA is required") - device = torch.device("cuda", local_rank) - torch.cuda.set_device(device) - if distributed: - dist.init_process_group(backend="nccl", device_id=device) - dist.barrier() - master_process = rank == 0 - - # Fast math knobs - torch.backends.cuda.matmul.allow_tf32 = True - torch.backends.cudnn.allow_tf32 = True - from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp - - enable_cudnn_sdp(False) - enable_flash_sdp(True) - enable_mem_efficient_sdp(False) - enable_math_sdp(False) - - logfile = None - if master_process: - os.makedirs("logs", exist_ok=True) - logfile = f"logs/{args.run_id}.txt" - print(logfile) - - def log0(msg: str, console: bool = True) -> None: - if not master_process: - return - if console: - print(msg) - if logfile is not None: - with open(logfile, "a", encoding="utf-8") as f: - print(msg, file=f) - - log0(code, console=False) - log0("=" * 100, console=False) - log0(f"Running Python {sys.version}", console=False) - log0(f"Running PyTorch {torch.__version__}", console=False) - log0( - subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, - console=False, - ) - log0("=" * 100, console=False) - - # ----------------------------- - # TOKENIZER + VALIDATION METRIC SETUP - # ----------------------------- - - random.seed(args.seed) - np.random.seed(args.seed) - torch.manual_seed(args.seed) - torch.cuda.manual_seed_all(args.seed) - - if not args.tokenizer_path.endswith(".model"): - raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") - sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) - if int(sp.vocab_size()) != args.vocab_size: - raise ValueError( - f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" - ) - dataset_dir = Path(args.data_path).resolve() - actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) - effective_eval_seq_len = args.eval_seq_len if args.eval_seq_len > 0 else args.train_seq_len - val_seq_len = max(args.train_seq_len, effective_eval_seq_len) - val_tokens = load_validation_tokens(args.val_files, val_seq_len) - base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( - sp, args.vocab_size, device - ) - log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") - log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") - log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") - - # ----------------------------- - # MODEL + OPTIMIZER SETUP - # ----------------------------- - - CastedLinear._qat_enabled = args.qat_enabled - - base_model = GPT( - vocab_size=args.vocab_size, - num_layers=args.num_layers, - model_dim=args.model_dim, - num_heads=args.num_heads, - num_kv_heads=args.num_kv_heads, - mlp_mult=args.mlp_mult, - tie_embeddings=args.tie_embeddings, - tied_embed_init_std=args.tied_embed_init_std, - logit_softcap=args.logit_softcap, - rope_base=args.rope_base, - qk_gain_init=args.qk_gain_init, - mtp_num_heads=args.mtp_num_heads, - mtp_loss_weight=args.mtp_loss_weight, - bigram_vocab_size=args.bigram_vocab_size, - bigram_dim=args.bigram_dim, - xsa_last_n=args.xsa_last_n, - rope_dims=args.rope_dims, - ln_scale=args.ln_scale, - ).to(device).bfloat16() - for module in base_model.modules(): - if isinstance(module, CastedLinear): - module.float() - restore_low_dim_params_to_fp32(base_model) - compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) - model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model - - # Optimizer split: - # - token embedding (Adam) uses EMBED_LR - # - untied lm_head (Adam) uses HEAD_LR - # - matrix params in transformer blocks use MATRIX_LR via Muon - # - vectors/scalars use SCALAR_LR via Adam - block_named_params = list(base_model.blocks.named_parameters()) - matrix_params = [ - p - for name, p in block_named_params - if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - if base_model.mtp_num_heads > 0: - matrix_params.extend([p for p in base_model.mtp_heads.parameters() if p.ndim == 2]) - scalar_params = [ - p - for name, p in block_named_params - if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - if base_model.skip_weights.numel() > 0: - scalar_params.append(base_model.skip_weights) - scalar_params.append(base_model.smear.gate) - if base_model.bigram is not None: - scalar_params.append(base_model.bigram.scale) - token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr - tok_params = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] - if base_model.bigram is not None: - tok_params.append({"params": [base_model.bigram.embed.weight], "lr": token_lr, "base_lr": token_lr}) - if base_model.bigram.proj is not None: - matrix_params.append(base_model.bigram.proj.weight) - optimizer_tok = torch.optim.AdamW( - tok_params, - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - weight_decay=args.adam_wd, - fused=True, - ) - optimizer_muon = Muon( - matrix_params, - lr=args.matrix_lr, - momentum=args.muon_momentum, - backend_steps=args.muon_backend_steps, - weight_decay=args.muon_wd, - ) - for group in optimizer_muon.param_groups: - group["base_lr"] = args.matrix_lr - optimizer_scalar = torch.optim.AdamW( - [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - weight_decay=args.adam_wd, - fused=True, - ) - optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] - if base_model.lm_head is not None: - optimizer_head = torch.optim.Adam( - [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizers.insert(1, optimizer_head) - - n_params = sum(p.numel() for p in base_model.parameters()) - mtp_params = sum(p.numel() for p in base_model.mtp_heads.parameters()) - log0(f"model_params:{n_params}") - log0(f"mtp_num_heads:{args.mtp_num_heads} mtp_loss_weight:{args.mtp_loss_weight} mtp_params:{mtp_params}") - log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") - log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") - log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") - log0( - f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " - f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " - f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" - ) - log0( - f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " - f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " - f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" - ) - log0(f"seed:{args.seed}") - - # ----------------------------- - # DATA LOADER & MODEL WARMUP - # ----------------------------- - - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - def zero_grad_all() -> None: - for opt in optimizers: - opt.zero_grad(set_to_none=True) - - max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None - - def lr_mul(step: int, elapsed_ms: float) -> float: - if args.warmdown_iters <= 0: - return 1.0 - if max_wallclock_ms is None: - warmdown_start = max(args.iterations - args.warmdown_iters, 0) - return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 - step_ms = elapsed_ms / max(step, 1) - warmdown_ms = args.warmdown_iters * step_ms - remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) - return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 - - # Warmup primes the compiled forward/backward/optimizer paths, then we restore the - # initial weights/optimizer state so measured training starts from the true init. - if args.warmup_steps > 0: - initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} - initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] - model.train() - for warmup_step in range(args.warmup_steps): - zero_grad_all() - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - warmup_loss = model(x, y) - (warmup_loss * grad_scale).backward() - for opt in optimizers: - opt.step() - zero_grad_all() - if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: - log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") - base_model.load_state_dict(initial_model_state, strict=True) - for opt, state in zip(optimizers, initial_optimizer_states, strict=True): - opt.load_state_dict(state) - zero_grad_all() - if distributed: - model.require_backward_grad_sync = True - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - # ----------------------------- - # MAIN TRAINING LOOP - # ----------------------------- - - swa_state: dict[str, Tensor] | None = None - swa_count = 0 - - ema_state: dict[str, Tensor] | None = None - if args.ema_enabled: - ema_state = {name: t.detach().float().clone() for name, t in base_model.state_dict().items()} - - training_time_ms = 0.0 - stop_after_step: int | None = None - torch.cuda.synchronize() - t0 = time.perf_counter() - - step = 0 - while True: - last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) - - should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) - if should_validate: - torch.cuda.synchronize() - training_time_ms += 1000.0 * (time.perf_counter() - t0) - val_loss, val_bpb = eval_val( - args, - model, - rank, - world_size, - device, - grad_accum_steps, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, - ) - log0( - f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " - f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" - ) - torch.cuda.synchronize() - t0 = time.perf_counter() - - if last_step: - if stop_after_step is not None and step < args.iterations: - log0( - f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " - f"step:{step}/{args.iterations}" - ) - break - - elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - scale = lr_mul(step, elapsed_ms) - qat_threshold = float(os.environ.get("QAT_THRESHOLD", "0.1")) - if args.late_qat and scale < qat_threshold and not CastedLinear._qat_enabled: - CastedLinear._qat_enabled = True - log0(f"late_qat:enabled step:{step} scale:{scale:.4f}") - zero_grad_all() - train_loss = torch.zeros((), device=device) - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - loss = model(x, y) - train_loss += loss.detach() - (loss * grad_scale).backward() - train_loss /= grad_accum_steps - - frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 - muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum - for group in optimizer_muon.param_groups: - group["momentum"] = muon_momentum - - for opt in optimizers: - for group in opt.param_groups: - group["lr"] = group["base_lr"] * scale - - if args.grad_clip_norm > 0: - torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) - for opt in optimizers: - opt.step() - zero_grad_all() - - step += 1 - approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - - if ema_state is not None: - d = args.ema_decay - with torch.no_grad(): - for name, t in base_model.state_dict().items(): - ema_state[name].mul_(d).add_(t.detach().float(), alpha=1.0 - d) - - if args.swa_enabled and not args.ema_enabled and scale < 0.5 and step % args.swa_every == 0: - if swa_state is None: - swa_state = {name: t.detach().float().clone() for name, t in base_model.state_dict().items()} - swa_count = 1 - log0(f"swa:start step:{step}") - else: - for name, t in base_model.state_dict().items(): - swa_state[name].add_(t.detach().float()) - swa_count += 1 - - should_log_train = ( - args.train_log_every > 0 - and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) - ) - if should_log_train: - log0( - f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " - f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" - ) - - # Needed to sync whether we've reached the wallclock cap. - reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms - if distributed and max_wallclock_ms is not None: - reached_cap_tensor = torch.tensor(int(reached_cap), device=device) - dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) - reached_cap = bool(reached_cap_tensor.item()) - if stop_after_step is None and reached_cap: - stop_after_step = step - - log0( - f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " - f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" - ) - - if ema_state is not None: - log0("ema:applying EMA weights") - avg_state = {name: t.to(dtype=base_model.state_dict()[name].dtype) - for name, t in ema_state.items()} - del ema_state - base_model.load_state_dict(avg_state, strict=True) - elif args.swa_enabled and swa_state is not None and swa_count > 1: - log0(f"swa:applying averaged {swa_count} checkpoints") - avg_state = {name: (t / swa_count).to(dtype=base_model.state_dict()[name].dtype) - for name, t in swa_state.items()} - del swa_state - base_model.load_state_dict(avg_state, strict=True) - - # ----------------------------- - # SERIALIZATION + ROUNDTRIP VALIDATION - # ----------------------------- - - full_state_dict = base_model.state_dict() - export_sd = {k: v for k, v in full_state_dict.items() if "mtp_heads" not in k} - excluded_mtp = sum(int(t.numel()) for k, t in full_state_dict.items() if "mtp_heads" in k) - if excluded_mtp > 0: - log0(f"export_excluding_mtp_params:{excluded_mtp}") - - if master_process: - torch.save(export_sd, "final_model.pt") - model_bytes = os.path.getsize("final_model.pt") - code_bytes = len(code.encode("utf-8")) - log0(f"Serialized model: {model_bytes} bytes") - log0(f"Code size: {code_bytes} bytes") - - sd_cpu = {k: v.detach().cpu() for k, v in export_sd.items()} - quant_result, quant_meta = mixed_quantize_int6(sd_cpu, {"mlp", "attn"}) - quant_buf = io.BytesIO() - torch.save({"w": quant_result, "m": quant_meta}, quant_buf) - quant_raw = quant_buf.getvalue() - quant_blob = zstandard.ZstdCompressor(level=22).compress(quant_raw) if _COMPRESSOR == "zstd" else zlib.compress(quant_raw, 9) - if master_process: - with open("final_model.int6.ptz", "wb") as f: - f.write(quant_blob) - quant_file_bytes = len(quant_blob) - code_bytes = len(code.encode("utf-8")) - log0(f"Serialized model int6+{_COMPRESSOR}: {quant_file_bytes} bytes") - log0(f"Total submission size int6+{_COMPRESSOR}: {quant_file_bytes + code_bytes} bytes") - - # Roundtrip: decompress + dequantize into fresh model + eval - if distributed: - dist.barrier() - with open("final_model.int6.ptz", "rb") as f: - quant_blob_disk = f.read() - quant_state = torch.load( - io.BytesIO(zstandard.ZstdDecompressor().decompress(quant_blob_disk) if _COMPRESSOR == "zstd" else zlib.decompress(quant_blob_disk)), - map_location="cpu", - ) - deq_state = dequantize_mixed_int6(quant_state["w"], quant_state["m"], sd_cpu) - - eval_model = GPT( - vocab_size=args.vocab_size, num_layers=args.num_layers, model_dim=args.model_dim, - num_heads=args.num_heads, num_kv_heads=args.num_kv_heads, mlp_mult=args.mlp_mult, - tie_embeddings=args.tie_embeddings, tied_embed_init_std=args.tied_embed_init_std, - logit_softcap=args.logit_softcap, rope_base=args.rope_base, qk_gain_init=args.qk_gain_init, - mtp_num_heads=0, mtp_loss_weight=0.0, - bigram_vocab_size=args.bigram_vocab_size, bigram_dim=args.bigram_dim, - xsa_last_n=args.xsa_last_n, - rope_dims=args.rope_dims, - ln_scale=args.ln_scale, - ).to(device).bfloat16() - for m in eval_model.modules(): - if isinstance(m, CastedLinear): - m.float() - restore_low_dim_params_to_fp32(eval_model) - eval_model.load_state_dict(deq_state, strict=True) - compiled_eval = torch.compile(eval_model, dynamic=False, fullgraph=True) - - # Standard non-overlapping eval (sanity check) - torch.cuda.synchronize() - t_qeval = time.perf_counter() - q_val_loss, q_val_bpb = eval_val( - args, compiled_eval, rank, world_size, device, grad_accum_steps, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - eval_seq_len=effective_eval_seq_len, - ) - torch.cuda.synchronize() - log0( - f"final_int6_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " - f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" - ) - log0(f"final_int6_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") - - # Sliding window eval (submission score) - sw_seq_len = effective_eval_seq_len - if args.eval_stride > 0 and args.eval_stride < sw_seq_len: - torch.cuda.synchronize() - t_slide = time.perf_counter() - sw_val_loss, sw_val_bpb = eval_val_sliding( - args, eval_model, rank, world_size, device, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - stride=args.eval_stride, - eval_seq_len=sw_seq_len, - ) - torch.cuda.synchronize() - log0( - f"final_int6_sliding_window val_loss:{sw_val_loss:.4f} val_bpb:{sw_val_bpb:.4f} " - f"stride:{args.eval_stride} eval_time:{1000.0 * (time.perf_counter() - t_slide):.0f}ms" - ) - log0(f"final_int6_sliding_window_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") - - # Second sliding window eval at stride=64 for submission comparison - if args.eval_stride != 64 and 64 < sw_seq_len: - torch.cuda.synchronize() - t_slide64 = time.perf_counter() - sw64_val_loss, sw64_val_bpb = eval_val_sliding( - args, eval_model, rank, world_size, device, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - stride=64, - eval_seq_len=sw_seq_len, - ) - torch.cuda.synchronize() - log0( - f"final_int6_sliding_window_s64 val_loss:{sw64_val_loss:.4f} val_bpb:{sw64_val_bpb:.4f} " - f"stride:64 eval_time:{1000.0 * (time.perf_counter() - t_slide64):.0f}ms" - ) - log0(f"final_int6_sliding_window_s64_exact val_loss:{sw64_val_loss:.8f} val_bpb:{sw64_val_bpb:.8f}") - - if distributed: - dist.destroy_process_group() - - -if __name__ == "__main__": - main() - -==================================================================================================== -Running Python 3.12.3 (main, Nov 6 2025, 13:44:16) [GCC 13.3.0] -Running PyTorch 2.9.1+cu128 -Sat Mar 21 05:44:05 2026 -+-----------------------------------------------------------------------------------------+ -| NVIDIA-SMI 580.126.09 Driver Version: 580.126.09 CUDA Version: 13.0 | -+-----------------------------------------+------------------------+----------------------+ -| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | -| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | -| | | MIG M. | -|=========================================+========================+======================| -| 0 NVIDIA H100 80GB HBM3 On | 00000000:18:00.0 Off | 0 | -| N/A 32C P0 115W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 1 NVIDIA H100 80GB HBM3 On | 00000000:2A:00.0 Off | 0 | -| N/A 34C P0 124W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 2 NVIDIA H100 80GB HBM3 On | 00000000:3A:00.0 Off | 0 | -| N/A 34C P0 117W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 3 NVIDIA H100 80GB HBM3 On | 00000000:5D:00.0 Off | 0 | -| N/A 32C P0 115W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 4 NVIDIA H100 80GB HBM3 On | 00000000:9A:00.0 Off | 0 | -| N/A 31C P0 115W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 5 NVIDIA H100 80GB HBM3 On | 00000000:AB:00.0 Off | 0 | -| N/A 33C P0 117W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 6 NVIDIA H100 80GB HBM3 On | 00000000:BA:00.0 Off | 0 | -| N/A 31C P0 114W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 7 NVIDIA H100 80GB HBM3 On | 00000000:DB:00.0 Off | 0 | -| N/A 30C P0 115W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ - -+-----------------------------------------------------------------------------------------+ -| Processes: | -| GPU GI CI PID Type Process name GPU Memory | -| ID ID Usage | -|=========================================================================================| -| 0 N/A N/A 185030 C /usr/local/bin/python 1512MiB | -| 1 N/A N/A 185031 C /usr/local/bin/python 1512MiB | -| 2 N/A N/A 185032 C /usr/local/bin/python 1512MiB | -| 3 N/A N/A 185033 C /usr/local/bin/python 1512MiB | -| 4 N/A N/A 185034 C /usr/local/bin/python 1512MiB | -| 5 N/A N/A 185035 C /usr/local/bin/python 1512MiB | -| 6 N/A N/A 185036 C /usr/local/bin/python 1512MiB | -| 7 N/A N/A 185037 C /usr/local/bin/python 1512MiB | -+-----------------------------------------------------------------------------------------+ - -==================================================================================================== -val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model -train_loader:dataset:fineweb10B_sp1024 train_shards:80 -val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 -model_params:26829913 -mtp_num_heads:0 mtp_loss_weight:0.2 mtp_params:0 -world_size:8 grad_accum_steps:1 -sdp_backends:cudnn=False flash=True mem_efficient=False math=False -attention_mode:gqa num_heads:8 num_kv_heads:4 -tie_embeddings:True embed_lr:0.035 head_lr:0.0 matrix_lr:0.025 scalar_lr:0.025 -train_batch_tokens:786432 train_seq_len:2048 iterations:9000 warmup_steps:20 max_wallclock_seconds:600.000 -seed:42 -warmup_step:1/20 -warmup_step:2/20 -warmup_step:3/20 -warmup_step:4/20 -warmup_step:5/20 -warmup_step:6/20 -warmup_step:7/20 -warmup_step:8/20 -warmup_step:9/20 -warmup_step:10/20 -warmup_step:11/20 -warmup_step:12/20 -warmup_step:13/20 -warmup_step:14/20 -warmup_step:15/20 -warmup_step:16/20 -warmup_step:17/20 -warmup_step:18/20 -warmup_step:19/20 -warmup_step:20/20 -step:1/9000 train_loss:6.9320 train_time:142ms step_avg:142.11ms -step:2/9000 train_loss:8.7768 train_time:224ms step_avg:112.24ms -step:3/9000 train_loss:7.9054 train_time:315ms step_avg:104.98ms -step:4/9000 train_loss:7.1607 train_time:411ms step_avg:102.83ms -step:5/9000 train_loss:6.9209 train_time:516ms step_avg:103.27ms -step:6/9000 train_loss:6.8715 train_time:614ms step_avg:102.27ms -step:7/9000 train_loss:6.7137 train_time:703ms step_avg:100.43ms -step:8/9000 train_loss:6.5909 train_time:793ms step_avg:99.18ms -step:9/9000 train_loss:6.3294 train_time:889ms step_avg:98.82ms -step:10/9000 train_loss:6.0588 train_time:999ms step_avg:99.95ms -step:200/9000 train_loss:2.4340 train_time:17064ms step_avg:85.32ms -step:400/9000 train_loss:2.4435 train_time:34106ms step_avg:85.26ms -step:600/9000 train_loss:2.3511 train_time:51027ms step_avg:85.04ms -step:800/9000 train_loss:2.2489 train_time:67997ms step_avg:85.00ms -step:1000/9000 train_loss:2.2826 train_time:84896ms step_avg:84.90ms -step:1200/9000 train_loss:2.3525 train_time:101864ms step_avg:84.89ms -step:1400/9000 train_loss:2.1856 train_time:118833ms step_avg:84.88ms -step:1600/9000 train_loss:2.0735 train_time:135730ms step_avg:84.83ms -step:1800/9000 train_loss:2.1486 train_time:152728ms step_avg:84.85ms -step:2000/9000 train_loss:2.0638 train_time:169636ms step_avg:84.82ms -step:2200/9000 train_loss:2.1301 train_time:186633ms step_avg:84.83ms -step:2400/9000 train_loss:2.0638 train_time:203557ms step_avg:84.82ms -step:2600/9000 train_loss:2.1098 train_time:220546ms step_avg:84.83ms -step:2800/9000 train_loss:2.1492 train_time:237532ms step_avg:84.83ms -step:3000/9000 train_loss:2.1564 train_time:254422ms step_avg:84.81ms -step:3200/9000 train_loss:2.1705 train_time:271405ms step_avg:84.81ms -step:3400/9000 train_loss:2.0181 train_time:288316ms step_avg:84.80ms -step:3600/9000 train_loss:2.0976 train_time:305390ms step_avg:84.83ms -step:3800/9000 train_loss:2.0773 train_time:322318ms step_avg:84.82ms -step:4000/9000 train_loss:1.9823 train_time:339425ms step_avg:84.86ms -step:4200/9000 train_loss:2.1635 train_time:356724ms step_avg:84.93ms -step:4400/9000 train_loss:2.0384 train_time:373673ms step_avg:84.93ms -step:4600/9000 train_loss:1.8476 train_time:390667ms step_avg:84.93ms -step:4800/9000 train_loss:2.4311 train_time:407586ms step_avg:84.91ms -step:5000/9000 train_loss:2.1140 train_time:424611ms step_avg:84.92ms -step:5200/9000 train_loss:2.0471 train_time:441585ms step_avg:84.92ms -step:5400/9000 train_loss:2.0539 train_time:458641ms step_avg:84.93ms -step:5600/9000 train_loss:1.9587 train_time:475668ms step_avg:84.94ms -step:5800/9000 train_loss:2.0055 train_time:492573ms step_avg:84.93ms -step:6000/9000 train_loss:1.9430 train_time:509591ms step_avg:84.93ms -step:6200/9000 train_loss:1.9568 train_time:526503ms step_avg:84.92ms -step:6400/9000 train_loss:2.0014 train_time:543572ms step_avg:84.93ms -step:6600/9000 train_loss:1.8450 train_time:560515ms step_avg:84.93ms -late_qat:enabled step:6764 scale:0.1000 -step:6800/9000 train_loss:2.0266 train_time:577545ms step_avg:84.93ms -step:7000/9000 train_loss:1.7903 train_time:594889ms step_avg:84.98ms -step:7061/9000 val_loss:1.9278 val_bpb:1.1418 train_time:600046ms step_avg:84.98ms -stopping_early: wallclock_cap train_time:600046ms step:7061/9000 -peak memory allocated: 20596 MiB reserved: 20652 MiB -ema:applying EMA weights -Serialized model: 105783807 bytes -Code size: 67617 bytes -Serialized model int6+zstd: 15461049 bytes -Total submission size int6+zstd: 15528666 bytes -final_int6_roundtrip val_loss:1.9394 val_bpb:1.1486 eval_time:6093ms -final_int6_roundtrip_exact val_loss:1.93940332 val_bpb:1.14862389 -final_int6_sliding_window val_loss:1.8995 val_bpb:1.1250 stride:64 eval_time:72687ms -final_int6_sliding_window_exact val_loss:1.89948149 val_bpb:1.12498291 diff --git a/records/track_10min_16mb/2026-03-22_11L_EMA_GPTQ-lite_warmdown3500_QAT015_1.1233/README.md b/records/track_10min_16mb/2026-03-22_11L_EMA_GPTQ-lite_warmdown3500_QAT015_1.1233/README.md deleted file mode 100644 index cc01a339bc..0000000000 --- a/records/track_10min_16mb/2026-03-22_11L_EMA_GPTQ-lite_warmdown3500_QAT015_1.1233/README.md +++ /dev/null @@ -1,76 +0,0 @@ -## Record: 11L EMA + GPTQ-lite + warmdown3500 + QAT@0.15 (val_bpb: 1.1233) - -**val_bpb: 1.1233** (sliding window stride=64, 3-seed mean) | **15.55 MB** (mean) | 8xH100 SXM, 600s - -### Key Innovations Over PR #374 - -Two novel post-training optimizations plus training hyperparameter tuning on top of PR #374's architecture: - -| Change | PR #374 | This | Impact | -|--------|---------|------|--------| -| **GPTQ-lite** | Fixed clip (row max) | 5 clip percentiles per row, pick min MSE | -0.0006 BPB (zero training cost) | -| **EMA** | None (Tight SWA only) | EMA decay=0.997 every step | -0.0006 BPB (smoother averaging) | -| **Warmdown** | 3000 | 3500 | -0.0002 BPB | -| **Late QAT threshold** | 0.1 | 0.15 | -0.0001 BPB (earlier fake quant, smaller quant gap) | -| **Total** | **1.1246** | **1.1233** | **-0.0013 BPB** | - -### GPTQ-lite: Per-Layer Optimal Clip Percentile Search - -Instead of using the row maximum for int6 quantization scale, we try 5 clip percentiles (0.999, 0.9995, 0.9999, 0.99999, 1.0) per weight matrix row and pick the one minimizing reconstruction MSE. This is applied during post-training quantization with zero training cost. - -### EMA Weight Averaging - -Exponential moving average (decay=0.997) maintained every training step, applied before quantization. Stacks with Tight SWA — EMA provides continuous smoothing while SWA captures discrete checkpoints during warmdown. - -### Results (3 seeds, 8xH100 SXM) - -| Seed | Steps | val_loss | Sliding BPB (s64) | Artifact | -|------|-------|----------|-------------------|----------| -| **1337** | 7101 | 1.8958 | **1.1228** | 15.56 MB | -| 42 | ~7100 | 1.8972 | 1.1236 | 15.54 MB | -| 2024 | ~7100 | 1.8971 | 1.1236 | 15.59 MB | - -**Mean: 1.1233 | Std: 0.0005** | Submitted: seed 1337 (best) - -### Architecture (from PR #374) - -- 11 transformer layers, 512-dim, 8 heads (4 KV heads, GQA) -- 3x MLP expansion (1536 hidden), relu-squared activation -- U-Net skip connections (5 encoder, 6 decoder) -- Efficient Partial XSA on last 4 layers (GQA-aware, zero-alloc) -- Partial RoPE (16/64 dims) + NTK-aware scaling -- LN Scale Factor 1/sqrt(layer_idx+1) -- Shared Value Embedding (dim=128, layers 9,10) with per-layer learned scales -- SmearGate + BigramHash (2048 buckets, dim=128) -- Tied embeddings, logit softcap=30.0 - -### Training - -- FlashAttention 3 (Hopper-optimized) -- Muon optimizer (matrices): lr=0.025, momentum=0.99 (warmup 0.92->0.99 over 1500 steps), WD=0.04 -- AdamW (embeddings): lr=0.035, (scalars): lr=0.025, WD=0.04 -- Gradient clip: 0.3 -- Batch: 786,432 tokens/step, seq_len=2048 -- Warmdown: 3500 iterations (wallclock-based) -- **EMA**: decay=0.997, every step -- **Tight SWA**: every 50 steps when scale<0.2 -- **Late QAT**: STE int6 fake-quantization when LR scale<0.15 -- OrthoInit + muP-scaled output projections - -### Quantization - -- **GPTQ-lite**: Per-row optimal clip percentile search (5 candidates) for int6 -- Int6 per-row for MLP + attention weights -- Int8 per-row for embeddings -- Control tensors in fp32 -- zstd level 22 compression - -### Run Command - -```bash -SEED=1337 bash eval/eval.sh -``` - -### Reproducibility - -All 3 seeds produce valid artifacts under 16MB with tight variance (std=0.0005 BPB). The GPTQ-lite clip search is deterministic. diff --git a/records/track_10min_16mb/2026-03-22_11L_EMA_GPTQ-lite_warmdown3500_QAT015_1.1233/submission.json b/records/track_10min_16mb/2026-03-22_11L_EMA_GPTQ-lite_warmdown3500_QAT015_1.1233/submission.json deleted file mode 100644 index e7cc1eb49c..0000000000 --- a/records/track_10min_16mb/2026-03-22_11L_EMA_GPTQ-lite_warmdown3500_QAT015_1.1233/submission.json +++ /dev/null @@ -1,10 +0,0 @@ -{ - "author": "Tianhao Wu", - "github_id": "signalrush", - "name": "Record: 11L EMA + GPTQ-lite + warmdown3500 + QAT@0.15", - "blurb": "EMA(0.997) weight averaging + GPTQ-lite optimal clip percentile search + warmdown=3500 + Late QAT threshold=0.15, built on PR#374 stack (11L, XSA4, Partial RoPE 16/64, LN Scale, VE128, Tight SWA, SmearGate, BigramHash, int6+zstd-22).", - "date": "2026-03-22T00:00:00Z", - "val_loss": 1.89576235, - "val_bpb": 1.12278022, - "bytes_total": 15555017 -} diff --git a/records/track_10min_16mb/2026-03-22_11L_EMA_GPTQ-lite_warmdown3500_QAT015_1.1233/train.log b/records/track_10min_16mb/2026-03-22_11L_EMA_GPTQ-lite_warmdown3500_QAT015_1.1233/train.log deleted file mode 100644 index 7074fcb3f8..0000000000 --- a/records/track_10min_16mb/2026-03-22_11L_EMA_GPTQ-lite_warmdown3500_QAT015_1.1233/train.log +++ /dev/null @@ -1,86 +0,0 @@ -Running: torchrun --standalone --nproc_per_node=8 train_gpt.py -W0321 23:58:21.049000 2107250 site-packages/torch/distributed/run.py:852] -W0321 23:58:21.049000 2107250 site-packages/torch/distributed/run.py:852] ***************************************** -W0321 23:58:21.049000 2107250 site-packages/torch/distributed/run.py:852] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. -W0321 23:58:21.049000 2107250 site-packages/torch/distributed/run.py:852] ***************************************** -logs/4d9a81fd-61ff-4d91-ba5b-8122bf0d29b1.txt -val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model -train_loader:dataset:fineweb10B_sp1024 train_shards:80 -val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 -model_params:26993756 -mtp_num_heads:0 mtp_loss_weight:0.2 mtp_params:0 -XSA:last_4 active_layers:[7, 8, 9, 10] -world_size:8 grad_accum_steps:1 -sdp_backends:cudnn=False flash=True mem_efficient=False math=False -attention_mode:gqa num_heads:8 num_kv_heads:4 -tie_embeddings:True embed_lr:0.035 head_lr:0.0 matrix_lr:0.025 scalar_lr:0.025 -train_batch_tokens:786432 train_seq_len:2048 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 -seed:1337 -warmup_step:1/20 -warmup_step:2/20 -warmup_step:3/20 -warmup_step:4/20 -warmup_step:5/20 -warmup_step:6/20 -warmup_step:7/20 -warmup_step:8/20 -warmup_step:9/20 -warmup_step:10/20 -warmup_step:11/20 -warmup_step:12/20 -warmup_step:13/20 -warmup_step:14/20 -warmup_step:15/20 -warmup_step:16/20 -warmup_step:17/20 -warmup_step:18/20 -warmup_step:19/20 -warmup_step:20/20 -step:0/20000 val_loss:6.9279 val_bpb:4.1031 train_time:0ms step_avg:0.01ms -step:1/20000 train_loss:6.9299 train_time:132ms step_avg:132.47ms -step:2/20000 train_loss:8.5642 train_time:210ms step_avg:105.01ms -step:3/20000 train_loss:7.8180 train_time:292ms step_avg:97.49ms -step:4/20000 train_loss:7.2359 train_time:374ms step_avg:93.61ms -step:5/20000 train_loss:7.0667 train_time:457ms step_avg:91.34ms -step:6/20000 train_loss:6.8298 train_time:539ms step_avg:89.79ms -step:7/20000 train_loss:6.7205 train_time:621ms step_avg:88.70ms -step:8/20000 train_loss:6.7474 train_time:703ms step_avg:87.89ms -step:9/20000 train_loss:6.4050 train_time:785ms step_avg:87.22ms -step:10/20000 train_loss:6.0813 train_time:867ms step_avg:86.73ms -step:500/20000 train_loss:2.3955 train_time:42105ms step_avg:84.21ms -step:1000/20000 train_loss:2.2705 train_time:84305ms step_avg:84.30ms -step:1500/20000 train_loss:2.2150 train_time:126530ms step_avg:84.35ms -step:2000/20000 train_loss:2.0567 train_time:168765ms step_avg:84.38ms -step:2500/20000 train_loss:2.1560 train_time:211011ms step_avg:84.40ms -step:3000/20000 train_loss:2.1529 train_time:253267ms step_avg:84.42ms -step:3500/20000 train_loss:2.1761 train_time:295680ms step_avg:84.48ms -step:4000/20000 train_loss:1.9673 train_time:337942ms step_avg:84.49ms -step:4000/20000 val_loss:2.0601 val_bpb:1.2201 train_time:337946ms step_avg:84.49ms -step:4500/20000 train_loss:2.1182 train_time:380212ms step_avg:84.49ms -step:5000/20000 train_loss:2.1014 train_time:422463ms step_avg:84.49ms -step:5500/20000 train_loss:2.0151 train_time:464730ms step_avg:84.50ms -step:6000/20000 train_loss:1.9386 train_time:507080ms step_avg:84.51ms -swa:start step:6450 -step:6500/20000 train_loss:2.0793 train_time:549410ms step_avg:84.52ms -late_qat:enabled step:6574 scale:0.1499 -step:7000/20000 train_loss:1.7878 train_time:591918ms step_avg:84.56ms -step:7096/20000 val_loss:1.9240 val_bpb:1.1395 train_time:600036ms step_avg:84.56ms -stopping_early: wallclock_cap train_time:600036ms step:7096/20000 -peak memory allocated: 20862 MiB reserved: 21204 MiB -ema:applying EMA weights -DIAGNOSTIC post_ema val_loss:1.9222 val_bpb:1.1385 eval_time:1960ms -Serialized model: 106178569 bytes -Code size: 67603 bytes -Serialized model int6+zstd: 15487414 bytes -Total submission size int6+zstd: 15555017 bytes -Total submission size int8+zlib: 15555017 bytes -final_int6_roundtrip val_loss:1.9359 val_bpb:1.1466 eval_time:5788ms -final_int6_roundtrip_exact val_loss:1.93592275 val_bpb:1.14656250 -final_int6_sliding_window val_loss:1.8958 val_bpb:1.1228 stride:64 eval_time:73133ms -final_int6_sliding_window_exact val_loss:1.89576235 val_bpb:1.12278022 -final_int8_zlib_roundtrip_exact val_loss:1.89576235 val_bpb:1.12278022 ---- -val_bpb: 1.12278022 -artifact_bytes: 15555017 -line_count: 1402 -valid: true diff --git a/records/track_10min_16mb/2026-03-22_11L_EMA_GPTQ-lite_warmdown3500_QAT015_1.1233/train_gpt.py b/records/track_10min_16mb/2026-03-22_11L_EMA_GPTQ-lite_warmdown3500_QAT015_1.1233/train_gpt.py deleted file mode 100644 index eb42a63276..0000000000 --- a/records/track_10min_16mb/2026-03-22_11L_EMA_GPTQ-lite_warmdown3500_QAT015_1.1233/train_gpt.py +++ /dev/null @@ -1,1402 +0,0 @@ -from __future__ import annotations -import copy -import glob -import io -import math -import os -import random -import subprocess -import sys -import time -import uuid -import zlib -from pathlib import Path -try: - import zstandard - _COMPRESSOR = "zstd" -except ImportError: - _COMPRESSOR = "zlib" -import numpy as np -import sentencepiece as spm -import torch -import torch.distributed as dist -import torch.nn.functional as F -from torch import Tensor, nn -from torch.nn.parallel import DistributedDataParallel as DDP -from flash_attn_interface import flash_attn_func as flash_attn_3_func -class Hyperparameters: - data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") - train_files = os.path.join(data_path, "fineweb_train_*.bin") - val_files = os.path.join(data_path, "fineweb_val_*.bin") - tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") - run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) - seed = int(os.environ.get("SEED", 1337)) - val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) - val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 4000)) - train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 500)) - iterations = int(os.environ.get("ITERATIONS", 20000)) - warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3500)) - warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) - train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 786_432)) - train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) - eval_seq_len = int(os.environ.get("EVAL_SEQ_LEN", 2048)) - max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) - qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) - vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) - num_layers = int(os.environ.get("NUM_LAYERS", 11)) - num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) - model_dim = int(os.environ.get("MODEL_DIM", 512)) - num_heads = int(os.environ.get("NUM_HEADS", 8)) - mlp_mult = float(os.environ.get("MLP_MULT", 3.0)) - tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) - rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) - logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) - embed_lr = float(os.environ.get("EMBED_LR", 0.6)) - head_lr = float(os.environ.get("HEAD_LR", 0.008)) - tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.035)) - tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) - matrix_lr = float(os.environ.get("MATRIX_LR", 0.025)) - scalar_lr = float(os.environ.get("SCALAR_LR", 0.025)) - muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) - muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) - muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) - muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) - beta1 = float(os.environ.get("BETA1", 0.9)) - beta2 = float(os.environ.get("BETA2", 0.95)) - adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) - grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) - eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) - mtp_num_heads = int(os.environ.get("MTP_NUM_HEADS", 0)) - mtp_loss_weight = float(os.environ.get("MTP_LOSS_WEIGHT", 0.2)) - muon_beta2 = float(os.environ.get("MUON_BETA2", 0.95)) - swa_enabled = bool(int(os.environ.get("SWA_ENABLED", "1"))) - swa_every = int(os.environ.get("SWA_EVERY", 50)) # tighter: collect more recent checkpoints - muon_wd = float(os.environ.get("MUON_WD", 0.04)) - adam_wd = float(os.environ.get("ADAM_WD", 0.04)) - qat_enabled = bool(int(os.environ.get("QAT_ENABLED", "0"))) - bigram_vocab_size = int(os.environ.get("BIGRAM_VOCAB_SIZE", 2048)) - bigram_dim = int(os.environ.get("BIGRAM_DIM", 128)) - xsa_last_n = int(os.environ.get("XSA_LAST_N", 4)) # XSA on last 4 layers (0 = disabled) - rope_dims = int(os.environ.get("ROPE_DIMS", 16)) - ln_scale = bool(int(os.environ.get("LN_SCALE", "1"))) - dtg_enabled = bool(int(os.environ.get("DTG_ENABLED", "0"))) - late_qat_threshold = float(os.environ.get("LATE_QAT_THRESHOLD", 0.15)) - ve_enabled = bool(int(os.environ.get("VE_ENABLED", "1"))) - ve_dim = int(os.environ.get("VE_DIM", 128)) - ve_layers = os.environ.get("VE_LAYERS", "9,10") -def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: - a, b, c = (3.4445, -4.7750, 2.0315) - X = G.bfloat16() - X /= X.norm() + eps - transposed = G.size(0) > G.size(1) - if transposed: - X = X.T - for _ in range(steps): - A = X @ X.T - B = b * A + c * A @ A - X = a * X + B @ X - return X.T if transposed else X -class Muon(torch.optim.Optimizer): - def __init__(self, params, lr: float, momentum: float, backend_steps: int, - nesterov: bool = True, weight_decay: float = 0.0): - super().__init__( - params, - dict(lr=lr, momentum=momentum, backend_steps=backend_steps, - nesterov=nesterov, weight_decay=weight_decay), - ) - @torch.no_grad() - def step(self, closure=None): - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - distributed = dist.is_available() and dist.is_initialized() - world_size = dist.get_world_size() if distributed else 1 - rank = dist.get_rank() if distributed else 0 - for group in self.param_groups: - params = group["params"] - if not params: - continue - lr = group["lr"] - momentum = group["momentum"] - backend_steps = group["backend_steps"] - nesterov = group["nesterov"] - total_params = sum(int(p.numel()) for p in params) - updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) - curr = 0 - for i, p in enumerate(params): - if i % world_size == rank and p.grad is not None: - g = p.grad - state = self.state[p] - if "momentum_buffer" not in state: - state["momentum_buffer"] = torch.zeros_like(g) - buf = state["momentum_buffer"] - buf.mul_(momentum).add_(g) - if nesterov: - g = g.add(buf, alpha=momentum) - g = zeropower_via_newtonschulz5(g, steps=backend_steps) - g *= max(1, g.size(0) / g.size(1)) ** 0.5 - updates_flat[curr : curr + p.numel()] = g.reshape(-1) - curr += p.numel() - if distributed: - dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) - wd = group.get("weight_decay", 0.0) - curr = 0 - for p in params: - if wd > 0.0: - p.data.mul_(1.0 - lr * wd) - g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) - p.add_(g, alpha=-lr) - curr += p.numel() - return loss -def build_sentencepiece_luts( - sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device -) -> tuple[Tensor, Tensor, Tensor]: - sp_vocab_size = int(sp.vocab_size()) - table_size = max(sp_vocab_size, vocab_size) - base_bytes_np = np.zeros((table_size,), dtype=np.int16) - has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) - is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) - for token_id in range(sp_vocab_size): - if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): - continue - is_boundary_token_np[token_id] = False - if sp.is_byte(token_id): - base_bytes_np[token_id] = 1 - continue - piece = sp.id_to_piece(token_id) - if piece.startswith("▁"): - has_leading_space_np[token_id] = True - piece = piece[1:] - base_bytes_np[token_id] = len(piece.encode("utf-8")) - return ( - torch.tensor(base_bytes_np, dtype=torch.int16, device=device), - torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), - torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), - ) -def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: - files = [Path(p) for p in sorted(glob.glob(pattern))] - if not files: - raise FileNotFoundError(f"No files found for pattern: {pattern}") - tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() - usable = ((tokens.numel() - 1) // seq_len) * seq_len - if usable <= 0: - raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") - return tokens[: usable + 1] -def eval_val( - args: Hyperparameters, - model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - grad_accum_steps: int, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, - eval_seq_len: int | None = None, -) -> tuple[float, float]: - seq_len = eval_seq_len or args.train_seq_len - local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) - if local_batch_tokens < seq_len: - raise ValueError( - "VAL_BATCH_SIZE must provide at least one sequence per rank; " - f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " - f"GRAD_ACCUM_STEPS={grad_accum_steps}, seq_len={seq_len}" - ) - local_batch_seqs = local_batch_tokens // seq_len - total_seqs = (val_tokens.numel() - 1) // seq_len - seq_start = (total_seqs * rank) // world_size - seq_end = (total_seqs * (rank + 1)) // world_size - val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) - val_token_count = torch.zeros((), device=device, dtype=torch.float64) - val_byte_count = torch.zeros((), device=device, dtype=torch.float64) - model.eval() - with torch.inference_mode(): - for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): - batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) - raw_start = batch_seq_start * seq_len - raw_end = batch_seq_end * seq_len + 1 - local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) - x = local[:-1].reshape(-1, seq_len) - y = local[1:].reshape(-1, seq_len) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - batch_loss = model(x, y).detach() - batch_token_count = float(y.numel()) - val_loss_sum += batch_loss.to(torch.float64) * batch_token_count - val_token_count += batch_token_count - prev_ids = x.reshape(-1) - tgt_ids = y.reshape(-1) - token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) - token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) - val_byte_count += token_bytes.to(torch.float64).sum() - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) - val_loss = val_loss_sum / val_token_count - bits_per_token = val_loss.item() / math.log(2.0) - tokens_per_byte = val_token_count.item() / val_byte_count.item() - model.train() - return float(val_loss.item()), float(bits_per_token * tokens_per_byte) -CONTROL_TENSOR_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "CONTROL_TENSOR_NAME_PATTERNS", - "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,smear,dtg_gate,ve_layer_scales,ve_shared.scale", - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", - ",".join(CONTROL_TENSOR_NAME_PATTERNS), - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 -INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 -INT8_PER_ROW_SCALE_DTYPE = torch.float16 -INT8_CLIP_PERCENTILE = 99.99984 -INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 -def tensor_nbytes(t: Tensor) -> int: - return int(t.numel()) * int(t.element_size()) -def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: - if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): - return t.float().contiguous() - if t.dtype in {torch.float32, torch.bfloat16}: - passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") - return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() - return t -def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: - t32 = t.float() - if t32.ndim == 2: - clip_abs = ( - torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) - if t32.numel() - else torch.empty((t32.shape[0],), dtype=torch.float32) - ) - clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) - scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) - q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() - return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() - clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 - scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) - q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() - return q, scale -def quantize_state_dict_int8(state_dict: dict[str, Tensor]): - quantized: dict[str, Tensor] = {} - scales: dict[str, Tensor] = {} - dtypes: dict[str, str] = {} - passthrough: dict[str, Tensor] = {} - passthrough_orig_dtypes: dict[str, str] = {} - qmeta: dict[str, dict[str, object]] = {} - stats = dict.fromkeys( - ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), - 0, - ) - for name, tensor in state_dict.items(): - t = tensor.detach().to("cpu").contiguous() - stats["param_count"] += int(t.numel()) - stats["num_tensors"] += 1 - stats["baseline_tensor_bytes"] += tensor_nbytes(t) - if not t.is_floating_point(): - stats["num_nonfloat_tensors"] += 1 - passthrough[name] = t - stats["int8_payload_bytes"] += tensor_nbytes(t) - continue - if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: - kept = keep_float_tensor(name, t, passthrough_orig_dtypes) - passthrough[name] = kept - stats["int8_payload_bytes"] += tensor_nbytes(kept) - continue - stats["num_float_tensors"] += 1 - q, s = quantize_float_tensor(t) - if s.ndim > 0: - qmeta[name] = {"scheme": "per_row", "axis": 0} - quantized[name] = q - scales[name] = s - dtypes[name] = str(t.dtype).removeprefix("torch.") - stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) - obj: dict[str, object] = { - "__quant_format__": "int8_clean_per_row_v1", - "quantized": quantized, - "scales": scales, - "dtypes": dtypes, - "passthrough": passthrough, - } - if qmeta: - obj["qmeta"] = qmeta - if passthrough_orig_dtypes: - obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes - return obj, stats -def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: - out: dict[str, Tensor] = {} - qmeta = obj.get("qmeta", {}) - passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) - for name, q in obj["quantized"].items(): - dtype = getattr(torch, obj["dtypes"][name]) - s = obj["scales"][name] - if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: - s = s.to(dtype=torch.float32) - out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() - else: - scale = float(s.item()) - out[name] = (q.float() * scale).to(dtype=dtype).contiguous() - for name, t in obj["passthrough"].items(): - out_t = t.detach().to("cpu").contiguous() - orig_dtype = passthrough_orig_dtypes.get(name) - if isinstance(orig_dtype, str): - out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() - out[name] = out_t - return out -def load_data_shard(file: Path) -> Tensor: - header_bytes = 256 * np.dtype(" None: - self.file_idx = (self.file_idx + 1) % len(self.files) - self.tokens = load_data_shard(self.files[self.file_idx]) - self.pos = 0 - def take(self, n: int) -> Tensor: - chunks: list[Tensor] = [] - remaining = n - while remaining > 0: - avail = self.tokens.numel() - self.pos - if avail <= 0: - self._advance_file() - continue - k = min(remaining, avail) - chunks.append(self.tokens[self.pos : self.pos + k]) - self.pos += k - remaining -= k - return chunks[0] if len(chunks) == 1 else torch.cat(chunks) -class DistributedTokenLoader: - def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): - self.rank = rank - self.world_size = world_size - self.device = device - self.stream = TokenStream(pattern) - def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: - local_tokens = global_tokens // (self.world_size * grad_accum_steps) - per_rank_span = local_tokens + 1 - chunk = self.stream.take(per_rank_span * self.world_size) - start = self.rank * per_rank_span - local = chunk[start : start + per_rank_span].to(dtype=torch.int64) - x = local[:-1].reshape(-1, seq_len) - y = local[1:].reshape(-1, seq_len) - return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) -class RMSNorm(nn.Module): - def __init__(self, eps: float | None = None): - super().__init__() - self.eps = eps - def forward(self, x: Tensor) -> Tensor: - return F.rms_norm(x, (x.size(-1),), eps=self.eps) -class CastedLinear(nn.Linear): - _qat_enabled: bool = False - def forward(self, x: Tensor) -> Tensor: - w = self.weight.to(x.dtype) - if CastedLinear._qat_enabled and self.training and w.ndim == 2: - with torch.no_grad(): - w32 = self.weight.float() - row_max = w32.abs().amax(dim=1) - scale = (row_max / 31.0).clamp_min(1.0 / 31.0) - w_q = (torch.clamp(torch.round(w32 / scale[:, None]), -32, 31) * scale[:, None]).to(x.dtype) - w = w + (w_q - w).detach() - bias = self.bias.to(x.dtype) if self.bias is not None else None - return F.linear(x, w, bias) -def restore_low_dim_params_to_fp32(module: nn.Module) -> None: - with torch.no_grad(): - for name, param in module.named_parameters(): - if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: - param.data = param.data.float() -class Rotary(nn.Module): - def __init__(self, dim: int, base: float = 10000.0, train_seq_len: int = 1024, rope_dims: int = 0): - super().__init__() - self.dim = dim - self.base = base - self.train_seq_len = train_seq_len - self.rope_dims = rope_dims if rope_dims > 0 else dim - inv_freq = 1.0 / (base ** (torch.arange(0, self.rope_dims, 2, dtype=torch.float32) / self.rope_dims)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self._seq_len_cached = 0 - self._cos_cached: Tensor | None = None - self._sin_cached: Tensor | None = None - def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: - if ( - self._cos_cached is None - or self._sin_cached is None - or self._seq_len_cached != seq_len - or self._cos_cached.device != device - ): - rd = self.rope_dims - if seq_len > self.train_seq_len: - scale = seq_len / self.train_seq_len - new_base = self.base * (scale ** (rd / (rd - 2))) - inv_freq = 1.0 / (new_base ** (torch.arange(0, rd, 2, dtype=torch.float32, device=device) / rd)) - else: - inv_freq = self.inv_freq.to(device) - t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) - freqs = torch.outer(t, inv_freq) - self._cos_cached = freqs.cos()[None, :, None, :] - self._sin_cached = freqs.sin()[None, :, None, :] - self._seq_len_cached = seq_len - return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) -def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor, rope_dims: int = 0) -> Tensor: - if rope_dims > 0 and rope_dims < x.size(-1): - x_rope, x_pass = x[..., :rope_dims], x[..., rope_dims:] - half = rope_dims // 2 - x1, x2 = x_rope[..., :half], x_rope[..., half:] - x_rope = torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) - return torch.cat((x_rope, x_pass), dim=-1) - half = x.size(-1) // 2 - x1, x2 = x[..., :half], x[..., half:] - return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) -class CausalSelfAttention(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - rope_base: float, - qk_gain_init: float, - ): - super().__init__() - if dim % num_heads != 0: - raise ValueError("model_dim must be divisible by num_heads") - if num_heads % num_kv_heads != 0: - raise ValueError("num_heads must be divisible by num_kv_heads") - self.num_heads = num_heads - self.num_kv_heads = num_kv_heads - self.head_dim = dim // num_heads - if self.head_dim % 2 != 0: - raise ValueError("head_dim must be even for RoPE") - kv_dim = self.num_kv_heads * self.head_dim - self.c_q = CastedLinear(dim, dim, bias=False) - self.c_k = CastedLinear(dim, kv_dim, bias=False) - self.c_v = CastedLinear(dim, kv_dim, bias=False) - self.proj = CastedLinear(dim, dim, bias=False) - self.proj._zero_init = True - self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) - self.rope_dims = 0 # set by GPT.__init__ for partial RoPE - self.rotary = Rotary(self.head_dim, base=rope_base, train_seq_len=1024) - self.use_xsa = False # set by GPT.__init__ for deep layers only - def _xsa_efficient(self, y: Tensor, v: Tensor) -> Tensor: - """Efficient XSA: subtract self-value projection via GQA-aware reshape (no repeat_interleave). - y: [B, T, H, D], v: [B, T, Hkv, D]. H must be divisible by Hkv.""" - B, T, H, D = y.shape - Hkv = v.size(-2) - group = H // Hkv - y_g = y.reshape(B, T, Hkv, group, D) # [B, T, Hkv, group, D] - vn = F.normalize(v, dim=-1).unsqueeze(-2) # [B, T, Hkv, 1, D] — broadcast ready - proj = (y_g * vn).sum(dim=-1, keepdim=True) * vn - return (y_g - proj).reshape(B, T, H, D) - def forward(self, x: Tensor, v_embed: Tensor | None = None) -> Tensor: - bsz, seqlen, dim = x.shape - q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim) - k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) - v = self.c_v(x) - if v_embed is not None: - v = v + v_embed - v = v.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) - q = F.rms_norm(q, (q.size(-1),)) - k = F.rms_norm(k, (k.size(-1),)) - cos, sin = self.rotary(seqlen, x.device, q.dtype) - q = apply_rotary_emb(q, cos, sin, self.rope_dims) - k = apply_rotary_emb(k, cos, sin, self.rope_dims) - q = q * self.q_gain.to(dtype=q.dtype)[None, None, :, None] - y = flash_attn_3_func(q, k, v, causal=True) - if self.use_xsa: - y = self._xsa_efficient(y, v) - y = y.reshape(bsz, seqlen, dim) - return self.proj(y) -class SmearGate(nn.Module): - def __init__(self, dim: int): - super().__init__() - self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) - def forward(self, x: Tensor) -> Tensor: - g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] - x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) - return (1 - g) * x + g * x_prev -class BigramHashEmbedding(nn.Module): - def __init__(self, bigram_vocab_size: int, bigram_dim: int, model_dim: int): - super().__init__() - self.bigram_vocab_size = bigram_vocab_size - self.embed = nn.Embedding(bigram_vocab_size, bigram_dim) - nn.init.zeros_(self.embed.weight) - self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None - if self.proj is not None: - nn.init.zeros_(self.proj.weight) - self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) - def bigram_hash(self, tokens: Tensor) -> Tensor: - t = tokens.to(torch.int32) - mod = self.bigram_vocab_size - 1 - out = torch.empty_like(t) - out[..., 0] = mod - out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod - return out.long() - def forward(self, token_ids: Tensor) -> Tensor: - h = self.embed(self.bigram_hash(token_ids)) - if self.proj is not None: - h = self.proj(h) - return h * self.scale.to(dtype=h.dtype) -class ValueEmbedding(nn.Module): - """Reinject token identity into attention values at specific layers. - Each table maps vocab tokens to a low-dim embedding, projected to model_dim.""" - def __init__(self, vocab_size: int, ve_dim: int, model_dim: int): - super().__init__() - self.embed = nn.Embedding(vocab_size, ve_dim) - nn.init.normal_(self.embed.weight, std=0.01) - self.proj = CastedLinear(ve_dim, model_dim, bias=False) if ve_dim != model_dim else None - if self.proj is not None: - nn.init.zeros_(self.proj.weight) - self.scale = nn.Parameter(torch.tensor(0.1, dtype=torch.float32)) - def forward(self, token_ids: Tensor) -> Tensor: - h = self.embed(token_ids) - if self.proj is not None: - h = self.proj(h) - return h * self.scale.to(dtype=h.dtype) -class MLP(nn.Module): - def __init__(self, dim: int, mlp_mult: int): - super().__init__() - hidden = int(mlp_mult * dim) - self.fc = CastedLinear(dim, hidden, bias=False) - self.proj = CastedLinear(hidden, dim, bias=False) - self.proj._zero_init = True - def forward(self, x: Tensor) -> Tensor: - x = torch.relu(self.fc(x)) - return self.proj(x.square()) -class Block(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: int, - rope_base: float, - qk_gain_init: float, - layer_idx: int = 0, - ln_scale: bool = False, - dtg: bool = False, - ): - super().__init__() - self.attn_norm = RMSNorm() - self.mlp_norm = RMSNorm() - self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) - self.mlp = MLP(dim, mlp_mult) - self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) - self.ln_scale_factor = 1.0 / math.sqrt(layer_idx + 1) if ln_scale else 1.0 - if dtg: - self.dtg_gate = nn.Linear(dim, 1, bias=True) - nn.init.zeros_(self.dtg_gate.weight) - nn.init.constant_(self.dtg_gate.bias, 2.0) - else: - self.dtg_gate = None - def forward(self, x: Tensor, x0: Tensor, v_embed: Tensor | None = None) -> Tensor: - mix = self.resid_mix.to(dtype=x.dtype) - x_in = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 - attn_out = self.attn(self.attn_norm(x_in) * self.ln_scale_factor, v_embed=v_embed) - x_out = x_in + self.attn_scale.to(dtype=x_in.dtype)[None, None, :] * attn_out - x_out = x_out + self.mlp_scale.to(dtype=x_out.dtype)[None, None, :] * self.mlp(self.mlp_norm(x_out) * self.ln_scale_factor) - if self.dtg_gate is not None: - gate = torch.sigmoid(self.dtg_gate(x_in.detach())) - x_out = x_in + gate * (x_out - x_in) - return x_out -class GPT(nn.Module): - def __init__( - self, - vocab_size: int, - num_layers: int, - model_dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: int, - tie_embeddings: bool, - tied_embed_init_std: float, - logit_softcap: float, - rope_base: float, - qk_gain_init: float, - mtp_num_heads: int = 0, - mtp_loss_weight: float = 0.1, - bigram_vocab_size: int = 0, - bigram_dim: int = 128, - xsa_last_n: int = 0, - rope_dims: int = 0, - ln_scale: bool = False, - dtg: bool = False, - ve_enabled: bool = False, - ve_dim: int = 128, - ve_layers: str = "9,10", - ): - super().__init__() - self._ve_target_dim = num_kv_heads * (model_dim // num_heads) # kv_dim for value projection - if logit_softcap <= 0.0: - raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") - self.tie_embeddings = tie_embeddings - self.tied_embed_init_std = tied_embed_init_std - self.logit_softcap = logit_softcap - self.mtp_num_heads = mtp_num_heads - self.mtp_loss_weight = mtp_loss_weight - self.tok_emb = nn.Embedding(vocab_size, model_dim) - self.bigram = BigramHashEmbedding(bigram_vocab_size, bigram_dim, model_dim) if bigram_vocab_size > 0 else None - self.smear = SmearGate(model_dim) - self.num_encoder_layers = num_layers // 2 - self.num_decoder_layers = num_layers - self.num_encoder_layers - self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) - self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) - self.blocks = nn.ModuleList( - [ - Block( - model_dim, - num_heads, - num_kv_heads, - mlp_mult, - rope_base, - qk_gain_init, - layer_idx=i, - ln_scale=ln_scale, - dtg=dtg, - ) - for i in range(num_layers) - ] - ) - if rope_dims > 0: - head_dim = model_dim // num_heads - for block in self.blocks: - block.attn.rope_dims = rope_dims - block.attn.rotary = Rotary(head_dim, base=rope_base, train_seq_len=1024, rope_dims=rope_dims) - self.ve_layer_indices = [int(x) for x in ve_layers.split(",") if x.strip()] if ve_enabled else [] - kv_dim = self._ve_target_dim - if self.ve_layer_indices: - self.ve_shared = ValueEmbedding(vocab_size, ve_dim, kv_dim) - self.ve_layer_scales = nn.ParameterList( - [nn.Parameter(torch.ones(1, dtype=torch.float32)) for _ in self.ve_layer_indices] - ) - else: - self.ve_shared = None - self.ve_layer_scales = nn.ParameterList() - self.value_embeds = nn.ModuleList() # keep empty for compat - self.final_norm = RMSNorm() - self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) - if self.lm_head is not None: - self.lm_head._zero_init = True - self.mtp_heads = nn.ModuleList( - [CastedLinear(model_dim, vocab_size, bias=False) for _ in range(mtp_num_heads)] - ) - for head in self.mtp_heads: - head._zero_init = True - if xsa_last_n > 0: - for i in range(max(0, num_layers - xsa_last_n), num_layers): - self.blocks[i].attn.use_xsa = True - self._init_weights() - def _init_weights(self) -> None: - if self.tie_embeddings: - nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) - num_layers = len(self.blocks) - for name, module in self.named_modules(): - if isinstance(module, nn.Linear): - if getattr(module, "_zero_init", False): - nn.init.zeros_(module.weight) - elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: - nn.init.orthogonal_(module.weight, gain=1.0) - if ".proj." in name or name.endswith(".proj"): - with torch.no_grad(): - module.weight.mul_(1.0 / math.sqrt(2 * num_layers)) - def _get_ve(self, layer_idx: int, input_ids: Tensor, ve_cache: dict | None = None) -> Tensor | None: - """Get value embedding for a specific layer using shared table + per-layer scale.""" - if self.ve_shared is None or layer_idx not in self.ve_layer_indices: - return None - if ve_cache is not None and 've' not in ve_cache: - ve_cache['ve'] = self.ve_shared(input_ids) - ve_base = ve_cache['ve'] if ve_cache is not None else self.ve_shared(input_ids) - ve_idx = self.ve_layer_indices.index(layer_idx) - return ve_base * self.ve_layer_scales[ve_idx].to(dtype=ve_base.dtype) - def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: - x = self.tok_emb(input_ids) - if self.bigram is not None: - x = x + self.bigram(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - x = self.smear(x) - x0 = x - skips: list[Tensor] = [] - ve_cache: dict = {} - for i in range(self.num_encoder_layers): - ve = self._get_ve(i, input_ids, ve_cache) - x = self.blocks[i](x, x0, v_embed=ve) - skips.append(x) - for i in range(self.num_decoder_layers): - bi = self.num_encoder_layers + i - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - ve = self._get_ve(bi, input_ids, ve_cache) - x = self.blocks[bi](x, x0, v_embed=ve) - x = self.final_norm(x) - x_flat = x.reshape(-1, x.size(-1)) - targets = target_ids.reshape(-1) - if self.tie_embeddings: - logits_proj = F.linear(x_flat, self.tok_emb.weight) - else: - if self.lm_head is None: - raise RuntimeError("lm_head is required when tie_embeddings=False") - logits_proj = self.lm_head(x_flat) - logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - main_loss = F.cross_entropy(logits.float(), targets, reduction="mean") - if self.training and self.mtp_num_heads > 0 and self.mtp_loss_weight > 0.0: - _, seqlen, dim = x.shape - mtp_loss_sum = x.new_zeros(()) - mtp_loss_count = 0 - for k, mtp_head in enumerate(self.mtp_heads): - valid_t = seqlen - (k + 1) - if valid_t <= 0: - continue - mtp_hidden = x[:, :valid_t, :].reshape(-1, dim) - mtp_targets = target_ids[:, k + 1 :].reshape(-1) - mtp_logits_proj = mtp_head(mtp_hidden) - mtp_logits = self.logit_softcap * torch.tanh(mtp_logits_proj / self.logit_softcap) - mtp_loss_sum = mtp_loss_sum + F.cross_entropy(mtp_logits.float(), mtp_targets, reduction="mean") - mtp_loss_count += 1 - if mtp_loss_count > 0: - main_loss = main_loss + self.mtp_loss_weight * (mtp_loss_sum / mtp_loss_count) - return main_loss - def forward_logits(self, input_ids: Tensor) -> Tensor: - """Return logits (bsz, seq_len, vocab) without computing loss.""" - x = self.tok_emb(input_ids) - if self.bigram is not None: - x = x + self.bigram(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - x = self.smear(x) - x0 = x - skips: list[Tensor] = [] - ve_cache: dict = {} - for i in range(self.num_encoder_layers): - ve = self._get_ve(i, input_ids, ve_cache) - x = self.blocks[i](x, x0, v_embed=ve) - skips.append(x) - for i in range(self.num_decoder_layers): - bi = self.num_encoder_layers + i - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - ve = self._get_ve(bi, input_ids, ve_cache) - x = self.blocks[bi](x, x0, v_embed=ve) - x = self.final_norm(x) - if self.tie_embeddings: - logits_proj = F.linear(x, self.tok_emb.weight) - else: - logits_proj = self.lm_head(x) - return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) -def eval_val_sliding( - args: Hyperparameters, - base_model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, - stride: int, - batch_seqs: int = 32, - eval_seq_len: int | None = None, -) -> tuple[float, float]: - """Sliding window evaluation: each token scored with maximum context.""" - seq_len = eval_seq_len or args.train_seq_len - total_tokens = val_tokens.numel() - 1 - window_starts = [ws for ws in range(0, total_tokens, stride) - if min(ws + seq_len, total_tokens) - ws >= 1] - total_windows = len(window_starts) - my_s = (total_windows * rank) // world_size - my_e = (total_windows * (rank + 1)) // world_size - my_windows = window_starts[my_s:my_e] - loss_sum = torch.zeros((), device=device, dtype=torch.float64) - token_count = torch.zeros((), device=device, dtype=torch.float64) - byte_count = torch.zeros((), device=device, dtype=torch.float64) - base_model.eval() - compiled_logits = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) - with torch.inference_mode(): - for bi in range(0, len(my_windows), batch_seqs): - batch_ws = my_windows[bi:bi + batch_seqs] - bsz = len(batch_ws) - x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) - y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) - wlens: list[int] = [] - for i, ws in enumerate(batch_ws): - end = min(ws + seq_len, total_tokens) - wlen = end - ws - wlens.append(wlen) - chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) - x_batch[i, :wlen] = chunk[:-1] - y_batch[i, :wlen] = chunk[1:] - with torch.autocast(device_type="cuda", dtype=torch.bfloat16): - logits = compiled_logits(x_batch) - nll = F.cross_entropy( - logits.reshape(-1, logits.size(-1)).float(), - y_batch.reshape(-1), - reduction="none", - ).reshape(bsz, seq_len) - for i, ws in enumerate(batch_ws): - wlen = wlens[i] - s = 0 if ws == 0 else max(wlen - stride, 0) - scored_nll = nll[i, s:wlen].to(torch.float64) - loss_sum += scored_nll.sum() - token_count += float(wlen - s) - tgt = y_batch[i, s:wlen] - prev = x_batch[i, s:wlen] - tb = base_bytes_lut[tgt].to(torch.float64) - tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) - byte_count += tb.sum() - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) - val_loss = (loss_sum / token_count).item() - bits_per_token = val_loss / math.log(2.0) - tokens_per_byte = token_count.item() / byte_count.item() - base_model.train() - return val_loss, bits_per_token * tokens_per_byte -def _classify_param(name: str) -> str: - if "tok_emb" in name or "lm_head" in name: - return "embed" - if ".mlp." in name: - return "mlp" - if ".attn." in name or (".proj." in name and ".mlp." not in name): - return "attn" - return "other" -def quantize_int6_per_row(t: Tensor, clip_range: int = 31) -> tuple[Tensor, Tensor]: - t32 = t.float() - if t32.ndim == 2: - best_q, best_s, best_err = None, None, float('inf') - for pct in [0.9990, 0.9995, 0.9999, 0.99999, 1.0]: - if pct < 1.0: - row_clip = torch.quantile(t32.abs(), pct, dim=1) - else: - row_clip = t32.abs().amax(dim=1) - s = (row_clip / clip_range).clamp_min(1.0 / clip_range).to(torch.float16) - q = torch.clamp(torch.round(t32 / s.float()[:, None]), -clip_range, clip_range).to(torch.int8) - recon = q.float() * s.float()[:, None] - err = (t32 - recon).pow(2).mean().item() - if err < best_err: - best_q, best_s, best_err = q, s, err - return best_q, best_s - amax = t32.abs().max().item() - scale = torch.tensor(amax / clip_range if amax > 0 else 1.0, dtype=torch.float16) - q = torch.clamp(torch.round(t32 / scale.float()), -clip_range, clip_range).to(torch.int8) - return q, scale -def mixed_quantize_int6(state_dict: dict[str, Tensor], int6_cats: set[str]): - num_layers_total = max( - (int(k.split(".")[1]) for k in state_dict if k.startswith("blocks.")), - default=0, - ) + 1 - late_k_layers = set(range(num_layers_total - 2, num_layers_total)) - result: dict[str, Tensor] = {} - meta: dict[str, object] = {} - for name, tensor in state_dict.items(): - t = tensor.detach().cpu().contiguous() - cat = _classify_param(name) - if not t.is_floating_point() or t.numel() <= 65536: - result[name] = t.to(torch.float16) if t.is_floating_point() else t - meta[name] = "passthrough" - continue - if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): - result[name] = t.float() - meta[name] = "passthrough_ctrl" - continue - if cat in int6_cats and t.ndim >= 1: - q, s = quantize_int6_per_row(t) - result[name + ".q"] = q - result[name + ".scale"] = s - meta[name] = {"type": "int6"} - else: - q, s = quantize_float_tensor(t) - result[name + ".q"] = q - result[name + ".scale"] = s - meta[name] = {"type": "int8"} - return result, meta -def dequantize_mixed_int6(result: dict[str, Tensor], meta: dict[str, object], - template_sd: dict[str, Tensor]) -> dict[str, Tensor]: - out: dict[str, Tensor] = {} - for name, orig in template_sd.items(): - info = meta.get(name) - if info is None: - continue - orig_dtype = orig.dtype - if info in ("passthrough", "passthrough_ctrl", "passthrough_fp16"): - t = result[name] - if t.dtype == torch.float16 and orig_dtype in (torch.float32, torch.bfloat16): - t = t.to(orig_dtype) - out[name] = t - continue - q, s = result[name + ".q"], result[name + ".scale"] - if s.ndim > 0: - out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig_dtype) - else: - out[name] = (q.float() * float(s.item())).to(orig_dtype) - return out -def main() -> None: - global zeropower_via_newtonschulz5 - code = Path(__file__).read_text(encoding="utf-8") - args = Hyperparameters() - zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) - distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ - rank = int(os.environ.get("RANK", "0")) - world_size = int(os.environ.get("WORLD_SIZE", "1")) - local_rank = int(os.environ.get("LOCAL_RANK", "0")) - if world_size <= 0: - raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") - if 8 % world_size != 0: - raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") - grad_accum_steps = 8 // world_size - grad_scale = 1.0 / grad_accum_steps - if not torch.cuda.is_available(): - raise RuntimeError("CUDA is required") - device = torch.device("cuda", local_rank) - torch.cuda.set_device(device) - if distributed: - dist.init_process_group(backend="nccl", device_id=device) - dist.barrier() - master_process = rank == 0 - torch.backends.cuda.matmul.allow_tf32 = True - torch.backends.cudnn.allow_tf32 = True - from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp - enable_cudnn_sdp(False) - enable_flash_sdp(True) - enable_mem_efficient_sdp(False) - enable_math_sdp(False) - logfile = None - if master_process: - os.makedirs("logs", exist_ok=True) - logfile = f"logs/{args.run_id}.txt" - print(logfile) - def log0(msg: str, console: bool = True) -> None: - if not master_process: - return - if console: - print(msg) - if logfile is not None: - with open(logfile, "a", encoding="utf-8") as f: - print(msg, file=f) - log0(code, console=False) - log0("=" * 100, console=False) - log0(f"Running Python {sys.version}", console=False) - log0(f"Running PyTorch {torch.__version__}", console=False) - log0( - subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, - console=False, - ) - log0("=" * 100, console=False) - random.seed(args.seed) - np.random.seed(args.seed) - torch.manual_seed(args.seed) - torch.cuda.manual_seed_all(args.seed) - if not args.tokenizer_path.endswith(".model"): - raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") - sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) - if int(sp.vocab_size()) != args.vocab_size: - raise ValueError( - f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" - ) - dataset_dir = Path(args.data_path).resolve() - actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) - effective_eval_seq_len = args.eval_seq_len if args.eval_seq_len > 0 else args.train_seq_len - val_seq_len = max(args.train_seq_len, effective_eval_seq_len) - val_tokens = load_validation_tokens(args.val_files, val_seq_len) - base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( - sp, args.vocab_size, device - ) - log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") - log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") - log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") - CastedLinear._qat_enabled = args.qat_enabled - base_model = GPT( - vocab_size=args.vocab_size, - num_layers=args.num_layers, - model_dim=args.model_dim, - num_heads=args.num_heads, - num_kv_heads=args.num_kv_heads, - mlp_mult=args.mlp_mult, - tie_embeddings=args.tie_embeddings, - tied_embed_init_std=args.tied_embed_init_std, - logit_softcap=args.logit_softcap, - rope_base=args.rope_base, - qk_gain_init=args.qk_gain_init, - mtp_num_heads=args.mtp_num_heads, - mtp_loss_weight=args.mtp_loss_weight, - bigram_vocab_size=args.bigram_vocab_size, - bigram_dim=args.bigram_dim, - xsa_last_n=args.xsa_last_n, - rope_dims=args.rope_dims, - ln_scale=args.ln_scale, - dtg=args.dtg_enabled, - ve_enabled=args.ve_enabled, - ve_dim=args.ve_dim, - ve_layers=args.ve_layers, - ).to(device).bfloat16() - for module in base_model.modules(): - if isinstance(module, CastedLinear): - module.float() - restore_low_dim_params_to_fp32(base_model) - compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) - model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model - block_named_params = list(base_model.blocks.named_parameters()) - matrix_params = [ - p - for name, p in block_named_params - if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - if base_model.mtp_num_heads > 0: - matrix_params.extend([p for p in base_model.mtp_heads.parameters() if p.ndim == 2]) - scalar_params = [ - p - for name, p in block_named_params - if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - if base_model.skip_weights.numel() > 0: - scalar_params.append(base_model.skip_weights) - scalar_params.append(base_model.smear.gate) - if base_model.bigram is not None: - scalar_params.append(base_model.bigram.scale) - token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr - tok_params = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] - if base_model.bigram is not None: - tok_params.append({"params": [base_model.bigram.embed.weight], "lr": token_lr, "base_lr": token_lr}) - if base_model.bigram.proj is not None: - matrix_params.append(base_model.bigram.proj.weight) - if base_model.ve_shared is not None: - tok_params.append({"params": [base_model.ve_shared.embed.weight], "lr": token_lr, "base_lr": token_lr}) - if base_model.ve_shared.proj is not None: - matrix_params.append(base_model.ve_shared.proj.weight) - scalar_params.append(base_model.ve_shared.scale) - for s in base_model.ve_layer_scales: - scalar_params.append(s) - optimizer_tok = torch.optim.AdamW( - tok_params, - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - weight_decay=args.adam_wd, - fused=True, - ) - optimizer_muon = Muon( - matrix_params, - lr=args.matrix_lr, - momentum=args.muon_momentum, - backend_steps=args.muon_backend_steps, - weight_decay=args.muon_wd, - ) - for group in optimizer_muon.param_groups: - group["base_lr"] = args.matrix_lr - optimizer_scalar = torch.optim.AdamW( - [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - weight_decay=args.adam_wd, - fused=True, - ) - optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] - if base_model.lm_head is not None: - optimizer_head = torch.optim.Adam( - [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizers.insert(1, optimizer_head) - n_params = sum(p.numel() for p in base_model.parameters()) - mtp_params = sum(p.numel() for p in base_model.mtp_heads.parameters()) - log0(f"model_params:{n_params}") - log0(f"mtp_num_heads:{args.mtp_num_heads} mtp_loss_weight:{args.mtp_loss_weight} mtp_params:{mtp_params}") - xsa_layers = [i for i, b in enumerate(base_model.blocks) if b.attn.use_xsa] - log0(f"XSA:last_{args.xsa_last_n} active_layers:{xsa_layers}") - log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") - log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") - log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") - log0( - f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " - f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " - f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" - ) - log0( - f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " - f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " - f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" - ) - log0(f"seed:{args.seed}") - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - def zero_grad_all() -> None: - for opt in optimizers: - opt.zero_grad(set_to_none=True) - max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None - def lr_mul(step: int, elapsed_ms: float) -> float: - if args.warmdown_iters <= 0: - return 1.0 - if max_wallclock_ms is None: - warmdown_start = max(args.iterations - args.warmdown_iters, 0) - return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 - step_ms = elapsed_ms / max(step, 1) - warmdown_ms = args.warmdown_iters * step_ms - remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) - return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 - if args.warmup_steps > 0: - initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} - initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] - model.train() - for warmup_step in range(args.warmup_steps): - zero_grad_all() - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - warmup_loss = model(x, y) - (warmup_loss * grad_scale).backward() - for opt in optimizers: - opt.step() - zero_grad_all() - if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: - log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") - base_model.load_state_dict(initial_model_state, strict=True) - for opt, state in zip(optimizers, initial_optimizer_states, strict=True): - opt.load_state_dict(state) - zero_grad_all() - if distributed: - model.require_backward_grad_sync = True - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - swa_state: dict[str, Tensor] | None = None - swa_count = 0 - ema_state = {name: t.detach().float().clone() for name, t in base_model.state_dict().items()} - ema_decay = 0.997 - training_time_ms = 0.0 - stop_after_step: int | None = None - torch.cuda.synchronize() - t0 = time.perf_counter() - step = 0 - while True: - last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) - should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) - if should_validate: - torch.cuda.synchronize() - training_time_ms += 1000.0 * (time.perf_counter() - t0) - val_loss, val_bpb = eval_val( - args, - model, - rank, - world_size, - device, - grad_accum_steps, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, - ) - log0( - f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " - f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" - ) - torch.cuda.synchronize() - t0 = time.perf_counter() - if last_step: - if stop_after_step is not None and step < args.iterations: - log0( - f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " - f"step:{step}/{args.iterations}" - ) - break - elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - scale = lr_mul(step, elapsed_ms) - if args.late_qat_threshold > 0 and scale < args.late_qat_threshold and not CastedLinear._qat_enabled: - CastedLinear._qat_enabled = True - log0(f"late_qat:enabled step:{step} scale:{scale:.4f}") - zero_grad_all() - train_loss = torch.zeros((), device=device) - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - loss = model(x, y) - train_loss += loss.detach() - (loss * grad_scale).backward() - train_loss /= grad_accum_steps - frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 - muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum - for group in optimizer_muon.param_groups: - group["momentum"] = muon_momentum - for opt in optimizers: - for group in opt.param_groups: - group["lr"] = group["base_lr"] * scale - if args.grad_clip_norm > 0: - torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) - for opt in optimizers: - opt.step() - zero_grad_all() - # EMA update - with torch.no_grad(): - for name, t in base_model.state_dict().items(): - ema_state[name].mul_(ema_decay).add_(t.detach().float(), alpha=1.0 - ema_decay) - step += 1 - approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - if args.swa_enabled and scale < 0.2 and step % args.swa_every == 0: - if swa_state is None: - swa_state = {name: t.detach().cpu().clone() for name, t in base_model.state_dict().items()} - swa_count = 1 - log0(f"swa:start step:{step}") - else: - for name, t in base_model.state_dict().items(): - swa_state[name] += t.detach().cpu() - swa_count += 1 - should_log_train = ( - args.train_log_every > 0 - and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) - ) - if should_log_train: - log0( - f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " - f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" - ) - reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms - if distributed and max_wallclock_ms is not None: - reached_cap_tensor = torch.tensor(int(reached_cap), device=device) - dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) - reached_cap = bool(reached_cap_tensor.item()) - if stop_after_step is None and reached_cap: - stop_after_step = step - log0( - f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " - f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" - ) - # Apply EMA weights (better than SWA alone per PR#401) - log0("ema:applying EMA weights") - current_state = base_model.state_dict() - avg_state = {name: t.to(dtype=current_state[name].dtype) for name, t in ema_state.items()} - base_model.load_state_dict(avg_state, strict=True) - torch.cuda.synchronize() - t_diag = time.perf_counter() - diag_val_loss, diag_val_bpb = eval_val( - args, compiled_model, rank, world_size, device, grad_accum_steps, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - ) - torch.cuda.synchronize() - log0( - f"DIAGNOSTIC post_ema val_loss:{diag_val_loss:.4f} val_bpb:{diag_val_bpb:.4f} " - f"eval_time:{1000.0 * (time.perf_counter() - t_diag):.0f}ms" - ) - full_state_dict = base_model.state_dict() - export_sd = {k: v for k, v in full_state_dict.items() if "mtp_heads" not in k} - excluded_mtp = sum(int(t.numel()) for k, t in full_state_dict.items() if "mtp_heads" in k) - if excluded_mtp > 0: - log0(f"export_excluding_mtp_params:{excluded_mtp}") - if master_process: - torch.save(export_sd, "final_model.pt") - model_bytes = os.path.getsize("final_model.pt") - code_bytes = len(code.encode("utf-8")) - log0(f"Serialized model: {model_bytes} bytes") - log0(f"Code size: {code_bytes} bytes") - sd_cpu = {k: v.detach().cpu() for k, v in export_sd.items()} - quant_result, quant_meta = mixed_quantize_int6(sd_cpu, {"mlp", "attn"}) - quant_buf = io.BytesIO() - torch.save({"w": quant_result, "m": quant_meta}, quant_buf) - quant_raw = quant_buf.getvalue() - quant_blob = zstandard.ZstdCompressor(level=22).compress(quant_raw) if _COMPRESSOR == "zstd" else zlib.compress(quant_raw, 9) - if master_process: - with open("final_model.int6.ptz", "wb") as f: - f.write(quant_blob) - quant_file_bytes = len(quant_blob) - code_bytes = len(code.encode("utf-8")) - log0(f"Serialized model int6+{_COMPRESSOR}: {quant_file_bytes} bytes") - log0(f"Total submission size int6+{_COMPRESSOR}: {quant_file_bytes + code_bytes} bytes") - log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") - if distributed: - dist.barrier() - with open("final_model.int6.ptz", "rb") as f: - quant_blob_disk = f.read() - quant_state = torch.load( - io.BytesIO(zstandard.ZstdDecompressor().decompress(quant_blob_disk) if _COMPRESSOR == "zstd" else zlib.decompress(quant_blob_disk)), - map_location="cpu", - ) - deq_state = dequantize_mixed_int6(quant_state["w"], quant_state["m"], sd_cpu) - eval_model = GPT( - vocab_size=args.vocab_size, num_layers=args.num_layers, model_dim=args.model_dim, - num_heads=args.num_heads, num_kv_heads=args.num_kv_heads, mlp_mult=args.mlp_mult, - tie_embeddings=args.tie_embeddings, tied_embed_init_std=args.tied_embed_init_std, - logit_softcap=args.logit_softcap, rope_base=args.rope_base, qk_gain_init=args.qk_gain_init, - mtp_num_heads=0, mtp_loss_weight=0.0, - bigram_vocab_size=args.bigram_vocab_size, bigram_dim=args.bigram_dim, - xsa_last_n=args.xsa_last_n, # must match training model - rope_dims=args.rope_dims, ln_scale=args.ln_scale, dtg=args.dtg_enabled, - ve_enabled=args.ve_enabled, ve_dim=args.ve_dim, ve_layers=args.ve_layers, - ).to(device).bfloat16() - for m in eval_model.modules(): - if isinstance(m, CastedLinear): - m.float() - restore_low_dim_params_to_fp32(eval_model) - eval_model.load_state_dict(deq_state, strict=True) - compiled_eval = torch.compile(eval_model, dynamic=False, fullgraph=True) - torch.cuda.synchronize() - t_qeval = time.perf_counter() - q_val_loss, q_val_bpb = eval_val( - args, compiled_eval, rank, world_size, device, grad_accum_steps, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - eval_seq_len=effective_eval_seq_len, - ) - torch.cuda.synchronize() - log0( - f"final_int6_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " - f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" - ) - log0(f"final_int6_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") - sw_seq_len = effective_eval_seq_len - if args.eval_stride > 0 and args.eval_stride < sw_seq_len: - torch.cuda.synchronize() - t_slide = time.perf_counter() - sw_val_loss, sw_val_bpb = eval_val_sliding( - args, eval_model, rank, world_size, device, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - stride=args.eval_stride, - eval_seq_len=sw_seq_len, - ) - torch.cuda.synchronize() - log0( - f"final_int6_sliding_window val_loss:{sw_val_loss:.4f} val_bpb:{sw_val_bpb:.4f} " - f"stride:{args.eval_stride} eval_time:{1000.0 * (time.perf_counter() - t_slide):.0f}ms" - ) - log0(f"final_int6_sliding_window_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") - log0(f"final_int8_zlib_roundtrip_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") - if args.eval_stride != 64 and 64 < sw_seq_len: - torch.cuda.synchronize() - t_slide64 = time.perf_counter() - sw64_val_loss, sw64_val_bpb = eval_val_sliding( - args, eval_model, rank, world_size, device, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - stride=64, - eval_seq_len=sw_seq_len, - ) - torch.cuda.synchronize() - log0( - f"final_int6_sliding_window_s64 val_loss:{sw64_val_loss:.4f} val_bpb:{sw64_val_bpb:.4f} " - f"stride:64 eval_time:{1000.0 * (time.perf_counter() - t_slide64):.0f}ms" - ) - log0(f"final_int6_sliding_window_s64_exact val_loss:{sw64_val_loss:.8f} val_bpb:{sw64_val_bpb:.8f}") - log0(f"final_int8_zlib_roundtrip_exact val_loss:{sw64_val_loss:.8f} val_bpb:{sw64_val_bpb:.8f}") - if distributed: - dist.destroy_process_group() -if __name__ == "__main__": - main() diff --git a/records/track_10min_16mb/2026-03-22_11L_EMA_GPTQ-lite_warmdown3500_QAT015_1.1233/train_seed1337.log b/records/track_10min_16mb/2026-03-22_11L_EMA_GPTQ-lite_warmdown3500_QAT015_1.1233/train_seed1337.log deleted file mode 100644 index 7074fcb3f8..0000000000 --- a/records/track_10min_16mb/2026-03-22_11L_EMA_GPTQ-lite_warmdown3500_QAT015_1.1233/train_seed1337.log +++ /dev/null @@ -1,86 +0,0 @@ -Running: torchrun --standalone --nproc_per_node=8 train_gpt.py -W0321 23:58:21.049000 2107250 site-packages/torch/distributed/run.py:852] -W0321 23:58:21.049000 2107250 site-packages/torch/distributed/run.py:852] ***************************************** -W0321 23:58:21.049000 2107250 site-packages/torch/distributed/run.py:852] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. -W0321 23:58:21.049000 2107250 site-packages/torch/distributed/run.py:852] ***************************************** -logs/4d9a81fd-61ff-4d91-ba5b-8122bf0d29b1.txt -val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model -train_loader:dataset:fineweb10B_sp1024 train_shards:80 -val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 -model_params:26993756 -mtp_num_heads:0 mtp_loss_weight:0.2 mtp_params:0 -XSA:last_4 active_layers:[7, 8, 9, 10] -world_size:8 grad_accum_steps:1 -sdp_backends:cudnn=False flash=True mem_efficient=False math=False -attention_mode:gqa num_heads:8 num_kv_heads:4 -tie_embeddings:True embed_lr:0.035 head_lr:0.0 matrix_lr:0.025 scalar_lr:0.025 -train_batch_tokens:786432 train_seq_len:2048 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 -seed:1337 -warmup_step:1/20 -warmup_step:2/20 -warmup_step:3/20 -warmup_step:4/20 -warmup_step:5/20 -warmup_step:6/20 -warmup_step:7/20 -warmup_step:8/20 -warmup_step:9/20 -warmup_step:10/20 -warmup_step:11/20 -warmup_step:12/20 -warmup_step:13/20 -warmup_step:14/20 -warmup_step:15/20 -warmup_step:16/20 -warmup_step:17/20 -warmup_step:18/20 -warmup_step:19/20 -warmup_step:20/20 -step:0/20000 val_loss:6.9279 val_bpb:4.1031 train_time:0ms step_avg:0.01ms -step:1/20000 train_loss:6.9299 train_time:132ms step_avg:132.47ms -step:2/20000 train_loss:8.5642 train_time:210ms step_avg:105.01ms -step:3/20000 train_loss:7.8180 train_time:292ms step_avg:97.49ms -step:4/20000 train_loss:7.2359 train_time:374ms step_avg:93.61ms -step:5/20000 train_loss:7.0667 train_time:457ms step_avg:91.34ms -step:6/20000 train_loss:6.8298 train_time:539ms step_avg:89.79ms -step:7/20000 train_loss:6.7205 train_time:621ms step_avg:88.70ms -step:8/20000 train_loss:6.7474 train_time:703ms step_avg:87.89ms -step:9/20000 train_loss:6.4050 train_time:785ms step_avg:87.22ms -step:10/20000 train_loss:6.0813 train_time:867ms step_avg:86.73ms -step:500/20000 train_loss:2.3955 train_time:42105ms step_avg:84.21ms -step:1000/20000 train_loss:2.2705 train_time:84305ms step_avg:84.30ms -step:1500/20000 train_loss:2.2150 train_time:126530ms step_avg:84.35ms -step:2000/20000 train_loss:2.0567 train_time:168765ms step_avg:84.38ms -step:2500/20000 train_loss:2.1560 train_time:211011ms step_avg:84.40ms -step:3000/20000 train_loss:2.1529 train_time:253267ms step_avg:84.42ms -step:3500/20000 train_loss:2.1761 train_time:295680ms step_avg:84.48ms -step:4000/20000 train_loss:1.9673 train_time:337942ms step_avg:84.49ms -step:4000/20000 val_loss:2.0601 val_bpb:1.2201 train_time:337946ms step_avg:84.49ms -step:4500/20000 train_loss:2.1182 train_time:380212ms step_avg:84.49ms -step:5000/20000 train_loss:2.1014 train_time:422463ms step_avg:84.49ms -step:5500/20000 train_loss:2.0151 train_time:464730ms step_avg:84.50ms -step:6000/20000 train_loss:1.9386 train_time:507080ms step_avg:84.51ms -swa:start step:6450 -step:6500/20000 train_loss:2.0793 train_time:549410ms step_avg:84.52ms -late_qat:enabled step:6574 scale:0.1499 -step:7000/20000 train_loss:1.7878 train_time:591918ms step_avg:84.56ms -step:7096/20000 val_loss:1.9240 val_bpb:1.1395 train_time:600036ms step_avg:84.56ms -stopping_early: wallclock_cap train_time:600036ms step:7096/20000 -peak memory allocated: 20862 MiB reserved: 21204 MiB -ema:applying EMA weights -DIAGNOSTIC post_ema val_loss:1.9222 val_bpb:1.1385 eval_time:1960ms -Serialized model: 106178569 bytes -Code size: 67603 bytes -Serialized model int6+zstd: 15487414 bytes -Total submission size int6+zstd: 15555017 bytes -Total submission size int8+zlib: 15555017 bytes -final_int6_roundtrip val_loss:1.9359 val_bpb:1.1466 eval_time:5788ms -final_int6_roundtrip_exact val_loss:1.93592275 val_bpb:1.14656250 -final_int6_sliding_window val_loss:1.8958 val_bpb:1.1228 stride:64 eval_time:73133ms -final_int6_sliding_window_exact val_loss:1.89576235 val_bpb:1.12278022 -final_int8_zlib_roundtrip_exact val_loss:1.89576235 val_bpb:1.12278022 ---- -val_bpb: 1.12278022 -artifact_bytes: 15555017 -line_count: 1402 -valid: true diff --git a/records/track_10min_16mb/2026-03-22_11L_EMA_GPTQ-lite_warmdown3500_QAT015_1.1233/train_seed2024.log b/records/track_10min_16mb/2026-03-22_11L_EMA_GPTQ-lite_warmdown3500_QAT015_1.1233/train_seed2024.log deleted file mode 100644 index 9412ffde44..0000000000 --- a/records/track_10min_16mb/2026-03-22_11L_EMA_GPTQ-lite_warmdown3500_QAT015_1.1233/train_seed2024.log +++ /dev/null @@ -1,86 +0,0 @@ -Running: torchrun --standalone --nproc_per_node=8 train_gpt.py -W0322 00:23:51.386000 2115243 site-packages/torch/distributed/run.py:852] -W0322 00:23:51.386000 2115243 site-packages/torch/distributed/run.py:852] ***************************************** -W0322 00:23:51.386000 2115243 site-packages/torch/distributed/run.py:852] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. -W0322 00:23:51.386000 2115243 site-packages/torch/distributed/run.py:852] ***************************************** -logs/5f9d529b-86e0-49b1-bc78-9db5b6f035ec.txt -val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model -train_loader:dataset:fineweb10B_sp1024 train_shards:80 -val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 -model_params:26993756 -mtp_num_heads:0 mtp_loss_weight:0.2 mtp_params:0 -XSA:last_4 active_layers:[7, 8, 9, 10] -world_size:8 grad_accum_steps:1 -sdp_backends:cudnn=False flash=True mem_efficient=False math=False -attention_mode:gqa num_heads:8 num_kv_heads:4 -tie_embeddings:True embed_lr:0.035 head_lr:0.0 matrix_lr:0.025 scalar_lr:0.025 -train_batch_tokens:786432 train_seq_len:2048 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 -seed:2024 -warmup_step:1/20 -warmup_step:2/20 -warmup_step:3/20 -warmup_step:4/20 -warmup_step:5/20 -warmup_step:6/20 -warmup_step:7/20 -warmup_step:8/20 -warmup_step:9/20 -warmup_step:10/20 -warmup_step:11/20 -warmup_step:12/20 -warmup_step:13/20 -warmup_step:14/20 -warmup_step:15/20 -warmup_step:16/20 -warmup_step:17/20 -warmup_step:18/20 -warmup_step:19/20 -warmup_step:20/20 -step:0/20000 val_loss:6.9317 val_bpb:4.1053 train_time:0ms step_avg:0.01ms -step:1/20000 train_loss:6.9330 train_time:131ms step_avg:131.07ms -step:2/20000 train_loss:8.7269 train_time:211ms step_avg:105.41ms -step:3/20000 train_loss:7.8782 train_time:293ms step_avg:97.67ms -step:4/20000 train_loss:7.1893 train_time:375ms step_avg:93.82ms -step:5/20000 train_loss:6.9777 train_time:457ms step_avg:91.48ms -step:6/20000 train_loss:6.8793 train_time:540ms step_avg:89.96ms -step:7/20000 train_loss:6.7214 train_time:622ms step_avg:88.82ms -step:8/20000 train_loss:6.6369 train_time:704ms step_avg:88.00ms -step:9/20000 train_loss:6.4078 train_time:786ms step_avg:87.36ms -step:10/20000 train_loss:6.0701 train_time:868ms step_avg:86.82ms -step:500/20000 train_loss:2.4065 train_time:42148ms step_avg:84.30ms -step:1000/20000 train_loss:2.2735 train_time:84368ms step_avg:84.37ms -step:1500/20000 train_loss:2.2205 train_time:126591ms step_avg:84.39ms -step:2000/20000 train_loss:2.0551 train_time:168827ms step_avg:84.41ms -step:2500/20000 train_loss:2.1597 train_time:211064ms step_avg:84.43ms -step:3000/20000 train_loss:2.1585 train_time:253312ms step_avg:84.44ms -step:3500/20000 train_loss:2.1734 train_time:295628ms step_avg:84.47ms -step:4000/20000 train_loss:1.9680 train_time:337863ms step_avg:84.47ms -step:4000/20000 val_loss:2.0602 val_bpb:1.2202 train_time:337868ms step_avg:84.47ms -step:4500/20000 train_loss:2.1197 train_time:380107ms step_avg:84.47ms -step:5000/20000 train_loss:2.1022 train_time:422352ms step_avg:84.47ms -step:5500/20000 train_loss:2.0151 train_time:464612ms step_avg:84.47ms -step:6000/20000 train_loss:1.9397 train_time:506844ms step_avg:84.47ms -swa:start step:6450 -step:6500/20000 train_loss:2.0784 train_time:549168ms step_avg:84.49ms -late_qat:enabled step:6577 scale:0.1498 -step:7000/20000 train_loss:1.7901 train_time:591726ms step_avg:84.53ms -step:7099/20000 val_loss:1.9248 val_bpb:1.1400 train_time:600070ms step_avg:84.53ms -stopping_early: wallclock_cap train_time:600070ms step:7099/20000 -peak memory allocated: 20862 MiB reserved: 21204 MiB -ema:applying EMA weights -DIAGNOSTIC post_ema val_loss:1.9230 val_bpb:1.1389 eval_time:1962ms -Serialized model: 106178569 bytes -Code size: 67603 bytes -Serialized model int6+zstd: 15519202 bytes -Total submission size int6+zstd: 15586805 bytes -Total submission size int8+zlib: 15586805 bytes -final_int6_roundtrip val_loss:1.9370 val_bpb:1.1472 eval_time:5773ms -final_int6_roundtrip_exact val_loss:1.93702996 val_bpb:1.14721825 -final_int6_sliding_window val_loss:1.8971 val_bpb:1.1236 stride:64 eval_time:73179ms -final_int6_sliding_window_exact val_loss:1.89708903 val_bpb:1.12356596 -final_int8_zlib_roundtrip_exact val_loss:1.89708903 val_bpb:1.12356596 ---- -val_bpb: 1.12356596 -artifact_bytes: 15586805 -line_count: 1402 -valid: true diff --git a/records/track_10min_16mb/2026-03-22_11L_EMA_GPTQ-lite_warmdown3500_QAT015_1.1233/train_seed42.log b/records/track_10min_16mb/2026-03-22_11L_EMA_GPTQ-lite_warmdown3500_QAT015_1.1233/train_seed42.log deleted file mode 100644 index ad9d29b76a..0000000000 --- a/records/track_10min_16mb/2026-03-22_11L_EMA_GPTQ-lite_warmdown3500_QAT015_1.1233/train_seed42.log +++ /dev/null @@ -1,86 +0,0 @@ -Running: torchrun --standalone --nproc_per_node=8 train_gpt.py -W0322 00:11:06.893000 2111397 site-packages/torch/distributed/run.py:852] -W0322 00:11:06.893000 2111397 site-packages/torch/distributed/run.py:852] ***************************************** -W0322 00:11:06.893000 2111397 site-packages/torch/distributed/run.py:852] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. -W0322 00:11:06.893000 2111397 site-packages/torch/distributed/run.py:852] ***************************************** -logs/4b9790dd-118f-4acd-ab65-d51f91a21e75.txt -val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model -train_loader:dataset:fineweb10B_sp1024 train_shards:80 -val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 -model_params:26993756 -mtp_num_heads:0 mtp_loss_weight:0.2 mtp_params:0 -XSA:last_4 active_layers:[7, 8, 9, 10] -world_size:8 grad_accum_steps:1 -sdp_backends:cudnn=False flash=True mem_efficient=False math=False -attention_mode:gqa num_heads:8 num_kv_heads:4 -tie_embeddings:True embed_lr:0.035 head_lr:0.0 matrix_lr:0.025 scalar_lr:0.025 -train_batch_tokens:786432 train_seq_len:2048 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 -seed:42 -warmup_step:1/20 -warmup_step:2/20 -warmup_step:3/20 -warmup_step:4/20 -warmup_step:5/20 -warmup_step:6/20 -warmup_step:7/20 -warmup_step:8/20 -warmup_step:9/20 -warmup_step:10/20 -warmup_step:11/20 -warmup_step:12/20 -warmup_step:13/20 -warmup_step:14/20 -warmup_step:15/20 -warmup_step:16/20 -warmup_step:17/20 -warmup_step:18/20 -warmup_step:19/20 -warmup_step:20/20 -step:0/20000 val_loss:6.9301 val_bpb:4.1044 train_time:0ms step_avg:0.01ms -step:1/20000 train_loss:6.9318 train_time:133ms step_avg:133.06ms -step:2/20000 train_loss:8.6431 train_time:210ms step_avg:105.23ms -step:3/20000 train_loss:7.8534 train_time:293ms step_avg:97.68ms -step:4/20000 train_loss:7.2570 train_time:375ms step_avg:93.87ms -step:5/20000 train_loss:7.0164 train_time:457ms step_avg:91.49ms -step:6/20000 train_loss:6.8954 train_time:540ms step_avg:89.94ms -step:7/20000 train_loss:6.7667 train_time:622ms step_avg:88.80ms -step:8/20000 train_loss:6.6946 train_time:704ms step_avg:87.96ms -step:9/20000 train_loss:6.4074 train_time:786ms step_avg:87.31ms -step:10/20000 train_loss:6.0739 train_time:868ms step_avg:86.76ms -step:500/20000 train_loss:2.4001 train_time:42133ms step_avg:84.27ms -step:1000/20000 train_loss:2.2734 train_time:84360ms step_avg:84.36ms -step:1500/20000 train_loss:2.2181 train_time:126581ms step_avg:84.39ms -step:2000/20000 train_loss:2.0579 train_time:168831ms step_avg:84.42ms -step:2500/20000 train_loss:2.1605 train_time:211061ms step_avg:84.42ms -step:3000/20000 train_loss:2.1540 train_time:253285ms step_avg:84.43ms -step:3500/20000 train_loss:2.1728 train_time:295506ms step_avg:84.43ms -step:4000/20000 train_loss:1.9715 train_time:337712ms step_avg:84.43ms -step:4000/20000 val_loss:2.0617 val_bpb:1.2211 train_time:337716ms step_avg:84.43ms -step:4500/20000 train_loss:2.1197 train_time:379916ms step_avg:84.43ms -step:5000/20000 train_loss:2.1018 train_time:422121ms step_avg:84.42ms -step:5500/20000 train_loss:2.0142 train_time:464335ms step_avg:84.42ms -step:6000/20000 train_loss:1.9376 train_time:506553ms step_avg:84.43ms -swa:start step:6450 -step:6500/20000 train_loss:2.0800 train_time:548836ms step_avg:84.44ms -late_qat:enabled step:6581 scale:0.1500 -step:7000/20000 train_loss:1.7895 train_time:591335ms step_avg:84.48ms -step:7103/20000 val_loss:1.9253 val_bpb:1.1403 train_time:600059ms step_avg:84.48ms -stopping_early: wallclock_cap train_time:600059ms step:7103/20000 -peak memory allocated: 20862 MiB reserved: 21204 MiB -ema:applying EMA weights -DIAGNOSTIC post_ema val_loss:1.9235 val_bpb:1.1392 eval_time:1959ms -Serialized model: 106178569 bytes -Code size: 67603 bytes -Serialized model int6+zstd: 15473072 bytes -Total submission size int6+zstd: 15540675 bytes -Total submission size int8+zlib: 15540675 bytes -final_int6_roundtrip val_loss:1.9371 val_bpb:1.1473 eval_time:5814ms -final_int6_roundtrip_exact val_loss:1.93712085 val_bpb:1.14727208 -final_int6_sliding_window val_loss:1.8972 val_bpb:1.1236 stride:64 eval_time:73133ms -final_int6_sliding_window_exact val_loss:1.89715168 val_bpb:1.12360306 -final_int8_zlib_roundtrip_exact val_loss:1.89715168 val_bpb:1.12360306 ---- -val_bpb: 1.12360306 -artifact_bytes: 15540675 -line_count: 1402 -valid: true diff --git a/records/track_10min_16mb/2026-03-23_LeakyReLU_LegalTTT_ParallelMuon/README.md b/records/track_10min_16mb/2026-03-23_LeakyReLU_LegalTTT_ParallelMuon/README.md deleted file mode 100644 index f32c20f847..0000000000 --- a/records/track_10min_16mb/2026-03-23_LeakyReLU_LegalTTT_ParallelMuon/README.md +++ /dev/null @@ -1,125 +0,0 @@ -# LeakyReLU² + Legal Score-First TTT + Parallel Muon - -**val_bpb: 1.1194** (3-seed mean, std 0.0006) | **~15.95 MB** | 8×H100 SXM - -## Results (8×H100 80GB SXM, PyTorch 2.9.1+cu128) - -| Seed | step_avg | steps | Pre-TTT bpb | **Post-TTT bpb** | TTT gain | TTT time | Artifact | -|------|----------|-------|-------------|-----------------|----------|----------|----------| -| 1337 | 83.3ms | 7,179 | 1.1217 | **1.1192** | -0.0025 | 410s | 15,977,386 | -| 42 | 83.4ms | 7,182 | 1.1227 | **1.1200** | -0.0027 | 408s | 15,876,510 | -| 2025 | 83.4ms | 7,193 | 1.1212 | **1.1189** | -0.0023 | 408s | 15,990,006 | -| **Mean** | **83.4ms** | **7,185** | **1.1218** | **1.1194 (std 0.0006)** | **-0.0025** | **~409s** | | - -## Key Innovation: LeakyReLU(0.5)² - -One-line activation change that delivers -0.003 BPB: - -```python -# Standard (relu²) -x = torch.relu(self.fc(x)).square() - -# This submission (leaky relu²) -x = F.leaky_relu(self.fc(x), negative_slope=0.5).square() -``` - -LeakyReLU with slope 0.5 preserves negative gradient flow through the MLP, allowing the model to learn from both positive and negative pre-activations. The squaring step still produces non-negative outputs, maintaining the relu² inductive bias while eliminating dead neurons. - -This activation is used in PR #493 (ablated at -0.003 BPB) and PR #518 (part of their 1.0622 record submission). - -## Legal TTT Protocol - -Backward-looking, score-first TTT following PR #461's framework: - -1. Val tokens split into 1,893 non-overlapping 32K-token chunks -2. **For each chunk**: - - **SCORE**: Sliding window eval under `torch.inference_mode()` — no gradients, no weight mutation possible - - **TRAIN**: SGD(lr=0.002, momentum=0.9) on the already-scored chunk. 3 epochs, all blocks unfrozen, cosine LR decay, grad clip 1.0 -3. Last chunk scored but never trained on -4. Chunk N scored by model adapted only on chunks 0..N-1 - -`inference_mode()` is a PyTorch context manager that disables gradient tracking and prohibits in-place weight mutation, providing a hard guarantee that scoring is stateless. - -### TTT Hyperparameters - -| Parameter | Value | -|-----------|-------| -| Chunk size | 32,768 tokens | -| Optimizer | SGD + momentum(0.9) | -| Learning rate | 0.002 (cosine decay across chunks) | -| Epochs per chunk | 3 | -| Frozen blocks | None (all blocks adapt) | -| Gradient clip | 1.0 | - -### Timing Budget - -| Phase | Time | -|-------|------| -| Training | 600s (≤10 min) | -| Standard eval (int6 roundtrip + sliding window) | ~120s | -| Legal TTT (score-first sliding + adaptation) | ~410s | -| **Total eval** | **~530s (< 10 min)** | - -## Training Architecture - -PR #414 stack with Parameter Banking + Parallel Muon (PR #399): - -| Component | Setting | -|-----------|---------| -| Layers | 11 (512d, 8H, 4KV) | -| MLP | 3× with **LeakyReLU(0.5)²** | -| BigramHash | 1536 | -| XSA | Last 4 layers | -| RoPE | Partial (16/64 dims) | -| LN Scale | 1/√(layer+1) | -| VE128 | Layers 9-10 | -| Weight avg | EMA(0.997) + Tight SWA(every 50) | -| Quantization | GPTQ-lite int6 + lzma | -| Optimizer | Parameter Banking + Parallel Muon | - -### Parameter Banking + Parallel Muon - -First introduced in [PR #399](https://github.com/openai/parameter-golf/pull/399): - -- 4 contiguous 3D `nn.Parameter` banks replace 66 separate `nn.Linear` weights -- Batched Newton-Schulz orthogonalization via `torch.bmm` -- DDP removed for banks; async reduce-scatter → local NS → async all-gather -- 83.3ms/step vs ~85ms baseline - -## Run Command - -```bash -NUM_LAYERS=11 BIGRAM_VOCAB_SIZE=1536 XSA_LAST_N=4 \ -EMA_ENABLED=1 EMA_DECAY=0.997 SWA_ENABLED=1 SWA_EVERY=50 \ -ROPE_DIMS=16 LN_SCALE=1 LATE_QAT=1 LATE_QAT_THRESHOLD=0.15 \ -VE_ENABLED=1 VE_DIM=128 VE_LAYERS=9,10 \ -TTT_ENABLED=1 TTT_LR=0.002 TTT_EPOCHS=3 TTT_CHUNK_TOKENS=32768 \ -TTT_FREEZE_BLOCKS=0 TTT_MOMENTUM=0.9 TTT_BATCH_SEQS=32 TTT_GRAD_CLIP=1.0 \ -MUON_WD=0.04 ADAM_WD=0.04 \ -MATRIX_LR=0.025 SCALAR_LR=0.025 TIED_EMBED_LR=0.035 \ -MUON_MOMENTUM=0.99 MUON_MOMENTUM_WARMUP_START=0.92 \ -MUON_MOMENTUM_WARMUP_STEPS=1500 WARMDOWN_ITERS=3500 \ -ITERATIONS=9000 MAX_WALLCLOCK_SECONDS=600 EVAL_STRIDE=64 \ -SEED=1337 \ -torchrun --standalone --nproc_per_node=8 train_gpt.py -``` - -## Ablation - -Incremental contribution of each technique (all seed 1337): - -| Change | Pre-TTT bpb | Post-TTT bpb | Delta | -|--------|-------------|-------------|-------| -| PR #414 base (relu², BIGRAM=2048) | 1.1234 | — | — | -| + Parameter Banking + Parallel Muon | 1.1234 | — | ±0.0000 | -| + Legal TTT (3ep, freeze=2) | — | 1.1217 | -0.0017 | -| + TTT freeze=0 (all blocks) | — | 1.1213 | -0.0004 | -| + BigramHash 2048→3072 | — | 1.1204 | -0.0009 | -| + **LeakyReLU(0.5)²** | 1.1213 | **1.1183** | **-0.0021** | - -## Credits - -- **LeakyReLU² activation**: [PR #493](https://github.com/openai/parameter-golf/pull/493) by @parinzee, [PR #518](https://github.com/openai/parameter-golf/pull/518) by @sofiabod -- **Optimizer (Parameter Banking + Parallel Muon)**: [PR #399](https://github.com/openai/parameter-golf/pull/399) by @abaybektursun -- **TTT recipe**: [PR #461](https://github.com/openai/parameter-golf/pull/461) by @Christopher-Lee-McClendon (adapted: freeze=0 instead of original freeze=2) -- **Base model**: [PR #414](https://github.com/openai/parameter-golf/pull/414) by @signalrush diff --git a/records/track_10min_16mb/2026-03-23_LeakyReLU_LegalTTT_ParallelMuon/submission.json b/records/track_10min_16mb/2026-03-23_LeakyReLU_LegalTTT_ParallelMuon/submission.json deleted file mode 100644 index eb947f43af..0000000000 --- a/records/track_10min_16mb/2026-03-23_LeakyReLU_LegalTTT_ParallelMuon/submission.json +++ /dev/null @@ -1,9 +0,0 @@ -{ - "name": "LeakyReLU² + Legal Score-First TTT + Parallel Muon", - "val_bpb": 1.1194, - "bytes_total": 15990006, - "blurb": "LeakyReLU(0.5)² activation (-0.003 BPB vs relu²) + legal score-first TTT (PR #461 recipe, 3ep SGD, all blocks unfrozen) + BigramHash(1536) + Parameter Banking + Parallel Muon (PR #399). Built on PR #414 stack. 3-seed mean: 1.1194 (std 0.0006). All artifacts under 16MB, all eval under 10 min.", - "author": "abaybektursun", - "github_id": "abaybektursun", - "date": "2026-03-23" -} diff --git a/records/track_10min_16mb/2026-03-23_LeakyReLU_LegalTTT_ParallelMuon/train_gpt.py b/records/track_10min_16mb/2026-03-23_LeakyReLU_LegalTTT_ParallelMuon/train_gpt.py deleted file mode 100644 index e2d89198d0..0000000000 --- a/records/track_10min_16mb/2026-03-23_LeakyReLU_LegalTTT_ParallelMuon/train_gpt.py +++ /dev/null @@ -1,1898 +0,0 @@ -from __future__ import annotations -import copy -import glob -import io -import lzma -import math -import os -import random -import subprocess -import sys -import time -import uuid -import zlib -from pathlib import Path -try: - import zstandard - _COMPRESSOR = "zstd" -except ImportError: - _COMPRESSOR = "zlib" -import numpy as np -import sentencepiece as spm -import torch -import torch.distributed as dist -import torch.nn.functional as F -from torch import Tensor, nn -from torch.nn.parallel import DistributedDataParallel as DDP -from flash_attn_interface import flash_attn_func as flash_attn_3_func -class Hyperparameters: - data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") - train_files = os.path.join(data_path, "fineweb_train_*.bin") - val_files = os.path.join(data_path, "fineweb_val_*.bin") - tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") - run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) - seed = int(os.environ.get("SEED", 1337)) - val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) - val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 4000)) - train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 500)) - iterations = int(os.environ.get("ITERATIONS", 20000)) - warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3500)) - warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) - train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 786_432)) - train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) - eval_seq_len = int(os.environ.get("EVAL_SEQ_LEN", 2048)) - max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) - qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) - vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) - num_layers = int(os.environ.get("NUM_LAYERS", 11)) - num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) - model_dim = int(os.environ.get("MODEL_DIM", 512)) - num_heads = int(os.environ.get("NUM_HEADS", 8)) - mlp_mult = float(os.environ.get("MLP_MULT", 3.0)) - tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) - rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) - logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) - embed_lr = float(os.environ.get("EMBED_LR", 0.6)) - head_lr = float(os.environ.get("HEAD_LR", 0.008)) - tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.035)) - tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) - matrix_lr = float(os.environ.get("MATRIX_LR", 0.025)) - scalar_lr = float(os.environ.get("SCALAR_LR", 0.025)) - muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) - muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) - muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) - muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) - beta1 = float(os.environ.get("BETA1", 0.9)) - beta2 = float(os.environ.get("BETA2", 0.95)) - adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) - grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) - eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) - mtp_num_heads = int(os.environ.get("MTP_NUM_HEADS", 0)) - mtp_loss_weight = float(os.environ.get("MTP_LOSS_WEIGHT", 0.2)) - muon_beta2 = float(os.environ.get("MUON_BETA2", 0.95)) - swa_enabled = bool(int(os.environ.get("SWA_ENABLED", "1"))) - swa_every = int(os.environ.get("SWA_EVERY", 50)) - lawa_enabled = bool(int(os.environ.get("LAWA_ENABLED", "0"))) - lawa_k = int(os.environ.get("LAWA_K", 10)) - lawa_freq = int(os.environ.get("LAWA_FREQ", 100)) - muon_wd = float(os.environ.get("MUON_WD", 0.04)) - adam_wd = float(os.environ.get("ADAM_WD", 0.04)) - qat_enabled = bool(int(os.environ.get("QAT_ENABLED", "0"))) - bigram_vocab_size = int(os.environ.get("BIGRAM_VOCAB_SIZE", 2048)) - bigram_dim = int(os.environ.get("BIGRAM_DIM", 128)) - xsa_last_n = int(os.environ.get("XSA_LAST_N", 4)) - rope_dims = int(os.environ.get("ROPE_DIMS", 16)) - ln_scale = bool(int(os.environ.get("LN_SCALE", "1"))) - dtg_enabled = bool(int(os.environ.get("DTG_ENABLED", "0"))) - late_qat_threshold = float(os.environ.get("LATE_QAT_THRESHOLD", 0.15)) - ve_enabled = bool(int(os.environ.get("VE_ENABLED", "1"))) - ve_dim = int(os.environ.get("VE_DIM", 128)) - ve_layers = os.environ.get("VE_LAYERS", "9,10") - gated_attention = bool(int(os.environ.get("GATED_ATTENTION", "0"))) - value_residual = bool(int(os.environ.get("VALUE_RESIDUAL", "0"))) - ttt_enabled = bool(int(os.environ.get("TTT_ENABLED", "0"))) - ttt_lr = float(os.environ.get("TTT_LR", 0.002)) - ttt_epochs = int(os.environ.get("TTT_EPOCHS", 3)) - ttt_chunk_tokens = int(os.environ.get("TTT_CHUNK_TOKENS", 32768)) - ttt_freeze_blocks = int(os.environ.get("TTT_FREEZE_BLOCKS", 2)) - ttt_momentum = float(os.environ.get("TTT_MOMENTUM", 0.9)) - ttt_batch_seqs = int(os.environ.get("TTT_BATCH_SEQS", 32)) - ttt_grad_clip = float(os.environ.get("TTT_GRAD_CLIP", 1.0)) - -# --- Batched Newton-Schulz orthogonalization --- - -def zeropower_via_newtonschulz5(G: Tensor, steps: int = 5, eps: float = 1e-7) -> Tensor: - """Batched Newton-Schulz orthogonalization. G: (B,M,N) or (M,N).""" - a, b, c = (3.4445, -4.7750, 2.0315) - was_2d = G.ndim == 2 - if was_2d: - G = G.unsqueeze(0) - X = G.bfloat16() - transposed = X.size(-2) > X.size(-1) - if transposed: - X = X.mT - X = X / (X.norm(dim=(-2, -1), keepdim=True) + eps) - for _ in range(steps): - A = X @ X.mT - B = b * A + c * (A @ A) - X = a * X + B @ X - if transposed: - X = X.mT - if was_2d: - X = X.squeeze(0) - return X - -# --- Parallel Muon optimizer --- - -class Muon(torch.optim.Optimizer): - """Parallel Muon: post-backward reduce-scatter -> local NS5 -> all-gather. - - No DDP for bank params. After backward, this optimizer: - 1. Launches async reduce-scatter for all banks (biggest first) - 2. Returns control so Adam can step on small params while RS is in-flight - 3. Waits for each RS, runs local NS5 on the shard, launches async all-gather - 4. Each all-gather overlaps with next bank's NS5 - """ - def __init__(self, params, lr: float, momentum: float, backend_steps: int, - nesterov: bool = True, weight_decay: float = 0.0): - super().__init__( - params, - dict(lr=lr, momentum=momentum, backend_steps=backend_steps, - nesterov=nesterov, weight_decay=weight_decay), - ) - self._built = False - - def _build(self): - self._distributed = dist.is_available() and dist.is_initialized() - self._world_size = dist.get_world_size() if self._distributed else 1 - self._rank = dist.get_rank() if self._distributed else 0 - ws = self._world_size - - self._bank_meta = [] - for group in self.param_groups: - for p in group["params"]: - B = p.shape[0] - padded_B = ((B + ws - 1) // ws) * ws - shard_B = padded_B // ws - tail = p.shape[1:] - dev = p.device - self._bank_meta.append({ - 'p': p, - 'B': B, - 'padded_grad': torch.zeros(padded_B, *tail, device=dev, dtype=torch.bfloat16), - 'shard': torch.zeros(shard_B, *tail, device=dev, dtype=torch.bfloat16), - 'shard_mom': torch.zeros(shard_B, *tail, device=dev, dtype=torch.bfloat16), - 'full_update': torch.zeros(padded_B, *tail, device=dev, dtype=torch.bfloat16), - 'scale': max(1, p.shape[-2] / p.shape[-1]) ** 0.5, - }) - # Sort by size descending -- launch biggest reduce-scatters first - self._bank_meta.sort(key=lambda m: -m['p'].numel()) - self._built = True - - def launch_reduce_scatters(self): - """Phase 1: launch async reduce-scatter for all banks. Call right after backward.""" - if not self._built: - self._build() - if not self._distributed: - return - self._rs_futures = [] - for m in self._bank_meta: - p = m['p'] - if p.grad is None: - self._rs_futures.append(None) - continue - pg = m['padded_grad'] - pg[:m['B']].copy_(p.grad.bfloat16()) - if pg.shape[0] > m['B']: - pg[m['B']:].zero_() - fut = dist.reduce_scatter_tensor(m['shard'], pg, op=dist.ReduceOp.AVG, async_op=True) - self._rs_futures.append(fut) - - @torch.no_grad() - def step(self, closure=None): - """Phase 3: wait for RS, local NS5, all-gather. Call AFTER Adam steps.""" - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - if not self._built: - self._build() - - for group in self.param_groups: - lr = group["lr"] - momentum = group["momentum"] - backend_steps = group["backend_steps"] - nesterov = group["nesterov"] - wd = group.get("weight_decay", 0.0) - - prev_ag_handle = None - prev_m = None - - sharded = self._distributed and hasattr(self, '_rs_futures') - - for i, m in enumerate(self._bank_meta): - p = m['p'] - if p.grad is None: - continue - - if prev_ag_handle is not None: - prev_ag_handle.wait() - pp = prev_m['p'] - upd = prev_m['full_update'][:prev_m['B']] - if wd > 0.0: - pp.data.mul_(1.0 - lr * wd) - pp.add_(upd.to(dtype=pp.dtype), alpha=-lr * prev_m['scale']) - - if sharded and self._rs_futures[i] is not None: - self._rs_futures[i].wait() - g = m['shard'] - buf = m['shard_mom'] - else: - g = p.grad.bfloat16() - state = self.state[p] - if "momentum_buffer" not in state: - state["momentum_buffer"] = torch.zeros_like(g) - buf = state["momentum_buffer"] - - buf.mul_(momentum).add_(g) - if nesterov: - update = g.add(buf, alpha=momentum) - else: - update = buf - - update = zeropower_via_newtonschulz5(update, steps=backend_steps) - - if sharded: - prev_ag_handle = dist.all_gather_into_tensor( - m['full_update'], update, async_op=True) - prev_m = m - else: - if wd > 0.0: - p.data.mul_(1.0 - lr * wd) - p.add_(update.to(dtype=p.dtype), alpha=-lr * m['scale']) - - if prev_ag_handle is not None: - prev_ag_handle.wait() - pp = prev_m['p'] - upd = prev_m['full_update'][:prev_m['B']] - if wd > 0.0: - pp.data.mul_(1.0 - lr * wd) - pp.add_(upd.to(dtype=pp.dtype), alpha=-lr * prev_m['scale']) - - if hasattr(self, '_rs_futures'): - del self._rs_futures - - return loss - -# --- Tokenizer evaluation helpers --- - -def build_sentencepiece_luts( - sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device -) -> tuple[Tensor, Tensor, Tensor]: - sp_vocab_size = int(sp.vocab_size()) - table_size = max(sp_vocab_size, vocab_size) - base_bytes_np = np.zeros((table_size,), dtype=np.int16) - has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) - is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) - for token_id in range(sp_vocab_size): - if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): - continue - is_boundary_token_np[token_id] = False - if sp.is_byte(token_id): - base_bytes_np[token_id] = 1 - continue - piece = sp.id_to_piece(token_id) - if piece.startswith("\u2581"): - has_leading_space_np[token_id] = True - piece = piece[1:] - base_bytes_np[token_id] = len(piece.encode("utf-8")) - return ( - torch.tensor(base_bytes_np, dtype=torch.int16, device=device), - torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), - torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), - ) -def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: - files = [Path(p) for p in sorted(glob.glob(pattern))] - if not files: - raise FileNotFoundError(f"No files found for pattern: {pattern}") - tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() - usable = ((tokens.numel() - 1) // seq_len) * seq_len - if usable <= 0: - raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") - return tokens[: usable + 1] -def eval_val( - args: Hyperparameters, - model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - grad_accum_steps: int, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, - eval_seq_len: int | None = None, -) -> tuple[float, float]: - seq_len = eval_seq_len or args.train_seq_len - local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) - if local_batch_tokens < seq_len: - raise ValueError( - "VAL_BATCH_SIZE must provide at least one sequence per rank; " - f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " - f"GRAD_ACCUM_STEPS={grad_accum_steps}, seq_len={seq_len}" - ) - local_batch_seqs = local_batch_tokens // seq_len - total_seqs = (val_tokens.numel() - 1) // seq_len - seq_start = (total_seqs * rank) // world_size - seq_end = (total_seqs * (rank + 1)) // world_size - val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) - val_token_count = torch.zeros((), device=device, dtype=torch.float64) - val_byte_count = torch.zeros((), device=device, dtype=torch.float64) - model.eval() - with torch.inference_mode(): - for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): - batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) - raw_start = batch_seq_start * seq_len - raw_end = batch_seq_end * seq_len + 1 - local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) - x = local[:-1].reshape(-1, seq_len) - y = local[1:].reshape(-1, seq_len) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - batch_loss = model(x, y).detach() - batch_token_count = float(y.numel()) - val_loss_sum += batch_loss.to(torch.float64) * batch_token_count - val_token_count += batch_token_count - prev_ids = x.reshape(-1) - tgt_ids = y.reshape(-1) - token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) - token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) - val_byte_count += token_bytes.to(torch.float64).sum() - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) - val_loss = val_loss_sum / val_token_count - bits_per_token = val_loss.item() / math.log(2.0) - tokens_per_byte = val_token_count.item() / val_byte_count.item() - model.train() - return float(val_loss.item()), float(bits_per_token * tokens_per_byte) - -# --- Quantization helpers --- - -CONTROL_TENSOR_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "CONTROL_TENSOR_NAME_PATTERNS", - "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,smear,dtg_gate,ve_layer_scales,ve_shared.scale,attn_gate,vr_lambda", - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", - ",".join(CONTROL_TENSOR_NAME_PATTERNS), - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 -INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 -INT8_PER_ROW_SCALE_DTYPE = torch.float16 -INT8_CLIP_PERCENTILE = 99.99984 -INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 -def tensor_nbytes(t: Tensor) -> int: - return int(t.numel()) * int(t.element_size()) -def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: - if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): - return t.float().contiguous() - if t.dtype in {torch.float32, torch.bfloat16}: - passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") - return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() - return t -def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: - t32 = t.float() - if t32.ndim == 2: - clip_abs = ( - torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) - if t32.numel() - else torch.empty((t32.shape[0],), dtype=torch.float32) - ) - clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) - scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) - q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() - return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() - clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 - scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) - q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() - return q, scale -def quantize_state_dict_int8(state_dict: dict[str, Tensor]): - quantized: dict[str, Tensor] = {} - scales: dict[str, Tensor] = {} - dtypes: dict[str, str] = {} - passthrough: dict[str, Tensor] = {} - passthrough_orig_dtypes: dict[str, str] = {} - qmeta: dict[str, dict[str, object]] = {} - stats = dict.fromkeys( - ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), - 0, - ) - for name, tensor in state_dict.items(): - t = tensor.detach().to("cpu").contiguous() - stats["param_count"] += int(t.numel()) - stats["num_tensors"] += 1 - stats["baseline_tensor_bytes"] += tensor_nbytes(t) - if not t.is_floating_point(): - stats["num_nonfloat_tensors"] += 1 - passthrough[name] = t - stats["int8_payload_bytes"] += tensor_nbytes(t) - continue - if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: - kept = keep_float_tensor(name, t, passthrough_orig_dtypes) - passthrough[name] = kept - stats["int8_payload_bytes"] += tensor_nbytes(kept) - continue - stats["num_float_tensors"] += 1 - q, s = quantize_float_tensor(t) - if s.ndim > 0: - qmeta[name] = {"scheme": "per_row", "axis": 0} - quantized[name] = q - scales[name] = s - dtypes[name] = str(t.dtype).removeprefix("torch.") - stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) - obj: dict[str, object] = { - "__quant_format__": "int8_clean_per_row_v1", - "quantized": quantized, - "scales": scales, - "dtypes": dtypes, - "passthrough": passthrough, - } - if qmeta: - obj["qmeta"] = qmeta - if passthrough_orig_dtypes: - obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes - return obj, stats -def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: - out: dict[str, Tensor] = {} - qmeta = obj.get("qmeta", {}) - passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) - for name, q in obj["quantized"].items(): - dtype = getattr(torch, obj["dtypes"][name]) - s = obj["scales"][name] - if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: - s = s.to(dtype=torch.float32) - out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() - else: - scale = float(s.item()) - out[name] = (q.float() * scale).to(dtype=dtype).contiguous() - for name, t in obj["passthrough"].items(): - out_t = t.detach().to("cpu").contiguous() - orig_dtype = passthrough_orig_dtypes.get(name) - if isinstance(orig_dtype, str): - out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() - out[name] = out_t - return out - -# --- Data loading --- - -def load_data_shard(file: Path) -> Tensor: - header_bytes = 256 * np.dtype(" None: - self.file_idx = (self.file_idx + 1) % len(self.files) - self.tokens = load_data_shard(self.files[self.file_idx]) - self.pos = 0 - def take(self, n: int) -> Tensor: - chunks: list[Tensor] = [] - remaining = n - while remaining > 0: - avail = self.tokens.numel() - self.pos - if avail <= 0: - self._advance_file() - continue - k = min(remaining, avail) - chunks.append(self.tokens[self.pos : self.pos + k]) - self.pos += k - remaining -= k - return chunks[0] if len(chunks) == 1 else torch.cat(chunks) -class DistributedTokenLoader: - def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): - self.rank = rank - self.world_size = world_size - self.device = device - self.stream = TokenStream(pattern) - def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: - local_tokens = global_tokens // (self.world_size * grad_accum_steps) - per_rank_span = local_tokens + 1 - chunk = self.stream.take(per_rank_span * self.world_size) - start = self.rank * per_rank_span - local = chunk[start : start + per_rank_span].to(dtype=torch.int64) - x = local[:-1].reshape(-1, seq_len) - y = local[1:].reshape(-1, seq_len) - return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) - -# --- Transformer modules --- - -class RMSNorm(nn.Module): - def __init__(self, eps: float | None = None): - super().__init__() - self.eps = eps - def forward(self, x: Tensor) -> Tensor: - return F.rms_norm(x, (x.size(-1),), eps=self.eps) -class CastedLinear(nn.Linear): - _qat_enabled: bool = False - def forward(self, x: Tensor) -> Tensor: - w = self.weight.to(x.dtype) - if CastedLinear._qat_enabled and self.training and w.ndim == 2: - with torch.no_grad(): - w32 = self.weight.float() - row_max = w32.abs().amax(dim=1) - scale = (row_max / 31.0).clamp_min(1.0 / 31.0) - w_q = (torch.clamp(torch.round(w32 / scale[:, None]), -32, 31) * scale[:, None]).to(x.dtype) - w = w + (w_q - w).detach() - bias = self.bias.to(x.dtype) if self.bias is not None else None - return F.linear(x, w, bias) -def restore_low_dim_params_to_fp32(module: nn.Module) -> None: - with torch.no_grad(): - for name, param in module.named_parameters(): - if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: - param.data = param.data.float() -class Rotary(nn.Module): - def __init__(self, dim: int, base: float = 10000.0, train_seq_len: int = 1024, rope_dims: int = 0): - super().__init__() - self.dim = dim - self.base = base - self.train_seq_len = train_seq_len - self.rope_dims = rope_dims if rope_dims > 0 else dim - inv_freq = 1.0 / (base ** (torch.arange(0, self.rope_dims, 2, dtype=torch.float32) / self.rope_dims)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self._seq_len_cached = 0 - self._cos_cached: Tensor | None = None - self._sin_cached: Tensor | None = None - def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: - if ( - self._cos_cached is None - or self._sin_cached is None - or self._seq_len_cached != seq_len - or self._cos_cached.device != device - ): - rd = self.rope_dims - if seq_len > self.train_seq_len: - scale = seq_len / self.train_seq_len - new_base = self.base * (scale ** (rd / (rd - 2))) - inv_freq = 1.0 / (new_base ** (torch.arange(0, rd, 2, dtype=torch.float32, device=device) / rd)) - else: - inv_freq = self.inv_freq.to(device) - t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) - freqs = torch.outer(t, inv_freq) - self._cos_cached = freqs.cos()[None, :, None, :] - self._sin_cached = freqs.sin()[None, :, None, :] - self._seq_len_cached = seq_len - return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) -def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor, rope_dims: int = 0) -> Tensor: - if rope_dims > 0 and rope_dims < x.size(-1): - x_rope, x_pass = x[..., :rope_dims], x[..., rope_dims:] - half = rope_dims // 2 - x1, x2 = x_rope[..., :half], x_rope[..., half:] - x_rope = torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) - return torch.cat((x_rope, x_pass), dim=-1) - half = x.size(-1) // 2 - x1, x2 = x[..., :half], x[..., half:] - return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) - -class CausalSelfAttention(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - rope_base: float, - qk_gain_init: float, - gated_attention: bool = False, - value_residual: bool = False, - ): - super().__init__() - if dim % num_heads != 0: - raise ValueError("model_dim must be divisible by num_heads") - if num_heads % num_kv_heads != 0: - raise ValueError("num_heads must be divisible by num_kv_heads") - self.num_heads = num_heads - self.num_kv_heads = num_kv_heads - self.head_dim = dim // num_heads - if self.head_dim % 2 != 0: - raise ValueError("head_dim must be even for RoPE") - # No CastedLinear -- weights come from banks - self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) - self.rope_dims = 0 # set by GPT.__init__ for partial RoPE - self.rotary = Rotary(self.head_dim, base=rope_base, train_seq_len=1024) - self.use_xsa = False # set by GPT.__init__ for deep layers only - # Gated attention and value residual (non-banked small params) - self.gated_attention = gated_attention - if gated_attention: - self.attn_gate = nn.Linear(dim, num_heads, bias=True) - nn.init.zeros_(self.attn_gate.weight) - nn.init.constant_(self.attn_gate.bias, 4.0) - self.value_residual = value_residual - if value_residual: - self.vr_lambda = nn.Parameter(torch.tensor([0.5, 0.5], dtype=torch.float32)) - def _xsa_efficient(self, y: Tensor, v: Tensor) -> Tensor: - """Efficient XSA: subtract self-value projection via GQA-aware reshape (no repeat_interleave). - y: [B, T, H, D], v: [B, T, Hkv, D]. H must be divisible by Hkv.""" - B, T, H, D = y.shape - Hkv = v.size(-2) - group = H // Hkv - y_g = y.reshape(B, T, Hkv, group, D) # [B, T, Hkv, group, D] - vn = F.normalize(v, dim=-1).unsqueeze(-2) # [B, T, Hkv, 1, D] -- broadcast ready - proj = (y_g * vn).sum(dim=-1, keepdim=True) * vn - return (y_g - proj).reshape(B, T, H, D) - def forward(self, x: Tensor, q_w: Tensor, k_w: Tensor, v_w: Tensor, out_w: Tensor, v_embed: Tensor | None = None, v0: Tensor | None = None) -> tuple[Tensor, Tensor | None]: - bsz, seqlen, dim = x.shape - q = F.linear(x, q_w.to(x.dtype)).reshape(bsz, seqlen, self.num_heads, self.head_dim) - k = F.linear(x, k_w.to(x.dtype)).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) - v = F.linear(x, v_w.to(x.dtype)) - if v_embed is not None: - v = v + v_embed - v = v.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) - raw_v = v if self.value_residual else None - if self.value_residual and v0 is not None: - lam = self.vr_lambda.to(dtype=v.dtype) - v = lam[0] * v0 + lam[1] * v - q = F.rms_norm(q, (q.size(-1),)) - k = F.rms_norm(k, (k.size(-1),)) - cos, sin = self.rotary(seqlen, x.device, q.dtype) - q = apply_rotary_emb(q, cos, sin, self.rope_dims) - k = apply_rotary_emb(k, cos, sin, self.rope_dims) - q = q * self.q_gain.to(dtype=q.dtype)[None, None, :, None] - y = flash_attn_3_func(q, k, v, causal=True) - if self.use_xsa: - y = self._xsa_efficient(y, v) - if self.gated_attention: - # gate shape: (bsz, seqlen, num_heads) -> (bsz, seqlen, num_heads, 1) for B,T,H,D layout - gate = torch.sigmoid(self.attn_gate(x)).unsqueeze(-1) - y = y * gate - y = y.reshape(bsz, seqlen, dim) - return F.linear(y, out_w.to(x.dtype)), raw_v - -class SmearGate(nn.Module): - def __init__(self, dim: int): - super().__init__() - self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) - def forward(self, x: Tensor) -> Tensor: - g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] - x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) - return (1 - g) * x + g * x_prev - -class BigramHashEmbedding(nn.Module): - def __init__(self, bigram_vocab_size: int, bigram_dim: int, model_dim: int): - super().__init__() - self.bigram_vocab_size = bigram_vocab_size - self.embed = nn.Embedding(bigram_vocab_size, bigram_dim) - nn.init.zeros_(self.embed.weight) - self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None - if self.proj is not None: - nn.init.zeros_(self.proj.weight) - self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) - def bigram_hash(self, tokens: Tensor) -> Tensor: - t = tokens.to(torch.int32) - mod = self.bigram_vocab_size - 1 - out = torch.empty_like(t) - out[..., 0] = mod - out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod - return out.long() - def forward(self, token_ids: Tensor) -> Tensor: - h = self.embed(self.bigram_hash(token_ids)) - if self.proj is not None: - h = self.proj(h) - return h * self.scale.to(dtype=h.dtype) - -class ValueEmbedding(nn.Module): - """Reinject token identity into attention values at specific layers. - Each table maps vocab tokens to a low-dim embedding, projected to model_dim.""" - def __init__(self, vocab_size: int, ve_dim: int, model_dim: int): - super().__init__() - self.embed = nn.Embedding(vocab_size, ve_dim) - nn.init.normal_(self.embed.weight, std=0.01) - self.proj = CastedLinear(ve_dim, model_dim, bias=False) if ve_dim != model_dim else None - if self.proj is not None: - nn.init.zeros_(self.proj.weight) - self.scale = nn.Parameter(torch.tensor(0.1, dtype=torch.float32)) - def forward(self, token_ids: Tensor) -> Tensor: - h = self.embed(token_ids) - if self.proj is not None: - h = self.proj(h) - return h * self.scale.to(dtype=h.dtype) - -class MLP(nn.Module): - def __init__(self, dim: int, mlp_mult: int): - super().__init__() - # No CastedLinear -- weights come from banks - def forward(self, x: Tensor, up_w: Tensor, down_w: Tensor) -> Tensor: - x = F.leaky_relu(F.linear(x, up_w.to(x.dtype)), negative_slope=0.5) - return F.linear(x.square(), down_w.to(x.dtype)) - -class Block(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: int, - rope_base: float, - qk_gain_init: float, - layer_idx: int = 0, - ln_scale: bool = False, - dtg: bool = False, - gated_attention: bool = False, - value_residual: bool = False, - ): - super().__init__() - self.attn_norm = RMSNorm() - self.mlp_norm = RMSNorm() - self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init, - gated_attention=gated_attention, value_residual=value_residual) - self.mlp = MLP(dim, mlp_mult) - self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) - self.ln_scale_factor = 1.0 / math.sqrt(layer_idx + 1) if ln_scale else 1.0 - if dtg: - self.dtg_gate = nn.Linear(dim, 1, bias=True) - nn.init.zeros_(self.dtg_gate.weight) - nn.init.constant_(self.dtg_gate.bias, 2.0) - else: - self.dtg_gate = None - def forward(self, x: Tensor, x0: Tensor, q_w: Tensor, k_w: Tensor, v_w: Tensor, out_w: Tensor, up_w: Tensor, down_w: Tensor, v_embed: Tensor | None = None, v0: Tensor | None = None) -> tuple[Tensor, Tensor | None]: - mix = self.resid_mix.to(dtype=x.dtype) - x_in = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 - attn_out, raw_v = self.attn(self.attn_norm(x_in) * self.ln_scale_factor, q_w, k_w, v_w, out_w, v_embed=v_embed, v0=v0) - x_out = x_in + self.attn_scale.to(dtype=x_in.dtype)[None, None, :] * attn_out - x_out = x_out + self.mlp_scale.to(dtype=x_out.dtype)[None, None, :] * self.mlp(self.mlp_norm(x_out) * self.ln_scale_factor, up_w, down_w) - if self.dtg_gate is not None: - gate = torch.sigmoid(self.dtg_gate(x_in.detach())) - x_out = x_in + gate * (x_out - x_in) - return x_out, raw_v - -class GPT(nn.Module): - def __init__( - self, - vocab_size: int, - num_layers: int, - model_dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: int, - tie_embeddings: bool, - tied_embed_init_std: float, - logit_softcap: float, - rope_base: float, - qk_gain_init: float, - mtp_num_heads: int = 0, - mtp_loss_weight: float = 0.1, - bigram_vocab_size: int = 0, - bigram_dim: int = 128, - xsa_last_n: int = 0, - rope_dims: int = 0, - ln_scale: bool = False, - dtg: bool = False, - ve_enabled: bool = False, - ve_dim: int = 128, - ve_layers: str = "9,10", - gated_attention: bool = False, - value_residual: bool = False, - ): - super().__init__() - self._ve_target_dim = num_kv_heads * (model_dim // num_heads) # kv_dim for value projection - if logit_softcap <= 0.0: - raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") - self.tie_embeddings = tie_embeddings - self.tied_embed_init_std = tied_embed_init_std - self.logit_softcap = logit_softcap - self.value_residual = value_residual - self.mtp_num_heads = mtp_num_heads - self.mtp_loss_weight = mtp_loss_weight - self.tok_emb = nn.Embedding(vocab_size, model_dim) - self.bigram = BigramHashEmbedding(bigram_vocab_size, bigram_dim, model_dim) if bigram_vocab_size > 0 else None - self.smear = SmearGate(model_dim) - self.num_encoder_layers = num_layers // 2 - self.num_decoder_layers = num_layers - self.num_encoder_layers - self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) - self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) - # Parameter banks: contiguous 3D tensors for batched optimizer - head_dim = model_dim // num_heads - kv_dim = num_kv_heads * head_dim - mlp_dim = int(mlp_mult * model_dim) - self.num_layers = num_layers - self.qo_bank = nn.Parameter(torch.empty(2 * num_layers, model_dim, model_dim)) - self.kv_bank = nn.Parameter(torch.empty(2 * num_layers, kv_dim, model_dim)) - self.mlp_up_bank = nn.Parameter(torch.empty(num_layers, mlp_dim, model_dim)) - self.mlp_down_bank = nn.Parameter(torch.empty(num_layers, model_dim, mlp_dim)) - self.blocks = nn.ModuleList( - [ - Block( - model_dim, - num_heads, - num_kv_heads, - mlp_mult, - rope_base, - qk_gain_init, - layer_idx=i, - ln_scale=ln_scale, - dtg=dtg, - gated_attention=gated_attention, - value_residual=value_residual, - ) - for i in range(num_layers) - ] - ) - if rope_dims > 0: - head_dim = model_dim // num_heads - for block in self.blocks: - block.attn.rope_dims = rope_dims - block.attn.rotary = Rotary(head_dim, base=rope_base, train_seq_len=1024, rope_dims=rope_dims) - self.ve_layer_indices = [int(x) for x in ve_layers.split(",") if x.strip()] if ve_enabled else [] - kv_dim_ve = self._ve_target_dim - if self.ve_layer_indices: - self.ve_shared = ValueEmbedding(vocab_size, ve_dim, kv_dim_ve) - self.ve_layer_scales = nn.ParameterList( - [nn.Parameter(torch.ones(1, dtype=torch.float32)) for _ in self.ve_layer_indices] - ) - else: - self.ve_shared = None - self.ve_layer_scales = nn.ParameterList() - self.value_embeds = nn.ModuleList() # keep empty for compat - self.final_norm = RMSNorm() - self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) - if self.lm_head is not None: - self.lm_head._zero_init = True - self.mtp_heads = nn.ModuleList( - [CastedLinear(model_dim, vocab_size, bias=False) for _ in range(mtp_num_heads)] - ) - for head in self.mtp_heads: - head._zero_init = True - if xsa_last_n > 0: - for i in range(max(0, num_layers - xsa_last_n), num_layers): - self.blocks[i].attn.use_xsa = True - self._init_weights() - def _init_weights(self) -> None: - if self.tie_embeddings: - nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) - n = self.num_layers - proj_scale = 1.0 / math.sqrt(2 * n) - # Init banks: orthogonal, with proj layers scaled down and out/down zero-init - for i in range(n): - nn.init.orthogonal_(self.qo_bank.data[i], gain=1.0) # Q - nn.init.zeros_(self.qo_bank.data[n + i]) # Out (zero init) - nn.init.orthogonal_(self.kv_bank.data[i], gain=1.0) # K - nn.init.orthogonal_(self.kv_bank.data[n + i], gain=1.0) # V - nn.init.orthogonal_(self.mlp_up_bank.data[i], gain=1.0) # MLP up - nn.init.zeros_(self.mlp_down_bank.data[i]) # MLP down (zero init) - # Scale proj layers (out_proj and mlp_down are "proj" layers) - self.qo_bank.data[n + i].mul_(proj_scale) - self.mlp_down_bank.data[i].mul_(proj_scale) - # Init remaining nn.Linear modules (bigram proj, mtp heads, lm_head) - for name, module in self.named_modules(): - if isinstance(module, nn.Linear): - if getattr(module, "_zero_init", False): - nn.init.zeros_(module.weight) - elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: - nn.init.orthogonal_(module.weight, gain=1.0) - def _get_ve(self, layer_idx: int, input_ids: Tensor, ve_cache: dict | None = None) -> Tensor | None: - """Get value embedding for a specific layer using shared table + per-layer scale.""" - if self.ve_shared is None or layer_idx not in self.ve_layer_indices: - return None - if ve_cache is not None and 've' not in ve_cache: - ve_cache['ve'] = self.ve_shared(input_ids) - ve_base = ve_cache['ve'] if ve_cache is not None else self.ve_shared(input_ids) - ve_idx = self.ve_layer_indices.index(layer_idx) - return ve_base * self.ve_layer_scales[ve_idx].to(dtype=ve_base.dtype) - def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: - n = self.num_layers - x = self.tok_emb(input_ids) - if self.bigram is not None: - x = x + self.bigram(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - x = self.smear(x) - x0 = x - v0 = None - skips: list[Tensor] = [] - ve_cache: dict = {} - for i in range(self.num_encoder_layers): - ve = self._get_ve(i, input_ids, ve_cache) - x, raw_v = self.blocks[i](x, x0, - self.qo_bank[i], self.kv_bank[i], self.kv_bank[n + i], - self.qo_bank[n + i], self.mlp_up_bank[i], self.mlp_down_bank[i], - v_embed=ve, v0=v0) - if v0 is None and raw_v is not None: - v0 = raw_v - skips.append(x) - for i in range(self.num_decoder_layers): - bi = self.num_encoder_layers + i - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - ve = self._get_ve(bi, input_ids, ve_cache) - x, _ = self.blocks[bi](x, x0, - self.qo_bank[bi], self.kv_bank[bi], self.kv_bank[n + bi], - self.qo_bank[n + bi], self.mlp_up_bank[bi], self.mlp_down_bank[bi], - v_embed=ve, v0=v0) - x = self.final_norm(x) - x_flat = x.reshape(-1, x.size(-1)) - targets = target_ids.reshape(-1) - if self.tie_embeddings: - logits_proj = F.linear(x_flat, self.tok_emb.weight) - else: - if self.lm_head is None: - raise RuntimeError("lm_head is required when tie_embeddings=False") - logits_proj = self.lm_head(x_flat) - logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - main_loss = F.cross_entropy(logits.float(), targets, reduction="mean") - if self.training and self.mtp_num_heads > 0 and self.mtp_loss_weight > 0.0: - _, seqlen, dim = x.shape - mtp_loss_sum = x.new_zeros(()) - mtp_loss_count = 0 - for k, mtp_head in enumerate(self.mtp_heads): - valid_t = seqlen - (k + 1) - if valid_t <= 0: - continue - mtp_hidden = x[:, :valid_t, :].reshape(-1, dim) - mtp_targets = target_ids[:, k + 1 :].reshape(-1) - mtp_logits_proj = mtp_head(mtp_hidden) - mtp_logits = self.logit_softcap * torch.tanh(mtp_logits_proj / self.logit_softcap) - mtp_loss_sum = mtp_loss_sum + F.cross_entropy(mtp_logits.float(), mtp_targets, reduction="mean") - mtp_loss_count += 1 - if mtp_loss_count > 0: - main_loss = main_loss + self.mtp_loss_weight * (mtp_loss_sum / mtp_loss_count) - return main_loss - def forward_logits(self, input_ids: Tensor) -> Tensor: - """Return logits (bsz, seq_len, vocab) without computing loss.""" - n = self.num_layers - x = self.tok_emb(input_ids) - if self.bigram is not None: - x = x + self.bigram(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - x = self.smear(x) - x0 = x - v0 = None - skips: list[Tensor] = [] - ve_cache: dict = {} - for i in range(self.num_encoder_layers): - ve = self._get_ve(i, input_ids, ve_cache) - x, raw_v = self.blocks[i](x, x0, - self.qo_bank[i], self.kv_bank[i], self.kv_bank[n + i], - self.qo_bank[n + i], self.mlp_up_bank[i], self.mlp_down_bank[i], - v_embed=ve, v0=v0) - if v0 is None and raw_v is not None: - v0 = raw_v - skips.append(x) - for i in range(self.num_decoder_layers): - bi = self.num_encoder_layers + i - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - ve = self._get_ve(bi, input_ids, ve_cache) - x, _ = self.blocks[bi](x, x0, - self.qo_bank[bi], self.kv_bank[bi], self.kv_bank[n + bi], - self.qo_bank[n + bi], self.mlp_up_bank[bi], self.mlp_down_bank[bi], - v_embed=ve, v0=v0) - x = self.final_norm(x) - if self.tie_embeddings: - logits_proj = F.linear(x, self.tok_emb.weight) - else: - logits_proj = self.lm_head(x) - return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - -# --- Sliding window evaluation --- - -def eval_val_sliding( - args: Hyperparameters, - base_model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, - stride: int, - batch_seqs: int = 32, - eval_seq_len: int | None = None, -) -> tuple[float, float]: - """Sliding window evaluation: each token scored with maximum context.""" - seq_len = eval_seq_len or args.train_seq_len - total_tokens = val_tokens.numel() - 1 - window_starts = [ws for ws in range(0, total_tokens, stride) - if min(ws + seq_len, total_tokens) - ws >= 1] - total_windows = len(window_starts) - my_s = (total_windows * rank) // world_size - my_e = (total_windows * (rank + 1)) // world_size - my_windows = window_starts[my_s:my_e] - loss_sum = torch.zeros((), device=device, dtype=torch.float64) - token_count = torch.zeros((), device=device, dtype=torch.float64) - byte_count = torch.zeros((), device=device, dtype=torch.float64) - base_model.eval() - compiled_logits = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) - with torch.inference_mode(): - for bi in range(0, len(my_windows), batch_seqs): - batch_ws = my_windows[bi:bi + batch_seqs] - bsz = len(batch_ws) - x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) - y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) - wlens: list[int] = [] - for i, ws in enumerate(batch_ws): - end = min(ws + seq_len, total_tokens) - wlen = end - ws - wlens.append(wlen) - chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) - x_batch[i, :wlen] = chunk[:-1] - y_batch[i, :wlen] = chunk[1:] - with torch.autocast(device_type="cuda", dtype=torch.bfloat16): - logits = compiled_logits(x_batch) - nll = F.cross_entropy( - logits.reshape(-1, logits.size(-1)).float(), - y_batch.reshape(-1), - reduction="none", - ).reshape(bsz, seq_len) - for i, ws in enumerate(batch_ws): - wlen = wlens[i] - s = 0 if ws == 0 else max(wlen - stride, 0) - scored_nll = nll[i, s:wlen].to(torch.float64) - loss_sum += scored_nll.sum() - token_count += float(wlen - s) - tgt = y_batch[i, s:wlen] - prev = x_batch[i, s:wlen] - tb = base_bytes_lut[tgt].to(torch.float64) - tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) - byte_count += tb.sum() - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) - val_loss = (loss_sum / token_count).item() - bits_per_token = val_loss / math.log(2.0) - tokens_per_byte = token_count.item() / byte_count.item() - base_model.train() - return val_loss, bits_per_token * tokens_per_byte - - -def eval_val_sliding_ttt( - args: Hyperparameters, base_model: nn.Module, rank: int, world_size: int, - device: torch.device, val_tokens: Tensor, base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, is_boundary_token_lut: Tensor, - stride: int, batch_seqs: int = 32, log0=print, -) -> tuple[float, float]: - """Legal score-first TTT (PR #461 recipe): score each chunk with sliding windows, - then train on it. Every token scored BEFORE any update that could use it.""" - seq_len = args.train_seq_len - total_tokens = val_tokens.numel() - 1 - ttt_chunk = args.ttt_chunk_tokens - - # Pre-compute all window starts - window_starts = [ws for ws in range(0, total_tokens, stride) - if min(ws + seq_len, total_tokens) - ws >= stride or ws == 0] - - # Assign each window to a chunk based on the first token it scores - num_chunks = (total_tokens + ttt_chunk - 1) // ttt_chunk - chunk_windows: list[list[int]] = [[] for _ in range(num_chunks)] - for ws in window_starts: - end = min(ws + seq_len, total_tokens) - wlen = end - ws - s = 0 if ws == 0 else max(wlen - stride, 0) - scored_start = ws + s - ci = min(scored_start // ttt_chunk, num_chunks - 1) - chunk_windows[ci].append(ws) - - log0(f"ttt_sliding:start chunks={num_chunks} chunk_tokens={ttt_chunk} " - f"total_windows={len(window_starts)} stride={stride} " - f"ttt_lr={args.ttt_lr} ttt_epochs={args.ttt_epochs} " - f"freeze_blocks={args.ttt_freeze_blocks}") - - loss_sum = torch.zeros((), device=device, dtype=torch.float64) - token_count = torch.zeros((), device=device, dtype=torch.float64) - byte_count = torch.zeros((), device=device, dtype=torch.float64) - - # Freeze first N blocks - frozen_block_ids = set(range(min(args.ttt_freeze_blocks, len(base_model.blocks)))) - ttt_params = [] - for name, p in base_model.named_parameters(): - freeze = False - for bi in frozen_block_ids: - if f"blocks.{bi}." in name: - freeze = True - break - if freeze: - p.requires_grad_(False) - else: - p.requires_grad_(True) - ttt_params.append(p) - - log0(f"ttt_sliding:params unfrozen={sum(p.numel() for p in ttt_params)} " - f"frozen={sum(p.numel() for p in base_model.parameters() if not p.requires_grad)}") - - optimizer = torch.optim.SGD(ttt_params, lr=args.ttt_lr, momentum=args.ttt_momentum) - t0 = time.perf_counter() - - for ci in range(num_chunks): - windows = chunk_windows[ci] - if not windows: - continue - chunk_start = ci * ttt_chunk - chunk_end = min((ci + 1) * ttt_chunk, total_tokens) - - # --- Phase 1: SCORE this chunk's windows (inference_mode) --- - my_s = (len(windows) * rank) // world_size - my_e = (len(windows) * (rank + 1)) // world_size - my_windows = windows[my_s:my_e] - - base_model.eval() - with torch.inference_mode(): - for bi in range(0, len(my_windows), batch_seqs): - batch_ws = my_windows[bi:bi + batch_seqs] - bsz = len(batch_ws) - x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) - y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) - wlens: list[int] = [] - for i, ws in enumerate(batch_ws): - end = min(ws + seq_len, total_tokens) - wlen = end - ws - wlens.append(wlen) - chunk_tok = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) - x_batch[i, :wlen] = chunk_tok[:-1] - y_batch[i, :wlen] = chunk_tok[1:] - with torch.autocast(device_type="cuda", dtype=torch.bfloat16): - logits = base_model.forward_logits(x_batch) - nll = F.cross_entropy( - logits.reshape(-1, logits.size(-1)).float(), - y_batch.reshape(-1), reduction="none", - ).reshape(bsz, seq_len) - for i, ws in enumerate(batch_ws): - wlen = wlens[i] - s = 0 if ws == 0 else max(wlen - stride, 0) - scored_nll = nll[i, s:wlen].to(torch.float64) - loss_sum += scored_nll.sum() - token_count += float(wlen - s) - tgt, prev = y_batch[i, s:wlen], x_batch[i, s:wlen] - tb = base_bytes_lut[tgt].to(torch.float64) - tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) - byte_count += tb.sum() - - # --- Phase 2: TRAIN on this chunk (already scored = legal) --- - is_last_chunk = (ci == num_chunks - 1) - if not is_last_chunk and args.ttt_epochs > 0: - base_model.train() - chunk_seqs = (chunk_end - chunk_start) // seq_len - if chunk_seqs > 0: - cos_lr = args.ttt_lr * 0.5 * (1.0 + math.cos(math.pi * ci / max(num_chunks - 1, 1))) - for pg in optimizer.param_groups: - pg['lr'] = cos_lr - my_seq_s = (chunk_seqs * rank) // world_size - my_seq_e = (chunk_seqs * (rank + 1)) // world_size - my_chunk_seqs = my_seq_e - my_seq_s - for _ep in range(args.ttt_epochs): - for bs in range(0, my_chunk_seqs, args.ttt_batch_seqs): - be = min(bs + args.ttt_batch_seqs, my_chunk_seqs) - actual_bs = my_seq_s + bs - start_tok = chunk_start + actual_bs * seq_len - end_tok = chunk_start + (my_seq_s + be) * seq_len + 1 - if end_tok > val_tokens.numel(): - continue - local = val_tokens[start_tok:end_tok].to(device=device, dtype=torch.int64) - x = local[:-1].reshape(-1, seq_len) - y = local[1:].reshape(-1, seq_len) - optimizer.zero_grad(set_to_none=True) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16): - loss = base_model(x, y) - loss.backward() - if world_size > 1: - for p in ttt_params: - if p.grad is not None: - dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) - torch.nn.utils.clip_grad_norm_(ttt_params, args.ttt_grad_clip) - optimizer.step() - - if rank == 0 and (ci % 10 == 0 or ci == num_chunks - 1): - elapsed = time.perf_counter() - t0 - rl = loss_sum.item() / max(token_count.item(), 1) - rbpb = rl / math.log(2.0) * (token_count.item() / max(byte_count.item(), 1)) if token_count.item() > 0 else 0.0 - log0(f" ttt_chunk [{ci+1}/{num_chunks}] bpb={rbpb:.6f} time={elapsed:.1f}s") - - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) - - val_loss = (loss_sum / token_count).item() - val_bpb = val_loss / math.log(2.0) * (token_count.item() / byte_count.item()) - - for p in base_model.parameters(): - p.requires_grad_(True) - base_model.eval() - - log0(f"ttt_sliding:done val_loss={val_loss:.6f} val_bpb={val_bpb:.6f} " - f"elapsed={time.perf_counter() - t0:.1f}s") - return val_loss, val_bpb - - -# --- GPTQ-lite int6 quantization --- - -def _classify_param(name: str) -> str: - if "tok_emb" in name or "lm_head" in name: - return "embed" - if ".mlp." in name: - return "mlp" - if ".attn." in name or (".proj." in name and ".mlp." not in name): - return "attn" - return "other" -def quantize_int6_per_row(t: Tensor, clip_range: int = 31) -> tuple[Tensor, Tensor]: - t32 = t.float() - if t32.ndim == 2: - best_q, best_s, best_err = None, None, float('inf') - for pct in [0.9990, 0.9995, 0.9999, 0.99999, 1.0]: - if pct < 1.0: - row_clip = torch.quantile(t32.abs(), pct, dim=1) - else: - row_clip = t32.abs().amax(dim=1) - s = (row_clip / clip_range).clamp_min(1.0 / clip_range).to(torch.float16) - q = torch.clamp(torch.round(t32 / s.float()[:, None]), -clip_range, clip_range).to(torch.int8) - recon = q.float() * s.float()[:, None] - err = (t32 - recon).pow(2).mean().item() - if err < best_err: - best_q, best_s, best_err = q, s, err - return best_q, best_s - amax = t32.abs().max().item() - scale = torch.tensor(amax / clip_range if amax > 0 else 1.0, dtype=torch.float16) - q = torch.clamp(torch.round(t32 / scale.float()), -clip_range, clip_range).to(torch.int8) - return q, scale - -def _unbank_state_dict(sd: dict[str, Tensor], num_layers: int) -> dict[str, Tensor]: - """Convert 3D bank tensors into individual 2D tensors with standard names.""" - out: dict[str, Tensor] = {} - n = num_layers - for name, tensor in sd.items(): - if name == "qo_bank": - for i in range(n): - out[f"blocks.{i}.attn.c_q.weight"] = tensor[i] - out[f"blocks.{i}.attn.proj.weight"] = tensor[n + i] - elif name == "kv_bank": - for i in range(n): - out[f"blocks.{i}.attn.c_k.weight"] = tensor[i] - out[f"blocks.{i}.attn.c_v.weight"] = tensor[n + i] - elif name == "mlp_up_bank": - for i in range(n): - out[f"blocks.{i}.mlp.fc.weight"] = tensor[i] - elif name == "mlp_down_bank": - for i in range(n): - out[f"blocks.{i}.mlp.proj.weight"] = tensor[i] - else: - out[name] = tensor - return out - -def _rebank_state_dict(sd: dict[str, Tensor], num_layers: int, template_sd: dict[str, Tensor]) -> dict[str, Tensor]: - """Convert individual 2D tensors back into 3D bank tensors.""" - out: dict[str, Tensor] = {} - n = num_layers - # Reconstruct banks from individual weight keys - qo_slices = [None] * (2 * n) - kv_slices = [None] * (2 * n) - up_slices = [None] * n - down_slices = [None] * n - consumed = set() - for i in range(n): - qk = f"blocks.{i}.attn.c_q.weight" - if qk in sd: - qo_slices[i] = sd[qk] - consumed.add(qk) - ok = f"blocks.{i}.attn.proj.weight" - if ok in sd: - qo_slices[n + i] = sd[ok] - consumed.add(ok) - kk = f"blocks.{i}.attn.c_k.weight" - if kk in sd: - kv_slices[i] = sd[kk] - consumed.add(kk) - vk = f"blocks.{i}.attn.c_v.weight" - if vk in sd: - kv_slices[n + i] = sd[vk] - consumed.add(vk) - fk = f"blocks.{i}.mlp.fc.weight" - if fk in sd: - up_slices[i] = sd[fk] - consumed.add(fk) - dk = f"blocks.{i}.mlp.proj.weight" - if dk in sd: - down_slices[i] = sd[dk] - consumed.add(dk) - out["qo_bank"] = torch.stack(qo_slices).to(dtype=template_sd["qo_bank"].dtype) - out["kv_bank"] = torch.stack(kv_slices).to(dtype=template_sd["kv_bank"].dtype) - out["mlp_up_bank"] = torch.stack(up_slices).to(dtype=template_sd["mlp_up_bank"].dtype) - out["mlp_down_bank"] = torch.stack(down_slices).to(dtype=template_sd["mlp_down_bank"].dtype) - for name, tensor in sd.items(): - if name not in consumed: - out[name] = tensor - return out - -def mixed_quantize_int6(state_dict: dict[str, Tensor], int6_cats: set[str]): - num_layers_total = max( - (int(k.split(".")[1]) for k in state_dict if k.startswith("blocks.")), - default=0, - ) + 1 - late_k_layers = set(range(num_layers_total - 2, num_layers_total)) - result: dict[str, Tensor] = {} - meta: dict[str, object] = {} - for name, tensor in state_dict.items(): - t = tensor.detach().cpu().contiguous() - cat = _classify_param(name) - if not t.is_floating_point() or t.numel() <= 65536: - result[name] = t.to(torch.float16) if t.is_floating_point() else t - meta[name] = "passthrough" - continue - if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): - result[name] = t.float() - meta[name] = "passthrough_ctrl" - continue - if cat in int6_cats and t.ndim >= 1: - q, s = quantize_int6_per_row(t) - result[name + ".q"] = q - result[name + ".scale"] = s - meta[name] = {"type": "int6"} - else: - q, s = quantize_float_tensor(t) - result[name + ".q"] = q - result[name + ".scale"] = s - meta[name] = {"type": "int8"} - return result, meta -def dequantize_mixed_int6(result: dict[str, Tensor], meta: dict[str, object], - template_sd: dict[str, Tensor]) -> dict[str, Tensor]: - out: dict[str, Tensor] = {} - for name, orig in template_sd.items(): - info = meta.get(name) - if info is None: - continue - orig_dtype = orig.dtype - if info in ("passthrough", "passthrough_ctrl", "passthrough_fp16"): - t = result[name] - if t.dtype == torch.float16 and orig_dtype in (torch.float32, torch.bfloat16): - t = t.to(orig_dtype) - out[name] = t - continue - q, s = result[name + ".q"], result[name + ".scale"] - if s.ndim > 0: - out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig_dtype) - else: - out[name] = (q.float() * float(s.item())).to(orig_dtype) - return out - -# --- Training --- - -def main() -> None: - code = Path(__file__).read_text(encoding="utf-8") - args = Hyperparameters() - # zeropower_via_newtonschulz5 runs eagerly with bmm -- do NOT compile - distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ - rank = int(os.environ.get("RANK", "0")) - world_size = int(os.environ.get("WORLD_SIZE", "1")) - local_rank = int(os.environ.get("LOCAL_RANK", "0")) - if world_size <= 0: - raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") - if 8 % world_size != 0: - raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") - grad_accum_steps = 8 // world_size - grad_scale = 1.0 / grad_accum_steps - if not torch.cuda.is_available(): - raise RuntimeError("CUDA is required") - device = torch.device("cuda", local_rank) - torch.cuda.set_device(device) - if distributed: - dist.init_process_group(backend="nccl", device_id=device) - dist.barrier() - master_process = rank == 0 - torch.backends.cuda.matmul.allow_tf32 = True - torch.backends.cudnn.allow_tf32 = True - from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp - enable_cudnn_sdp(False) - enable_flash_sdp(True) - enable_mem_efficient_sdp(False) - enable_math_sdp(False) - logfile = None - if master_process: - os.makedirs("logs", exist_ok=True) - logfile = f"logs/{args.run_id}.txt" - print(logfile) - def log0(msg: str, console: bool = True) -> None: - if not master_process: - return - if console: - print(msg) - if logfile is not None: - with open(logfile, "a", encoding="utf-8") as f: - print(msg, file=f) - log0(code, console=False) - log0("=" * 100, console=False) - log0(f"Running Python {sys.version}", console=False) - log0(f"Running PyTorch {torch.__version__}", console=False) - log0( - subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, - console=False, - ) - log0("=" * 100, console=False) - random.seed(args.seed) - np.random.seed(args.seed) - torch.manual_seed(args.seed) - torch.cuda.manual_seed_all(args.seed) - if not args.tokenizer_path.endswith(".model"): - raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") - sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) - if int(sp.vocab_size()) != args.vocab_size: - raise ValueError( - f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" - ) - dataset_dir = Path(args.data_path).resolve() - actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) - effective_eval_seq_len = args.eval_seq_len if args.eval_seq_len > 0 else args.train_seq_len - val_seq_len = max(args.train_seq_len, effective_eval_seq_len) - val_tokens = load_validation_tokens(args.val_files, val_seq_len) - base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( - sp, args.vocab_size, device - ) - log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") - log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") - log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") - CastedLinear._qat_enabled = args.qat_enabled - base_model = GPT( - vocab_size=args.vocab_size, - num_layers=args.num_layers, - model_dim=args.model_dim, - num_heads=args.num_heads, - num_kv_heads=args.num_kv_heads, - mlp_mult=args.mlp_mult, - tie_embeddings=args.tie_embeddings, - tied_embed_init_std=args.tied_embed_init_std, - logit_softcap=args.logit_softcap, - rope_base=args.rope_base, - qk_gain_init=args.qk_gain_init, - mtp_num_heads=args.mtp_num_heads, - mtp_loss_weight=args.mtp_loss_weight, - bigram_vocab_size=args.bigram_vocab_size, - bigram_dim=args.bigram_dim, - xsa_last_n=args.xsa_last_n, - rope_dims=args.rope_dims, - ln_scale=args.ln_scale, - dtg=args.dtg_enabled, - ve_enabled=args.ve_enabled, - ve_dim=args.ve_dim, - ve_layers=args.ve_layers, - gated_attention=args.gated_attention, - value_residual=args.value_residual, - ).to(device).bfloat16() - # Banks stay FP32 (like CastedLinear weights), cast to BF16 in forward - base_model.qo_bank.data = base_model.qo_bank.data.float() - base_model.kv_bank.data = base_model.kv_bank.data.float() - base_model.mlp_up_bank.data = base_model.mlp_up_bank.data.float() - base_model.mlp_down_bank.data = base_model.mlp_down_bank.data.float() - for module in base_model.modules(): - if isinstance(module, CastedLinear): - module.float() - restore_low_dim_params_to_fp32(base_model) - # No DDP -- Parallel Muon handles bank grad communication via reduce-scatter, - # and non-bank grads are manually all-reduced before Adam steps. - compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) - model = compiled_model - - # Optimizer split: - # - 4 parameter banks -> Muon (batched Newton-Schulz) - # - token embedding -> Adam - # - scalars/control tensors -> Adam - # - bigram proj, mtp heads, VE proj -> Adam (small matrix params not worth banking) - matrix_params = [ - base_model.qo_bank, base_model.kv_bank, - base_model.mlp_up_bank, base_model.mlp_down_bank, - ] - block_named_params = list(base_model.blocks.named_parameters()) - scalar_params = [ - p - for name, p in block_named_params - if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - if base_model.skip_weights.numel() > 0: - scalar_params.append(base_model.skip_weights) - scalar_params.append(base_model.smear.gate) - if base_model.bigram is not None: - scalar_params.append(base_model.bigram.scale) - token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr - tok_params = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] - if base_model.bigram is not None: - tok_params.append({"params": [base_model.bigram.embed.weight], "lr": token_lr, "base_lr": token_lr}) - if base_model.bigram.proj is not None: - scalar_params.append(base_model.bigram.proj.weight) - if base_model.ve_shared is not None: - tok_params.append({"params": [base_model.ve_shared.embed.weight], "lr": token_lr, "base_lr": token_lr}) - if base_model.ve_shared.proj is not None: - scalar_params.append(base_model.ve_shared.proj.weight) - scalar_params.append(base_model.ve_shared.scale) - for s in base_model.ve_layer_scales: - scalar_params.append(s) - optimizer_tok = torch.optim.AdamW( - tok_params, - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - weight_decay=args.adam_wd, - fused=True, - ) - optimizer_muon = Muon( - matrix_params, - lr=args.matrix_lr, - momentum=args.muon_momentum, - backend_steps=args.muon_backend_steps, - weight_decay=args.muon_wd, - ) - for group in optimizer_muon.param_groups: - group["base_lr"] = args.matrix_lr - optimizer_scalar = torch.optim.AdamW( - [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - weight_decay=args.adam_wd, - fused=True, - ) - # Non-bank params that need manual all-reduce (replicated across GPUs) - replicated_params = list(optimizer_tok.param_groups[0]["params"]) - for pg in optimizer_tok.param_groups[1:]: - replicated_params.extend(pg["params"]) - replicated_params.extend(scalar_params) - - optimizer_head = None - if base_model.lm_head is not None: - optimizer_head = torch.optim.Adam( - [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - replicated_params.append(base_model.lm_head.weight) - optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] - if optimizer_head is not None: - optimizers.append(optimizer_head) - n_params = sum(p.numel() for p in base_model.parameters()) - mtp_params = sum(p.numel() for p in base_model.mtp_heads.parameters()) - log0(f"model_params:{n_params}") - log0(f"mtp_num_heads:{args.mtp_num_heads} mtp_loss_weight:{args.mtp_loss_weight} mtp_params:{mtp_params}") - xsa_layers = [i for i, b in enumerate(base_model.blocks) if b.attn.use_xsa] - log0(f"XSA:last_{args.xsa_last_n} active_layers:{xsa_layers}") - log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") - log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") - log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") - log0( - f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " - f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " - f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" - ) - log0( - f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " - f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " - f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" - ) - log0(f"seed:{args.seed}") - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - def zero_grad_all() -> None: - for opt in optimizers: - opt.zero_grad(set_to_none=True) - max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None - def lr_mul(step: int, elapsed_ms: float) -> float: - if args.warmdown_iters <= 0: - return 1.0 - if max_wallclock_ms is None: - warmdown_start = max(args.iterations - args.warmdown_iters, 0) - return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 - step_ms = elapsed_ms / max(step, 1) - warmdown_ms = args.warmdown_iters * step_ms - remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) - return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 - if args.warmup_steps > 0: - initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} - initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] - model.train() - for warmup_step in range(args.warmup_steps): - zero_grad_all() - for micro_step in range(grad_accum_steps): - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - warmup_loss = model(x, y) - (warmup_loss * grad_scale).backward() - # All-reduce all grads for warmup (simple, not optimized) - if distributed: - for p in base_model.parameters(): - if p.grad is not None: - dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) - for opt in optimizers: - opt.step() - zero_grad_all() - if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: - log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") - base_model.load_state_dict(initial_model_state, strict=True) - for opt, state in zip(optimizers, initial_optimizer_states, strict=True): - opt.load_state_dict(state) - zero_grad_all() - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - swa_state: dict[str, Tensor] | None = None - swa_count = 0 - from collections import deque - lawa_queue: deque[dict[str, Tensor]] = deque(maxlen=args.lawa_k) - ema_state = {name: t.detach().float().clone() for name, t in base_model.state_dict().items()} - ema_decay = 0.997 - training_time_ms = 0.0 - stop_after_step: int | None = None - torch.cuda.synchronize() - t0 = time.perf_counter() - step = 0 - while True: - last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) - should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) - if should_validate: - torch.cuda.synchronize() - training_time_ms += 1000.0 * (time.perf_counter() - t0) - val_loss, val_bpb = eval_val( - args, - model, - rank, - world_size, - device, - grad_accum_steps, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, - ) - log0( - f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " - f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" - ) - torch.cuda.synchronize() - t0 = time.perf_counter() - if last_step: - if stop_after_step is not None and step < args.iterations: - log0( - f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " - f"step:{step}/{args.iterations}" - ) - break - elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - scale = lr_mul(step, elapsed_ms) - if args.late_qat_threshold > 0 and scale < args.late_qat_threshold and not CastedLinear._qat_enabled: - CastedLinear._qat_enabled = True - log0(f"late_qat:enabled step:{step} scale:{scale:.4f}") - zero_grad_all() - train_loss = torch.zeros((), device=device) - for micro_step in range(grad_accum_steps): - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - loss = model(x, y) - train_loss += loss.detach() - (loss * grad_scale).backward() - train_loss /= grad_accum_steps - frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 - muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum - for group in optimizer_muon.param_groups: - group["momentum"] = muon_momentum - for opt in optimizers: - for group in opt.param_groups: - group["lr"] = group["base_lr"] * scale - if args.grad_clip_norm > 0: - torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) - # === 3-phase overlapped optimizer step === - # Phase 1: Launch async reduce-scatter for banks (biggest first) - optimizer_muon.launch_reduce_scatters() - # Phase 2: All-reduce non-bank grads + step Adam (while bank RS is in-flight) - if distributed: - for p in replicated_params: - if p.grad is not None: - dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) - optimizer_tok.step() - optimizer_scalar.step() - if optimizer_head is not None: - optimizer_head.step() - # Phase 3: Wait for RS, local NS5, all-gather (banks processed last) - optimizer_muon.step() - zero_grad_all() - # EMA update - with torch.no_grad(): - for name, t in base_model.state_dict().items(): - ema_state[name].mul_(ema_decay).add_(t.detach().float(), alpha=1.0 - ema_decay) - step += 1 - approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - if args.swa_enabled and scale < 0.2 and step % args.swa_every == 0: - if swa_state is None: - swa_state = {name: t.detach().cpu().clone() for name, t in base_model.state_dict().items()} - swa_count = 1 - log0(f"swa:start step:{step}") - else: - for name, t in base_model.state_dict().items(): - swa_state[name] += t.detach().cpu() - swa_count += 1 - if args.lawa_enabled and step % args.lawa_freq == 0: - lawa_queue.append({name: t.detach().cpu().clone() for name, t in base_model.state_dict().items()}) - should_log_train = ( - args.train_log_every > 0 - and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) - ) - if should_log_train: - log0( - f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " - f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" - ) - reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms - if distributed and max_wallclock_ms is not None: - reached_cap_tensor = torch.tensor(int(reached_cap), device=device) - dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) - reached_cap = bool(reached_cap_tensor.item()) - if stop_after_step is None and reached_cap: - stop_after_step = step - log0( - f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " - f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" - ) - # Apply weight averaging - if args.lawa_enabled and len(lawa_queue) > 1: - log0(f"lawa:applying LAWA averaging k={len(lawa_queue)}") - current_state = base_model.state_dict() - avg_state = {name: torch.zeros(t.shape, dtype=torch.float32, device='cpu') for name, t in current_state.items()} - for snap in lawa_queue: - for name in avg_state: - avg_state[name] += snap[name].float() - for name in avg_state: - avg_state[name] /= len(lawa_queue) - avg_state[name] = avg_state[name].to(dtype=current_state[name].dtype) - base_model.load_state_dict(avg_state, strict=True) - else: - log0("ema:applying EMA weights") - current_state = base_model.state_dict() - avg_state = {name: t.to(dtype=current_state[name].dtype) for name, t in ema_state.items()} - base_model.load_state_dict(avg_state, strict=True) - torch.cuda.synchronize() - t_diag = time.perf_counter() - diag_val_loss, diag_val_bpb = eval_val( - args, compiled_model, rank, world_size, device, grad_accum_steps, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - ) - torch.cuda.synchronize() - log0( - f"DIAGNOSTIC post_ema val_loss:{diag_val_loss:.4f} val_bpb:{diag_val_bpb:.4f} " - f"eval_time:{1000.0 * (time.perf_counter() - t_diag):.0f}ms" - ) - full_state_dict = base_model.state_dict() - export_sd = {k: v for k, v in full_state_dict.items() if "mtp_heads" not in k} - excluded_mtp = sum(int(t.numel()) for k, t in full_state_dict.items() if "mtp_heads" in k) - if excluded_mtp > 0: - log0(f"export_excluding_mtp_params:{excluded_mtp}") - if master_process: - torch.save(export_sd, "final_model.pt") - model_bytes = os.path.getsize("final_model.pt") - code_bytes = len(code.encode("utf-8")) - log0(f"Serialized model: {model_bytes} bytes") - log0(f"Code size: {code_bytes} bytes") - # Unbank 3D tensors into individual 2D tensors for quantization - sd_cpu = {k: v.detach().cpu() for k, v in export_sd.items()} - unbanked_sd = _unbank_state_dict(sd_cpu, args.num_layers) - quant_result, quant_meta = mixed_quantize_int6(unbanked_sd, {"mlp", "attn"}) - quant_buf = io.BytesIO() - torch.save({"w": quant_result, "m": quant_meta}, quant_buf) - quant_raw = quant_buf.getvalue() - quant_blob = lzma.compress(quant_raw, preset=6) - if master_process: - with open("final_model.int6.ptz", "wb") as f: - f.write(quant_blob) - quant_file_bytes = len(quant_blob) - code_bytes = len(code.encode("utf-8")) - log0(f"Serialized model int6+lzma: {quant_file_bytes} bytes") - log0(f"Total submission size int6+lzma: {quant_file_bytes + code_bytes} bytes") - if distributed: - dist.barrier() - with open("final_model.int6.ptz", "rb") as f: - quant_blob_disk = f.read() - quant_state = torch.load( - io.BytesIO(lzma.decompress(quant_blob_disk)), - map_location="cpu", - ) - deq_unbanked = dequantize_mixed_int6(quant_state["w"], quant_state["m"], unbanked_sd) - # Re-bank the dequantized tensors - deq_state = _rebank_state_dict(deq_unbanked, args.num_layers, sd_cpu) - eval_model = GPT( - vocab_size=args.vocab_size, num_layers=args.num_layers, model_dim=args.model_dim, - num_heads=args.num_heads, num_kv_heads=args.num_kv_heads, mlp_mult=args.mlp_mult, - tie_embeddings=args.tie_embeddings, tied_embed_init_std=args.tied_embed_init_std, - logit_softcap=args.logit_softcap, rope_base=args.rope_base, qk_gain_init=args.qk_gain_init, - mtp_num_heads=0, mtp_loss_weight=0.0, - bigram_vocab_size=args.bigram_vocab_size, bigram_dim=args.bigram_dim, - xsa_last_n=args.xsa_last_n, - rope_dims=args.rope_dims, ln_scale=args.ln_scale, dtg=args.dtg_enabled, - ve_enabled=args.ve_enabled, ve_dim=args.ve_dim, ve_layers=args.ve_layers, - gated_attention=args.gated_attention, value_residual=args.value_residual, - ).to(device).bfloat16() - eval_model.qo_bank.data = eval_model.qo_bank.data.float() - eval_model.kv_bank.data = eval_model.kv_bank.data.float() - eval_model.mlp_up_bank.data = eval_model.mlp_up_bank.data.float() - eval_model.mlp_down_bank.data = eval_model.mlp_down_bank.data.float() - for m in eval_model.modules(): - if isinstance(m, CastedLinear): - m.float() - restore_low_dim_params_to_fp32(eval_model) - eval_model.load_state_dict(deq_state, strict=True) - compiled_eval = torch.compile(eval_model, dynamic=False, fullgraph=True) - torch.cuda.synchronize() - t_qeval = time.perf_counter() - q_val_loss, q_val_bpb = eval_val( - args, compiled_eval, rank, world_size, device, grad_accum_steps, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - eval_seq_len=effective_eval_seq_len, - ) - torch.cuda.synchronize() - log0( - f"final_int6_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " - f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" - ) - log0(f"final_int6_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") - sw_seq_len = effective_eval_seq_len - if args.eval_stride > 0 and args.eval_stride < sw_seq_len: - torch.cuda.synchronize() - t_slide = time.perf_counter() - sw_val_loss, sw_val_bpb = eval_val_sliding( - args, eval_model, rank, world_size, device, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - stride=args.eval_stride, - eval_seq_len=sw_seq_len, - ) - torch.cuda.synchronize() - log0( - f"final_int6_sliding_window val_loss:{sw_val_loss:.4f} val_bpb:{sw_val_bpb:.4f} " - f"stride:{args.eval_stride} eval_time:{1000.0 * (time.perf_counter() - t_slide):.0f}ms" - ) - log0(f"final_int6_sliding_window_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") - log0(f"final_int8_zlib_roundtrip_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") - if args.eval_stride != 64 and 64 < sw_seq_len: - torch.cuda.synchronize() - t_slide64 = time.perf_counter() - sw64_val_loss, sw64_val_bpb = eval_val_sliding( - args, eval_model, rank, world_size, device, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - stride=64, - eval_seq_len=sw_seq_len, - ) - torch.cuda.synchronize() - log0( - f"final_int6_sliding_window_s64 val_loss:{sw64_val_loss:.4f} val_bpb:{sw64_val_bpb:.4f} " - f"stride:64 eval_time:{1000.0 * (time.perf_counter() - t_slide64):.0f}ms" - ) - log0(f"final_int6_sliding_window_s64_exact val_loss:{sw64_val_loss:.8f} val_bpb:{sw64_val_bpb:.8f}") - log0(f"final_int8_zlib_roundtrip_exact val_loss:{sw64_val_loss:.8f} val_bpb:{sw64_val_bpb:.8f}") - # Legal score-first TTT (PR #461 recipe) - if args.ttt_enabled: - torch.cuda.synchronize() - t_ttt = time.perf_counter() - ttt_loss, ttt_bpb = eval_val_sliding_ttt( - args, eval_model, rank, world_size, device, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - stride=args.eval_stride, log0=log0, - ) - torch.cuda.synchronize() - log0(f"legal_ttt val_loss:{ttt_loss:.4f} val_bpb:{ttt_bpb:.4f} " - f"eval_time:{1000.0 * (time.perf_counter() - t_ttt):.0f}ms") - log0(f"legal_ttt_exact val_loss:{ttt_loss:.8f} val_bpb:{ttt_bpb:.8f}") - if distributed: - dist.destroy_process_group() -if __name__ == "__main__": - main() diff --git a/records/track_10min_16mb/2026-03-23_LeakyReLU_LegalTTT_ParallelMuon/train_seed1337.log b/records/track_10min_16mb/2026-03-23_LeakyReLU_LegalTTT_ParallelMuon/train_seed1337.log deleted file mode 100644 index b4ded71789..0000000000 --- a/records/track_10min_16mb/2026-03-23_LeakyReLU_LegalTTT_ParallelMuon/train_seed1337.log +++ /dev/null @@ -1,275 +0,0 @@ -W0323 16:05:58.476000 139742 torch/distributed/run.py:803] -W0323 16:05:58.476000 139742 torch/distributed/run.py:803] ***************************************** -W0323 16:05:58.476000 139742 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. -W0323 16:05:58.476000 139742 torch/distributed/run.py:803] ***************************************** -logs/af2d01c8-c188-494a-b1b0-f8a6adce5b39.txt -val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=/root/parameter-golf/data/tokenizers/fineweb_1024_bpe.model -train_loader:dataset:fineweb10B_sp1024 train_shards:80 -val_loader:shards pattern=/root/parameter-golf/data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 -model_params:26928220 -mtp_num_heads:0 mtp_loss_weight:0.2 mtp_params:0 -XSA:last_4 active_layers:[7, 8, 9, 10] -world_size:8 grad_accum_steps:1 -sdp_backends:cudnn=False flash=True mem_efficient=False math=False -attention_mode:gqa num_heads:8 num_kv_heads:4 -tie_embeddings:True embed_lr:0.035 head_lr:0.0 matrix_lr:0.025 scalar_lr:0.025 -train_batch_tokens:786432 train_seq_len:2048 iterations:9000 warmup_steps:20 max_wallclock_seconds:600.000 -seed:1337 -warmup_step:1/20 -warmup_step:2/20 -warmup_step:3/20 -warmup_step:4/20 -warmup_step:5/20 -warmup_step:6/20 -warmup_step:7/20 -warmup_step:8/20 -warmup_step:9/20 -warmup_step:10/20 -warmup_step:11/20 -warmup_step:12/20 -warmup_step:13/20 -warmup_step:14/20 -warmup_step:15/20 -warmup_step:16/20 -warmup_step:17/20 -warmup_step:18/20 -warmup_step:19/20 -warmup_step:20/20 -step:0/9000 val_loss:6.9304 val_bpb:4.1046 train_time:0ms step_avg:0.01ms -step:1/9000 train_loss:6.9322 train_time:133ms step_avg:132.54ms -step:2/9000 train_loss:8.6545 train_time:166ms step_avg:82.93ms -step:3/9000 train_loss:7.6927 train_time:246ms step_avg:81.97ms -step:4/9000 train_loss:7.2518 train_time:327ms step_avg:81.84ms -step:5/9000 train_loss:7.1707 train_time:409ms step_avg:81.81ms -step:6/9000 train_loss:7.1159 train_time:489ms step_avg:81.57ms -step:7/9000 train_loss:7.0268 train_time:571ms step_avg:81.51ms -step:8/9000 train_loss:6.9599 train_time:652ms step_avg:81.47ms -step:9/9000 train_loss:6.5750 train_time:733ms step_avg:81.47ms -step:10/9000 train_loss:6.2000 train_time:816ms step_avg:81.61ms -step:500/9000 train_loss:2.3992 train_time:41450ms step_avg:82.90ms -step:1000/9000 train_loss:2.2677 train_time:83020ms step_avg:83.02ms -step:1500/9000 train_loss:2.2127 train_time:124687ms step_avg:83.12ms -step:2000/9000 train_loss:2.0525 train_time:166425ms step_avg:83.21ms -step:2500/9000 train_loss:2.1634 train_time:208180ms step_avg:83.27ms -step:3000/9000 train_loss:2.1507 train_time:249938ms step_avg:83.31ms -step:3500/9000 train_loss:2.1664 train_time:291695ms step_avg:83.34ms -step:4000/9000 train_loss:1.9643 train_time:333469ms step_avg:83.37ms -step:4000/9000 val_loss:2.0577 val_bpb:1.2187 train_time:333519ms step_avg:83.38ms -step:4500/9000 train_loss:2.1175 train_time:375247ms step_avg:83.39ms -step:5000/9000 train_loss:2.1003 train_time:417014ms step_avg:83.40ms -step:5500/9000 train_loss:2.0130 train_time:458773ms step_avg:83.41ms -step:6000/9000 train_loss:1.9391 train_time:500541ms step_avg:83.42ms -swa:start step:6500 -step:6500/9000 train_loss:2.0804 train_time:542303ms step_avg:83.43ms -late_qat:enabled step:6662 scale:0.1498 -step:7000/9000 train_loss:1.7889 train_time:584853ms step_avg:83.55ms -step:7179/9000 val_loss:1.9214 val_bpb:1.1379 train_time:600128ms step_avg:83.59ms -stopping_early: wallclock_cap train_time:600128ms step:7179/9000 -peak memory allocated: 21471 MiB reserved: 22002 MiB -ema:applying EMA weights -DIAGNOSTIC post_ema val_loss:1.9196 val_bpb:1.1369 eval_time:1971ms -Serialized model: 106027446 bytes -Code size: 89458 bytes -Serialized model int6+lzma: 15887928 bytes -Total submission size int6+lzma: 15977386 bytes -final_int6_roundtrip val_loss:1.9336 val_bpb:1.1452 eval_time:6562ms -final_int6_roundtrip_exact val_loss:1.93358885 val_bpb:1.14518023 -final_int6_sliding_window val_loss:1.8939 val_bpb:1.1217 stride:64 eval_time:74066ms -final_int6_sliding_window_exact val_loss:1.89385556 val_bpb:1.12165091 -final_int8_zlib_roundtrip_exact val_loss:1.89385556 val_bpb:1.12165091 -ttt_sliding:start chunks=1893 chunk_tokens=32768 total_windows=969088 stride=64 ttt_lr=0.002 ttt_epochs=3 freeze_blocks=0 -ttt_sliding:params unfrozen=26928220 frozen=0 - ttt_chunk [1/1893] bpb=1.159980 time=0.5s - ttt_chunk [11/1893] bpb=1.146741 time=2.7s - ttt_chunk [21/1893] bpb=1.131759 time=4.8s - ttt_chunk [31/1893] bpb=1.129251 time=7.0s - ttt_chunk [41/1893] bpb=1.116035 time=9.1s - ttt_chunk [51/1893] bpb=1.110636 time=11.3s - ttt_chunk [61/1893] bpb=1.117614 time=13.4s - ttt_chunk [71/1893] bpb=1.116068 time=15.6s - ttt_chunk [81/1893] bpb=1.115280 time=17.8s - ttt_chunk [91/1893] bpb=1.115913 time=19.9s - ttt_chunk [101/1893] bpb=1.119640 time=22.1s - ttt_chunk [111/1893] bpb=1.121926 time=24.2s - ttt_chunk [121/1893] bpb=1.115465 time=26.4s - ttt_chunk [131/1893] bpb=1.115687 time=28.5s - ttt_chunk [141/1893] bpb=1.121404 time=30.7s - ttt_chunk [151/1893] bpb=1.123210 time=32.9s - ttt_chunk [161/1893] bpb=1.122919 time=35.0s - ttt_chunk [171/1893] bpb=1.127371 time=37.2s - ttt_chunk [181/1893] bpb=1.129716 time=39.3s - ttt_chunk [191/1893] bpb=1.137063 time=41.5s - ttt_chunk [201/1893] bpb=1.135822 time=43.7s - ttt_chunk [211/1893] bpb=1.133605 time=45.8s - ttt_chunk [221/1893] bpb=1.135028 time=48.0s - ttt_chunk [231/1893] bpb=1.133602 time=50.1s - ttt_chunk [241/1893] bpb=1.133887 time=52.3s - ttt_chunk [251/1893] bpb=1.133402 time=54.5s - ttt_chunk [261/1893] bpb=1.130566 time=56.6s - ttt_chunk [271/1893] bpb=1.129398 time=58.8s - ttt_chunk [281/1893] bpb=1.130685 time=61.0s - ttt_chunk [291/1893] bpb=1.132438 time=63.1s - ttt_chunk [301/1893] bpb=1.133145 time=65.3s - ttt_chunk [311/1893] bpb=1.135237 time=67.4s - ttt_chunk [321/1893] bpb=1.137158 time=69.6s - ttt_chunk [331/1893] bpb=1.136909 time=71.8s - ttt_chunk [341/1893] bpb=1.135973 time=73.9s - ttt_chunk [351/1893] bpb=1.138226 time=76.1s - ttt_chunk [361/1893] bpb=1.138372 time=78.3s - ttt_chunk [371/1893] bpb=1.137648 time=80.4s - ttt_chunk [381/1893] bpb=1.137809 time=82.6s - ttt_chunk [391/1893] bpb=1.137588 time=84.7s - ttt_chunk [401/1893] bpb=1.135457 time=87.1s - ttt_chunk [411/1893] bpb=1.134245 time=89.3s - ttt_chunk [421/1893] bpb=1.133342 time=91.4s - ttt_chunk [431/1893] bpb=1.133174 time=93.6s - ttt_chunk [441/1893] bpb=1.133560 time=95.8s - ttt_chunk [451/1893] bpb=1.133816 time=98.0s - ttt_chunk [461/1893] bpb=1.132666 time=100.1s - ttt_chunk [471/1893] bpb=1.133250 time=102.3s - ttt_chunk [481/1893] bpb=1.132911 time=104.5s - ttt_chunk [491/1893] bpb=1.131846 time=106.6s - ttt_chunk [501/1893] bpb=1.131375 time=108.8s - ttt_chunk [511/1893] bpb=1.130701 time=110.9s - ttt_chunk [521/1893] bpb=1.128297 time=113.1s - ttt_chunk [531/1893] bpb=1.129482 time=115.2s - ttt_chunk [541/1893] bpb=1.129803 time=117.4s - ttt_chunk [551/1893] bpb=1.128773 time=119.6s - ttt_chunk [561/1893] bpb=1.129299 time=121.7s - ttt_chunk [571/1893] bpb=1.128266 time=123.9s - ttt_chunk [581/1893] bpb=1.127432 time=126.0s - ttt_chunk [591/1893] bpb=1.126796 time=128.3s - ttt_chunk [601/1893] bpb=1.127264 time=130.4s - ttt_chunk [611/1893] bpb=1.127183 time=132.6s - ttt_chunk [621/1893] bpb=1.127047 time=134.7s - ttt_chunk [631/1893] bpb=1.127715 time=136.9s - ttt_chunk [641/1893] bpb=1.127455 time=139.1s - ttt_chunk [651/1893] bpb=1.127603 time=141.2s - ttt_chunk [661/1893] bpb=1.127066 time=143.4s - ttt_chunk [671/1893] bpb=1.127382 time=145.5s - ttt_chunk [681/1893] bpb=1.128066 time=147.7s - ttt_chunk [691/1893] bpb=1.129077 time=149.9s - ttt_chunk [701/1893] bpb=1.128554 time=152.0s - ttt_chunk [711/1893] bpb=1.128542 time=154.2s - ttt_chunk [721/1893] bpb=1.128212 time=156.3s - ttt_chunk [731/1893] bpb=1.128268 time=158.5s - ttt_chunk [741/1893] bpb=1.128398 time=160.7s - ttt_chunk [751/1893] bpb=1.128249 time=162.8s - ttt_chunk [761/1893] bpb=1.128211 time=165.0s - ttt_chunk [771/1893] bpb=1.127869 time=167.1s - ttt_chunk [781/1893] bpb=1.128624 time=169.3s - ttt_chunk [791/1893] bpb=1.128218 time=171.5s - ttt_chunk [801/1893] bpb=1.128539 time=173.6s - ttt_chunk [811/1893] bpb=1.128319 time=175.8s - ttt_chunk [821/1893] bpb=1.128102 time=178.0s - ttt_chunk [831/1893] bpb=1.127943 time=180.1s - ttt_chunk [841/1893] bpb=1.127322 time=182.3s - ttt_chunk [851/1893] bpb=1.127076 time=184.5s - ttt_chunk [861/1893] bpb=1.126825 time=186.6s - ttt_chunk [871/1893] bpb=1.127099 time=188.8s - ttt_chunk [881/1893] bpb=1.127289 time=190.9s - ttt_chunk [891/1893] bpb=1.126847 time=193.1s - ttt_chunk [901/1893] bpb=1.126579 time=195.2s - ttt_chunk [911/1893] bpb=1.126697 time=197.4s - ttt_chunk [921/1893] bpb=1.127204 time=199.6s - ttt_chunk [931/1893] bpb=1.127192 time=201.7s - ttt_chunk [941/1893] bpb=1.126846 time=203.9s - ttt_chunk [951/1893] bpb=1.127227 time=206.0s - ttt_chunk [961/1893] bpb=1.127318 time=208.2s - ttt_chunk [971/1893] bpb=1.128176 time=210.4s - ttt_chunk [981/1893] bpb=1.128236 time=212.6s - ttt_chunk [991/1893] bpb=1.128215 time=214.7s - ttt_chunk [1001/1893] bpb=1.128164 time=216.9s - ttt_chunk [1011/1893] bpb=1.127968 time=219.0s - ttt_chunk [1021/1893] bpb=1.128320 time=221.2s - ttt_chunk [1031/1893] bpb=1.128764 time=223.3s - ttt_chunk [1041/1893] bpb=1.128402 time=225.5s - ttt_chunk [1051/1893] bpb=1.128156 time=227.7s - ttt_chunk [1061/1893] bpb=1.128204 time=229.8s - ttt_chunk [1071/1893] bpb=1.128853 time=232.0s - ttt_chunk [1081/1893] bpb=1.129143 time=234.2s - ttt_chunk [1091/1893] bpb=1.129876 time=236.3s - ttt_chunk [1101/1893] bpb=1.129889 time=238.5s - ttt_chunk [1111/1893] bpb=1.129732 time=240.6s - ttt_chunk [1121/1893] bpb=1.129538 time=242.8s - ttt_chunk [1131/1893] bpb=1.129404 time=244.9s - ttt_chunk [1141/1893] bpb=1.129105 time=247.1s - ttt_chunk [1151/1893] bpb=1.129106 time=249.3s - ttt_chunk [1161/1893] bpb=1.128706 time=251.5s - ttt_chunk [1171/1893] bpb=1.129036 time=253.7s - ttt_chunk [1181/1893] bpb=1.128284 time=255.9s - ttt_chunk [1191/1893] bpb=1.128158 time=258.1s - ttt_chunk [1201/1893] bpb=1.128567 time=260.2s - ttt_chunk [1211/1893] bpb=1.128096 time=262.4s - ttt_chunk [1221/1893] bpb=1.127782 time=264.5s - ttt_chunk [1231/1893] bpb=1.127494 time=266.7s - ttt_chunk [1241/1893] bpb=1.127163 time=268.9s - ttt_chunk [1251/1893] bpb=1.126589 time=271.0s - ttt_chunk [1261/1893] bpb=1.126572 time=273.2s - ttt_chunk [1271/1893] bpb=1.126191 time=275.3s - ttt_chunk [1281/1893] bpb=1.126002 time=277.5s - ttt_chunk [1291/1893] bpb=1.125767 time=279.6s - ttt_chunk [1301/1893] bpb=1.125183 time=281.8s - ttt_chunk [1311/1893] bpb=1.124788 time=284.0s - ttt_chunk [1321/1893] bpb=1.124455 time=286.1s - ttt_chunk [1331/1893] bpb=1.124379 time=288.3s - ttt_chunk [1341/1893] bpb=1.124254 time=290.5s - ttt_chunk [1351/1893] bpb=1.124182 time=292.6s - ttt_chunk [1361/1893] bpb=1.124250 time=294.8s - ttt_chunk [1371/1893] bpb=1.124118 time=297.0s - ttt_chunk [1381/1893] bpb=1.124101 time=299.1s - ttt_chunk [1391/1893] bpb=1.123708 time=301.3s - ttt_chunk [1401/1893] bpb=1.123674 time=303.4s - ttt_chunk [1411/1893] bpb=1.123790 time=305.6s - ttt_chunk [1421/1893] bpb=1.124052 time=307.8s - ttt_chunk [1431/1893] bpb=1.123759 time=309.9s - ttt_chunk [1441/1893] bpb=1.124264 time=312.1s - ttt_chunk [1451/1893] bpb=1.124605 time=314.3s - ttt_chunk [1461/1893] bpb=1.124147 time=316.4s - ttt_chunk [1471/1893] bpb=1.125183 time=318.6s - ttt_chunk [1481/1893] bpb=1.124739 time=320.7s - ttt_chunk [1491/1893] bpb=1.124560 time=322.9s - ttt_chunk [1501/1893] bpb=1.124466 time=325.0s - ttt_chunk [1511/1893] bpb=1.124486 time=327.2s - ttt_chunk [1521/1893] bpb=1.124509 time=329.4s - ttt_chunk [1531/1893] bpb=1.123992 time=331.6s - ttt_chunk [1541/1893] bpb=1.123849 time=333.7s - ttt_chunk [1551/1893] bpb=1.124168 time=335.9s - ttt_chunk [1561/1893] bpb=1.124173 time=338.0s - ttt_chunk [1571/1893] bpb=1.124012 time=340.2s - ttt_chunk [1581/1893] bpb=1.124107 time=342.4s - ttt_chunk [1591/1893] bpb=1.123963 time=344.5s - ttt_chunk [1601/1893] bpb=1.124145 time=346.8s - ttt_chunk [1611/1893] bpb=1.124075 time=349.0s - ttt_chunk [1621/1893] bpb=1.123673 time=351.1s - ttt_chunk [1631/1893] bpb=1.123969 time=353.3s - ttt_chunk [1641/1893] bpb=1.123971 time=355.5s - ttt_chunk [1651/1893] bpb=1.123931 time=357.6s - ttt_chunk [1661/1893] bpb=1.123807 time=359.8s - ttt_chunk [1671/1893] bpb=1.124271 time=361.9s - ttt_chunk [1681/1893] bpb=1.124429 time=364.1s - ttt_chunk [1691/1893] bpb=1.124252 time=366.3s - ttt_chunk [1701/1893] bpb=1.124410 time=368.4s - ttt_chunk [1711/1893] bpb=1.124423 time=370.6s - ttt_chunk [1721/1893] bpb=1.124424 time=372.7s - ttt_chunk [1731/1893] bpb=1.124296 time=374.9s - ttt_chunk [1741/1893] bpb=1.124094 time=377.1s - ttt_chunk [1751/1893] bpb=1.123923 time=379.2s - ttt_chunk [1761/1893] bpb=1.124061 time=381.4s - ttt_chunk [1771/1893] bpb=1.123969 time=383.5s - ttt_chunk [1781/1893] bpb=1.123989 time=385.7s - ttt_chunk [1791/1893] bpb=1.123582 time=387.9s - ttt_chunk [1801/1893] bpb=1.123452 time=390.0s - ttt_chunk [1811/1893] bpb=1.123351 time=392.2s - ttt_chunk [1821/1893] bpb=1.123415 time=394.3s - ttt_chunk [1831/1893] bpb=1.122820 time=396.5s - ttt_chunk [1841/1893] bpb=1.122755 time=398.7s - ttt_chunk [1851/1893] bpb=1.122549 time=400.8s - ttt_chunk [1861/1893] bpb=1.122183 time=403.0s - ttt_chunk [1871/1893] bpb=1.122167 time=405.1s - ttt_chunk [1881/1893] bpb=1.121710 time=407.3s - ttt_chunk [1891/1893] bpb=1.121468 time=409.5s - ttt_chunk [1893/1893] bpb=1.121512 time=409.7s -ttt_sliding:done val_loss=1.889768 val_bpb=1.119230 elapsed=409.8s -legal_ttt val_loss:1.8898 val_bpb:1.1192 eval_time:410268ms -legal_ttt_exact val_loss:1.88976776 val_bpb:1.11922988 diff --git a/records/track_10min_16mb/2026-03-23_LeakyReLU_LegalTTT_ParallelMuon/train_seed2025.log b/records/track_10min_16mb/2026-03-23_LeakyReLU_LegalTTT_ParallelMuon/train_seed2025.log deleted file mode 100644 index 9fd5f6fa56..0000000000 --- a/records/track_10min_16mb/2026-03-23_LeakyReLU_LegalTTT_ParallelMuon/train_seed2025.log +++ /dev/null @@ -1,275 +0,0 @@ -W0323 15:44:02.717000 138514 torch/distributed/run.py:803] -W0323 15:44:02.717000 138514 torch/distributed/run.py:803] ***************************************** -W0323 15:44:02.717000 138514 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. -W0323 15:44:02.717000 138514 torch/distributed/run.py:803] ***************************************** -logs/7bdd34d1-526e-4f6c-ba9e-7b92a0c4736b.txt -val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=/root/parameter-golf/data/tokenizers/fineweb_1024_bpe.model -train_loader:dataset:fineweb10B_sp1024 train_shards:80 -val_loader:shards pattern=/root/parameter-golf/data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 -model_params:26928220 -mtp_num_heads:0 mtp_loss_weight:0.2 mtp_params:0 -XSA:last_4 active_layers:[7, 8, 9, 10] -world_size:8 grad_accum_steps:1 -sdp_backends:cudnn=False flash=True mem_efficient=False math=False -attention_mode:gqa num_heads:8 num_kv_heads:4 -tie_embeddings:True embed_lr:0.035 head_lr:0.0 matrix_lr:0.025 scalar_lr:0.025 -train_batch_tokens:786432 train_seq_len:2048 iterations:9000 warmup_steps:20 max_wallclock_seconds:600.000 -seed:2025 -warmup_step:1/20 -warmup_step:2/20 -warmup_step:3/20 -warmup_step:4/20 -warmup_step:5/20 -warmup_step:6/20 -warmup_step:7/20 -warmup_step:8/20 -warmup_step:9/20 -warmup_step:10/20 -warmup_step:11/20 -warmup_step:12/20 -warmup_step:13/20 -warmup_step:14/20 -warmup_step:15/20 -warmup_step:16/20 -warmup_step:17/20 -warmup_step:18/20 -warmup_step:19/20 -warmup_step:20/20 -step:0/9000 val_loss:6.9302 val_bpb:4.1045 train_time:0ms step_avg:0.01ms -step:1/9000 train_loss:6.9311 train_time:131ms step_avg:130.73ms -step:2/9000 train_loss:8.6819 train_time:160ms step_avg:80.02ms -step:3/9000 train_loss:7.7058 train_time:241ms step_avg:80.20ms -step:4/9000 train_loss:7.2718 train_time:321ms step_avg:80.31ms -step:5/9000 train_loss:7.1777 train_time:402ms step_avg:80.43ms -step:6/9000 train_loss:7.0947 train_time:483ms step_avg:80.53ms -step:7/9000 train_loss:7.0222 train_time:564ms step_avg:80.58ms -step:8/9000 train_loss:6.9409 train_time:645ms step_avg:80.67ms -step:9/9000 train_loss:6.6069 train_time:726ms step_avg:80.64ms -step:10/9000 train_loss:6.2031 train_time:807ms step_avg:80.69ms -step:500/9000 train_loss:2.4020 train_time:41378ms step_avg:82.76ms -step:1000/9000 train_loss:2.2623 train_time:82897ms step_avg:82.90ms -step:1500/9000 train_loss:2.2112 train_time:124470ms step_avg:82.98ms -step:2000/9000 train_loss:2.0540 train_time:166113ms step_avg:83.06ms -step:2500/9000 train_loss:2.1597 train_time:207787ms step_avg:83.11ms -step:3000/9000 train_loss:2.1511 train_time:249483ms step_avg:83.16ms -step:3500/9000 train_loss:2.1673 train_time:291182ms step_avg:83.19ms -step:4000/9000 train_loss:1.9706 train_time:332892ms step_avg:83.22ms -step:4000/9000 val_loss:2.0569 val_bpb:1.2182 train_time:332947ms step_avg:83.24ms -step:4500/9000 train_loss:2.1149 train_time:374614ms step_avg:83.25ms -step:5000/9000 train_loss:2.0982 train_time:416325ms step_avg:83.26ms -step:5500/9000 train_loss:2.0155 train_time:458013ms step_avg:83.28ms -step:6000/9000 train_loss:1.9379 train_time:499707ms step_avg:83.28ms -step:6500/9000 train_loss:2.0789 train_time:541386ms step_avg:83.29ms -swa:start step:6550 -late_qat:enabled step:6675 scale:0.1499 -step:7000/9000 train_loss:1.7875 train_time:583724ms step_avg:83.39ms -step:7193/9000 val_loss:1.9204 val_bpb:1.1374 train_time:600110ms step_avg:83.43ms -stopping_early: wallclock_cap train_time:600110ms step:7193/9000 -peak memory allocated: 21471 MiB reserved: 22002 MiB -ema:applying EMA weights -DIAGNOSTIC post_ema val_loss:1.9187 val_bpb:1.1363 eval_time:1978ms -Serialized model: 106027446 bytes -Code size: 89458 bytes -Serialized model int6+lzma: 15900548 bytes -Total submission size int6+lzma: 15990006 bytes -final_int6_roundtrip val_loss:1.9327 val_bpb:1.1447 eval_time:6601ms -final_int6_roundtrip_exact val_loss:1.93273877 val_bpb:1.14467677 -final_int6_sliding_window val_loss:1.8931 val_bpb:1.1212 stride:64 eval_time:73676ms -final_int6_sliding_window_exact val_loss:1.89307789 val_bpb:1.12119033 -final_int8_zlib_roundtrip_exact val_loss:1.89307789 val_bpb:1.12119033 -ttt_sliding:start chunks=1893 chunk_tokens=32768 total_windows=969088 stride=64 ttt_lr=0.002 ttt_epochs=3 freeze_blocks=0 -ttt_sliding:params unfrozen=26928220 frozen=0 - ttt_chunk [1/1893] bpb=1.152282 time=0.5s - ttt_chunk [11/1893] bpb=1.146772 time=2.7s - ttt_chunk [21/1893] bpb=1.131382 time=4.8s - ttt_chunk [31/1893] bpb=1.129534 time=7.0s - ttt_chunk [41/1893] bpb=1.116460 time=9.1s - ttt_chunk [51/1893] bpb=1.110423 time=11.3s - ttt_chunk [61/1893] bpb=1.117251 time=13.4s - ttt_chunk [71/1893] bpb=1.115874 time=15.6s - ttt_chunk [81/1893] bpb=1.115346 time=17.8s - ttt_chunk [91/1893] bpb=1.116215 time=19.9s - ttt_chunk [101/1893] bpb=1.119896 time=22.1s - ttt_chunk [111/1893] bpb=1.122270 time=24.3s - ttt_chunk [121/1893] bpb=1.115611 time=26.4s - ttt_chunk [131/1893] bpb=1.115856 time=28.6s - ttt_chunk [141/1893] bpb=1.121426 time=30.8s - ttt_chunk [151/1893] bpb=1.123119 time=32.9s - ttt_chunk [161/1893] bpb=1.122850 time=35.0s - ttt_chunk [171/1893] bpb=1.127140 time=37.2s - ttt_chunk [181/1893] bpb=1.129331 time=39.4s - ttt_chunk [191/1893] bpb=1.136699 time=41.5s - ttt_chunk [201/1893] bpb=1.135510 time=43.7s - ttt_chunk [211/1893] bpb=1.133369 time=45.9s - ttt_chunk [221/1893] bpb=1.134856 time=48.0s - ttt_chunk [231/1893] bpb=1.133523 time=50.2s - ttt_chunk [241/1893] bpb=1.133875 time=52.3s - ttt_chunk [251/1893] bpb=1.133357 time=54.5s - ttt_chunk [261/1893] bpb=1.130421 time=56.6s - ttt_chunk [271/1893] bpb=1.129357 time=58.8s - ttt_chunk [281/1893] bpb=1.130584 time=61.0s - ttt_chunk [291/1893] bpb=1.132247 time=63.1s - ttt_chunk [301/1893] bpb=1.132970 time=65.2s - ttt_chunk [311/1893] bpb=1.135004 time=67.4s - ttt_chunk [321/1893] bpb=1.136960 time=69.6s - ttt_chunk [331/1893] bpb=1.136742 time=71.7s - ttt_chunk [341/1893] bpb=1.135738 time=74.0s - ttt_chunk [351/1893] bpb=1.137965 time=76.1s - ttt_chunk [361/1893] bpb=1.138106 time=78.2s - ttt_chunk [371/1893] bpb=1.137349 time=80.4s - ttt_chunk [381/1893] bpb=1.137521 time=82.5s - ttt_chunk [391/1893] bpb=1.137343 time=84.6s - ttt_chunk [401/1893] bpb=1.135290 time=86.8s - ttt_chunk [411/1893] bpb=1.134107 time=89.0s - ttt_chunk [421/1893] bpb=1.133185 time=91.1s - ttt_chunk [431/1893] bpb=1.133024 time=93.3s - ttt_chunk [441/1893] bpb=1.133443 time=95.5s - ttt_chunk [451/1893] bpb=1.133772 time=97.6s - ttt_chunk [461/1893] bpb=1.132625 time=99.7s - ttt_chunk [471/1893] bpb=1.133228 time=101.9s - ttt_chunk [481/1893] bpb=1.132845 time=104.0s - ttt_chunk [491/1893] bpb=1.131737 time=106.2s - ttt_chunk [501/1893] bpb=1.131240 time=108.3s - ttt_chunk [511/1893] bpb=1.130566 time=110.4s - ttt_chunk [521/1893] bpb=1.128170 time=112.6s - ttt_chunk [531/1893] bpb=1.129307 time=114.7s - ttt_chunk [541/1893] bpb=1.129636 time=116.9s - ttt_chunk [551/1893] bpb=1.128604 time=119.0s - ttt_chunk [561/1893] bpb=1.129129 time=121.2s - ttt_chunk [571/1893] bpb=1.128048 time=123.4s - ttt_chunk [581/1893] bpb=1.127259 time=125.5s - ttt_chunk [591/1893] bpb=1.126614 time=127.7s - ttt_chunk [601/1893] bpb=1.127054 time=129.8s - ttt_chunk [611/1893] bpb=1.126950 time=131.9s - ttt_chunk [621/1893] bpb=1.126828 time=134.1s - ttt_chunk [631/1893] bpb=1.127544 time=136.2s - ttt_chunk [641/1893] bpb=1.127276 time=138.4s - ttt_chunk [651/1893] bpb=1.127391 time=140.5s - ttt_chunk [661/1893] bpb=1.126855 time=142.6s - ttt_chunk [671/1893] bpb=1.127210 time=144.8s - ttt_chunk [681/1893] bpb=1.127887 time=147.0s - ttt_chunk [691/1893] bpb=1.128838 time=149.1s - ttt_chunk [701/1893] bpb=1.128250 time=151.3s - ttt_chunk [711/1893] bpb=1.128254 time=153.4s - ttt_chunk [721/1893] bpb=1.127885 time=155.5s - ttt_chunk [731/1893] bpb=1.127934 time=157.7s - ttt_chunk [741/1893] bpb=1.128045 time=159.8s - ttt_chunk [751/1893] bpb=1.127901 time=161.9s - ttt_chunk [761/1893] bpb=1.127837 time=164.1s - ttt_chunk [771/1893] bpb=1.127515 time=166.2s - ttt_chunk [781/1893] bpb=1.128260 time=168.4s - ttt_chunk [791/1893] bpb=1.127885 time=170.5s - ttt_chunk [801/1893] bpb=1.128179 time=172.7s - ttt_chunk [811/1893] bpb=1.127949 time=174.8s - ttt_chunk [821/1893] bpb=1.127719 time=176.9s - ttt_chunk [831/1893] bpb=1.127589 time=179.1s - ttt_chunk [841/1893] bpb=1.126934 time=181.2s - ttt_chunk [851/1893] bpb=1.126655 time=183.4s - ttt_chunk [861/1893] bpb=1.126408 time=185.5s - ttt_chunk [871/1893] bpb=1.126708 time=187.7s - ttt_chunk [881/1893] bpb=1.126886 time=189.8s - ttt_chunk [891/1893] bpb=1.126462 time=192.0s - ttt_chunk [901/1893] bpb=1.126196 time=194.2s - ttt_chunk [911/1893] bpb=1.126327 time=196.3s - ttt_chunk [921/1893] bpb=1.126828 time=198.5s - ttt_chunk [931/1893] bpb=1.126803 time=200.6s - ttt_chunk [941/1893] bpb=1.126489 time=202.8s - ttt_chunk [951/1893] bpb=1.126884 time=205.0s - ttt_chunk [961/1893] bpb=1.126953 time=207.1s - ttt_chunk [971/1893] bpb=1.127810 time=209.3s - ttt_chunk [981/1893] bpb=1.127883 time=211.4s - ttt_chunk [991/1893] bpb=1.127868 time=213.6s - ttt_chunk [1001/1893] bpb=1.127839 time=215.7s - ttt_chunk [1011/1893] bpb=1.127632 time=217.9s - ttt_chunk [1021/1893] bpb=1.127972 time=220.0s - ttt_chunk [1031/1893] bpb=1.128407 time=222.2s - ttt_chunk [1041/1893] bpb=1.128058 time=224.3s - ttt_chunk [1051/1893] bpb=1.127811 time=226.5s - ttt_chunk [1061/1893] bpb=1.127859 time=228.8s - ttt_chunk [1071/1893] bpb=1.128487 time=231.2s - ttt_chunk [1081/1893] bpb=1.128762 time=233.4s - ttt_chunk [1091/1893] bpb=1.129511 time=235.6s - ttt_chunk [1101/1893] bpb=1.129513 time=237.7s - ttt_chunk [1111/1893] bpb=1.129367 time=239.9s - ttt_chunk [1121/1893] bpb=1.129154 time=242.0s - ttt_chunk [1131/1893] bpb=1.129018 time=244.1s - ttt_chunk [1141/1893] bpb=1.128719 time=246.3s - ttt_chunk [1151/1893] bpb=1.128722 time=248.4s - ttt_chunk [1161/1893] bpb=1.128339 time=250.6s - ttt_chunk [1171/1893] bpb=1.128671 time=252.7s - ttt_chunk [1181/1893] bpb=1.127917 time=254.9s - ttt_chunk [1191/1893] bpb=1.127781 time=257.0s - ttt_chunk [1201/1893] bpb=1.128201 time=259.2s - ttt_chunk [1211/1893] bpb=1.127720 time=261.3s - ttt_chunk [1221/1893] bpb=1.127433 time=263.4s - ttt_chunk [1231/1893] bpb=1.127158 time=265.6s - ttt_chunk [1241/1893] bpb=1.126821 time=267.7s - ttt_chunk [1251/1893] bpb=1.126240 time=269.9s - ttt_chunk [1261/1893] bpb=1.126196 time=272.0s - ttt_chunk [1271/1893] bpb=1.125822 time=274.2s - ttt_chunk [1281/1893] bpb=1.125625 time=276.3s - ttt_chunk [1291/1893] bpb=1.125380 time=278.5s - ttt_chunk [1301/1893] bpb=1.124783 time=280.6s - ttt_chunk [1311/1893] bpb=1.124396 time=282.8s - ttt_chunk [1321/1893] bpb=1.124057 time=284.9s - ttt_chunk [1331/1893] bpb=1.123985 time=287.1s - ttt_chunk [1341/1893] bpb=1.123840 time=289.2s - ttt_chunk [1351/1893] bpb=1.123753 time=291.4s - ttt_chunk [1361/1893] bpb=1.123814 time=293.5s - ttt_chunk [1371/1893] bpb=1.123683 time=295.7s - ttt_chunk [1381/1893] bpb=1.123677 time=297.8s - ttt_chunk [1391/1893] bpb=1.123287 time=299.9s - ttt_chunk [1401/1893] bpb=1.123259 time=302.1s - ttt_chunk [1411/1893] bpb=1.123394 time=304.2s - ttt_chunk [1421/1893] bpb=1.123643 time=306.4s - ttt_chunk [1431/1893] bpb=1.123362 time=308.5s - ttt_chunk [1441/1893] bpb=1.123857 time=310.7s - ttt_chunk [1451/1893] bpb=1.124190 time=312.8s - ttt_chunk [1461/1893] bpb=1.123725 time=315.0s - ttt_chunk [1471/1893] bpb=1.124770 time=317.1s - ttt_chunk [1481/1893] bpb=1.124311 time=319.3s - ttt_chunk [1491/1893] bpb=1.124134 time=321.4s - ttt_chunk [1501/1893] bpb=1.124026 time=323.5s - ttt_chunk [1511/1893] bpb=1.124042 time=325.7s - ttt_chunk [1521/1893] bpb=1.124049 time=327.8s - ttt_chunk [1531/1893] bpb=1.123528 time=330.0s - ttt_chunk [1541/1893] bpb=1.123385 time=332.1s - ttt_chunk [1551/1893] bpb=1.123700 time=334.3s - ttt_chunk [1561/1893] bpb=1.123702 time=336.4s - ttt_chunk [1571/1893] bpb=1.123545 time=338.6s - ttt_chunk [1581/1893] bpb=1.123651 time=340.7s - ttt_chunk [1591/1893] bpb=1.123514 time=342.8s - ttt_chunk [1601/1893] bpb=1.123694 time=345.0s - ttt_chunk [1611/1893] bpb=1.123636 time=347.1s - ttt_chunk [1621/1893] bpb=1.123234 time=349.3s - ttt_chunk [1631/1893] bpb=1.123540 time=351.4s - ttt_chunk [1641/1893] bpb=1.123533 time=353.6s - ttt_chunk [1651/1893] bpb=1.123500 time=355.7s - ttt_chunk [1661/1893] bpb=1.123379 time=357.9s - ttt_chunk [1671/1893] bpb=1.123867 time=360.0s - ttt_chunk [1681/1893] bpb=1.124007 time=362.2s - ttt_chunk [1691/1893] bpb=1.123839 time=364.3s - ttt_chunk [1701/1893] bpb=1.123997 time=366.4s - ttt_chunk [1711/1893] bpb=1.123990 time=368.6s - ttt_chunk [1721/1893] bpb=1.123994 time=370.7s - ttt_chunk [1731/1893] bpb=1.123871 time=372.9s - ttt_chunk [1741/1893] bpb=1.123674 time=375.0s - ttt_chunk [1751/1893] bpb=1.123508 time=377.2s - ttt_chunk [1761/1893] bpb=1.123657 time=379.3s - ttt_chunk [1771/1893] bpb=1.123559 time=381.5s - ttt_chunk [1781/1893] bpb=1.123594 time=383.6s - ttt_chunk [1791/1893] bpb=1.123180 time=385.7s - ttt_chunk [1801/1893] bpb=1.123057 time=387.9s - ttt_chunk [1811/1893] bpb=1.122950 time=390.0s - ttt_chunk [1821/1893] bpb=1.123006 time=392.2s - ttt_chunk [1831/1893] bpb=1.122406 time=394.3s - ttt_chunk [1841/1893] bpb=1.122337 time=396.5s - ttt_chunk [1851/1893] bpb=1.122135 time=398.6s - ttt_chunk [1861/1893] bpb=1.121779 time=400.8s - ttt_chunk [1871/1893] bpb=1.121774 time=402.9s - ttt_chunk [1881/1893] bpb=1.121330 time=405.0s - ttt_chunk [1891/1893] bpb=1.121096 time=407.2s - ttt_chunk [1893/1893] bpb=1.121136 time=407.5s -ttt_sliding:done val_loss=1.889192 val_bpb=1.118889 elapsed=407.5s -legal_ttt val_loss:1.8892 val_bpb:1.1189 eval_time:407984ms -legal_ttt_exact val_loss:1.88919190 val_bpb:1.11888882 diff --git a/records/track_10min_16mb/2026-03-23_LeakyReLU_LegalTTT_ParallelMuon/train_seed42.log b/records/track_10min_16mb/2026-03-23_LeakyReLU_LegalTTT_ParallelMuon/train_seed42.log deleted file mode 100644 index 8ebc281ed1..0000000000 --- a/records/track_10min_16mb/2026-03-23_LeakyReLU_LegalTTT_ParallelMuon/train_seed42.log +++ /dev/null @@ -1,275 +0,0 @@ -W0323 15:20:53.391000 134686 torch/distributed/run.py:803] -W0323 15:20:53.391000 134686 torch/distributed/run.py:803] ***************************************** -W0323 15:20:53.391000 134686 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. -W0323 15:20:53.391000 134686 torch/distributed/run.py:803] ***************************************** -logs/da29a702-7295-47ef-8940-4be71fa4fdf7.txt -val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=/root/parameter-golf/data/tokenizers/fineweb_1024_bpe.model -train_loader:dataset:fineweb10B_sp1024 train_shards:80 -val_loader:shards pattern=/root/parameter-golf/data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 -model_params:26928220 -mtp_num_heads:0 mtp_loss_weight:0.2 mtp_params:0 -XSA:last_4 active_layers:[7, 8, 9, 10] -world_size:8 grad_accum_steps:1 -sdp_backends:cudnn=False flash=True mem_efficient=False math=False -attention_mode:gqa num_heads:8 num_kv_heads:4 -tie_embeddings:True embed_lr:0.035 head_lr:0.0 matrix_lr:0.025 scalar_lr:0.025 -train_batch_tokens:786432 train_seq_len:2048 iterations:9000 warmup_steps:20 max_wallclock_seconds:600.000 -seed:42 -warmup_step:1/20 -warmup_step:2/20 -warmup_step:3/20 -warmup_step:4/20 -warmup_step:5/20 -warmup_step:6/20 -warmup_step:7/20 -warmup_step:8/20 -warmup_step:9/20 -warmup_step:10/20 -warmup_step:11/20 -warmup_step:12/20 -warmup_step:13/20 -warmup_step:14/20 -warmup_step:15/20 -warmup_step:16/20 -warmup_step:17/20 -warmup_step:18/20 -warmup_step:19/20 -warmup_step:20/20 -step:0/9000 val_loss:6.9293 val_bpb:4.1039 train_time:0ms step_avg:0.01ms -step:1/9000 train_loss:6.9308 train_time:131ms step_avg:130.72ms -step:2/9000 train_loss:8.6422 train_time:159ms step_avg:79.60ms -step:3/9000 train_loss:7.6901 train_time:243ms step_avg:81.12ms -step:4/9000 train_loss:7.2779 train_time:324ms step_avg:81.00ms -step:5/9000 train_loss:7.2217 train_time:406ms step_avg:81.22ms -step:6/9000 train_loss:7.1406 train_time:487ms step_avg:81.10ms -step:7/9000 train_loss:7.0922 train_time:568ms step_avg:81.08ms -step:8/9000 train_loss:7.0289 train_time:649ms step_avg:81.08ms -step:9/9000 train_loss:6.6338 train_time:730ms step_avg:81.08ms -step:10/9000 train_loss:6.2566 train_time:811ms step_avg:81.13ms -step:500/9000 train_loss:2.3966 train_time:41425ms step_avg:82.85ms -step:1000/9000 train_loss:2.2607 train_time:82976ms step_avg:82.98ms -step:1500/9000 train_loss:2.2114 train_time:124600ms step_avg:83.07ms -step:2000/9000 train_loss:2.0543 train_time:166321ms step_avg:83.16ms -step:2500/9000 train_loss:2.1581 train_time:208061ms step_avg:83.22ms -step:3000/9000 train_loss:2.1522 train_time:249813ms step_avg:83.27ms -step:3500/9000 train_loss:2.1732 train_time:291582ms step_avg:83.31ms -step:4000/9000 train_loss:1.9667 train_time:333329ms step_avg:83.33ms -step:4000/9000 val_loss:2.0595 val_bpb:1.2197 train_time:333381ms step_avg:83.35ms -step:4500/9000 train_loss:2.1157 train_time:375087ms step_avg:83.35ms -step:5000/9000 train_loss:2.1011 train_time:416858ms step_avg:83.37ms -step:5500/9000 train_loss:2.0126 train_time:458617ms step_avg:83.38ms -step:6000/9000 train_loss:1.9423 train_time:500364ms step_avg:83.39ms -swa:start step:6500 -step:6500/9000 train_loss:2.0788 train_time:542114ms step_avg:83.40ms -late_qat:enabled step:6664 scale:0.1499 -step:7000/9000 train_loss:1.7889 train_time:584621ms step_avg:83.52ms -step:7182/9000 val_loss:1.9225 val_bpb:1.1386 train_time:600128ms step_avg:83.56ms -stopping_early: wallclock_cap train_time:600128ms step:7182/9000 -peak memory allocated: 21471 MiB reserved: 22000 MiB -ema:applying EMA weights -DIAGNOSTIC post_ema val_loss:1.9208 val_bpb:1.1376 eval_time:1977ms -Serialized model: 106027446 bytes -Code size: 89458 bytes -Serialized model int6+lzma: 15787052 bytes -Total submission size int6+lzma: 15876510 bytes -final_int6_roundtrip val_loss:1.9350 val_bpb:1.1460 eval_time:20430ms -final_int6_roundtrip_exact val_loss:1.93502821 val_bpb:1.14603270 -final_int6_sliding_window val_loss:1.8956 val_bpb:1.1227 stride:64 eval_time:97749ms -final_int6_sliding_window_exact val_loss:1.89556781 val_bpb:1.12266500 -final_int8_zlib_roundtrip_exact val_loss:1.89556781 val_bpb:1.12266500 -ttt_sliding:start chunks=1893 chunk_tokens=32768 total_windows=969088 stride=64 ttt_lr=0.002 ttt_epochs=3 freeze_blocks=0 -ttt_sliding:params unfrozen=26928220 frozen=0 - ttt_chunk [1/1893] bpb=1.161909 time=0.5s - ttt_chunk [11/1893] bpb=1.146988 time=2.7s - ttt_chunk [21/1893] bpb=1.133321 time=4.8s - ttt_chunk [31/1893] bpb=1.131377 time=7.0s - ttt_chunk [41/1893] bpb=1.117512 time=9.2s - ttt_chunk [51/1893] bpb=1.111922 time=11.3s - ttt_chunk [61/1893] bpb=1.118583 time=13.4s - ttt_chunk [71/1893] bpb=1.116834 time=15.6s - ttt_chunk [81/1893] bpb=1.116160 time=17.7s - ttt_chunk [91/1893] bpb=1.117028 time=19.9s - ttt_chunk [101/1893] bpb=1.120690 time=22.0s - ttt_chunk [111/1893] bpb=1.123170 time=24.2s - ttt_chunk [121/1893] bpb=1.116607 time=26.3s - ttt_chunk [131/1893] bpb=1.116961 time=28.5s - ttt_chunk [141/1893] bpb=1.122351 time=30.6s - ttt_chunk [151/1893] bpb=1.124214 time=32.8s - ttt_chunk [161/1893] bpb=1.123841 time=34.9s - ttt_chunk [171/1893] bpb=1.128275 time=37.1s - ttt_chunk [181/1893] bpb=1.130429 time=39.2s - ttt_chunk [191/1893] bpb=1.137679 time=41.4s - ttt_chunk [201/1893] bpb=1.136359 time=43.5s - ttt_chunk [211/1893] bpb=1.134080 time=45.6s - ttt_chunk [221/1893] bpb=1.135553 time=47.8s - ttt_chunk [231/1893] bpb=1.134289 time=50.0s - ttt_chunk [241/1893] bpb=1.134639 time=52.1s - ttt_chunk [251/1893] bpb=1.134101 time=54.3s - ttt_chunk [261/1893] bpb=1.131275 time=56.4s - ttt_chunk [271/1893] bpb=1.130166 time=58.6s - ttt_chunk [281/1893] bpb=1.131392 time=60.7s - ttt_chunk [291/1893] bpb=1.133204 time=63.0s - ttt_chunk [301/1893] bpb=1.133879 time=65.1s - ttt_chunk [311/1893] bpb=1.135889 time=67.3s - ttt_chunk [321/1893] bpb=1.137760 time=69.5s - ttt_chunk [331/1893] bpb=1.137552 time=71.6s - ttt_chunk [341/1893] bpb=1.136540 time=73.7s - ttt_chunk [351/1893] bpb=1.138836 time=75.9s - ttt_chunk [361/1893] bpb=1.138998 time=78.0s - ttt_chunk [371/1893] bpb=1.138241 time=80.2s - ttt_chunk [381/1893] bpb=1.138388 time=82.3s - ttt_chunk [391/1893] bpb=1.138226 time=84.5s - ttt_chunk [401/1893] bpb=1.136142 time=86.6s - ttt_chunk [411/1893] bpb=1.134984 time=88.8s - ttt_chunk [421/1893] bpb=1.134048 time=90.9s - ttt_chunk [431/1893] bpb=1.133893 time=93.0s - ttt_chunk [441/1893] bpb=1.134233 time=95.2s - ttt_chunk [451/1893] bpb=1.134538 time=97.3s - ttt_chunk [461/1893] bpb=1.133489 time=99.5s - ttt_chunk [471/1893] bpb=1.134113 time=101.7s - ttt_chunk [481/1893] bpb=1.133735 time=103.8s - ttt_chunk [491/1893] bpb=1.132623 time=105.9s - ttt_chunk [501/1893] bpb=1.132112 time=108.1s - ttt_chunk [511/1893] bpb=1.131422 time=110.2s - ttt_chunk [521/1893] bpb=1.129021 time=112.4s - ttt_chunk [531/1893] bpb=1.130191 time=114.5s - ttt_chunk [541/1893] bpb=1.130566 time=116.7s - ttt_chunk [551/1893] bpb=1.129469 time=118.8s - ttt_chunk [561/1893] bpb=1.129994 time=120.9s - ttt_chunk [571/1893] bpb=1.128956 time=123.1s - ttt_chunk [581/1893] bpb=1.128143 time=125.2s - ttt_chunk [591/1893] bpb=1.127465 time=127.4s - ttt_chunk [601/1893] bpb=1.127956 time=129.5s - ttt_chunk [611/1893] bpb=1.127873 time=131.7s - ttt_chunk [621/1893] bpb=1.127718 time=133.8s - ttt_chunk [631/1893] bpb=1.128404 time=135.9s - ttt_chunk [641/1893] bpb=1.128160 time=138.1s - ttt_chunk [651/1893] bpb=1.128290 time=140.3s - ttt_chunk [661/1893] bpb=1.127773 time=142.5s - ttt_chunk [671/1893] bpb=1.128124 time=144.7s - ttt_chunk [681/1893] bpb=1.128794 time=146.8s - ttt_chunk [691/1893] bpb=1.129783 time=149.0s - ttt_chunk [701/1893] bpb=1.129235 time=151.1s - ttt_chunk [711/1893] bpb=1.129209 time=153.3s - ttt_chunk [721/1893] bpb=1.128862 time=155.4s - ttt_chunk [731/1893] bpb=1.128910 time=157.6s - ttt_chunk [741/1893] bpb=1.129004 time=159.7s - ttt_chunk [751/1893] bpb=1.128835 time=161.9s - ttt_chunk [761/1893] bpb=1.128773 time=164.0s - ttt_chunk [771/1893] bpb=1.128469 time=166.2s - ttt_chunk [781/1893] bpb=1.129249 time=168.3s - ttt_chunk [791/1893] bpb=1.128848 time=170.5s - ttt_chunk [801/1893] bpb=1.129160 time=172.6s - ttt_chunk [811/1893] bpb=1.128928 time=174.8s - ttt_chunk [821/1893] bpb=1.128721 time=176.9s - ttt_chunk [831/1893] bpb=1.128543 time=179.1s - ttt_chunk [841/1893] bpb=1.127875 time=181.3s - ttt_chunk [851/1893] bpb=1.127634 time=183.4s - ttt_chunk [861/1893] bpb=1.127390 time=185.6s - ttt_chunk [871/1893] bpb=1.127670 time=187.8s - ttt_chunk [881/1893] bpb=1.127849 time=189.9s - ttt_chunk [891/1893] bpb=1.127424 time=192.1s - ttt_chunk [901/1893] bpb=1.127170 time=194.3s - ttt_chunk [911/1893] bpb=1.127283 time=196.4s - ttt_chunk [921/1893] bpb=1.127772 time=198.6s - ttt_chunk [931/1893] bpb=1.127756 time=200.7s - ttt_chunk [941/1893] bpb=1.127426 time=202.9s - ttt_chunk [951/1893] bpb=1.127838 time=205.0s - ttt_chunk [961/1893] bpb=1.127923 time=207.2s - ttt_chunk [971/1893] bpb=1.128766 time=209.3s - ttt_chunk [981/1893] bpb=1.128874 time=211.5s - ttt_chunk [991/1893] bpb=1.128901 time=213.6s - ttt_chunk [1001/1893] bpb=1.128884 time=215.8s - ttt_chunk [1011/1893] bpb=1.128691 time=217.9s - ttt_chunk [1021/1893] bpb=1.129045 time=220.0s - ttt_chunk [1031/1893] bpb=1.129522 time=222.2s - ttt_chunk [1041/1893] bpb=1.129158 time=224.3s - ttt_chunk [1051/1893] bpb=1.128931 time=226.5s - ttt_chunk [1061/1893] bpb=1.128972 time=228.6s - ttt_chunk [1071/1893] bpb=1.129594 time=230.8s - ttt_chunk [1081/1893] bpb=1.129874 time=232.9s - ttt_chunk [1091/1893] bpb=1.130609 time=235.0s - ttt_chunk [1101/1893] bpb=1.130636 time=237.2s - ttt_chunk [1111/1893] bpb=1.130479 time=239.3s - ttt_chunk [1121/1893] bpb=1.130261 time=241.5s - ttt_chunk [1131/1893] bpb=1.130129 time=243.6s - ttt_chunk [1141/1893] bpb=1.129824 time=245.8s - ttt_chunk [1151/1893] bpb=1.129831 time=247.9s - ttt_chunk [1161/1893] bpb=1.129436 time=250.1s - ttt_chunk [1171/1893] bpb=1.129754 time=252.2s - ttt_chunk [1181/1893] bpb=1.129011 time=254.4s - ttt_chunk [1191/1893] bpb=1.128899 time=256.5s - ttt_chunk [1201/1893] bpb=1.129320 time=258.7s - ttt_chunk [1211/1893] bpb=1.128833 time=260.8s - ttt_chunk [1221/1893] bpb=1.128546 time=263.0s - ttt_chunk [1231/1893] bpb=1.128252 time=265.1s - ttt_chunk [1241/1893] bpb=1.127907 time=267.3s - ttt_chunk [1251/1893] bpb=1.127307 time=269.4s - ttt_chunk [1261/1893] bpb=1.127269 time=271.6s - ttt_chunk [1271/1893] bpb=1.126887 time=273.7s - ttt_chunk [1281/1893] bpb=1.126666 time=275.9s - ttt_chunk [1291/1893] bpb=1.126438 time=278.0s - ttt_chunk [1301/1893] bpb=1.125829 time=280.1s - ttt_chunk [1311/1893] bpb=1.125429 time=282.3s - ttt_chunk [1321/1893] bpb=1.125084 time=284.4s - ttt_chunk [1331/1893] bpb=1.125020 time=286.6s - ttt_chunk [1341/1893] bpb=1.124877 time=288.7s - ttt_chunk [1351/1893] bpb=1.124806 time=290.9s - ttt_chunk [1361/1893] bpb=1.124874 time=293.0s - ttt_chunk [1371/1893] bpb=1.124751 time=295.2s - ttt_chunk [1381/1893] bpb=1.124744 time=297.3s - ttt_chunk [1391/1893] bpb=1.124342 time=299.5s - ttt_chunk [1401/1893] bpb=1.124300 time=301.6s - ttt_chunk [1411/1893] bpb=1.124427 time=303.8s - ttt_chunk [1421/1893] bpb=1.124680 time=305.9s - ttt_chunk [1431/1893] bpb=1.124395 time=308.1s - ttt_chunk [1441/1893] bpb=1.124890 time=310.2s - ttt_chunk [1451/1893] bpb=1.125220 time=312.4s - ttt_chunk [1461/1893] bpb=1.124762 time=314.5s - ttt_chunk [1471/1893] bpb=1.125818 time=316.7s - ttt_chunk [1481/1893] bpb=1.125373 time=318.8s - ttt_chunk [1491/1893] bpb=1.125197 time=321.0s - ttt_chunk [1501/1893] bpb=1.125110 time=323.1s - ttt_chunk [1511/1893] bpb=1.125122 time=325.3s - ttt_chunk [1521/1893] bpb=1.125139 time=327.4s - ttt_chunk [1531/1893] bpb=1.124629 time=329.5s - ttt_chunk [1541/1893] bpb=1.124480 time=331.7s - ttt_chunk [1551/1893] bpb=1.124810 time=333.9s - ttt_chunk [1561/1893] bpb=1.124806 time=336.0s - ttt_chunk [1571/1893] bpb=1.124639 time=338.2s - ttt_chunk [1581/1893] bpb=1.124756 time=340.3s - ttt_chunk [1591/1893] bpb=1.124598 time=342.5s - ttt_chunk [1601/1893] bpb=1.124770 time=344.6s - ttt_chunk [1611/1893] bpb=1.124702 time=346.8s - ttt_chunk [1621/1893] bpb=1.124304 time=348.9s - ttt_chunk [1631/1893] bpb=1.124615 time=351.1s - ttt_chunk [1641/1893] bpb=1.124636 time=353.2s - ttt_chunk [1651/1893] bpb=1.124589 time=355.4s - ttt_chunk [1661/1893] bpb=1.124465 time=357.5s - ttt_chunk [1671/1893] bpb=1.124947 time=359.7s - ttt_chunk [1681/1893] bpb=1.125087 time=361.8s - ttt_chunk [1691/1893] bpb=1.124929 time=364.0s - ttt_chunk [1701/1893] bpb=1.125091 time=366.2s - ttt_chunk [1711/1893] bpb=1.125101 time=368.4s - ttt_chunk [1721/1893] bpb=1.125107 time=370.5s - ttt_chunk [1731/1893] bpb=1.124976 time=372.6s - ttt_chunk [1741/1893] bpb=1.124756 time=374.8s - ttt_chunk [1751/1893] bpb=1.124590 time=376.9s - ttt_chunk [1761/1893] bpb=1.124735 time=379.1s - ttt_chunk [1771/1893] bpb=1.124633 time=381.3s - ttt_chunk [1781/1893] bpb=1.124660 time=383.4s - ttt_chunk [1791/1893] bpb=1.124244 time=385.6s - ttt_chunk [1801/1893] bpb=1.124120 time=387.8s - ttt_chunk [1811/1893] bpb=1.124015 time=389.9s - ttt_chunk [1821/1893] bpb=1.124071 time=392.1s - ttt_chunk [1831/1893] bpb=1.123469 time=394.3s - ttt_chunk [1841/1893] bpb=1.123389 time=396.5s - ttt_chunk [1851/1893] bpb=1.123171 time=398.6s - ttt_chunk [1861/1893] bpb=1.122807 time=400.8s - ttt_chunk [1871/1893] bpb=1.122787 time=402.9s - ttt_chunk [1881/1893] bpb=1.122334 time=405.1s - ttt_chunk [1891/1893] bpb=1.122105 time=407.2s - ttt_chunk [1893/1893] bpb=1.122151 time=407.5s -ttt_sliding:done val_loss=1.891102 val_bpb=1.120020 elapsed=407.5s -legal_ttt val_loss:1.8911 val_bpb:1.1200 eval_time:408049ms -legal_ttt_exact val_loss:1.89110239 val_bpb:1.12002032 diff --git a/records/track_10min_16mb/2026-03-24_74M_Ternary_UNet_FP8_10L_8192BPE_YaRN_NeoMuon/README.md b/records/track_10min_16mb/2026-03-24_74M_Ternary_UNet_FP8_10L_8192BPE_YaRN_NeoMuon/README.md deleted file mode 100644 index db99765e5d..0000000000 --- a/records/track_10min_16mb/2026-03-24_74M_Ternary_UNet_FP8_10L_8192BPE_YaRN_NeoMuon/README.md +++ /dev/null @@ -1,152 +0,0 @@ -# Record: 1.1570 BPB — 73.7M Ternary U-Net Transformer - -**BitNet b1.58 + 10L + NeoMuon + 4x relu² MLP + Factored Tied Embedding + Poly5 Softcap + YaRN 2048 + 8192 BPE + FP8 QAT + Base-3 LZMA + Stride-16 Sliding Eval** - -**val_bpb: 1.1570** (3-seed mean sliding, std 0.0007) | **15.99 MB** max artifact | 8×H100 SXM, 599s - -> **Full experiment log covering 250+ runs, ablations, and decision rationale, that could help anyone else: [RESULTS.md](RESULTS.md). Complete training logs in my personal repo: [logs/](https://github.com/CiprianFlorin-Ifrim/openai-parameter-golf-submission/tree/main/logs/cuda).** - -The results document linked here and in my repo showcases all methods and sweeps applied to both Binary and Ternary Bitnets, which unfortunately are incompatible with many methods, such as Tversky Layers, EMA, Muon WD, LM Logit Head ranking and many more. Scaling ratios and applicable/rejected techniques can be useful for other submissions too. - -## Results (3 seeds, 8×H100 SXM) - -| Seed | Steps | ms/step | Sliding BPB (s16) | val_bpb | RT bpb | Artifact | -|------|-------|---------|-------------------|---------|--------|----------| -| 42 | 6,530 | 91.7 | **1.1565** | 1.1816 | 1.1837 | 15,993,853 bytes | -| 1337 | 6,520 | 91.9 | 1.1568 | 1.1825 | 1.1839 | 15,995,705 bytes | -| 7 | 6,530 | 91.8 | 1.1578 | 1.1823 | 1.1850 | 15,992,753 bytes | -| **Mean** | **6,527** | **91.8** | **1.1570** | **1.1821** | **1.1842** | **15,994,104 bytes** | -| **Std** | **5** | **0.1** | **0.0007** | **0.0005** | **0.0007** | **1,498 bytes** | - -## Architecture - -- 10 transformer layers, dim=768, 8 heads, 4 KV heads (GQA), head_dim=96 -- BitNet b1.58 ternary quantisation: weights {-1, 0, +1}, ~1.6 bits/param, per-group (128) absmean scaling -- 4x MLP expansion (hidden=3072) with **relu²** activation, fused gate+up projection -- U-Net encoder/decoder with learned skip weights (ones-init) and per-block residual mix from input embedding -- Factored tied embedding: 8192×254 bottleneck with learned 254-to-768 and 768-to-254 projections -- Polynomial softcap (degree 5, cap=10) with Z-loss regularisation (1e-4) -- YaRN positional encoding (max_len=2048, ROPE_BASE=5000) -- Fused QKV projection (single TernaryLinear) -- FlashAttention-3 (Hopper native kernels) -- 73.7M parameters, 15.92MB artifact (64.9M ternary + 2.5M fp8 + 70KB code) - -## Key Techniques - -### Architecture -- **Width over depth:** 768d/10L outperforms 512d/25L — faster steps (91ms vs 127ms) yield 6,530 vs 4,720 steps in 600s -- **4x relu² MLP:** relu² is -0.024 bpb over relu at zero cost; 4x width adds -0.008 bpb over 3x at same step budget -- **EMBED_DIM=254:** frees ~4MB for wider MLP; 254 = 256-2 to fit code within the byte budget - -### Training -- **NeoMuon** with 3 Newton-Schulz steps: compensates for ternary STE gradient attenuation; 3 steps equivalent to 5 at convergence (+190 free steps) -- **Fused QKV + fused relu²:** ~4-6ms/step saving (~180 extra training steps) -- **FlashAttention-3:** -9% step time (~380 free steps) -- **524k batch tokens:** optimal for ternary STE — 262k too noisy, 1M loses gradient updates - -### Evaluation -- **Temperature scaling (T=0.90):** 5-point grid on training tokens; relu² logits slightly underconfident -- **Sliding window (stride=16):** full context per scored token, ~0.025 bpb over chunked eval - -### Compression -- **Base-3 + LZMA (preset=9):** 5 trits/byte packing, 39% reduction over int8+zlib; auto-compared against bitmask per run -- **FP8 QAT (e4m3):** halves fp_params (~5MB to ~2.5MB), only 0.002 bpb RT penalty -- **Shrinkage fix:** corrects ternary zero-fraction scale mismatch, eliminating all roundtrip gaps - -## Setup and Run - -```bash -# Environment setup (conda + Python 3.13 + PyTorch + FlashAttention-3 + Triton + dataset) -bash setup.sh - -# Activate and run -conda activate golf -SEED=42 bash run_cuda_ternary.sh -``` - -
-Full run command - -```bash -RUN_ID=ternary_run \ -DATA_PATH=./data/datasets/fineweb10B_sp8192 \ -TOKENIZER_PATH=./data/tokenizers/fineweb_8192_bpe.model \ -ATTN_PROJ_TYPE=standard \ -LOGIT_HEAD_TYPE=standard \ -TVERSKY_MEMBERSHIP=sigmoid \ -TVERSKY_NUM_FEATURES=0 \ -TVERSKY_FEATURE_POOLS=0 \ -VOCAB_SIZE=8192 \ -BITNET_GROUP_SIZE=128 \ -BIGRAM_HASH=0 \ -EMBED_DIM=254 \ -TRAINING_DEPTH_RECURRENCE=0 \ -EVAL_DEPTH_RECURRENCE=0 \ -NUM_LAYERS=10 \ -MODEL_DIM=768 \ -NUM_KV_HEADS=4 \ -NUM_HEADS=8 \ -DIFF_ATTN=0 \ -MLP_MULT=4 \ -MLP_GROUPS=0 \ -MATRIX_OPTIMIZER=muon \ -ADAM_LR=0.05 \ -ADAM_WD=0.05 \ -MUON_BACKEND_STEPS=3 \ -MUON_MOMENTUM=0.95 \ -MUON_MOMENTUM_WARMUP_START=0.85 \ -MUON_MOMENTUM_WARMUP_STEPS=500 \ -MUON_WD=0.0 \ -MATRIX_LR=0.04 \ -SCALAR_LR=0.02 \ -TIED_EMBED_LR=0.02 \ -WARMDOWN_FRACTION=0.2 \ -LOGIT_SOFTCAP=10 \ -QK_GAIN_INIT=2.25 \ -ROPE_TYPE=yarn \ -YARN_MAX_LEN=2048 \ -ROPE_BASE=5000 \ -BATCH_TOKENS_START=0 \ -BATCH_SCHEDULE_FRACTION=0.33 \ -TRAIN_BATCH_TOKENS=524288 \ -SEQ_LEN_START=0 \ -SEQ_SCHEDULE_FRACTION=0.0 \ -TRAIN_SEQ_LEN=1024 \ -SMEAR=0 \ -ITERATIONS=10000 \ -WARMUP_STEPS=5 \ -MAX_WALLCLOCK_SECONDS=599 \ -VAL_LOSS_EVERY=0 \ -TRAIN_LOG_EVERY=1000 \ -CHURN_LOG_EVERY=0 \ -VAL_MAX_TOKENS=0 \ -TIE_EMBEDDINGS=1 \ -UNTIE_AT_FRACTION=0.00 \ -HEAD_LR=0.02 \ -CORR_WEIGHT_LR=0.02 \ -ACTIVATION=relu2 \ -SOFTCAP_TYPE=poly \ -MTP_HEADS=0 \ -REFINER=0 \ -REFINER_KERNEL=3 \ -SLIDING_EVAL=1 \ -SLIDING_EVAL_STRIDE=16 \ -SLIDING_BATCH_SIZE=256 \ -TEMP_SCALING=1 \ -FP_STORAGE=FP8 \ -SEED=42 \ -COMPILE_MODE=default \ -OMP_NUM_THREADS=1 torchrun --standalone --nproc_per_node=8 train_gpt_cuda_ternary.py -``` - -
- -## Compliance - -- [x] 3 seeds run on 8×H100 SXM -- [x] All 3 seeds train in <=600s (max: 599.7s) -- [x] All 3 seeds artifact <=16,000,000 bytes (max: 15,995,705) -- [x] Sliding window eval stride=16, consistent (std=0.0007) -- [x] No test-time training on validation data -- [x] No network calls during evaluation -- [x] No external compute diff --git a/records/track_10min_16mb/2026-03-24_74M_Ternary_UNet_FP8_10L_8192BPE_YaRN_NeoMuon/RESULTS.md b/records/track_10min_16mb/2026-03-24_74M_Ternary_UNet_FP8_10L_8192BPE_YaRN_NeoMuon/RESULTS.md deleted file mode 100644 index 82fcd581f0..0000000000 --- a/records/track_10min_16mb/2026-03-24_74M_Ternary_UNet_FP8_10L_8192BPE_YaRN_NeoMuon/RESULTS.md +++ /dev/null @@ -1,1236 +0,0 @@ -# Parameter Golf — Complete Experiment Log - -**Author:** Ciprian-Florin Ifrim -**Date:** March 2026 - ---- - -## Challenge Overview - -Train the best language model that fits in a 16MB artifact and trains in under 10 minutes on 8×H100 SXM GPUs, evaluated by tokenizer-agnostic bits-per-byte (BPB) compression on the FineWeb validation set. - -- **Baseline:** 1.2244 bpb (9L 512d int8+zlib, 1k vocab) -- **Our best (ternary, valid):** 1.1565 bpb sliding (P2, 10L 768d relu² 4×MLP fp8, EMBED_DIM=254, seed=42, 16.00MB) -- **Our best (binary, unconstrained):** 1.1239 bpb sliding (15L 768d binary relu² 4×MLP fp8, 50k steps / ~2h compute, 15.67MB) -- **Our best (quality, over budget):** 1.1771 bpb (F59, 12L 768d swiglu 3×MLP, 21.96MB) -- **Challenge period:** March 18 – April 30, 2026 -- **Compute sponsor:** OpenAI ($1M in compute credits) - -The challenge is framed as L(N) optimisation — minimising loss given fixed parameter count N, unconstrained by data, compute, steps, or architecture. Related challenges include NanoGPT Speedrun (L(T): lowest loss given constrained time) and NanoGPT Slowrun (L(D): lowest loss given constrained dataset). - ---- - -## Run Numbering Convention - -| Prefix | Description | -|--------|-------------| -| Plain (1–100) | Dev runs on RTX 5090, 100 steps | -| R prefix (R1...) | Record runs — 600s on 8×H100, leaderboard-targeted | -| S prefix (S1...) | Scaling runs — 1500 steps or 300s on 8×H100, controlled sweeps | -| SB prefix (SB1...) | Binary scaling runs | -| F prefix (F1...) | Final runs — 600s on 8×H100, official submissions | -| P prefix (P1...) | Pushed/submission runs — final config pushed to GitHub | - -Additionally, 20 early architecture iterations were performed on MLX (Mac Studio M1 Ultra, 32GB unified memory) and 2 on MPS (MacBook Pro M1 Pro, 32GB unified memory) for rapid prototyping before GPU scaling. - -> **Note:** This document covers ~85 named runs (F, S, R series). An additional ~165 dev runs (plain numbered 1–100, repeated sweeps, smoke tests) were conducted but are not individually listed. Key findings from those runs are incorporated into the sweep tables and decision rationale. Separate synthetic-data notebooks were used to isolate the behaviour of specific techniques (Tversky similarity, linear alternatives, grouped projections) before committing H100 compute. - ---- - -## Hardware - -| System | Spec | Notes | -|--------|------|-------| -| Dev | RTX 5090 32GB, single GPU | Triton smem ceiling 101KB/SM; blocks value embeddings and some kernels | -| Mac (MLX) | Mac Studio M1 Ultra 32GB | MLX early iteration, 20 runs | -| Mac (MPS) | MacBook Pro M1 Pro 32GB | MPS early iteration, 2 runs | -| Final | 8×H100 SXM 80GB | Primary training platform | - -**Step times at 768d (12L):** relu² 2x: 89ms | relu² 3x: 99ms | relu² 4x: 91ms | swiglu 3x: 127ms | leaky relu 3x: 103ms - -**Step times at 512d:** 26L baseline: 149ms → 136ms with FA3 → 127ms with FA3 + fusions + EMBED=256 at 25L - -**FlashAttention-3** reduced step time by ~9% (~380 free training steps per 600s run). - -**Kernel fusion optimisations** (fused QKV + fused SwiGLU + dataloader + softcap) saved a further ~7-10ms/step. - -**Width vs depth discovery:** 12L 768d at 106ms/step gets ~5640 steps in 600s vs ~4720 steps for 25L 512d — 920 extra steps from the faster per-step time of wider/shallower models. Final 10L 768d 4×MLP at 91.8ms/step gets ~6530 steps. - ---- - -## Architecture: Ternary U-Net Transformer - -### Quantisation Scheme - -BitNet b1.58 ternary quantisation — weights constrained to {−1, 0, +1} with per-group absmean scaling. Approximately 1.6 bits per parameter. - -**Compression pipeline:** Base-3 packing (5 trits/byte) or bitmask packing → LZMA (preset=9). Best method auto-selected per run. Bitmask wins when zero fraction is high. - -**Quantisation shrinkage fix:** When ternary Q contains zeros, `mean(|Q|) < 1.0`, causing scale mismatch on reload. Fix: inflate by `1/mean(|Q|)` during dequantisation. Eliminates all roundtrip gaps. - -### U-Net Skip Connections - -The model uses a U-Net style encoder/decoder structure with learned skip connections. The first `num_layers // 2` blocks (encoder) store their outputs; the second half (decoder) receives these via `x = x + skip_weight[i] * skips.pop()`. This allows the decoder to simultaneously access high-level semantic representations (from deep processing) and low-level token-level features (from early processing), without requiring the decoder to reconstruct low-level information from the compressed residual stream. - -Additionally, each block receives `x0` (the original input embedding) via a learned residual mix: `x = mix[0] * x + mix[1] * x0`, giving every layer direct access to the raw token representation regardless of accumulated residual drift. - -For odd layer counts, the decoder receives the larger half (e.g. 27L → 13 encoder + 14 decoder), which is the standard U-Net convention — more processing power applied after skip injection. - -### Factored Embedding - -With `EMBED_DIM=254`, token embedding is `[8192, 254]` instead of `[8192, 768]`, with learned projections `embed_proj` (254→768) and `embed_proj_rev` (768→254) for the tied output head. - -**EMBED_DIM history:** Started at 128 (dev runs), upgraded to 256 after an optimizer coverage fix revealed that the projection layers had not been receiving gradients (−0.024 bpb improvement vs 128 once trained), then trimmed to 254 to fit artifact+code under the 16,000,000 byte budget (~0.0004 bpb cost, 0.00018/dim from 128→256 scaling data). - -### Fused Operations - -**Fused QKV:** Single `TernaryLinear(dim, dim + 2*kv_dim)`. **Fused SwiGLU/relu²:** Gate and up projections combined into single wide matrix. Combined saving: ~4-6ms/step. - -### Z-Loss Regularisation - -`1e-4 * logsumexp(logits)²` (from PaLM/Gemma) anchors logits near zero, keeping gradients sharp through the ternary STE. - ---- - -## Compression Scheme - -### Base-3 + LZMA (Primary) - -5 trits per byte (1.585 bits/trit), lossless. LZMA at preset=9 achieves ~39% reduction over int8+zlib. Ternary distribution at convergence: ~20–29% zeros, ~35–40% each ±1. The skewed distribution (more zeros) is exploited by LZMA's entropy coding. - -### Bitmask Compression (Alternative) - -Encodes "is this weight zero?" and "if nonzero, is it +1?" as separate bitmasks. Both methods are tried and the smaller is selected automatically. In practice, bitmask and base-3+LZMA produce nearly identical artifact sizes — bitmask wins marginally in some runs (e.g. S72: 15.84MB vs 15.87MB). Zero fraction would need to drop below ~5% for bitmask to provide a clear advantage; our zero fraction ranges from 17–29% at convergence, making bitmask non-competitive. - -### 3D Tensor Support - -Conv1d weights (`[dim, dim, kernel]`) are reshaped to 2D before ternary quantisation and restored to original shape on load. - -### FP8 QAT - -Non-ternary parameters (embeddings, projections) stored at fp8 (e4m3) with Quantisation-Aware Training via STE. Halves fp_params storage (~5MB → ~2.5MB). Typical roundtrip gap: 0.001–0.002 bpb. - ---- - -## Submission Runs (P prefix) — Ternary - -Configuration: F88 (10L 768d relu² 4×MLP fp8, WD=0, EMBED_DIM=254, 599s wallclock, TEMP=0.90) - -| Seed | Steps | val_bpb | RT bpb | Sliding bpb | Train Time | Eval Time | Artifact | Budget | -|------|-------|---------|--------|-------------|------------|-----------|----------|--------| -| 1337 | 6520 | 1.1825 | 1.1839 | **1.1568** | 599.1s | 428.7s | 15.92MB | 16.00/16.00MB | -| 42 | 6530 | 1.1816 | 1.1837 | **1.1565** | 599.7s | 429.3s | 15.92MB | 15.99/16.00MB | -| 7 | 6530 | 1.1823 | 1.1850 | **1.1578** | 599.6s | 429.0s | 15.92MB | 15.99/16.00MB | -| **Mean** | **6527** | **1.1821** | **1.1842** | **1.1570** | **599.5s** | **429.0s** | **15.92MB** | | -| **Std** | **5** | **0.0005** | **0.0007** | **0.0007** | **0.3s** | **0.3s** | **0.00MB** | | - -All three seeds fit within the 16,000,000 byte budget. The standard deviation of 0.0007 bpb across seeds confirms high reproducibility. All runs achieve p < 0.001 improvement over the 1.2244 bpb baseline. - -### Batch Size Sensitivity (Ternary, 599s wallclock) - -| Batch Tokens | Steps | ms/step | val_bpb | Sliding bpb | Tokens Seen | Fits Budget | -|-------------|-------|---------|---------|-------------|-------------|-------------| -| 262,144 | 10,000 | 49 | 1.2413 | — | 2.6B | No | -| **524,288** | **6,530** | **92** | **1.1850** | **1.1578** | **3.4B** | **Yes** | -| 1,048,576 | 3,480 | 172 | 1.1925 | 1.1659 | 3.5B | No | - -524k batch tokens is the optimal operating point. Halving the batch (262k) doubles the step count but degrades quality by 0.056 bpb due to noisier gradients interacting poorly with the ternary STE. Doubling it (1M) sees similar total tokens but fewer gradient updates, costing 0.008 bpb. - ---- - -## Current Best Configuration - -### Ternary: 10L 768d relu² 4×MLP fp8, WD=0, EMBED_DIM=254 - -```bash -NUM_LAYERS=10 MODEL_DIM=768 NUM_HEADS=8 -NUM_KV_HEADS=4 MLP_MULT=4 VOCAB_SIZE=8192 -ACTIVATION=relu2 LOGIT_SOFTCAP=10 SOFTCAP_TYPE=poly -QK_GAIN_INIT=2.25 ROPE_BASE=5000 ROPE_TYPE=yarn -YARN_MAX_LEN=2048 EMBED_DIM=254 TIE_EMBEDDINGS=1 -BITNET_GROUP_SIZE=128 FP_STORAGE=FP8 MUON_WD=0.0 -MATRIX_LR=0.04 SCALAR_LR=0.02 TIED_EMBED_LR=0.02 -MUON_BACKEND_STEPS=3 MUON_MOMENTUM=0.95 WARMDOWN_FRACTION=0.2 -MAX_WALLCLOCK_SECONDS=599 -SLIDING_EVAL=1 SLIDING_EVAL_STRIDE=16 TEMP_SCALING=1 -TRAIN_BATCH_TOKENS=524288 -``` - -| Metric | Value | -|--------|-------| -| val_bpb (mean) | 1.1821 | -| RT bpb (mean) | 1.1842 | -| Sliding bpb (mean) | 1.1570 | -| Artifact + code | 15,992,753–15,995,705 / 16,000,000 bytes | -| Steps | 6520–6530 | -| ms/step | 91.8 | -| zero_frac | 0.335–0.336 | -| optimal_T | 0.90 | -| Params | 73,685,840 | - ---- - -## Dev Runs (RTX 5090, 100–500 steps) - -### Phase 0 — Ternary vs Binary (500 steps, 16L 512d, 1k vocab) - -| Run | Config | val_bpb | RT bpb | Artifact | ms/step | -|-----|--------|---------|--------|----------|---------| -| 17 | Ternary baseline | 1.7110 | 1.7300 | 23.95MB | 1312 | -| 18 | Binary {−1,+1} | 1.7121 | 1.7316 | 23.93MB | 1309 | - -Ternary wins by 0.0016 bpb. The zero state provides representational benefit. - ---- - -### Phase 1 — Training Techniques (100 steps, 9L 512d, 1k vocab) - -| Run | Config | val_bpb | RT bpb | Artifact | Notes | -|-----|--------|---------|--------|----------|-------| -| 19 | Ternary 16L 512d baseline | 2.3371 | 2.3793 | 7.33MB | | -| 20 | + Untie lm_head at 2/3 | 2.3569 | 2.3983 | 8.13MB | Deferred — needs wallclock fix | -| 21 | + Value embeddings | — | — | — | Blocked: RTX 5090 Triton smem | -| 22 | + Smear module | 2.3593 | 2.3985 | 7.33MB | Deferred — gate needs many steps | -| 23 | Baseline 9L 512d | 2.4483 | 2.4768 | 4.45MB | Switched from 16L | -| 24 | + Polynomial softcap | 2.3981 | 2.4438 | 4.45MB | **−0.033 rt** | -| 25 | + Seq length schedule | 2.4633 | 2.5106 | 4.45MB | Deferred — recompile cost | -| 26 | + NorMuon | 2.4018 | 2.4104 | 4.40MB | **−0.033 rt**, 5× smaller RT gap | -| 27 | + Grad accum delay | 2.6298 | 2.6571 | 4.40MB | Deferred — needs 2000+ steps | - ---- - -### Vocabulary Sweep (100 steps, 9L 512d) - -| Run | Vocab | val_bpb | RT bpb | Artifact | Notes | -|-----|-------|---------|--------|----------|-------| -| 23 | 1024 | 2.4483 | 2.4768 | 4.45MB | Baseline | -| 28 | 4096 | 2.0930 | 2.0974 | 6.68MB | −0.32 vs 1k | -| **29** | **8192** | **1.9946** | **1.9990** | **9.64MB** | **−0.42 vs 1k — largest single win** | - -8192 vocab locked. The tokeniser merges ~1.57× more aggressively than 1k, directly reducing BPB. Val token count drops from 63.8M (sp1024) to 40.5M (sp8192) for the same 50k documents. - ---- - -### Activation Sweep (100 steps, 9L 512d, 8k vocab) - -| Run | Activation | val_bpb | RT bpb | Artifact | ms/step | -|-----|-----------|---------|--------|----------|---------| -| 29 | relu2 | 1.9946 | 1.9990 | 9.64MB | 838 | -| 30 | relu | 1.9846 | 1.9879 | 9.63MB | 830 | -| **31** | **SwiGLU** | **1.9704** | **1.9743** | **10.70MB** | **960** | -| 32 | SwiGLU + MTP(2) | 1.9627 | 1.9672 | 10.69MB | 1111 | - -SwiGLU with MTP auxiliary loss gives −0.032 bpb but +16% slower. SwiGLU alone gives −0.025 bpb. MTP deferred. - ---- - -### Embedding Factorization Sweep (100 steps, 9L 512d, 8k vocab) - -| Run | EMBED_DIM | val_bpb | RT bpb | RT gap | Artifact | -|-----|-----------|---------|--------|--------|----------| -| 33a | 0 (=512) | 1.9931 | 1.9962 | 0.003 | 9.63MB | -| **33d** | **128** | **1.9656** | **1.9656** | **0.000** | **9.12MB** | -| 33c | 256 | 2.0538 | 2.1339 | 0.080 | 6.68MB | -| 33e | 64 | 2.0936 | 2.0968 | 0.003 | 4.49MB | -| 33f | 1024 | 2.0709 | 2.1845 | 0.114 | 15.60MB | - -128 was optimal at dev scale. After an optimizer fix revealed the projection layers had not been training, 256 became optimal at full convergence — see EMBED_DIM Sweep at full convergence. - ---- - -### Tversky Neural Network Investigation - -Based on Doumbouya et al. (2025). Three-term Tversky similarity: `S = theta * f(A intersection B) - alpha * f(A - B) - beta * f(B - A)` with learned membership functions. - -**Feature count sweep (FP16 features, ternary prototypes, 100 steps, 9L 512d):** - -| Run | Features | val_bpb | RT bpb | RT gap | Artifact | -|-----|----------|---------|--------|--------|----------| -| — | No Tversky | 1.9751 | 1.9751 | 0.000 | 5.33MB | -| 38 | 16 | 1.9877 | 2.0186 | 0.031 | 5.46MB | -| 39 | 32 | 1.9843 | 2.0133 | 0.029 | 5.57MB | -| 40 | 64 | 1.9790 | 2.0097 | 0.031 | 5.79MB | -| **41** | **128** | **1.9427** | **1.9865** | **0.044** | **6.20MB** | -| 42 | 256 | 1.9737 | 2.0863 | 0.113 | 5.63MB | -| 43 | 512 | 2.0036 | 2.0965 | 0.093 | 5.90MB | -| 44 | 128 + shrinkage fix | 1.9425 | **1.9425** | **0.000** | 6.20MB | - -Tversky showed genuine quality benefit (~-0.017 bpb) at dev scale with 128 features and fp16 prototype storage. However, subsequent investigation at full convergence (12L 768d) and with corrected prototype storage showed all Tversky variants within noise of the linear baseline. Additional experiments included full ternary prototypes, shared feature pools across layers, no-features mode, logit-head application, and different membership functions (sigmoid, poly, tanh). A synthetic-data notebook confirmed that Tversky's asymmetric similarity only helps on tasks with genuine directional feature relationships (hypernym/hyponym, cause/effect); next-token prediction on FineWeb web text is not such a task. - -At the 768d architecture with relu², Tversky also incurred a 19ms/step overhead because the smaller MLP no longer masked the compute cost. - -**Conclusion:** Tversky is quality-neutral on FineWeb language modelling regardless of configuration. Not a quantisation issue, not an optimizer issue — the task simply does not benefit from asymmetric similarity. - ---- - -### Key Hyperparameter Sweeps (100 steps, 9L 512d, 8k vocab) - -**QK_GAIN_INIT sweep:** - -| Run | QK_GAIN | val_bpb | Delta | -|-----|---------|---------|-------| -| 75 | 1.0 | 2.0007 | +0.0076 | -| 73 | 1.5 | 1.9931 | baseline | -| 81 | 2.15 | 1.9913 | −0.0018 | -| **79** | **2.25** | **1.9898** | **−0.0033** | -| 77 | 2.5 | 1.9915 | −0.0016 | -| 80 | 2.75 | 1.9975 | +0.0044 | -| 78 | 3.0 | 2.0011 | +0.0080 | - -Clear inverted-U response. **QK_GAIN_INIT=2.25 locked.** - -**LOGIT_SOFTCAP sweep:** - -| Run | SOFTCAP | val_bpb | Delta | -|-----|---------|---------|-------| -| 74 | 5 | 1.9942 | −0.0013 | -| **73** | **10** | **1.9931** | **−0.0024** | -| 72 | 20 | 1.9935 | −0.0020 | -| 71 | 50 | 1.9957 | +0.0003 | - -**LOGIT_SOFTCAP=10 locked.** - -**Softcap type (poly vs tanh):** - -| Run | Type | val_bpb | Notes | -|-----|------|---------|-------| -| S23 | poly | 1.3680 | | -| S24 | tanh | 1.3693 | | -| S28/S29 | both at EMBED=1024 | 1.3460–1.3462 | Identical at convergence | - -Zero effect. Polynomial retained as default. - -**ROPE_BASE sweep:** - -| Run | ROPE_BASE | val_bpb | Notes | -|-----|-----------|---------|-------| -| **70** | **5000** | **1.9959** | Best at short training | -| 73 | 10000 | 1.9931 | Close second | -| 69 | 20000 | 2.0008 | | -| 68 | 50000 | 2.0017 | | - -**KV Heads:** - -| Run | KV_HEADS | val_bpb | Artifact | -|-----|----------|---------|----------| -| **58** | **4 (GQA)** | **1.9955** | **7.75MB** | -| 66 | 8 (MHA) | 2.0148 | 8.46MB | - -**MLP_MULT:** - -| Run | MLP_MULT | val_bpb | Artifact | -|-----|----------|---------|----------| -| **58** | **2** | **1.9955** | **7.75MB** | -| 64 | 3 | 2.0004 | 9.09MB | -| 65 | 4 | 1.9992 | 10.39MB | - -**Storage precision:** - -| Run | Storage | val_bpb | RT bpb | RT gap | Artifact | -|-----|---------|---------|--------|--------|----------| -| **90** | **fp16** | **1.9656** | **1.9656** | **0.000** | **9.06MB** | -| 91 | fp8 | 1.9662 | 1.9702 | 0.004 | 7.83MB | -| 92 | fp4 | 1.9661 | 1.9955 | 0.029 | 7.11MB | - -**TTT-LoRA sweep (100 steps, ROPE=5000):** - -| Run | Rank | LR | TTT bpb | Delta | -|-----|------|-----|---------|-------| -| **85** | **8** | **0.01** | **1.9368** | **−0.0315** | -| 86 | 8 | 0.005 | 1.9378 | −0.0312 | -| 87 | 8 | 0.02 | 1.9644 | −0.0038 | -| **88** | **4** | **0.01** | **1.9371** | **−0.0285** | -| 89 | 16 | 0.01 | OOM | — | - -TTT confirmed working at dev scale (−0.0315 bpb). Incompatible at convergence — see TTT investigation. - -**EMBED_DIM sweep at 512d (12L, 100 steps):** - -| Run | EMBED_DIM | Tversky feat | RT bpb | Artifact | bpb/MB efficiency | -|-----|-----------|-------------|--------|----------|-------------------| -| 95 | 64 | 128 | 2.1961 | 8.40MB | worst | -| 98 | 96 | 128 | 2.0356 | 8.74MB | | -| 97 | 128 | 128 | 1.9656 | 9.12MB | best | -| 99 | 192 | 128 | 2.0409 | 10.07MB | | -| 94 | 256 | 128 | 2.0703 | 10.93MB | | -| 100 | 256 | 256 | 2.0340 | 10.09MB | RT gap 0.021 | -| 96 | 512 (off) | 128 | 2.0642 | 13.50MB | | - -128 confirmed optimal at dev scale. - ---- - -### Architecture Sizing Table (Ternary, EMBED_DIM=128, standard proj) - -| Config | Layers | Artifact | Under 16MB? | RT gap | Headroom | -|--------|--------|----------|-------------|--------|----------| -| fp16 | 20 | 14.23MB | Yes | 0.0001 | 1.77MB | -| **fp16** | **22** | **15.48MB** | **Yes** | **0.0001** | **0.52MB** | -| fp16 | 24 | 16.74MB | No | — | −0.74MB | -| fp8 QAT | 24 | 14.63MB | Yes | 0.028 | 1.37MB | -| fp8 QAT | 26 | 15.77MB | Yes | 0.066 | 0.23MB | -| **fp8 QAT** | **27** | **15.42MB** | **Yes** | **0.0025** | **0.58MB** | -| fp8 QAT | 28 | 15.92MB+code | Marginal | 0.0029 | ~0MB | -| fp8 QAT | 30 | 16.92MB | No | 0.0029 | −0.92MB | - ---- - -## H100 Record Runs (R prefix) - -**Hardware:** 8×H100 SXM 80GB | **Time limit:** 600 seconds - -| Run | Config | Steps | val_bpb | RT bpb | Artifact | Notes | -|-----|--------|-------|---------|--------|---------|-------| -| R1 | 22L Tversky fp16 | 4299 | 1.2789 | 1.2792 | 15.80MB | | -| R2 | 26L standard fp16 | 3973 | 1.2649 | 1.2650 | 15.85MB | Pre-LR tuning best | -| R3 | 16L Tversky fp16 | 5949 | 1.2900 | 1.2904 | 11.95MB | Too shallow | -| R4 | 9L Tversky fp16 | 10112 | 1.3374 | 1.3394 | 7.48MB | Way too shallow | -| R5 | 30L fp8 | 2852 | 1.2689 | 1.2815 | 17.22MB | Over budget | -| R6 | 26L fp16, 2× LR | ~4003 | 1.2991 | — | ~15.85MB | LR overshot | -| **R7** | **26L fp16, LR=0.02** | **4008** | **1.2608** | **1.2610** | **15.83MB** | **Best pre-FA3** | -| R8 | 26L fp16, LR=0.01 | 4017 | 1.2853 | 1.2855 | 15.72MB | LR too low | -| R9 | 26L BigramHash | 4010 | 1.2804 | 1.2802 | 15.81MB | BigramHash negative | -| R10 | 26L untie@66% | 3706 | 1.2754 | 1.2753 | 23.15MB | Over budget | -| R11 | 26L tied, updated code | 4009 | 1.2806 | 1.2808 | 15.81MB | Code regression | - -**LR sweep (R-series):** - -| LR | val_bpb | Notes | -|----|---------|-------| -| 0.08 | 1.2991 | Overshoots — ternary STE amplifies gradient noise | -| **0.02** | **1.2608** | **Optimal** | -| 0.01 | 1.2853 | Too slow | - ---- - -## Scaling Runs (S prefix) - -**Hardware:** 8×H100 SXM 80GB | **Steps:** 1500 | **Timer:** disabled (MAX_WALLCLOCK_SECONDS=0) -**Base config:** 26L 512d, EMBED_DIM=128, ROPE=5000, QK_GAIN=2.25, SOFTCAP=10, LR=0.02 all, VOCAB=8192, SwiGLU, SEED=1337 - ---- - -### Warmdown Sweep - -| Run | Fraction | val_bpb | -|-----|----------|---------| -| S3 | 10% | 1.3467 | -| **S1** | **20%** | **1.3438** | -| S2 | 30% | 1.3443 | -| S4 | 30% repeat | 1.3458 | -| S5 | 40% | 1.3501 | - -S2 vs S4 (identical config): 0.0015 bpb spread — confirmed seed variance floor. - -### Muon Backend Steps - -| Run | Steps | ms/step | val_bpb | -|-----|-------|---------|---------| -| S8 | 3 | 144.87 | 1.3491 | -| S9 | 4 | 146.61 | 1.3448 | -| **S1** | **5** | **149.19** | **1.3438** | -| S7 | 8 | 164.31 | 1.3441 | -| S6 | 10 | 157.95 | 1.3456 | - -At full convergence (F6 vs F1): 3 steps matches 5 due to +190 extra training steps. Locked at 3. - -### Muon Momentum - -| Run | Momentum | val_bpb | zero_frac | Artifact | -|-----|----------|---------|-----------|---------| -| S11 | 0.90 | 1.3680 | 0.179 | 15.39MB | -| **S1** | **0.95** | **1.3438** | **0.205** | **15.56MB** | -| S10 | 0.99 | 1.3505 | 0.259 | 15.78MB | - -Higher momentum increases zero_frac, inflating artifact size. - -### Architecture Experiments - -| Run | Config | ms/step | val_bpb | Notes | -|-----|--------|---------|---------|-------| -| S12 | 20L 640d (80M params) | 160.58 | 1.6676 | 17.75MB — over budget | -| **S1** | **26L 512d baseline** | **149.19** | **1.3438** | **Reference** | -| S13 | 26L, TRAINING_DR=2 | 281.63 | 1.3727 | ~795 effective steps, OOM at DR=3 | - -### Eval Depth Recurrence Sweep - -| Run | EVAL_DR | val_bpb | -|-----|---------|---------| -| S15 | 0/1 | 1.3685–1.3690 | -| S16 | 2 | 1.3688 | -| S17 | 3 | 1.3681 | -| S18 | 4 | 1.3690 | -| S19 | 5 | 1.3683 | - -Total range: 0.0009 bpb — pure noise. - -### Weight Decay (1500 steps) - -| Run | MUON_WD | val_bpb | zero_frac | Artifact | -|-----|---------|---------|-----------|---------| -| **S15** | **0.00** | **1.3685** | **0.179** | **15.39MB** | -| S20 | 0.04 | 1.3722 | 0.145 | 15.12MB | - -WD hurts at 1500 steps but saves 0.27MB. Reversed at full convergence — see Final Ternary Record Runs. - -### BigramHash - -| Run | Config | Steps | val_bpb | Artifact | -|-----|--------|-------|---------|---------| -| S21 | 26L + BigramHash | 1500 | 1.3681 | 15.45MB | -| R9 | 26L + BigramHash | 4010 | 1.2804 | 15.81MB | - -At full convergence: 0.020 bpb worse than R7. The 2.1MB fp16 cost of the bigram table displaces ternary layer depth at convergence. **Not viable within budget.** - -### Tied Embedding / Correction Weight / Untie Investigation - -| TIE_EMBEDDINGS | UNTIE_AT_FRACTION | LM_HEAD_RANK | Behaviour | -|---------------|-------------------|--------------|-----------| -| 0 | any | any | Untied from start — unstable, loss = log(8192) = 9.01 | -| 1 | 0.0 | 0 | Always tied — current best | -| 1 | 0.66 | 0 | Tied → full-rank untie at 66% of wallclock | -| 2 | 0.0 | 0 | Tied + correction weight residual on tok_emb | -| 2 | 0.66 | 0 | Tied + correction → full-rank untie at 66% | -| 2 | 0.66 | r | Tied + correction → SVD rank-r untie at 66% | - -**1500-step results:** - -| Run | TIE | UNTIE | RANK | val_bpb | Artifact | -|-----|-----|-------|------|---------|---------| -| S15 | 1 | 0.00 | 0 | 1.3685 | 15.39MB | -| S30 | 2 | 0.00 | 0 | 1.3678 | 15.39MB | -| S36 | 1 | 0.66 | 0 | 1.3648 | 22.83MB | -| **S37** | **2** | **0.66** | **0** | **1.3642** | **22.84MB** | -| S38 | 1 | 0.66 | 0 | 1.3667 | 22.84MB | -| S39 | 0 | 0.66 | 0 | 3.4890 | 10.88MB | - -Untie gives +0.005 bpb gain but adds 7.3MB — over budget. **TIE=1, no untie locked.** - -### LM Head Factorization (SVD-at-Untie) - -| Run | RANK | val_bpb | Artifact | Delta vs baseline | -|-----|------|---------|---------|-------------------| -| S37 | 0 (full) | 1.3642 | 22.84MB | +0.004 — over budget | -| S43 | 32 | 1.4873 | 17.27MB | −0.119 | -| S41 | 64 | 1.4243 | 17.60MB | −0.056 | -| S42 | 128 | 1.3889 | 18.40MB | −0.020 | - -SVD factorization does not recover within the remaining 34% of training. The model requires full-rank lm_head for 8192-class separability in 512-dimensional space. - -### Tied Embed LR Sweep - -| Run | TIED_EMBED_LR | MATRIX_LR | SCALAR_LR | val_bpb | -|-----|--------------|-----------|-----------|---------| -| S33 | 0.01 | 0.02 | 0.02 | 1.3723 | -| **S15** | **0.02** | **0.02** | **0.02** | **1.3685** | -| S34 | 0.03 | 0.02 | 0.02 | 1.3742 | - -Symmetric degradation. **TIED_EMBED_LR=0.02 locked.** - -### TTT-LoRA Investigation - -Test-time training with per-document LoRA adapters. Confirmed working at dev scale (−0.0315 bpb). Incompatible at convergence across 6 diagnostic runs. - -| Run | Config | val_bpb | TTT bpb | Notes | -|-----|--------|---------|---------|-------| -| S22 | TTT_LR=0.01 | 1.3690 | 1.5065 | TTT hurts | -| S23 | No lm_head_lora | 1.3690 | 1.4993 | Still hurts | -| S24 | tanh softcap | 1.3693 | 1.4982 | No improvement | -| S25 | Q/V loras only | 1.3692 | 1.5193 | Worse | -| S26 | EMBED_DIM=1024 | 1.3473 | 1.4746 | Bottleneck not cause | -| S27 | 9L (original depth) | 1.4039 | 1.5189 | Still incompatible at 9L | - -**Root cause:** Every `TernaryLinear` applies RMSNorm to its input before the weight multiply. The LoRA adapter delta is computed on the pre-normalised representation, but injected into a forward pass where base weights operate on a differently-normalised space. At 100 steps the model is poorly calibrated and LoRA signal dominates. At convergence, the base model's representations are precisely calibrated to this normalised space, and any LoRA delta corrupts rather than adapts. This incompatibility is architectural. **TTT permanently disabled.** - -### MTP (Multi-Token Prediction) - -| Run | MTP_HEADS | ms/step | val_bpb | Notes | -|-----|-----------|---------|---------|-------| -| **S47** | **0** | **149** | **1.3693** | **Baseline** | -| S45 | 2 | 157 | 1.3704 | +0.0011 worse | -| S62 | 2 | 144 | 1.3727 | +0.0034 worse | - -Confirmed at both 1500 steps and full convergence (post-fix retest: 0.006 bpb worse at both MTP=1 and MTP=2). A 60M+ parameter, 1.58-bit model does not have the parameter bandwidth for auxiliary future-planning objectives. - -### Smear Module - -| Run | SMEAR | val_bpb | ms/step | -|-----|-------|---------|---------| -| **S48** | **0** | **1.3687** | **149** | -| S49 | 1 | 1.3675 | 182 | - -+22% slower, −0.0012 bpb at 1500 steps. At full 600s wallclock, smear costs ~740 fewer training steps. Not viable within the ternary 10-minute budget but explored further in the binary track. - -### Sequence Length Schedule - -| Run | Config | val_bpb | ms/step avg | -|-----|--------|---------|-------------| -| S48 | baseline | 1.3687 | 149 | -| S51 | smear + seq@33% | 1.3660 | ~240 | -| S52 | smear + seq@33% repeat | 1.3640 | ~221 | -| **S58** | **smear + seq@33% + YaRN** | **1.3628** | **~221** | - -Real gain at 1500 steps but severe step penalty at full 600s. **Disabled for final runs.** - -### Batch Size Schedule - -| Run | Config | val_bpb | -|-----|--------|---------| -| S48 | baseline | 1.3687 | -| S50 | smear + batch | 1.3698 | -| S53 | smear + seq + batch | 1.3667 | - -Noisier gradients interfere with ternary STE convergence. **Not viable.** - -### YaRN Positional Encoding - -| Run | Config | val_bpb | -|-----|--------|---------| -| S48 | RoPE baseline | 1.3687 | -| S54 | YaRN 4096 | 1.3705 | -| S55 | YaRN 2048 | 1.3679 | -| S56 | YaRN 2048 + seq@33% | 1.3672 | -| S57 | YaRN 2048 + seq@50% + smear | 1.3637 | -| **S58** | **YaRN 2048 + seq@33% + smear** | **1.3628** | - -YaRN 4096 hurts (scale=0.25 too aggressive). YaRN 2048 marginally better. **YaRN 2048 retained; seq schedule disabled.** - -ROPE_BASE with YaRN: S63 (10000) = 1.3692, **S61 (5000) = 1.3686**. ROPE_BASE=5000 locked. - -### Sliding Window Evaluation - -| Run | Stride | Sliding bpb | Eval time | -|-----|--------|-------------|-----------| -| S60 | 16 | 1.3452* | >600s | -| S67 | 24 | 1.3146 | 592s | -| **S61/S66** | **32** | **1.3139–1.3452*** | **~350s** | - -*S60/S61 used incorrect momentum=0.90. At full convergence (F1): stride=32 gives 1.2312 sliding bpb in 280s. - -### Temperature Scaling - -Grid search over T in [0.80, 1.20] on 65,536 training tokens. 5-point grid. Optimal T was consistently 1.00 at convergence for the 512d SwiGLU architecture. At the 768d relu² architecture, T=0.90 was consistently optimal (relu² logits slightly underconfident). **TEMP_SCALING=1 in all final runs.** - -### Group Size Sweep (S73–S76, 2000 steps, 27L) - -| Run | Group Size | Layers | val_bpb | Artifact | Total | -|-----|-----------|--------|---------|----------|-------| -| S76 | 32 | 27 | 1.2739 | 17.64MB | 17.73MB | -| S75 | 64 | 27 | 1.2683 | 16.22MB | 16.31MB | -| **S73** | **128** | **27** | **1.2677** | **15.53MB** | **15.62MB** | -| S74 | 256 | 27 | 1.2699 | 15.19MB | 15.28MB | - -128 wins on both quality and compression. - -### Skip Weights Init — Zero vs Ones (S77) - -| Run | Init | val_bpb | artifact | -|-----|------|---------|---------| -| S73 | ones | 1.2677 | 15.62MB | -| S77 | zeros | 1.2781 | 15.62MB | - -Zero-init is **0.0104 bpb worse**. Decoder needs skip signal from step 0. - -### FP8/FP4 Storage with QAT - -**FP8 sweep:** - -| Run | Config | val_bpb | RT bpb | RT gap | Sliding bpb | Artifact | -|-----|--------|---------|--------|--------|-------------|---------| -| S64 | 26L fp16 | 1.3390 | 1.3390 | 0.000 | 1.3150 | 15.58MB | -| S65 | 30L fp8, no QAT | 1.3346 | 1.3394 | 0.0048 | 1.3150 | 16.92MB | -| S66 | 30L fp8, QAT | 1.3351 | 1.3380 | 0.0029 | **1.3139** | 16.92MB | -| S71 | 27L fp8, QAT | 1.3380 | 1.3405 | 0.0025 | 1.3164 | 15.42MB | -| S72 | 28L fp8, QAT | 1.3377 | 1.3406 | 0.0029 | 1.3166 | 15.92MB | - -QAT reduces fp8 RT gap from 0.0048 to 0.0029 (40% improvement). However at full convergence (F3), 28L fp8 QAT (1.2353 sliding) loses to 26L fp16 (1.2312 sliding). - -**FP4 sweep:** - -| Run | Config | val_bpb | RT bpb | RT gap | Sliding bpb | Artifact | -|-----|--------|---------|--------|--------|-------------|---------| -| S68 | 30L fp4 QAT | 1.3377 | 1.3643 | **0.0266** | 1.3404 | 16.49MB | -| S69 | 26L fp4 Tversky QAT | 1.3543 | 1.3835 | **0.0292** | 1.3606 | 15.01MB | -| S70 | 28L fp4 QAT | 1.3405 | 1.3666 | **0.0261** | 1.3424 | 15.43MB | - -FP4 RT gap of ~0.026–0.029 even with QAT is unrecoverable. **FP4 not viable at any layer count.** - -### EMBED_DIM Sweep (Full Convergence, 25L) - -| Config | EMBED_DIM | Steps | val_bpb | sliding_bpb | artifact | Notes | -|--------|-----------|-------|---------|-------------|---------|-------| -| S80 | 0 (=512) | 4500 | 1.1902 | ~1.168 est | 19.78MB | OOM on sliding eval | -| **F22** | **256** | **4720** | **1.2012** | **1.1739 (s16)** | **16.21MB** | **Best 512d result** | -| F16-era | 128 | 4310 | 1.2245 | — | 16.19MB | Pre-fix baseline | - -**EMBED_DIM=256 locked.** Budget impact: fp_params ~4.85MB vs ~2.48MB at 128 (+2.37MB). - ---- - -## Final Ternary Record Runs (F prefix) - -**Hardware:** 8×H100 SXM 80GB | **FlashAttention-3 enabled** | **Time limit:** 600 seconds - -| Run | Config | Steps | val_bpb | RT bpb | Sliding bpb | Eval time | Artifact | -|-----|--------|-------|---------|--------|-------------|-----------|---------| -| **F1** | **26L fp16, no smear, no seq** | **4362** | **1.2560** | **1.2560** | **1.2312** | **280s** | **15.85MB** | -| F2 | 26L fp16, smear + seq@33% | 3044 | 1.2779 | 1.2778 | 1.2535 | 390s | 15.85MB | -| F3 | 28L fp8 QAT, no smear, no seq | 4019 | 1.2571 | 1.2601 | 1.2353 (s24) | 385s | 16.14MB | -| F4 | 26L fp16, EMA=1 | 4145 | 1.2589 | 2.3307 | — | — | 14.52MB | -| F5 | 26L fp16, EMA fix v1 (smoke) | 407 | 1.5483 | 2.3642 | — | — | 14.90MB | -| F6 | 26L fp16, MUON_BACKEND_STEPS=3 | 4552 | 1.2558 | 1.2558 | 1.2311 (s24) | 362s | 15.81MB | -| F7 | 26L fp16, WD=0.04, steps=3 | 4499 | 1.2552 | 1.2551 | 1.2302 (s24) | 362s | 15.60MB | -| F8 | 28L fp16, WD=0.04, steps=2, LR=0.02 | 4219 | 1.2799 | 1.2801 | 1.2558 (s16) | 577s | 15.92MB | -| F9 | 28L fp16, WD=0.04, steps=2, LR=0.03 | 4231 | 1.2673 | 1.2676 | 1.2431 (s16) | 577s | 16.00MB | -| F10 | 28L fp16, WD=0.04, steps=2, LR=0.04 | 4226 | 1.2636 | 1.2636 | 1.2391 (s16) | 578s | 16.01MB | -| F11 | 28L fp16, WD=0.04, steps=3, LR=0.04 | 4137 | 1.2489 | 1.2488 | — | — | 16.69MB | -| F12 | 28L fp16, WD=0.04, steps=4, LR=0.04 | 4047 | 1.2496 | 1.2500 | — | — | 16.71MB | -| F13 | 28L fp16, WD=0.04, steps=3, LR=0.05 | 4048 | 1.2512 | 1.2510 | — | — | 16.73MB | -| F14 | 28L fp16, WD=0.04, steps=3, LR=0.08 | 4036 | 1.2576 | 1.2574 | — | — | 16.75MB | -| F15 | 27L fp16, AdamW matrix, LR=0.01 | 4676 | 1.2943 | 1.2942 | — | — | 15.71MB | -| F16 | 27L fp16, Muon, LR=0.04, WD=0.04 | 4310 | 1.2245 | — | — | — | 16.19MB | -| **F22** | **25L fp16, EMBED=256, steps=3, WD=0.04** | **4720** | **1.2012** | **1.2011** | **1.1739 (s16)** | **493s** | **16.21MB** | - -**Key findings:** F22 with EMBED_DIM=256 and corrected optimizer achieves 0.055 bpb improvement over F1 (the best pre-fix config). 28L extensively attempted (F8–F14) but artifact always over budget at competitive LR. AdamW for matrix params (F15) is clearly worse than Muon. - ---- - -## Phase 2 — Post-Optimizer-Fix Experiments (25L 512d EMBED=256) - -### EMA (Exponential Moving Average) - -| Run | Config | Steps | val_bpb | RT bpb | Artifact | -|-----|--------|-------|---------|--------|----------| -| F4 | EMA=1, decay=0.999 | 4145 | 1.2589 | 2.3307 | 14.52MB | -| — | Full run with EMA | 4144 | 1.2584 | 1.3776 | 14.94MB | - -**EMA is fundamentally incompatible with ternary quantization.** EMA averaging in fp32 produces smoother, more zero-centered weights. More latent weights near zero → more round to 0 in ternary → scale factor mismatch → 0.13 bpb RT gap. **Permanently disabled.** - -### Muon Backend Steps — Full Convergence - -| Run | Steps | step_avg | val_bpb | sliding_bpb | artifact | -|-----|-------|----------|---------|-------------|---------| -| F1 (steps=5) | 4362 | 137ms | 1.2560 | 1.2312 | 15.85MB | -| F6 (steps=3) | 4552 | 131ms | 1.2558 | 1.2311 | 15.81MB | - -6ms/step saving → 190 extra steps → quality equivalent. **MUON_BACKEND_STEPS=3 locked.** - -### Weight Decay — Full Convergence - -| Run | WD | Steps | val_bpb | sliding_bpb | zero_frac | artifact | -|-----|-----|-------|---------|-------------|-----------|---------| -| F6 | 0.00 | 4552 | 1.2558 | 1.2311 | 0.294 | 15.81MB | -| F7 | 0.04 | 4499 | 1.2552 | 1.2302 | 0.221 | 15.60MB | - -WD=0.04 wins at full convergence on the 26L architecture. However at 10L 4×MLP (Phase 4), WD=0.00 was better — wider MLP needs full weight freedom. - -### MTP Retest (Post-Fix) - -| Run | MTP_HEADS | Steps | step_avg | val_bpb | artifact | -|-----|-----------|-------|----------|---------|---------| -| F22 baseline | 0 | 4720 | 127ms | 1.2012 | 16.29MB | -| Run 26 | 1 | 4560 | 131ms | 1.2074 | 16.30MB | -| Run 27 | 2 | 4420 | 135ms | 1.2074 | 16.29MB | - -**MTP confirmed not viable post-fix.** 0.006 bpb worse at both heads. **MTP_HEADS=0 permanently locked.** - -### Tversky Phase 2 (Post-Fix, 12L 768d, fp16 Prototypes) - -Comprehensive retest with corrected optimizer and fp16 prototype storage: - -| Run | Config | Features | Pools | val_bpb | RT gap | -|-----|--------|----------|-------|---------|--------| -| 49 | No Tversky | — | — | **1.1888** | 0.0002 | -| 50 | Attn proj only | 128 | 1 | 1.1893 | 0.0000 | -| 51 | Attn proj only | 256 | 1 | 1.1894 | 0.0001 | -| 52 | Attn proj only | 32 | 1 | 1.1898 | 0.0001 | -| 53 | Attn + head | 128 | 1 | 1.1892 | — | -| 54 | Attn + head | 128 | 0 (local) | 1.1897 | +0.0006 | - -All variants within 0.001–0.002 bpb of baseline — pure noise. Confirmed by synthetic-data analysis that Tversky's asymmetric similarity only helps on tasks with directional feature relationships, which next-token prediction on web text is not. - ---- - -## Phase 3 — Architecture Exploration (Post-Optimizer-Fix) - -### Width vs Depth - -The central Phase 3 finding: wider models with fewer layers beat deeper models. - -#### 768d Scaling Curve - -| Run | Layers | Steps | step_avg | val_bpb | Artifact | -|-----|--------|-------|----------|---------|----------| -| 34 | 8 | 8110 | 74ms | 1.2894 | 12.94MB | -| 30 | 12 | 5640 | 106ms | 1.1893 | 17.50MB | -| 38 | 14 | 4900 | 122ms | 1.1870 | 19.79MB | -| 33/37 | 16 | 4320 | 139ms | 1.1825–37 | 22.08MB | -| 39 | 18 | 3870 | 155ms | 1.1801 | 24.39MB | -| 36 | 20 | 3510 | 171ms | 1.1854 | 26.67MB | - -Peak at 18L, then step penalty dominates. 8L collapses (U-Net encoder too shallow). Seed variance: Run 33 vs 37 = 0.0012 bpb. - -#### Cross-Architecture Comparison - -| Config | Layers | Dim | Steps | val_bpb | -|--------|--------|-----|-------|---------| -| F22 | 25 | 512 | 4720 | 1.2012 | -| Run 30 | 12 | 768 | 5640 | 1.1893 | -| Run 40 | 8 | 1024 | 5870 | 1.1858 | -| Run 41 | 10 | 896 | 5400 | 1.1862 | -| Run 35 | 20 | 640 | 4170 | 1.1927 | -| Run 42 | 6 | 896 | 8510 | 1.2157 | - -Width beats depth: 12L 768d (1.1893) beats 25L 512d (1.2012). Minimum viable depth: 768d ~10–12L, 896d ~10L, 1024d ~8L. - -### FP8 at 768d - -| Run | Layers | Storage | val_bpb | RT bpb | RT gap | -|-----|--------|---------|---------|--------|--------| -| 49 | 12 | fp16 | 1.1888 | 1.1886 | 0.0002 | -| 42 | 13 | fp8 | 1.1879 | 1.1900 | 0.0021 | - -FP8 RT gap acceptable at 768d. Enables extra layers within budget. - -### LM_HEAD_RANK Investigation (Post-Fix, 768d) - -| Run | Config | val_bpb | RT bpb | Total | Notes | -|-----|--------|---------|--------|-------|-------| -| Run 49 | baseline | 1.1888 | 1.1886 | 17.50MB | Reference | -| Run 43 | TIE=2, rank=256, fp8 | 1.2021 | 1.2028 | 20.41MB | Artifact bloated | -| Run 44 | TIE=0, rank=512, untie=0.0 | 1.3196 | 1.3195 | 16.92MB | Random head, no learning | -| Run 45 | TIE=2, rank=512, fp16 | 1.2312 | 1.2317 | 26.87MB | Catastrophic artifact blowup | - -Root cause: the SVD factors U and V require fp16/fp8 precision to maintain approximation quality. At any viable compression level, the two new matrices cost more storage than the original tied embedding saves. **Not viable.** - ---- - -## Phase 4 — Final Architecture Search - -### Activation Sweep (12L 768d 3×MLP, 600s) - -| Run | Activation | MLP | ms/step | Steps | val_bpb | Artifact | -|-----|-----------|-----|---------|-------|---------|----------| -| F55 | relu | 2× | 88.7 | 6760 | 1.2284 | 14.49MB | -| **F56** | **relu²** | **2×** | **89.5** | **6700** | **1.2042** | **14.48MB** | -| F60 | leaky relu | 3× | 102.6 | 5840 | 1.2094 | 17.50MB | -| **F57** | **relu²** | **3×** | **101.5** | **5910** | **1.1878** | **17.51MB** | -| F58 | swiglu | 3× | 127.4 | 4700 | 1.1786 | 22.05MB | -| **F59** | **swiglu** | **3×** | **127.3** | **4710** | **1.1771** | **21.96MB** | - -relu² beats relu by 0.024 bpb at no cost — strictly dominant. relu² locked for budget-constrained path. - -### MLP Width Sweep (600s) - -| Run | Activation | MLP | Layers | ms/step | Steps | val_bpb | Artifact | -|-----|-----------|-----|--------|---------|-------|---------|----------| -| F56 | relu² | 2× | 12 | 89.5 | 6700 | 1.2042 | 14.48MB | -| F64 | relu² | 3× | 12 | 99.4 | 6030 | 1.1873 | 17.50MB | -| F75 | relu² | 4× | 12 | 91.6 | 6550 | 1.1795 | 20.54MB | -| F82 | relu² | 4× | 10 | 91.6 | 6550 | 1.1861 | 16.04MB | - -4× MLP at 10L beats 3× at 12L within similar budget. - -### Layer Count vs MLP Width (fp8, 600s) - -| Run | Config | Layers | ms/step | Steps | val_bpb | RT bpb | Artifact | -|-----|--------|--------|---------|-------|---------|--------|----------| -| F78 | relu² 3× fp8 | 12 | 99.3 | 6040 | 1.1884 | 1.1898 | 15.80MB | -| F77 | relu² 3× fp8 | 13 | 106.6 | 5630 | 1.2065 | 1.2077 | 16.96MB | -| F80 | relu² 2× fp8 | 15 | 106.9 | 5610 | 1.2120 | 1.2136 | 15.45MB | -| F81 | relu² 2× fp8 | 16 | 113.9 | 5270 | 1.1996 | 1.2009 | 16.33MB | -| F79 | relu² 3× fp8 | 11 | 91.5 | 6560 | 1.1920 | 1.1933 | 14.66MB | -| **F82** | **relu² 4× fp8** | **10** | **91.6** | **6550** | **1.1861** | **1.1877** | **16.04MB** | -| F83 | swiglu 3× fp8 | 10 | 105.5 | 5690 | 1.1842 | 1.1853 | 17.29MB | - -### Weight Decay at 10L 4×MLP fp8 - -| Run | WD | val_bpb | RT bpb | Artifact | -|-----|-----|---------|--------|----------| -| F82 | 0.04 | 1.1861 | 1.1877 | 16.04MB | -| F84 | 0.08 | 1.1983 | 1.1998 | 16.04MB | -| **F85** | **0.00** | **1.1828** | **1.1844** | **16.02MB** | -| S87 | 0.00 | 1.1831 | 1.1843 | 16.01MB | -| **F88** | **0.00 (EMBED=254)** | **1.1820** | **1.1839** | **16.00MB — FITS** | - -WD=0 optimal at 10L 4× — opposite to 26L result. Wider MLP needs full weight freedom. - ---- - -## Binary Quantisation Track - -### Motivation - -Binary quantisation constrains weights to {-1, +1} with no zero state. At 1 bit/param vs ternary's 1.6 bits/param, binary packs approximately 60% more parameters per MB. The hypothesis was that additional depth could compensate for the loss of the zero state. - -Starting point: the ternary best config (10L, 768d, 8h, 4kv, 4× relu², FP8, 524k batch, 599s) scoring 1.1578 sliding bpb. - -### Binary Scaling Runs - -| Run | Layers | MLP | FP | Other | Steps | ms/step | Sliding bpb | Artifact | Fits | -|-----|--------|-----|-----|-------|-------|---------|-------------|----------|------| -| F17 | 17 | 4× | FP8 | — | 4010 | 149 | 1.2022 | 17.45MB | No | -| **F1** | **14** | **4×** | **FP8** | **—** | **4820** | **124** | **1.1824** | **14.74MB** | **Yes** | -| F2 | 14 | 4× | FP8 | EMA | 4800 | 125 | 1.2110 | 14.56MB | Yes | -| S3 | 15 | 4× | FP8 | — | 1000 | 133 | 1.3114 | 15.65MB | Yes | -| S4 | 20 | 3× | FP8 | — | 1000 | 160 | 1.3077 | 16.90MB | No | -| S5 | 21 | 3× | FP4 | — | 1000 | 167 | 1.3676 | 16.64MB | No | -| S6 | 19 | 3× | FP8 | — | 1000 | 152 | 1.3130 | 16.16MB | No | -| S7 | 15 | 4× | FP8 | refiner | 1000 | 135 | 1.3123 | 15.89MB | Yes | -| S8 | 15 | 4× | FP8 | smear | 1000 | 155 | 1.3043 | 15.67MB | Yes | -| S9 | 15 | 4× | FP8 | tversky_attn | 1000 | 179 | 1.4016 | 15.74MB | Yes | - -### Key Decisions from Binary Scaling - -**MLP width (4× vs 3×):** 4× won even when 3× received 4–5 extra layers. S3 (15L 4×) outperformed S6 (19L 3×) at matched steps. Width matters more than depth past a minimum viable layer count. - -**FP storage (FP8 vs FP4):** FP4 added a 0.06 bpb roundtrip penalty and was immediately ruled out. FP8 used for all non-binary tensors. - -**Layer count:** 17L was the theoretical maximum at 4× FP8 but landed 1.45MB over budget. 15L at 15.65MB was the maximum that fit. 14L left 1.26MB headroom. - -**EMA:** Mathematically sound for binary (no zero bucket means `mean(|Q|)=1.0` always, clean roundtrip). In practice, 0.03 bpb worse — the smoothed weights apparently hurt binary's learning dynamics despite the clean quantisation math. - -**Smear:** 0.007 bpb gain at 1000 steps but added 22ms/step overhead (133→155ms). Retained for the extended binary run to test whether the gain survives the step penalty at longer training. - -**Refiner (causal conv):** Neutral at 1000 steps, added 2ms/step. Not justified. - -**Tversky attention projection:** 0.09 bpb worse. Completely incompatible with binary weights. - -**Activation:** relu² inherited from ternary sweeps, not retested for binary. SwiGLU would cost ~4MB extra across 15 layers, eliminating the layer budget advantage. - -### Extended Binary Run (Unconstrained Compute) - -To measure the binary architecture's convergence ceiling without the 10-minute wallclock constraint, a single extended run was conducted at 50,000 steps (~2 hours on 8×H100). - -**Configuration:** 15L 768d, 4× relu², FP8, smear, 524k batch tokens, seed=42, MUON_WD=0.0 - -``` -step:50000/50000 val_loss:2.9692 val_bpb:1.1497 train_time:7763s -artifact:15.60MB binary:97320960(13685760B) fp:2542200(2585072B) code:70399 -budget:15670651/16000000 (15.67/16.00MB) FITS -final_binary_roundtrip val_loss:2.9743 val_bpb:1.1516 -temp_scaling optimal_T:0.90 -final_sliding val_loss:2.9027 val_bpb:1.1239 (stride=16, T=0.90) -``` - -| Metric | Value | -|--------|-------| -| val_bpb | 1.1497 | -| RT bpb | 1.1516 | -| Sliding bpb | **1.1239** | -| Artifact | 15.60MB (15.67MB total) | -| Params | 97,320,960 | -| Steps | 50,000 | -| ms/step | 155.3 | -| Training time | ~2.15 hours | - -The 1.1239 sliding bpb demonstrates that with sufficient compute the binary architecture reaches strong quality. This validates the compression approach — nearly 100M parameters in 15.67MB via 1-bit quantisation — though the 50k steps required far exceeds the competition's 10-minute budget. - -### Binary vs Ternary at Equal Architecture (Dev Scale) - -| Metric | Binary | Ternary | Delta | -|--------|--------|---------|-------| -| val_bpb | 1.8609 | 1.8113 | Ternary wins by 0.050 | -| Artifact | 9.14MB | 11.56MB | Binary saves 2.42MB | -| ms/step | 918 | 924 | Identical | -| RT gap | 0.000 | 0.000 | Both clean | - -Ternary is better at equal architecture. Binary's only advantage is fitting more layers in the same budget. - -### Binary Conclusion - -Binary lost the depth-for-sparsity trade. The 5 extra layers (15L binary vs 10L ternary) could not overcome ternary's representational advantage from the zero state. The 0.0016 bpb gap measured at 500 dev steps significantly understated the true difference at convergence. Ternary at 1.1578 sliding bpb (10-minute budget) outperforms binary's best fitting run (F1: 1.1824 at 14L without smear) by 0.025 bpb. Even the over-budget 17L binary run (1.2022) could not match ternary. - -The extended 50k-step binary run reaching 1.1239 sliding bpb shows that binary has a competitive convergence ceiling, but it requires approximately 8× more training steps to approach competitive quality — well beyond the competition constraints. - ---- - -## Grouped MLP Investigation - -Tested GroupedTernaryLinear: splits MLP into independent groups for parameter/speed savings. - -### Real Model Results (relu² 3×, 768d, 600s) - -| Run | Config | Layers | ms/step | Steps | val_bpb | Artifact | -|-----|--------|--------|---------|-------|---------|----------| -| F64 | standard | 12 | 99.4 | 6030 | 1.1873 | 17.50MB | -| F72 | g=2 | 12 | 87.4 | 6870 | 1.2180 | 12.97MB | -| F71 | g=4 | 12 | 83.5 | 7190 | 1.2429 | 10.74MB | -| F73 | g=2 | 16 | 114.2 | 5260 | 1.2037 | 16.04MB | -| F74 | swiglu g=2 | 12 | 113.3 | 5300 | 1.2084 | 15.24MB | - -Cross-group isolation costs 0.031–0.056 bpb. Even with 4 extra layers (F73), only recovers 0.014 of the deficit. **Not viable for language modelling.** - ---- - -## Differential Attention - -Microsoft (2024): computes two attention maps from split Q/K and takes their difference. - -| Run | Config | ms/step | Steps | val_bpb | -|-----|--------|---------|-------|---------| -| F64 | standard | 99.4 | 6030 | 1.1873 | -| F68 | diff_attn | 109.3 | 5480 | 1.2094 | - -Splits 96-dim heads into 48-dim sub-heads — insufficient dimensionality for meaningful attention patterns at this model scale. - ---- - -## Sequence Refiner (CausalConvRefiner) - -| Run | Config | ms/step | Steps | val_bpb | Artifact | -|-----|--------|---------|-------|---------|----------| -| F64 | none | 99.4 | 6030 | 1.1873 | 17.50MB | -| F69 | k=3 | 102.2 | 5860 | 1.1885 | 19.92MB | -| F70 | k=5 | 103.0 | 5820 | 1.2018 | 18.13MB | - -Noise-level quality improvement with storage bloat. 12 attention layers already saturate local pattern capture. - ---- - -## ByteCNN Vocabulary Generator - -Replaces `nn.Embedding(8192, 256)` with a CNN that generates the embedding matrix from byte spellings. - -``` -step:500 loss:9.0471 — step:2000 loss:9.0471 (flat, no learning) -``` - -All 8192 CNN-generated embeddings converge to near-identical vectors at initialisation. The CNN's inductive bias (byte-similar tokens → similar embeddings) destroys the initial diversity needed for gradient signal. - ---- - -## Asymmetric Tokenizer Investigation - -8k BPE input with 256-byte output to eliminate large output projection. - -| Model | BPB | Notes | -|-------|-----|-------| -| Standard (tied, emb=256) | 3.10 | reference | -| Asymmetric parallel (emb=256) | 8.65 | byte independence assumption fails | -| Asymmetric autoregressive (emb=256) | 8.17 | tiny GRU insufficient capacity | - -Multi-byte parallel heads assume conditional independence between bytes within a token — mathematically incorrect. Sequence-length mismatch (7 BPE tokens → 70 bytes) also incompatible with the evaluation framework. - ---- - -## Linear Alternative Exploration - -Systematic notebook testing of linear layer alternatives at real model dimensions (768d). - -### Projection Benchmark (DIM → DIM, H100) - -| Model | Params | ms | vs Linear | -|-------|--------|-----|-----------| -| Linear | 589,824 | 0.07ms | 1.00× | -| LowRank r=64 | 98,304 | 0.03ms | 0.44× | -| BlockDiag b=4 | 147,456 | 0.03ms | 0.40× | -| Grouped g=4 | 147,456 | 0.03ms | 0.40× | -| BD4 + mix32 | 196,608 | 0.07ms | 0.97× | -| Hash 65536 | 65,536 | 0.08ms | 1.13× | - -BlockDiag/Grouped offer speed advantages but cross-group isolation degrades LM quality in practice. - ---- - -## H100 Microbenchmark Results - -Standalone kernel timing vs torch.compile behaviour (critical lesson: standalone microbenchmarks can mislead when torch.compile fuses operations). - -### STE Speed - -| Variant | ms/call | -|---------|---------| -| Current | 0.041 | -| Reciprocal | 0.043 | - -No gain — 48 STE calls/step = ~2ms overhead (unavoidable). - -### Contiguous Checks - -Q and K are contiguous after RoPE. V is non-contiguous (view into fused QKV). V's `.contiguous()` costs 0.065ms/call = 0.78ms/step (necessary for flash_attn). - -### RoPE Variants - -Current (half-split + cat) is fastest at 0.52ms/call. - -### Softcap: Poly5 vs Tanh - -| Variant | ms/call | -|---------|---------| -| Poly5 (current) | 8.43 | -| Poly3 | 5.98 | -| Tanh | 2.12 | -| Hardtanh | 0.71 | - -**Critical finding:** Tanh is 4× faster standalone due to H100 hardware transcendental units. However in the real training loop, torch.compile fuses poly5 with surrounding ops into a single kernel. **Switching to tanh broke fusion — F63 was 16ms/step slower.** Poly5 retained. - -### CE + Z-Loss Fusion - -| Variant | ms/call (fwd+bwd) | -|---------|-------------------| -| Separate (current) | 16.56 | -| Fused (shared LSE) | 12.33 | - -**Same lesson:** 4.2ms saving standalone, but torch.compile already optimises `F.cross_entropy`. Manual gather+logsumexp prevents optimisation. Current approach retained. - ---- - -## Efficiency Analysis - -### BPB Gained Per Component - -| Component | BPB gain | Source | -|-----------|----------|--------| -| relu → relu² | −0.024 | F55 vs F56 | -| MLP 2× → 3× (relu²) | −0.017 | F56 vs F64 | -| MLP 3× → 4× (relu²) | −0.008 | F64 vs F75 | -| relu² → swiglu (at 3×) | −0.010 | F64 vs F59 | -| +1 layer (average) | −0.0012 | scaling data | -| fp16 → fp8 (RT penalty) | +0.002 | run 42 vs 49 | -| Sliding eval stride=16 | −0.025 | F22 data | -| WD=0.04 vs WD=0 (at 26L) | −0.001 | F7 vs F6 | - -### MB Cost Per Component - -| Component | MB/layer | -|-----------|----------| -| relu² 2× layer | 0.767 | -| relu² 3× layer | 1.003 | -| relu² 4× layer | 1.220 | -| swiglu 3× layer | 1.357 | -| fp16 → fp8 (fixed saving) | −2.51 | - -### Efficiency Ratio (BPB Gained Per MB Spent) - -| Change | BPB gain | MB cost | BPB/MB | -|--------|----------|---------|--------| -| relu → relu² | −0.024 | 0.00 | infinite (free) | -| Sliding eval | −0.025 | 0.00 | infinite (free) | -| MLP 2× → 3× | −0.017 | +2.83 (12L) | −0.0060/MB | -| MLP 3× → 4× | −0.008 | +2.83 (12L) | −0.0028/MB | -| relu² → swiglu | −0.010 | +4.25 (12L) | −0.0024/MB | -| +1 layer (relu² 2×) | −0.0012 | +0.767 | −0.0016/MB | -| +1 layer (relu² 3×) | −0.0012 | +1.003 | −0.0012/MB | - -MLP 2×→3× is the most efficient paid upgrade. relu² and sliding eval are free wins. - -### Layer Budget at 768d - -| Config | Max Layers | Est ms/step | -|--------|-----------|-------------| -| relu² 2× fp16 | 14L | ~95ms | -| relu² 2× fp8 | 17L | ~97ms | -| relu² 3× fp16 | 10L | ~99ms | -| relu² 3× fp8 | 13L | ~106ms | -| relu² 4× fp8 | 10L | ~92ms | -| swiglu 3× fp8 | 9L | ~105ms | - ---- - -## Ternary-Incompatible Techniques - -These are not merely unhelpful but structurally incompatible with 1.58-bit quantisation: - -| Technique | Mechanism of failure | -|-----------|---------------------| -| **EMA** | Weight averaging → values cluster near zero → ternary rounds most to 0 → 0.12 bpb RT gap | -| **TTT-LoRA** | LoRA delta computed outside RMSNorm space that TernaryLinear normalises into. Corrupts calibrated representations at convergence | -| **Ternary prototypes + sigmoid** | Sigmoid membership needs continuous values. Ternary {-1,0,+1} collapses membership patterns → 0.077 RT gap | -| **LM head rank factorisation** | SVD factors U,V need fp16 precision. Storage exceeds original tied embedding | - ---- - -## Software Optimisations - -| Optimisation | Saving | Notes | -|---|---|---| -| Fused QKV (c_q+c_k+c_v → single matmul) | ~2ms/step | Safe: in_features divisible by all group sizes | -| Fused SwiGLU/relu² (gate+up → single wide matmul) | ~2-4ms/step | Same params, fewer kernel launches | -| Z-loss regularisation (1e-4 x logsumexp²) | quality | Anchors logits, keeps STE gradients sharp | -| DataLoader int16 transfer (pin then cast on GPU) | ~1ms/step | 4× less PCIe bandwidth | -| FlashAttention-3 | ~13ms/step | ~9% speedup, ~380 free training steps | -| TernaryLinear bf16 weights, cleaner STE | ~1ms/step | Eliminates fp32 roundtrip | -| DDP static_graph + gradient_as_bucket_view | ~1ms/step | Free when find_unused=False | -| Fused optimizer loop (LR set + step in one pass) | ~0.5ms/step | Fewer Python-level iterations | -| Removed CUBLAS determinism tax | ~1ms/step | Not required for competition | -| Temperature grid: 5 points instead of 21 | ~1s total | T=0.90 consistently with relu² | -| Temp scaling moved to eval phase | ~3 steps gained | No longer steals training time | -| `_e()` helper for Hyperparameters | -1.8KB code | Eliminates env var boilerplate | -| 3D tensor ternary quantisation | storage fix | Conv1d weights reshaped to 2D for ternary | - ---- - -## Rejected Techniques (Summary) - -| Technique | Reason | -|-----------|--------| -| Tversky (all variants) | Quality-neutral on FineWeb LM — confirmed via synthetic data analysis; speed penalty with relu² | -| Differential attention | Halved head_dim (96→48) degrades quality at this model scale | -| Grouped MLP (g=2, g=4) | Cross-group isolation costs 0.031–0.056 bpb; not recoverable with extra layers | -| CausalConvRefiner | Noise-level quality; storage bloat from Conv1d weights | -| ByteCNN vocabulary generator | Embedding collapse — CNN inductive bias destroys initial diversity | -| Asymmetric tokenizer | Byte independence assumption incorrect; sequence mismatch with eval framework | -| EMA | Incompatible with ternary — weight averaging causes 0.12 bpb RT gap | -| TTT-LoRA | Architectural incompatibility with RMSNorm space in TernaryLinear | -| LM head factorisation | SVD factors bloat artifact beyond budget; unrecoverable quality loss | -| MTP | 0.006 bpb worse — model capacity too limited for auxiliary objectives | -| BigramHash | 0.020 bpb worse at convergence; fp16 table displaces ternary layers | -| Seq/batch schedule | Recompile and step penalties dominate at 600s wallclock | -| SmearModule | +22% step cost for −0.001 gain within ternary 10-minute budget | -| Depth recurrence | Halves effective steps; OOM at DR=3 | -| AdamW for matrix params | Clearly inferior to Muon for ternary weights | -| FP4 storage | 0.026–0.029 RT gap even with QAT — unrecoverable | -| Tanh softcap | Faster standalone but breaks torch.compile kernel fusion | -| Fused CE+Z-loss | Same — breaks compile optimisation | -| 16 heads at 768d | 48-dim head_dim insufficient for meaningful attention | -| relu (plain) | Strictly dominated by relu² | -| leaky relu | Strictly dominated by relu² | -| Distillation (in-run) | Train-from-scratch teacher always worse than supervised | -| reduce-overhead compile | Rotary + embed_proj_rev incompatible with CUDA graphs | -| max-autotune compile | 30+ minute kernel search prohibitive for 600s runs | -| Skip weights zero-init | 0.010 bpb worse — decoder needs skip signal from step 0 | -| EMBED_DIM=0 (full 512) | 19.78MB artifact — 3.78MB over budget | -| Untie lm_head full-rank | 7.3MB budget overrun not justified by 0.005 bpb gain | - ---- - -## Decision Log - -| Decision | Rationale | -|----------|-----------| -| 8k vocabulary | −0.42 bpb, largest single win | -| relu² activation | −0.024 bpb vs relu, free (no cost) | -| 4×MLP width | Best BPB within budget at 10L; 0.008 better than 3× | -| 10L 768d | Minimum viable depth at 768d with maximum MLP width | -| WD=0.0 at 10L 4× | Opposite to deep models — wider MLP needs full weight freedom | -| fp8 storage | Halves fp_params (5MB→2.5MB), enables wider MLP within budget | -| EMBED_DIM=254 | 256-2 dims to fit artifact+code under 16,000,000 byte budget; ~0.0004 bpb cost | -| BITNET_GROUP_SIZE=128 | Same quality as 64; saves 0.69MB | -| 8 heads, 4 KV, 96-dim head_dim | 16h at 48-dim insufficient; MHA only +0.0012 at +1.5MB | -| Poly softcap | Fuses with torch.compile; tanh breaks fusion | -| ROPE_BASE=5000 + YaRN 2048 | Best frequency calibration | -| Muon optimizer | Newton-Schulz normalisation compensates for ternary STE gradient attenuation | -| MUON_BACKEND_STEPS=3 | Equivalent to 5 at convergence; +190 extra steps | -| MUON_MOMENTUM=0.95 | Both directions degrade; affects artifact via zero_frac | -| WARMDOWN=20% | Asymmetric — too little hurts more than too much | -| MATRIX_LR=0.04 | Higher LR compensates for ternary STE gradient attenuation | -| SCALAR_LR=0.02 | Optimal — scalars do not pass through STE | -| TIED_EMBED_LR=0.02 | Optimal | -| TRAIN_BATCH_TOKENS=524k | Optimal tradeoff between gradient quality and step count | -| Base-3 + LZMA | 39% reduction over int8+zlib | -| Shrinkage fix | Eliminates all RT gaps universally | -| Skip weights ones-init | Decoder needs skip signal from step 0; zeros costs 0.010 bpb | -| Tied embeddings | Untie costs 7.3MB; not justified | -| Standard attn projection | Tversky quality-neutral; grouped destroys quality | -| No EMA | Fundamentally incompatible with ternary | -| No TTT | RMSNorm space incompatibility confirmed across 6 runs | -| No MTP | Confirmed post-fix: 0.006 bpb worse | -| Temperature scaling T=0.90 | relu² logits slightly underconfident; auto-calibrated | -| Fused QKV + relu² | ~130-180 free training steps per run | -| Z-loss regularisation | Anchors logits; keeps STE gradients sharp | -| FlashAttention-3 | Free ~380 extra training steps per 600s run | -| Sliding eval stride=16 | Best quality when eval budget unconstrained | -| Optimizer coverage fix | embed_proj/embed_proj_rev now train; +0.055 bpb improvement | -| MAX_WALLCLOCK_SECONDS=599 | 1s leeway for safety margin | -| Binary 15L 768d 4× fp8 | 97M params in 15.67MB — maximum parameter density; convergence ceiling validated at 50k steps | diff --git a/records/track_10min_16mb/2026-03-24_74M_Ternary_UNet_FP8_10L_8192BPE_YaRN_NeoMuon/fineweb_8192_bpe.model b/records/track_10min_16mb/2026-03-24_74M_Ternary_UNet_FP8_10L_8192BPE_YaRN_NeoMuon/fineweb_8192_bpe.model deleted file mode 100644 index 6574784f5f..0000000000 Binary files a/records/track_10min_16mb/2026-03-24_74M_Ternary_UNet_FP8_10L_8192BPE_YaRN_NeoMuon/fineweb_8192_bpe.model and /dev/null differ diff --git a/records/track_10min_16mb/2026-03-24_74M_Ternary_UNet_FP8_10L_8192BPE_YaRN_NeoMuon/fineweb_8192_bpe.vocab b/records/track_10min_16mb/2026-03-24_74M_Ternary_UNet_FP8_10L_8192BPE_YaRN_NeoMuon/fineweb_8192_bpe.vocab deleted file mode 100644 index 6e194bf03c..0000000000 --- a/records/track_10min_16mb/2026-03-24_74M_Ternary_UNet_FP8_10L_8192BPE_YaRN_NeoMuon/fineweb_8192_bpe.vocab +++ /dev/null @@ -1,8192 +0,0 @@ - 0 - 0 - 0 - 0 -<0x00> 0 -<0x01> 0 -<0x02> 0 -<0x03> 0 -<0x04> 0 -<0x05> 0 -<0x06> 0 -<0x07> 0 -<0x08> 0 -<0x09> 0 -<0x0A> 0 -<0x0B> 0 -<0x0C> 0 -<0x0D> 0 -<0x0E> 0 -<0x0F> 0 -<0x10> 0 -<0x11> 0 -<0x12> 0 -<0x13> 0 -<0x14> 0 -<0x15> 0 -<0x16> 0 -<0x17> 0 -<0x18> 0 -<0x19> 0 -<0x1A> 0 -<0x1B> 0 -<0x1C> 0 -<0x1D> 0 -<0x1E> 0 -<0x1F> 0 -<0x20> 0 -<0x21> 0 -<0x22> 0 -<0x23> 0 -<0x24> 0 -<0x25> 0 -<0x26> 0 -<0x27> 0 -<0x28> 0 -<0x29> 0 -<0x2A> 0 -<0x2B> 0 -<0x2C> 0 -<0x2D> 0 -<0x2E> 0 -<0x2F> 0 -<0x30> 0 -<0x31> 0 -<0x32> 0 -<0x33> 0 -<0x34> 0 -<0x35> 0 -<0x36> 0 -<0x37> 0 -<0x38> 0 -<0x39> 0 -<0x3A> 0 -<0x3B> 0 -<0x3C> 0 -<0x3D> 0 -<0x3E> 0 -<0x3F> 0 -<0x40> 0 -<0x41> 0 -<0x42> 0 -<0x43> 0 -<0x44> 0 -<0x45> 0 -<0x46> 0 -<0x47> 0 -<0x48> 0 -<0x49> 0 -<0x4A> 0 -<0x4B> 0 -<0x4C> 0 -<0x4D> 0 -<0x4E> 0 -<0x4F> 0 -<0x50> 0 -<0x51> 0 -<0x52> 0 -<0x53> 0 -<0x54> 0 -<0x55> 0 -<0x56> 0 -<0x57> 0 -<0x58> 0 -<0x59> 0 -<0x5A> 0 -<0x5B> 0 -<0x5C> 0 -<0x5D> 0 -<0x5E> 0 -<0x5F> 0 -<0x60> 0 -<0x61> 0 -<0x62> 0 -<0x63> 0 -<0x64> 0 -<0x65> 0 -<0x66> 0 -<0x67> 0 -<0x68> 0 -<0x69> 0 -<0x6A> 0 -<0x6B> 0 -<0x6C> 0 -<0x6D> 0 -<0x6E> 0 -<0x6F> 0 -<0x70> 0 -<0x71> 0 -<0x72> 0 -<0x73> 0 -<0x74> 0 -<0x75> 0 -<0x76> 0 -<0x77> 0 -<0x78> 0 -<0x79> 0 -<0x7A> 0 -<0x7B> 0 -<0x7C> 0 -<0x7D> 0 -<0x7E> 0 -<0x7F> 0 -<0x80> 0 -<0x81> 0 -<0x82> 0 -<0x83> 0 -<0x84> 0 -<0x85> 0 -<0x86> 0 -<0x87> 0 -<0x88> 0 -<0x89> 0 -<0x8A> 0 -<0x8B> 0 -<0x8C> 0 -<0x8D> 0 -<0x8E> 0 -<0x8F> 0 -<0x90> 0 -<0x91> 0 -<0x92> 0 -<0x93> 0 -<0x94> 0 -<0x95> 0 -<0x96> 0 -<0x97> 0 -<0x98> 0 -<0x99> 0 -<0x9A> 0 -<0x9B> 0 -<0x9C> 0 -<0x9D> 0 -<0x9E> 0 -<0x9F> 0 -<0xA0> 0 -<0xA1> 0 -<0xA2> 0 -<0xA3> 0 -<0xA4> 0 -<0xA5> 0 -<0xA6> 0 -<0xA7> 0 -<0xA8> 0 -<0xA9> 0 -<0xAA> 0 -<0xAB> 0 -<0xAC> 0 -<0xAD> 0 -<0xAE> 0 -<0xAF> 0 -<0xB0> 0 -<0xB1> 0 -<0xB2> 0 -<0xB3> 0 -<0xB4> 0 -<0xB5> 0 -<0xB6> 0 -<0xB7> 0 -<0xB8> 0 -<0xB9> 0 -<0xBA> 0 -<0xBB> 0 -<0xBC> 0 -<0xBD> 0 -<0xBE> 0 -<0xBF> 0 -<0xC0> 0 -<0xC1> 0 -<0xC2> 0 -<0xC3> 0 -<0xC4> 0 -<0xC5> 0 -<0xC6> 0 -<0xC7> 0 -<0xC8> 0 -<0xC9> 0 -<0xCA> 0 -<0xCB> 0 -<0xCC> 0 -<0xCD> 0 -<0xCE> 0 -<0xCF> 0 -<0xD0> 0 -<0xD1> 0 -<0xD2> 0 -<0xD3> 0 -<0xD4> 0 -<0xD5> 0 -<0xD6> 0 -<0xD7> 0 -<0xD8> 0 -<0xD9> 0 -<0xDA> 0 -<0xDB> 0 -<0xDC> 0 -<0xDD> 0 -<0xDE> 0 -<0xDF> 0 -<0xE0> 0 -<0xE1> 0 -<0xE2> 0 -<0xE3> 0 -<0xE4> 0 -<0xE5> 0 -<0xE6> 0 -<0xE7> 0 -<0xE8> 0 -<0xE9> 0 -<0xEA> 0 -<0xEB> 0 -<0xEC> 0 -<0xED> 0 -<0xEE> 0 -<0xEF> 0 -<0xF0> 0 -<0xF1> 0 -<0xF2> 0 -<0xF3> 0 -<0xF4> 0 -<0xF5> 0 -<0xF6> 0 -<0xF7> 0 -<0xF8> 0 -<0xF9> 0 -<0xFA> 0 -<0xFB> 0 -<0xFC> 0 -<0xFD> 0 -<0xFE> 0 -<0xFF> 0 -▁t -0 -▁a -1 -in -2 -he -3 -re -4 -on -5 -er -6 -▁the -7 -▁s -8 -▁w -9 -or -10 -at -11 -nd -12 -ou -13 -▁c -14 -it -15 -es -16 -▁f -17 -is -18 -ing -19 -en -20 -▁b -21 -▁p -22 -▁o -23 -an -24 -ed -25 -▁to -26 -al -27 -▁m -28 -ar -29 -▁and -30 -▁in -31 -▁of -32 -▁d -33 -le -34 -ic -35 -as -36 -▁h -37 -om -38 -ion -39 -▁th -40 -il -41 -▁T -42 -▁l -43 -ent -44 -ve -45 -▁I -46 -ro -47 -st -48 -▁y -49 -▁e -50 -▁re -51 -▁n -52 -▁S -53 -▁g -54 -et -55 -ct -56 -▁A -57 -▁C -58 -▁you -59 -ly -60 -ay -61 -id -62 -▁for -63 -▁on -64 -▁is -65 -ot -66 -▁be -67 -ow -68 -ol -69 -am -70 -ac -71 -ig -72 -us -73 -ad -74 -el -75 -▁M -76 -im -77 -ver -78 -ith -79 -ut -80 -▁st -81 -▁P -82 -ation -83 -▁with -84 -ur -85 -▁B -86 -▁that -87 -ir -88 -▁W -89 -ch -90 -▁he -91 -▁it -92 -▁The -93 -ce -94 -ill -95 -ers -96 -un -97 -▁al -98 -▁D -99 -ul -100 -▁an -101 -▁H -102 -▁F -103 -out -104 -ra -105 -ke -106 -▁pro -107 -▁wh -108 -▁as -109 -▁are -110 -se -111 -ter -112 -▁we -113 -▁ha -114 -▁R -115 -oo -116 -if -117 -ge -118 -our -119 -pp -120 -▁at -121 -ate -122 -ess -123 -▁com -124 -▁or -125 -▁con -126 -▁L -127 -her -128 -ore -129 -est -130 -▁fr -131 -ment -132 -igh -133 -▁- -134 -ab -135 -▁N -136 -▁se -137 -▁ne -138 -ld -139 -ort -140 -▁G -141 -▁E -142 -ri -143 -ist -144 -▁( -145 -▁your -146 -op -147 -▁O -148 -▁ex -149 -em -150 -ure -151 -ity -152 -▁r -153 -ant -154 -qu -155 -▁v -156 -▁was -157 -art -158 -ust -159 -▁have -160 -ive -161 -um -162 -▁this -163 -▁from -164 -pe -165 -▁de -166 -oc -167 -▁sh -168 -th -169 -ain -170 -up -171 -ies -172 -▁will -173 -▁by -174 -ight -175 -▁ch -176 -and -177 -os -178 -▁can -179 -ie -180 -nt -181 -all -182 -▁us -183 -ome -184 -▁not -185 -ard -186 -ud -187 -▁le -188 -res -189 -▁J -190 -ast -191 -.. -192 -ost -193 -▁pl -194 -ear -195 -▁ab -196 -ack -197 -▁su -198 -iv -199 -▁wor -200 -gh -201 -▁all -202 -rou -203 -ide -204 -ould -205 -▁j -206 -ell -207 -ial -208 -te -209 -ak -210 -ine -211 -od -212 -ag -213 -are -214 -▁has -215 -ice -216 -▁U -217 -▁Th -218 -▁do -219 -age -220 -▁k -221 -ook -222 -fe -223 -▁ad -224 -▁me -225 -ip -226 -▁In -227 -▁comp -228 -▁but -229 -▁up -230 -▁out -231 -ake -232 -per -233 -red -234 -▁whe -235 -ions -236 -ally -237 -pt -238 -ry -239 -og -240 -one -241 -▁more -242 -ail -243 -able -244 -ind -245 -▁my -246 -ite -247 -▁our -248 -ther -249 -▁en -250 -▁“ -251 -very -252 -▁Y -253 -▁sa -254 -▁so -255 -ich -256 -ime -257 -cc -258 -▁cl -259 -ong -260 -▁their -261 -▁K -262 -ated -263 -ood -264 -ame -265 -orm -266 -▁St -267 -▁they -268 -▁one -269 -▁te -270 -ber -271 -ace -272 -ike -273 -iz -274 -▁about -275 -so -276 -ous -277 -du -278 -ick -279 -ase -280 -ans -281 -▁" -282 -▁V -283 -pl -284 -▁cont -285 -act -286 -ia -287 -▁im -288 -▁work -289 -▁un -290 -▁who -291 -ree -292 -cl -293 -ire -294 -▁fe -295 -ign -296 -▁off -297 -▁his -298 -▁man -299 -ue -300 -ff -301 -ance -302 -▁go -303 -ll -304 -ach -305 -▁year -306 -▁new -307 -▁tr -308 -ays -309 -ne -310 -reat -311 -▁It -312 -ction -313 -ub -314 -ib -315 -ult -316 -▁app -317 -erv -318 -und -319 -▁We -320 -ap -321 -▁Ch -322 -ass -323 -▁qu -324 -ep -325 -▁res -326 -ary -327 -ark -328 -▁sp -329 -▁per -330 -ations -331 -ile -332 -ove -333 -form -334 -▁int -335 -▁get -336 -▁also -337 -▁time -338 -▁which -339 -ount -340 -ven -341 -▁like -342 -own -343 -▁other -344 -ents -345 -▁some -346 -ond -347 -ord -348 -▁any -349 -ings -350 -vel -351 -av -352 -▁been -353 -ical -354 -▁over -355 -▁part -356 -ress -357 -▁This -358 -▁dis -359 -ks -360 -▁He -361 -ors -362 -ence -363 -▁said -364 -▁sc -365 -▁rec -366 -▁ar -367 -ition -368 -▁them -369 -▁ag -370 -▁when -371 -▁pe -372 -ild -373 -port -374 -▁her -375 -ound -376 -ough -377 -▁kn -378 -ose -379 -ob -380 -irst -381 -low -382 -▁just -383 -mer -384 -int -385 -▁ro -386 -ov -387 -ck -388 -ish -389 -▁what -390 -oy -391 -▁pr -392 -ru -393 -▁spe -394 -▁pre -395 -▁there -396 -ens -397 -wn -398 -▁acc -399 -day -400 -▁if -401 -ren -402 -▁than -403 -▁would -404 -▁need -405 -▁Re -406 -▁had -407 -vers -408 -▁its -409 -▁were -410 -ink -411 -fter -412 -ning -413 -▁am -414 -ater -415 -... -416 -▁des -417 -old -418 -itt -419 -clud -420 -ade -421 -rough -422 -▁tw -423 -▁into -424 -lp -425 -ory -426 -use -427 -ople -428 -ool -429 -ang -430 -▁first -431 -▁how -432 -▁bec -433 -▁help -434 -lic -435 -hed -436 -ons -437 -▁add -438 -anc -439 -ft -440 -▁make -441 -amp -442 -gr -443 -▁bl -444 -▁look -445 -▁– -446 -▁Wh -447 -▁prov -448 -▁col -449 -▁includ -450 -▁people -451 -▁comm -452 -▁produ -453 -▁You -454 -▁Ne -455 -ual -456 -▁know -457 -ful -458 -▁she -459 -ian -460 -ments -461 -ates -462 -iew -463 -round -464 -▁em -465 -▁every -466 -▁back -467 -▁only -468 -▁serv -469 -tern -470 -les -471 -ious -472 -▁no -473 -▁may -474 -rent -475 -▁through -476 -▁bu -477 -ict -478 -▁most -479 -cts -480 -ating -481 -▁see -482 -▁want -483 -▁two -484 -▁ph -485 -com -486 -pport -487 -▁As -488 -xt -489 -we -490 -ities -491 -ices -492 -iss -493 -▁use -494 -▁well -495 -ont -496 -▁bet -497 -▁after -498 -▁If -499 -ise -500 -hing -501 -▁ind -502 -ause -503 -▁play -504 -▁Se -505 -ph -506 -▁und -507 -je -508 -▁& -509 -▁co -510 -ife -511 -▁| -512 -ock -513 -ily -514 -▁stud -515 -lect -516 -row -517 -▁act -518 -ting -519 -iness -520 -▁fl -521 -hen -522 -▁years -523 -▁Com -524 -▁Un -525 -urn -526 -ts -527 -▁$ -528 -enc -529 -aw -530 -▁these -531 -▁tra -532 -▁An -533 -fore -534 -▁cons -535 -▁under -536 -als -537 -cial -538 -ange -539 -▁exper -540 -bs -541 -aking -542 -▁ke -543 -oth -544 -▁now -545 -ures -546 -ational -547 -▁very -548 -▁Pro -549 -▁wee -550 -▁bus -551 -▁good -552 -▁gu -553 -ased -554 -vent -555 -▁And -556 -formation -557 -▁many -558 -▁sm -559 -get -560 -▁way -561 -any -562 -▁reg -563 -erson -564 -oint -565 -ific -566 -ward -567 -▁De -568 -ert -569 -ility -570 -▁start -571 -▁fin -572 -▁dif -573 -▁could -574 -rit -575 -lease -576 -▁great -577 -▁imp -578 -ork -579 -uch -580 -▁day -581 -fect -582 -▁rem -583 -▁Sh -584 -yst -585 -▁rel -586 -ience -587 -ible -588 -▁even -589 -▁For -590 -uring -591 -ty -592 -▁show -593 -▁high -594 -oss -595 -ics -596 -▁sec -597 -ull -598 -▁own -599 -nds -600 -velop -601 -▁inv -602 -▁where -603 -▁here -604 -▁don -605 -▁inc -606 -▁down -607 -). -608 -▁ent -609 -ident -610 -hes -611 -olog -612 -cess -613 -▁loc -614 -arch -615 -▁right -616 -ble -617 -▁then -618 -chool -619 -▁home -620 -▁should -621 -▁Al -622 -▁New -623 -elf -624 -alth -625 -The -626 -▁ass -627 -ied -628 -▁br -629 -its -630 -ited -631 -▁find -632 -ath -633 -air -634 -ular -635 -▁read -636 -▁too -637 -▁ac -638 -hip -639 -▁av -640 -▁set -641 -ix -642 -▁car -643 -▁fam -644 -ner -645 -▁information -646 -▁mon -647 -gan -648 -line -649 -▁best -650 -▁last -651 -ys -652 -▁min -653 -gram -654 -▁take -655 -io -656 -▁design -657 -▁Cl -658 -pect -659 -ract -660 -▁long -661 -ason -662 -▁did -663 -▁inst -664 -▁much -665 -omet -666 -▁che -667 -|| -668 -erm -669 -▁Be -670 -▁business -671 -ystem -672 -▁because -673 -▁before -674 -other -675 -ank -676 -▁dec -677 -ues -678 -▁But -679 -▁att -680 -▁ins -681 -▁Fr -682 -.” -683 -▁made -684 -▁team -685 -ative -686 -▁call -687 -▁Le -688 -▁him -689 -pr -690 -▁sur -691 -pen -692 -atch -693 -▁cre -694 -rib -695 -me -696 -▁think -697 -ject -698 -ollow -699 -az -700 -▁again -701 -▁world -702 -way -703 -ax -704 -ale -705 -ug -706 -▁Ad -707 -▁art -708 -▁mem -709 -▁does -710 -alk -711 -), -712 -▁vis -713 -arket -714 -▁being -715 -▁pres -716 -ave -717 -▁develop -718 -▁person -719 -oun -720 -▁requ -721 -arn -722 -ustom -723 -ower -724 -chn -725 -rest -726 -▁inte -727 -arm -728 -ient -729 -▁life -730 -▁those -731 -ener -732 -▁diffe -733 -▁such -734 -ins -735 -▁med -736 -ng -737 -ivers -738 -ince -739 -ouse -740 -▁support -741 -ving -742 -▁while -743 -ash -744 -irect -745 -▁Ar -746 -▁pol -747 -view -748 -land -749 -▁sk -750 -▁provid -751 -ss -752 -unity -753 -ier -754 -▁lead -755 -▁ra -756 -▁Te -757 -▁each -758 -▁around -759 -▁book -760 -der -761 -▁love -762 -▁free -763 -▁used -764 -ced -765 -akes -766 -▁care -767 -▁end -768 -read -769 -▁mod -770 -ailable -771 -▁ser -772 -▁comple -773 -▁post -774 -▁run -775 -▁gr -776 -ather -777 -▁disc -778 -▁sim -779 -ric -780 -▁program -781 -ality -782 -▁ret -783 -▁pub -784 -ces -785 -ional -786 -ages -787 -ually -788 -▁bo -789 -▁cur -790 -▁ed -791 -ines -792 -imes -793 -ton -794 -ives -795 -▁All -796 -▁det -797 -▁really -798 -roup -799 -ple -800 -oad -801 -ars -802 -▁eas -803 -ets -804 -▁On -805 -▁child -806 -▁system -807 -▁There -808 -▁So -809 -▁num -810 -iel -811 -au -812 -ize -813 -▁follow -814 -▁trans -815 -." -816 -led -817 -ene -818 -▁count -819 -▁going -820 -▁found -821 -,” -822 -▁top -823 -ah -824 -▁form -825 -▁char -826 -▁somet -827 -iet -828 -▁three -829 -ittle -830 -▁inter -831 -▁list -832 -▁cour -833 -ames -834 -man -835 -▁still -836 -▁Bl -837 -▁fun -838 -▁How -839 -▁month -840 -▁available -841 -▁place -842 -▁del -843 -ature -844 -▁Pl -845 -▁custom -846 -ute -847 -ness -848 -▁though -849 -▁They -850 -▁feel -851 -ways -852 -▁prof -853 -▁cle -854 -▁both -855 -▁To -856 -▁few -857 -▁sub -858 -cept -859 -▁aut -860 -orn -861 -meric -862 -▁str -863 -▁happ -864 -▁week -865 -▁sign -866 -▁open -867 -▁hand -868 -ved -869 -▁gl -870 -▁pur -871 -▁say -872 -uc -873 -▁report -874 -▁health -875 -▁game -876 -▁adv -877 -att -878 -▁rep -879 -▁market -880 -ital -881 -▁different -882 -oot -883 -ired -884 -orth -885 -▁frie -886 -bers -887 -▁keep -888 -▁same -889 -ering -890 -tt -891 -▁lot -892 -▁Ex -893 -▁She -894 -▁point -895 -▁Col -896 -ween -897 -▁techn -898 -▁family -899 -▁ev -900 -▁i -901 -ology -902 -▁exp -903 -iqu -904 -▁ext -905 -▁school -906 -ining -907 -▁little -908 -▁using -909 -," -910 -▁process -911 -ished -912 -atur -913 -▁company -914 -▁lar -915 -ata -916 -▁including -917 -▁Sc -918 -ross -919 -iving -920 -oh -921 -ants -922 -▁next -923 -▁plan -924 -▁win -925 -▁Americ -926 -ott -927 -▁fil -928 -▁real -929 -▁during -930 -▁Tr -931 -▁between -932 -thing -933 -ized -934 -▁water -935 -ger -936 -▁sol -937 -▁Ph -938 -▁import -939 -▁Q -940 -ody -941 -cent -942 -▁state -943 -▁What -944 -gg -945 -ield -946 -▁things -947 -ik -948 -ves -949 -▁met -950 -arly -951 -els -952 -▁come -953 -aut -954 -ists -955 -be -956 -▁allow -957 -▁big -958 -less -959 -aint -960 -reen -961 -▁mus -962 -▁put -963 -▁contin -964 -uss -965 -▁Or -966 -▁rece -967 -▁experience -968 -ware -969 -▁service -970 -▁opt -971 -▁build -972 -cer -973 -self -974 -▁small -975 -▁dri -976 -▁days -977 -▁appro -978 -ined -979 -iversity -980 -ex -981 -▁organ -982 -▁full -983 -ling -984 -▁since -985 -▁cent -986 -▁always -987 -▁rest -988 -▁try -989 -▁phot -990 -▁better -991 -▁cr -992 -▁sure -993 -▁When -994 -ution -995 -▁pat -996 -▁online -997 -▁pri -998 -▁quest -999 -▁ref -1000 -▁Ind -1001 -▁second -1002 -▁pass -1003 -▁something -1004 -▁var -1005 -illion -1006 -▁bel -1007 -▁interest -1008 -rand -1009 -ever -1010 -over -1011 -▁iss -1012 -▁partic -1013 -▁class -1014 -▁poss -1015 -▁gener -1016 -▁def -1017 -▁group -1018 -▁tri -1019 -▁mov -1020 -ffect -1021 -▁perform -1022 -▁hard -1023 -▁direct -1024 -▁Z -1025 -▁pay -1026 -pping -1027 -ours -1028 -▁With -1029 -▁result -1030 -▁bro -1031 -▁today -1032 -▁head -1033 -▁special -1034 -gy -1035 -▁— -1036 -▁sl -1037 -ps -1038 -▁ty -1039 -▁ve -1040 -ploy -1041 -ER -1042 -▁At -1043 -joy -1044 -▁stand -1045 -ms -1046 -work -1047 -ared -1048 -outh -1049 -▁another -1050 -▁ide -1051 -▁give -1052 -br -1053 -▁ann -1054 -▁Con -1055 -▁wom -1056 -▁provide -1057 -uck -1058 -▁got -1059 -▁cor -1060 -ccess -1061 -ior -1062 -▁Chr -1063 -ote -1064 -oor -1065 -▁Res -1066 -oney -1067 -▁meet -1068 -▁students -1069 -▁resp -1070 -istr -1071 -▁current -1072 -ense -1073 -ately -1074 -▁wr -1075 -▁without -1076 -ision -1077 -▁conf -1078 -▁Our -1079 -ients -1080 -rence -1081 -ok -1082 -ium -1083 -▁old -1084 -▁area -1085 -ley -1086 -ope -1087 -ards -1088 -▁number -1089 -▁four -1090 -▁bre -1091 -▁cost -1092 -aj -1093 -ems -1094 -ered -1095 -▁able -1096 -ically -1097 -▁soc -1098 -▁val -1099 -▁Sp -1100 -▁invest -1101 -▁must -1102 -con -1103 -▁access -1104 -▁services -1105 -▁unt -1106 -raph -1107 -ats -1108 -ird -1109 -▁ask -1110 -▁working -1111 -▁never -1112 -▁US -1113 -▁Cent -1114 -iver -1115 -▁No -1116 -stand -1117 -ww -1118 -▁webs -1119 -▁proble -1120 -▁public -1121 -▁vide -1122 -ission -1123 -▁visit -1124 -▁important -1125 -ann -1126 -▁light -1127 -pped -1128 -▁fact -1129 -let -1130 -▁sal -1131 -▁level -1132 -▁order -1133 -▁fac -1134 -ged -1135 -▁Comm -1136 -▁My -1137 -▁test -1138 -▁might -1139 -▁exc -1140 -ral -1141 -▁rese -1142 -▁product -1143 -▁local -1144 -▁night -1145 -▁season -1146 -inal -1147 -▁el -1148 -▁incre -1149 -ember -1150 -▁site -1151 -rol -1152 -▁That -1153 -▁sing -1154 -ruct -1155 -ample -1156 -▁expl -1157 -▁Mar -1158 -▁spec -1159 -▁grow -1160 -▁let -1161 -▁ca -1162 -▁proper -1163 -▁less -1164 -ording -1165 -▁enjoy -1166 -▁ob -1167 -▁past -1168 -▁event -1169 -▁products -1170 -▁Man -1171 -▁' -1172 -▁inf -1173 -▁May -1174 -▁looking -1175 -▁food -1176 -here -1177 -lection -1178 -▁within -1179 -▁profess -1180 -▁Fe -1181 -▁Is -1182 -▁data -1183 -▁making -1184 -▁pop -1185 -ertain -1186 -▁until -1187 -ases -1188 -ories -1189 -ffic -1190 -enn -1191 -ency -1192 -▁children -1193 -ently -1194 -▁University -1195 -We -1196 -gin -1197 -sh -1198 -▁job -1199 -▁offer -1200 -▁law -1201 -ery -1202 -ains -1203 -ney -1204 -urs -1205 -▁pos -1206 -eng -1207 -utes -1208 -▁power -1209 -▁view -1210 -▁turn -1211 -▁eng -1212 -▁email -1213 -ential -1214 -tend -1215 -▁oper -1216 -▁sit -1217 -▁check -1218 -▁against -1219 -ieve -1220 -▁est -1221 -▁Pr -1222 -ream -1223 -ised -1224 -▁Br -1225 -ina -1226 -▁prote -1227 -ids -1228 -ode -1229 -▁room -1230 -▁contact -1231 -IN -1232 -▁community -1233 -med -1234 -to -1235 -▁addition -1236 -▁prom -1237 -▁says -1238 -▁intern -1239 -load -1240 -▁toget -1241 -▁together -1242 -▁Fl -1243 -▁away -1244 -ivid -1245 -▁impro -1246 -▁quality -1247 -▁leg -1248 -ator -1249 -▁dist -1250 -▁creat -1251 -ills -1252 -irl -1253 -hor -1254 -▁indust -1255 -▁complete -1256 -▁news -1257 -aring -1258 -iron -1259 -ique -1260 -ret -1261 -▁App -1262 -icle -1263 -iday -1264 -agement -1265 -ified -1266 -oci -1267 -▁supp -1268 -osed -1269 -ability -1270 -▁project -1271 -▁website -1272 -▁Car -1273 -iety -1274 -ane -1275 -por -1276 -!! -1277 -▁change -1278 -co -1279 -▁success -1280 -▁dep -1281 -bo -1282 -▁learn -1283 -▁include -1284 -▁Co -1285 -pend -1286 -▁fav -1287 -▁chang -1288 -ym -1289 -▁Ste -1290 -▁detail -1291 -ism -1292 -▁offic -1293 -▁Can -1294 -▁members -1295 -▁dr -1296 -arent -1297 -son -1298 -▁buy -1299 -▁easy -1300 -▁please -1301 -rap -1302 -▁Me -1303 -aster -1304 -▁applic -1305 -ising -1306 -ury -1307 -▁name -1308 -▁pract -1309 -▁times -1310 -atures -1311 -▁along -1312 -▁equ -1313 -▁present -1314 -▁One -1315 -▁large -1316 -▁money -1317 -▁beaut -1318 -atter -1319 -augh -1320 -▁Am -1321 -aterial -1322 -the -1323 -▁Cont -1324 -iting -1325 -▁activ -1326 -vern -1327 -RE -1328 -▁employ -1329 -▁la -1330 -aff -1331 -une -1332 -▁house -1333 -ready -1334 -Th -1335 -▁course -1336 -▁expect -1337 -▁. -1338 -▁needs -1339 -ored -1340 -▁air -1341 -▁left -1342 -▁Christ -1343 -▁thing -1344 -itions -1345 -ift -1346 -sc -1347 -ably -1348 -▁cap -1349 -ider -1350 -ived -1351 -lish -1352 -▁music -1353 -▁dra -1354 -min -1355 -▁why -1356 -▁En -1357 -yle -1358 -ohn -1359 -ump -1360 -ify -1361 -▁hist -1362 -ec -1363 -ron -1364 -by -1365 -▁bas -1366 -ern -1367 -▁hum -1368 -▁video -1369 -rie -1370 -▁sw -1371 -▁account -1372 -ON -1373 -ffe -1374 -alf -1375 -ocus -1376 -veral -1377 -▁below -1378 -▁soft -1379 -▁hot -1380 -▁These -1381 -▁short -1382 -ries -1383 -▁Eng -1384 -▁line -1385 -▁live -1386 -pecial -1387 -▁opport -1388 -enef -1389 -▁create -1390 -book -1391 -▁cond -1392 -▁beh -1393 -▁... -1394 -▁perfect -1395 -uly -1396 -▁ce -1397 -▁page -1398 -▁word -1399 -▁/ -1400 -▁writ -1401 -AT -1402 -▁dem -1403 -ots -1404 -▁Med -1405 -▁mar -1406 -▁Please -1407 -fort -1408 -side -1409 -ows -1410 -mber -1411 -▁govern -1412 -▁pa -1413 -artment -1414 -▁already -1415 -▁Che -1416 -▁kind -1417 -▁After -1418 -▁enough -1419 -▁ever -1420 -▁research -1421 -ured -1422 -▁makes -1423 -▁following -1424 -▁million -1425 -▁Do -1426 -▁review -1427 -▁getting -1428 -▁dev -1429 -ten -1430 -itive -1431 -ush -1432 -▁friends -1433 -▁cut -1434 -▁conne -1435 -▁trad -1436 -ee -1437 -., -1438 -▁record -1439 -room -1440 -▁treat -1441 -▁side -1442 -▁const -1443 -vious -1444 -▁Ass -1445 -▁case -1446 -▁having -1447 -ajor -1448 -▁tell -1449 -▁Count -1450 -▁personal -1451 -▁move -1452 -▁based -1453 -▁story -1454 -viron -1455 -ention -1456 -▁John -1457 -rop -1458 -▁Your -1459 -▁Serv -1460 -▁won -1461 -unch -1462 -ips -1463 -▁Des -1464 -▁minutes -1465 -uper -1466 -▁become -1467 -uture -1468 -▁possible -1469 -osp -1470 -oice -1471 -iam -1472 -▁talk -1473 -▁city -1474 -ights -1475 -▁across -1476 -▁vers -1477 -▁share -1478 -ization -1479 -▁done -1480 -▁bit -1481 -▁camp -1482 -▁pack -1483 -▁didn -1484 -▁comes -1485 -▁men -1486 -▁understand -1487 -ead -1488 -▁several -1489 -▁-- -1490 -yn -1491 -▁: -1492 -▁country -1493 -▁Tw -1494 -▁hours -1495 -▁effect -1496 -▁cou -1497 -▁purch -1498 -iven -1499 -▁benef -1500 -ES -1501 -▁mil -1502 -▁women -1503 -uff -1504 -▁net -1505 -ividual -1506 -app -1507 -aces -1508 -▁percent -1509 -▁Comp -1510 -▁educ -1511 -wards -1512 -▁focus -1513 -▁often -1514 -▁material -1515 -ball -1516 -▁social -1517 -aim -1518 -▁elect -1519 -▁Wor -1520 -idd -1521 -ances -1522 -ination -1523 -uro -1524 -ides -1525 -ober -1526 -▁quick -1527 -▁Not -1528 -▁development -1529 -▁es -1530 -▁bring -1531 -▁return -1532 -orts -1533 -▁American -1534 -ister -1535 -ienc -1536 -▁doing -1537 -▁Bro -1538 -▁School -1539 -ript -1540 -▁pie -1541 -▁X -1542 -▁far -1543 -▁hold -1544 -arl -1545 -▁mult -1546 -ted -1547 -▁body -1548 -arr -1549 -err -1550 -▁Gr -1551 -of -1552 -mend -1553 -▁pot -1554 -ference -1555 -iful -1556 -ones -1557 -AN -1558 -▁wa -1559 -ners -1560 -▁fund -1561 -▁took -1562 -ograph -1563 -▁Here -1564 -▁tre -1565 -ource -1566 -lished -1567 -▁blog -1568 -oose -1569 -itc -1570 -AR -1571 -▁State -1572 -▁doesn -1573 -reet -1574 -conom -1575 -▁jo -1576 -vironment -1577 -▁deal -1578 -lement -1579 -▁others -1580 -▁City -1581 -▁Rep -1582 -▁came -1583 -▁called -1584 -▁started -1585 -▁sum -1586 -▁rele -1587 -org -1588 -▁Inst -1589 -nder -1590 -▁least -1591 -▁months -1592 -▁Intern -1593 -▁space -1594 -acy -1595 -▁Gu -1596 -▁mom -1597 -▁future -1598 -▁orig -1599 -▁compet -1600 -▁individual -1601 -oon -1602 -lege -1603 -▁went -1604 -▁occ -1605 -▁yet -1606 -▁young -1607 -rodu -1608 -▁clean -1609 -▁non -1610 -▁mind -1611 -▁told -1612 -ai -1613 -▁five -1614 -▁early -1615 -▁series -1616 -▁control -1617 -af -1618 -utions -1619 -▁term -1620 -▁major -1621 -oll -1622 -hers -1623 -ille -1624 -ape -1625 -▁games -1626 -ained -1627 -▁comb -1628 -▁means -1629 -▁pict -1630 -▁industry -1631 -▁chall -1632 -yl -1633 -▁tool -1634 -anks -1635 -▁Min -1636 -▁ens -1637 -▁lim -1638 -▁cover -1639 -ctor -1640 -▁fore -1641 -▁ago -1642 -AS -1643 -▁low -1644 -sw -1645 -▁key -1646 -fer -1647 -ama -1648 -▁x -1649 -▁heart -1650 -▁features -1651 -▁Ed -1652 -ilt -1653 -▁tem -1654 -rew -1655 -▁price -1656 -unic -1657 -▁store -1658 -fact -1659 -jects -1660 -▁offers -1661 -▁Ab -1662 -itor -1663 -back -1664 -▁once -1665 -▁specific -1666 -come -1667 -▁range -1668 -▁thought -1669 -ges -1670 -urity -1671 -ither -1672 -ateg -1673 -▁Bo -1674 -▁Jan -1675 -sel -1676 -▁pick -1677 -illed -1678 -▁Now -1679 -eral -1680 -▁God -1681 -▁Dr -1682 -▁favor -1683 -▁appear -1684 -year -1685 -▁More -1686 -▁York -1687 -ilities -1688 -▁Ke -1689 -▁Im -1690 -▁hope -1691 -▁redu -1692 -▁discuss -1693 -OR -1694 -ibr -1695 -▁happen -1696 -▁require -1697 -yr -1698 -▁Pe -1699 -▁However -1700 -atic -1701 -It -1702 -▁mean -1703 -▁single -1704 -nes -1705 -▁step -1706 -▁close -1707 -▁upd -1708 -▁land -1709 -▁break -1710 -▁ey -1711 -▁main -1712 -▁invol -1713 -most -1714 -anies -1715 -▁Pres -1716 -ourn -1717 -▁stay -1718 -▁government -1719 -▁Em -1720 -isk -1721 -isc -1722 -// -1723 -▁Sm -1724 -ony -1725 -▁field -1726 -de -1727 -▁priv -1728 -▁United -1729 -▁beautiful -1730 -resh -1731 -cle -1732 -▁Per -1733 -▁friend -1734 -▁everything -1735 -▁Qu -1736 -▁walk -1737 -ched -1738 -▁questions -1739 -▁added -1740 -▁hig -1741 -▁Cal -1742 -▁tax -1743 -aken -1744 -▁customers -1745 -▁strong -1746 -now -1747 -▁taking -1748 -▁install -1749 -for -1750 -:// -1751 -aps -1752 -ging -1753 -▁Pol -1754 -▁charact -1755 -▁wond -1756 -▁South -1757 -▁begin -1758 -▁study -1759 -ources -1760 -▁North -1761 -▁Just -1762 -▁announ -1763 -ief -1764 -ensive -1765 -▁miss -1766 -▁recom -1767 -▁travel -1768 -▁certain -1769 -▁Park -1770 -▁address -1771 -▁problem -1772 -▁By -1773 -▁County -1774 -▁actually -1775 -play -1776 -▁staff -1777 -▁tot -1778 -▁half -1779 -▁mess -1780 -▁z -1781 -aur -1782 -ew -1783 -inc -1784 -ians -1785 -▁search -1786 -▁technology -1787 -▁girl -1788 -▁media -1789 -urther -1790 -time -1791 -▁watch -1792 -▁typ -1793 -▁known -1794 -▁official -1795 -▁manag -1796 -▁National -1797 -▁six -1798 -irm -1799 -▁Pre -1800 -▁wind -1801 -▁enc -1802 -gle -1803 -atural -1804 -ural -1805 -▁front -1806 -ublic -1807 -▁Add -1808 -▁sound -1809 -▁improve -1810 -▁Post -1811 -wh -1812 -▁dig -1813 -irt -1814 -▁lat -1815 -▁content -1816 -▁Su -1817 -▁Stud -1818 -▁anal -1819 -▁track -1820 -itted -1821 -▁Mc -1822 -▁face -1823 -▁training -1824 -▁link -1825 -▁click -1826 -icy -1827 -▁ste -1828 -▁web -1829 -▁someone -1830 -ison -1831 -▁Oct -1832 -arning -1833 -▁works -1834 -▁author -1835 -▁later -1836 -▁building -1837 -not -1838 -lebr -1839 -▁host -1840 -ocu -1841 -▁Gl -1842 -▁environment -1843 -abor -1844 -cted -1845 -▁Center -1846 -▁mor -1847 -▁log -1848 -▁unique -1849 -▁everyone -1850 -▁Reg -1851 -raft -1852 -▁port -1853 -▁provides -1854 -IS -1855 -gest -1856 -▁ener -1857 -▁fall -1858 -▁cred -1859 -▁seen -1860 -▁Dep -1861 -▁film -1862 -ask -1863 -▁Day -1864 -▁prep -1865 -▁oil -1866 -▁particular -1867 -▁professional -1868 -▁aud -1869 -fully -1870 -▁Aug -1871 -▁Euro -1872 -ests -1873 -▁particip -1874 -lex -1875 -ided -1876 -unities -1877 -▁bar -1878 -ibility -1879 -▁results -1880 -▁ident -1881 -▁recommend -1882 -roll -1883 -▁press -1884 -ED -1885 -▁card -1886 -▁While -1887 -▁Will -1888 -▁whole -1889 -▁Don -1890 -aturday -1891 -▁World -1892 -rain -1893 -▁companies -1894 -ino -1895 -▁Ge -1896 -▁High -1897 -urch -1898 -▁Friday -1899 -▁office -1900 -IT -1901 -pper -1902 -▁Bar -1903 -▁March -1904 -▁color -1905 -▁events -1906 -▁anything -1907 -▁issues -1908 -EN -1909 -ancial -1910 -▁mot -1911 -▁eff -1912 -▁prob -1913 -▁mag -1914 -▁areas -1915 -▁pret -1916 -resent -1917 -▁vol -1918 -▁Some -1919 -▁comput -1920 -▁respons -1921 -ops -1922 -▁points -1923 -▁Acc -1924 -▁performance -1925 -▁near -1926 -▁pain -1927 -ster -1928 -obile -1929 -▁red -1930 -▁print -1931 -▁cook -1932 -▁Apr -1933 -itch -1934 -umb -1935 -▁given -1936 -▁history -1937 -▁econom -1938 -pecially -1939 -crib -1940 -obal -1941 -.... -1942 -▁feature -1943 -go -1944 -ili -1945 -ands -1946 -▁sell -1947 -▁designed -1948 -▁above -1949 -ches -1950 -▁maint -1951 -▁skin -1952 -▁text -1953 -▁aff -1954 -▁simple -1955 -eth -1956 -▁assist -1957 -IC -1958 -my -1959 -ued -1960 -▁age -1961 -icult -1962 -▁reason -1963 -inks -1964 -In -1965 -▁size -1966 -▁question -1967 -▁dou -1968 -imate -1969 -▁according -1970 -▁repl -1971 -iod -1972 -ply -1973 -▁Sec -1974 -nding -1975 -▁black -1976 -▁Aust -1977 -head -1978 -▁htt -1979 -edd -1980 -▁pretty -1981 -▁foot -1982 -▁believe -1983 -▁Saturday -1984 -oved -1985 -ables -1986 -▁due -1987 -▁Part -1988 -▁among -1989 -▁select -1990 -AL -1991 -itter -1992 -▁Sund -1993 -▁fire -1994 -cript -1995 -▁phys -1996 -omes -1997 -ental -1998 -ledge -1999 -▁idea -2000 -ety -2001 -▁latest -2002 -▁details -2003 -▁ant -2004 -▁popular -2005 -ole -2006 -▁third -2007 -▁et -2008 -ators -2009 -▁Mr -2010 -pro -2011 -val -2012 -▁management -2013 -aining -2014 -itional -2015 -▁includes -2016 -ruction -2017 -asing -2018 -▁July -2019 -▁energy -2020 -▁items -2021 -ze -2022 -▁weeks -2023 -ouch -2024 -onday -2025 -▁sent -2026 -▁Feb -2027 -▁living -2028 -ites -2029 -▁cult -2030 -▁receive -2031 -▁fre -2032 -▁continue -2033 -▁bad -2034 -▁June -2035 -▁relations -2036 -▁Europe -2037 -vert -2038 -astic -2039 -idence -2040 -▁human -2041 -▁parent -2042 -ulation -2043 -▁Val -2044 -▁His -2045 -▁claim -2046 -aily -2047 -▁Sept -2048 -ufact -2049 -ctions -2050 -elt -2051 -▁Dav -2052 -▁sex -2053 -▁prop -2054 -▁soon -2055 -ung -2056 -▁property -2057 -▁hon -2058 -nov -2059 -▁currently -2060 -▁amount -2061 -▁entire -2062 -new -2063 -▁West -2064 -uation -2065 -▁coming -2066 -ese -2067 -though -2068 -ana -2069 -ogn -2070 -▁Off -2071 -▁kids -2072 -▁TH -2073 -▁Tra -2074 -▁From -2075 -itting -2076 -▁phone -2077 -This -2078 -cast -2079 -▁final -2080 -▁consum -2081 -▁ess -2082 -▁happy -2083 -▁taken -2084 -▁celebr -2085 -▁docu -2086 -▁member -2087 -icro -2088 -.) -2089 -▁answ -2090 -▁meas -2091 -AC -2092 -▁wanted -2093 -▁type -2094 -▁software -2095 -selves -2096 -▁experienc -2097 -▁forward -2098 -▁diff -2099 -eds -2100 -▁whether -2101 -▁Us -2102 -▁wide -2103 -▁Read -2104 -▁either -2105 -▁Bu -2106 -ires -2107 -▁El -2108 -▁value -2109 -▁concer -2110 -▁deb -2111 -▁further -2112 -ux -2113 -ilar -2114 -ival -2115 -▁isn -2116 -▁coll -2117 -used -2118 -ams -2119 -aced -2120 -▁par -2121 -▁almost -2122 -▁required -2123 -▁crit -2124 -▁held -2125 -▁white -2126 -arter -2127 -▁date -2128 -▁comfort -2129 -▁quite -2130 -▁trying -2131 -▁provided -2132 -▁summer -2133 -▁Sw -2134 -▁fit -2135 -▁Pa -2136 -▁sugg -2137 -▁needed -2138 -▁favorite -2139 -▁tit -2140 -St -2141 -ees -2142 -▁Sunday -2143 -▁opportunity -2144 -▁Jo -2145 -▁ach -2146 -aching -2147 -uary -2148 -ek -2149 -▁Cor -2150 -▁via -2151 -▁extra -2152 -▁players -2153 -▁April -2154 -▁books -2155 -▁Monday -2156 -▁network -2157 -▁cop -2158 -amer -2159 -ler -2160 -▁example -2161 -▁box -2162 -▁users -2163 -▁, -2164 -itten -2165 -▁seem -2166 -▁period -2167 -▁various -2168 -▁Health -2169 -▁options -2170 -where -2171 -▁running -2172 -gress -2173 -▁style -2174 -▁especially -2175 -▁consider -2176 -▁yourself -2177 -▁Art -2178 -▁dam -2179 -▁safe -2180 -▁previous -2181 -▁swe -2182 -▁ways -2183 -▁version -2184 -▁created -2185 -▁sle -2186 -▁Mon -2187 -▁recently -2188 -▁potential -2189 -OU -2190 -▁issue -2191 -▁common -2192 -ises -2193 -▁di -2194 -▁Inc -2195 -▁stri -2196 -▁ready -2197 -▁attend -2198 -▁morning -2199 -▁regular -2200 -▁insp -2201 -▁else -2202 -▁road -2203 -▁nice -2204 -▁throughout -2205 -▁probably -2206 -▁ensure -2207 --- -2208 -▁veh -2209 -▁received -2210 -earch -2211 -▁ball -2212 -▁Associ -2213 -▁President -2214 -▁clear -2215 -▁download -2216 -par -2217 -icles -2218 -▁engine -2219 -▁sho -2220 -erc -2221 -▁song -2222 -azing -2223 -▁lo -2224 -▁brand -2225 -▁relationship -2226 -▁takes -2227 -▁reading -2228 -mit -2229 -▁natural -2230 -▁Aut -2231 -▁States -2232 -ades -2233 -amed -2234 -▁park -2235 -▁House -2236 -ively -2237 -▁shows -2238 -▁asked -2239 -▁medical -2240 -istration -2241 -ague -2242 -▁inj -2243 -▁hit -2244 -▁choose -2245 -▁collect -2246 -▁Direct -2247 -▁Mich -2248 -▁original -2249 -▁cool -2250 -▁spr -2251 -▁couple -2252 -angu -2253 -reme -2254 -ipping -2255 -▁represent -2256 -▁bott -2257 -▁init -2258 -▁release -2259 -▁goal -2260 -▁behind -2261 -ny -2262 -apt -2263 -oid -2264 -▁Face -2265 -▁wonder -2266 -▁Soc -2267 -▁recent -2268 -▁sales -2269 -eter -2270 -▁clients -2271 -▁financial -2272 -aging -2273 -overed -2274 -▁accom -2275 -▁fresh -2276 -▁fast -2277 -▁super -2278 -▁leave -2279 -▁problems -2280 -▁anyone -2281 -▁role -2282 -face -2283 -▁Get -2284 -gs -2285 -hib -2286 -▁Ser -2287 -▁career -2288 -uge -2289 -▁Fin -2290 -bor -2291 -▁Black -2292 -ume -2293 -▁cup -2294 -ried -2295 -ville -2296 -▁model -2297 -▁article -2298 -oura -2299 -▁ful -2300 -uesday -2301 -▁meth -2302 -arth -2303 -▁ground -2304 -▁programs -2305 -▁Up -2306 -▁hol -2307 -▁fail -2308 -na -2309 -▁sun -2310 -aving -2311 -▁weeke -2312 -▁accept -2313 -▁flow -2314 -ada -2315 -ursday -2316 -▁base -2317 -medi -2318 -▁customer -2319 -▁difficult -2320 -OT -2321 -atform -2322 -▁writing -2323 -anced -2324 -urance -2325 -▁looks -2326 -▁PM -2327 -▁tour -2328 -▁polit -2329 -▁likely -2330 -ox -2331 -hel -2332 -oogle -2333 -▁paper -2334 -▁ap -2335 -▁abs -2336 -▁simply -2337 -cing -2338 -name -2339 -verage -2340 -▁inside -2341 -▁manufact -2342 -▁TV -2343 -clus -2344 -▁etc -2345 -▁mix -2346 -▁total -2347 -▁included -2348 -▁po -2349 -idge -2350 -ming -2351 -▁Int -2352 -▁risk -2353 -▁Wed -2354 -adem -2355 -aker -2356 -▁increase -2357 -▁party -2358 -▁changes -2359 -▁ele -2360 -ashing -2361 -▁board -2362 -▁education -2363 -oud -2364 -▁Her -2365 -▁October -2366 -▁action -2367 -▁former -2368 -▁meeting -2369 -Wh -2370 -▁however -2371 -▁News -2372 -▁outside -2373 -ification -2374 -uit -2375 -iple -2376 -▁match -2377 -▁Ac -2378 -▁America -2379 -▁Act -2380 -▁nothing -2381 -▁security -2382 -▁self -2383 -ground -2384 -▁contrib -2385 -▁stop -2386 -ester -2387 -▁town -2388 -▁August -2389 -▁matter -2390 -▁position -2391 -▁Af -2392 -▁ple -2393 -▁bed -2394 -▁late -2395 -istrict -2396 -▁Ob -2397 -▁systems -2398 -▁Every -2399 -icated -2400 -adu -2401 -ules -2402 -▁Bus -2403 -▁words -2404 -▁playing -2405 -▁cir -2406 -▁pan -2407 -ST -2408 -▁UK -2409 -wood -2410 -▁sat -2411 -▁impact -2412 -▁anim -2413 -▁mark -2414 -▁private -2415 -▁application -2416 -▁police -2417 -▁knowledge -2418 -▁exist -2419 -▁photos -2420 -▁method -2421 -▁longer -2422 -▁coun -2423 -▁worked -2424 -iddle -2425 -▁national -2426 -▁projects -2427 -ederal -2428 -▁ord -2429 -▁Are -2430 -▁necess -2431 -ude -2432 -▁table -2433 -▁stra -2434 -off -2435 -▁Ag -2436 -empt -2437 -elcome -2438 -▁September -2439 -ecut -2440 -▁activities -2441 -▁worth -2442 -▁recogn -2443 -▁production -2444 -str -2445 -nesday -2446 -▁Department -2447 -based -2448 -aby -2449 -iff -2450 -▁comment -2451 -▁compl -2452 -▁skills -2453 -▁true -2454 -▁general -2455 -▁Austral -2456 -▁January -2457 -iol -2458 -▁round -2459 -▁lives -2460 -▁learning -2461 -▁Tuesday -2462 -▁Thursday -2463 -ID -2464 -che -2465 -▁Then -2466 -▁introdu -2467 -ky -2468 -arden -2469 -▁signific -2470 -ING -2471 -oom -2472 -▁Sal -2473 -▁ill -2474 -▁student -2475 -▁Pat -2476 -▁lay -2477 -▁hair -2478 -▁Free -2479 -▁Nove -2480 -▁computer -2481 -▁squ -2482 -▁purchase -2483 -▁tal -2484 -ham -2485 -▁Also -2486 -ession -2487 -ett -2488 -▁Mus -2489 -▁death -2490 -▁defin -2491 -▁seems -2492 -▁Of -2493 -ci -2494 -▁hands -2495 -izing -2496 -▁communic -2497 -mon -2498 -▁rad -2499 -▁choice -2500 -▁screen -2501 -AM -2502 -▁draw -2503 -▁concern -2504 -▁leading -2505 -▁additional -2506 -▁First -2507 -▁rights -2508 -attle -2509 -▁cell -2510 -▁credit -2511 -▁located -2512 -▁variety -2513 -▁leaders -2514 -▁Facebook -2515 -▁stat -2516 -▁tick -2517 -▁drive -2518 -▁movie -2519 -▁San -2520 -arget -2521 -oring -2522 -▁file -2523 -▁fig -2524 -ipment -2525 -▁hy -2526 -▁bud -2527 -▁image -2528 -▁determ -2529 -▁amazing -2530 -aign -2531 -▁Sim -2532 -▁suggest -2533 -mercial -2534 -▁chance -2535 -▁Red -2536 -▁associ -2537 -▁rather -2538 -▁practice -2539 -▁built -2540 -▁plans -2541 -▁function -2542 -oph -2543 -▁Har -2544 -▁providing -2545 -iter -2546 -▁cal -2547 -ached -2548 -airs -2549 -light -2550 -ought -2551 -urg -2552 -pm -2553 -▁War -2554 -▁vict -2555 -▁court -2556 -▁aw -2557 -▁saf -2558 -▁cand -2559 -example -2560 -▁Out -2561 -▁touch -2562 -▁Air -2563 -▁teac -2564 -cil -2565 -▁exam -2566 -▁autom -2567 -▁Street -2568 -▁international -2569 -▁loss -2570 -▁weekend -2571 -▁Wind -2572 -▁infl -2573 -▁prior -2574 -▁prevent -2575 -▁allows -2576 -▁arri -2577 -▁Calif -2578 -▁Click -2579 -irth -2580 -ibrary -2581 -▁character -2582 -▁piece -2583 -▁treatment -2584 -cember -2585 -itchen -2586 -olution -2587 -▁http -2588 -ma -2589 -▁similar -2590 -▁Most -2591 -▁moment -2592 -gar -2593 -oke -2594 -ruary -2595 -▁clos -2596 -▁Design -2597 -▁investig -2598 -▁rate -2599 -▁AM -2600 -reg -2601 -▁commit -2602 -▁growth -2603 -imum -2604 -▁norm -2605 -OM -2606 -iber -2607 -▁Dis -2608 -ivery -2609 -▁estab -2610 -▁cause -2611 -▁user -2612 -sp -2613 -▁deg -2614 -▁lost -2615 -▁display -2616 -▁collection -2617 -▁myself -2618 -▁Cr -2619 -▁op -2620 -▁enter -2621 -▁Wednesday -2622 -unt -2623 -▁rout -2624 -ault -2625 -▁decided -2626 -▁decision -2627 -▁sil -2628 -▁inde -2629 -▁Any -2630 -▁higher -2631 -cy -2632 -▁bal -2633 -▁daily -2634 -ha -2635 -ournal -2636 -▁digital -2637 -▁November -2638 -▁purp -2639 -▁Group -2640 -▁released -2641 -▁significant -2642 -▁reported -2643 -LE -2644 -▁Home -2645 -▁woman -2646 -▁Cour -2647 -▁easily -2648 -▁cannot -2649 -▁goes -2650 -▁International -2651 -▁excell -2652 -lin -2653 -▁wall -2654 -▁Thanks -2655 -▁quickly -2656 -▁College -2657 -▁usually -2658 -amb -2659 -▁bag -2660 -▁apply -2661 -▁floor -2662 -▁expected -2663 -iant -2664 -▁involved -2665 -▁Law -2666 -▁dom -2667 -▁attack -2668 -just -2669 -▁boy -2670 -illing -2671 -▁regard -2672 -▁platform -2673 -▁capt -2674 -▁iP -2675 -▁Net -2676 -▁encoura -2677 -▁protect -2678 -ondon -2679 -▁Cons -2680 -▁agree -2681 -ael -2682 -▁serious -2683 -▁December -2684 -▁safety -2685 -▁roll -2686 -▁saw -2687 -▁dress -2688 -▁Google -2689 -▁gen -2690 -▁parents -2691 -▁mach -2692 -idents -2693 -▁played -2694 -▁Service -2695 -▁immedi -2696 -▁surpr -2697 -mas -2698 -▁warm -2699 -zz -2700 -▁integr -2701 -▁mobile -2702 -▁tast -2703 -ica -2704 -▁February -2705 -▁sn -2706 -▁club -2707 -▁langu -2708 -▁president -2709 -▁sche -2710 -▁related -2711 -hern -2712 -▁shoot -2713 -▁finish -2714 -▁ideas -2715 -▁global -2716 -▁marketing -2717 -▁tools -2718 -▁ep -2719 -▁expert -2720 -band -2721 -▁code -2722 -▁exact -2723 -ospital -2724 -asons -2725 -▁mass -2726 -▁note -2727 -avy -2728 -▁photo -2729 -izes -2730 -▁save -2731 -▁source -2732 -▁ut -2733 -▁option -2734 -▁respect -2735 -▁Brit -2736 -▁Let -2737 -▁feed -2738 -enge -2739 -iding -2740 -▁arch -2741 -▁deep -2742 -▁corre -2743 -▁Ang -2744 -▁announced -2745 -ilies -2746 -▁appe -2747 -edding -2748 -▁Well -2749 -cription -2750 -▁La -2751 -www -2752 -hood -2753 -reng -2754 -▁stock -2755 -▁sens -2756 -▁admin -2757 -▁location -2758 -▁ri -2759 -ellow -2760 -▁gets -2761 -▁David -2762 -▁costs -2763 -▁helps -2764 -▁Av -2765 -ples -2766 -▁materials -2767 -ength -2768 -▁Je -2769 -ipe -2770 -rab -2771 -▁Tex -2772 -▁huge -2773 -▁published -2774 -agn -2775 -like -2776 -AP -2777 -▁send -2778 -▁mother -2779 -▁benefits -2780 -▁English -2781 -enior -2782 -mission -2783 -ography -2784 -▁lab -2785 -oday -2786 -▁Play -2787 -▁fight -2788 -▁Over -2789 -▁hear -2790 -▁weight -2791 -rown -2792 -▁Spr -2793 -ornia -2794 -uel -2795 -vey -2796 -iction -2797 -▁images -2798 -rought -2799 -▁restaur -2800 -key -2801 -▁gar -2802 -▁Book -2803 -▁earn -2804 -ald -2805 -▁ability -2806 -▁interview -2807 -add -2808 -▁Check -2809 -▁Business -2810 -atory -2811 -▁London -2812 -ructure -2813 -▁written -2814 -akers -2815 -▁challeng -2816 -▁standard -2817 -▁gives -2818 -▁giving -2819 -▁ones -2820 -▁legal -2821 -▁sense -2822 -▁campaign -2823 -▁Sch -2824 -▁dest -2825 -▁innov -2826 -erved -2827 -▁door -2828 -▁patients -2829 -rom -2830 -▁mid -2831 -▁trust -2832 -urt -2833 -▁sus -2834 -▁wasn -2835 -▁Services -2836 -▁center -2837 -▁instead -2838 -aged -2839 -▁Produ -2840 -▁fab -2841 -▁Coun -2842 -▁heat -2843 -▁neg -2844 -▁fine -2845 -▁item -2846 -▁Great -2847 -▁target -2848 -erous -2849 -▁prem -2850 -erve -2851 -▁sold -2852 -▁White -2853 -aught -2854 -▁wish -2855 -▁Trans -2856 -▁parts -2857 -▁write -2858 -▁levels -2859 -▁lic -2860 -▁award -2861 -iring -2862 -arant -2863 -aves -2864 -▁cases -2865 -▁describ -2866 -▁picture -2867 -▁pers -2868 -▁partners -2869 -▁Web -2870 -▁dry -2871 -▁neigh -2872 -irit -2873 -▁Mod -2874 -▁Prof -2875 -▁stuff -2876 -ashington -2877 -ida -2878 -▁pull -2879 -▁conditions -2880 -▁ded -2881 -atives -2882 -▁green -2883 -▁California -2884 -▁broad -2885 -▁effic -2886 -▁Hol -2887 -board -2888 -▁Hall -2889 -put -2890 -rows -2891 -▁Program -2892 -ivity -2893 -▁began -2894 -▁sale -2895 -▁upon -2896 -istic -2897 -▁highly -2898 -▁interesting -2899 -TM -2900 -bit -2901 -OS -2902 -▁vot -2903 -▁fans -2904 -▁stories -2905 -inner -2906 -▁request -2907 -▁contract -2908 -▁remember -2909 -▁slow -2910 -▁Cle -2911 -▁emer -2912 -▁subs -2913 -▁answer -2914 -▁Techn -2915 -anch -2916 -▁comments -2917 -acing -2918 -ocol -2919 -▁bra -2920 -▁Phot -2921 -▁wood -2922 -▁Other -2923 -▁lower -2924 -▁sym -2925 -▁dead -2926 -orge -2927 -▁prim -2928 -orage -2929 -▁modern -2930 -▁player -2931 -▁cat -2932 -coming -2933 -bum -2934 -▁interested -2935 -ooth -2936 -▁reports -2937 -aches -2938 -▁except -2939 -ara -2940 -lev -2941 -▁dise -2942 -▁trip -2943 -▁teams -2944 -▁Jack -2945 -▁Texas -2946 -▁attention -2947 -▁equipment -2948 -▁paint -2949 -sy -2950 -▁fully -2951 -▁wrong -2952 -▁directly -2953 -▁starting -2954 -▁completely -2955 -▁organization -2956 -▁types -2957 -uk -2958 -wide -2959 -▁Green -2960 -mm -2961 -▁resources -2962 -▁Last -2963 -▁www -2964 -ET -2965 -urb -2966 -ager -2967 -▁document -2968 -▁themselves -2969 -apan -2970 -▁dru -2971 -▁solutions -2972 -▁stru -2973 -▁viol -2974 -ashion -2975 -▁bank -2976 -▁Washington -2977 -▁Loc -2978 -▁Rem -2979 -ament -2980 -▁multiple -2981 -▁Association -2982 -▁band -2983 -▁achieve -2984 -▁condition -2985 -▁gold -2986 -▁businesses -2987 -▁Twitter -2988 -uses -2989 -▁wait -2990 -ule -2991 -▁Go -2992 -ening -2993 -udd -2994 -▁Each -2995 -▁affect -2996 -▁opportunities -2997 -▁vac -2998 -▁Gener -2999 -urer -3000 -▁hop -3001 -EC -3002 -▁sett -3003 -▁policy -3004 -▁Par -3005 -▁led -3006 -ension -3007 -▁thinking -3008 -▁dream -3009 -▁Once -3010 -raz -3011 -rel -3012 -▁groups -3013 -▁planning -3014 -▁commercial -3015 -EO -3016 -He -3017 -ffee -3018 -olf -3019 -▁Spe -3020 -▁separ -3021 -▁applications -3022 -▁qual -3023 -▁streng -3024 -▁approach -3025 -▁families -3026 -▁solution -3027 -▁Del -3028 -▁firm -3029 -▁Class -3030 -▁express -3031 -ores -3032 -▁gave -3033 -▁Found -3034 -enty -3035 -iles -3036 -▁offe -3037 -▁consult -3038 -▁Year -3039 -▁gift -3040 -▁subject -3041 -▁Mem -3042 -AD -3043 -▁Afric -3044 -▁prices -3045 -▁successful -3046 -ties -3047 -▁positive -3048 -▁employees -3049 -arlier -3050 -▁blood -3051 -▁AN -3052 -▁race -3053 -itute -3054 -▁deliver -3055 -oul -3056 -▁join -3057 -ares -3058 -▁itself -3059 -▁King -3060 -▁shot -3061 -▁advice -3062 -▁cert -3063 -▁THE -3064 -▁eye -3065 -riend -3066 -▁hour -3067 -▁defe -3068 -▁saying -3069 -▁healthy -3070 -▁glass -3071 -▁creating -3072 -▁Sub -3073 -▁According -3074 -▁dark -3075 -ration -3076 -▁spent -3077 -▁div -3078 -▁Even -3079 -▁Why -3080 -field -3081 -▁cy -3082 -itely -3083 -ford -3084 -▁Best -3085 -▁cancer -3086 -▁Christmas -3087 -▁effective -3088 -▁serve -3089 -omen -3090 -▁sites -3091 -▁budget -3092 -▁Whe -3093 -▁Road -3094 -▁lif -3095 -▁goals -3096 -▁message -3097 -king -3098 -▁Vis -3099 -▁reve -3100 -mb -3101 -down -3102 -▁Paul -3103 -▁fair -3104 -▁India -3105 -▁average -3106 -▁Dan -3107 -▁fix -3108 -▁circ -3109 -▁Office -3110 -▁Pri -3111 -▁condu -3112 -▁East -3113 -▁reach -3114 -elling -3115 -▁Since -3116 -▁cross -3117 -aughter -3118 -▁traditional -3119 -▁extreme -3120 -▁organiz -3121 -▁director -3122 -PS -3123 -▁Hot -3124 -▁implement -3125 -Ch -3126 -▁sometimes -3127 -▁physical -3128 -▁obs -3129 -ipped -3130 -▁camer -3131 -ords -3132 -vis -3133 -▁Oh -3134 -▁opp -3135 -▁adult -3136 -▁terms -3137 -iable -3138 -▁Germ -3139 -▁plant -3140 -▁wonderful -3141 -US -3142 -rote -3143 -▁hor -3144 -▁Many -3145 -▁Rec -3146 -▁aim -3147 -▁attempt -3148 -▁limited -3149 -▁pictures -3150 -tee -3151 -▁Japan -3152 -▁See -3153 -▁Develop -3154 -▁excellent -3155 -▁dro -3156 -urning -3157 -ysis -3158 -▁mount -3159 -BC -3160 -▁emb -3161 -▁Work -3162 -imately -3163 -onse -3164 -▁brought -3165 -uth -3166 -yond -3167 -▁Ann -3168 -▁quarter -3169 -hest -3170 -▁title -3171 -▁section -3172 -ecutive -3173 -▁block -3174 -▁delivery -3175 -▁Mor -3176 -▁became -3177 -▁farm -3178 -▁arr -3179 -▁carry -3180 -▁effort -3181 -▁IN -3182 -▁kitchen -3183 -▁mention -3184 -▁developed -3185 -▁imm -3186 -inary -3187 -▁Use -3188 -iance -3189 -yright -3190 -reci -3191 -▁jud -3192 -▁fish -3193 -▁China -3194 -▁Inter -3195 -▁countries -3196 -estern -3197 -▁progress -3198 -▁necessary -3199 -▁ge -3200 -▁suppl -3201 -▁sweet -3202 -pendent -3203 -▁complex -3204 -ocks -3205 -▁baby -3206 -vest -3207 -▁felt -3208 -mitted -3209 -▁feeling -3210 -▁System -3211 -▁nation -3212 -▁promot -3213 -▁Top -3214 -▁Make -3215 -▁Dem -3216 -▁Good -3217 -hold -3218 -iced -3219 -▁birth -3220 -▁sleep -3221 -▁growing -3222 -▁impress -3223 -porate -3224 -▁Public -3225 -▁places -3226 -ocr -3227 -▁seven -3228 -▁IT -3229 -▁Flor -3230 -ffects -3231 -venue -3232 -▁Mac -3233 -▁war -3234 -▁heard -3235 -itation -3236 -gu -3237 -pite -3238 -▁weather -3239 -▁Lear -3240 -▁Open -3241 -▁region -3242 -▁Michael -3243 -haps -3244 -▁billion -3245 -▁son -3246 -itary -3247 -▁star -3248 -▁Sur -3249 -duc -3250 -▁Today -3251 -▁hotel -3252 -▁wants -3253 -Re -3254 -▁Thank -3255 -▁stick -3256 -▁college -3257 -▁construction -3258 -IL -3259 -▁bi -3260 -▁album -3261 -▁spend -3262 -▁mat -3263 -▁cold -3264 -▁medic -3265 -▁stage -3266 -▁ver -3267 -▁Port -3268 -▁Director -3269 -▁individuals -3270 -▁double -3271 -nded -3272 -▁Canada -3273 -▁Market -3274 -): -3275 -EL -3276 -aries -3277 -▁Down -3278 -▁convers -3279 -▁Russ -3280 -▁profession -3281 -ying -3282 -▁ble -3283 -▁speed -3284 -▁distrib -3285 -pects -3286 -▁exerc -3287 -rup -3288 -▁ST -3289 -aled -3290 -▁finished -3291 -fl -3292 -▁gas -3293 -istry -3294 -▁suit -3295 -ils -3296 -▁pages -3297 -▁statement -3298 -pre -3299 -ancy -3300 -▁charge -3301 -▁ing -3302 -▁spot -3303 -▁ult -3304 -▁requirements -3305 -▁finally -3306 -▁schools -3307 -▁vehicle -3308 -▁smart -3309 -▁annual -3310 -▁Windows -3311 -". -3312 -ado -3313 -wor -3314 -▁eat -3315 -useum -3316 -▁feet -3317 -▁Board -3318 -▁advant -3319 -ibly -3320 -▁blue -3321 -▁load -3322 -▁aware -3323 -unk -3324 -▁Gold -3325 -▁Research -3326 -▁straight -3327 -▁appl -3328 -arc -3329 -▁Mark -3330 -▁nearly -3331 -ato -3332 -▁Bel -3333 -▁Tom -3334 -▁tried -3335 -▁hous -3336 -▁avoid -3337 -aling -3338 -ports -3339 -▁difference -3340 -▁wrote -3341 -▁William -3342 -▁Sol -3343 -▁pattern -3344 -owl -3345 -ened -3346 -▁James -3347 -▁respond -3348 -▁challenge -3349 -▁Bre -3350 -▁dog -3351 -▁beginning -3352 -ION -3353 -▁Educ -3354 -▁About -3355 -▁helping -3356 -:|| -3357 -▁benefit -3358 -▁insurance -3359 -▁situation -3360 -iment -3361 -▁essential -3362 -▁imag -3363 -ancing -3364 -unte -3365 -▁device -3366 -ceed -3367 -▁Obama -3368 -rast -3369 -▁shop -3370 -ological -3371 -▁Care -3372 -▁Indian -3373 -▁political -3374 -box -3375 -uted -3376 -▁Time -3377 -▁loved -3378 -▁Review -3379 -ube -3380 -▁nut -3381 -▁pow -3382 -overn -3383 -▁wear -3384 -▁Apple -3385 -▁Sl -3386 -▁Mag -3387 -olute -3388 -▁Find -3389 -▁activity -3390 -▁devices -3391 -▁moving -3392 -▁Met -3393 -▁lik -3394 -▁paid -3395 -▁enh -3396 -▁Club -3397 -▁Hel -3398 -▁uses -3399 -▁eight -3400 -▁exhib -3401 -▁Court -3402 -▁turned -3403 -oms -3404 -oses -3405 -▁posted -3406 -▁towards -3407 -”. -3408 -▁nature -3409 -▁Sk -3410 -▁partner -3411 -asy -3412 -▁investment -3413 -ourney -3414 -▁appreci -3415 -▁offering -3416 -▁temper -3417 -▁contain -3418 -▁largest -3419 -ivil -3420 -▁knew -3421 -▁ahead -3422 -oves -3423 -rench -3424 -idered -3425 -▁retail -3426 -▁hus -3427 -▁eyes -3428 -▁owners -3429 -▁language -3430 -▁Ant -3431 -inger -3432 -▁expand -3433 -house -3434 -ey -3435 -rences -3436 -ios -3437 -▁rent -3438 -ned -3439 -▁cas -3440 -▁connect -3441 -▁wife -3442 -ampions -3443 -▁advert -3444 -▁Rel -3445 -▁Rich -3446 -▁reduce -3447 -▁European -3448 -▁guarant -3449 -ago -3450 -cause -3451 -▁Look -3452 -▁sports -3453 -▁correct -3454 -aly -3455 -anta -3456 -▁categ -3457 -▁client -3458 -▁states -3459 -▁consist -3460 -pri -3461 -▁maybe -3462 -▁named -3463 -▁definitely -3464 -hips -3465 -▁influ -3466 -▁entertain -3467 -erry -3468 -hens -3469 -▁accur -3470 -▁concept -3471 -osing -3472 -ounds -3473 -▁runs -3474 -▁grand -3475 -▁stress -3476 -IP -3477 -change -3478 -▁Super -3479 -▁guide -3480 -▁homes -3481 -▁Have -3482 -▁thous -3483 -last -3484 -▁jobs -3485 -▁offered -3486 -estival -3487 -▁earlier -3488 -▁immediately -3489 -▁doll -3490 -▁numbers -3491 -sych -3492 -▁conc -3493 -iers -3494 -▁decl -3495 -▁Fam -3496 -esome -3497 -▁Rob -3498 -▁rates -3499 -▁Council -3500 -azine -3501 -▁rev -3502 -▁Community -3503 -▁path -3504 -▁collabor -3505 -lying -3506 -roud -3507 -▁Cop -3508 -You -3509 -alt -3510 -orrow -3511 -▁candid -3512 -▁interact -3513 -ails -3514 -▁remain -3515 -▁II -3516 -more -3517 -▁bottom -3518 -sec -3519 -dule -3520 -▁Sum -3521 -▁Cong -3522 -▁belie -3523 -▁drink -3524 -▁pieces -3525 -▁exactly -3526 -asc -3527 -lim -3528 -▁tips -3529 -▁Micro -3530 -▁View -3531 -iation -3532 -▁overall -3533 -▁max -3534 -▁federal -3535 -▁storage -3536 -vin -3537 -icious -3538 -▁Custom -3539 -▁opening -3540 -▁demand -3541 -▁Two -3542 -place -3543 -▁surround -3544 -▁Cur -3545 -▁histor -3546 -▁Bay -3547 -orial -3548 -▁Rober -3549 -▁adjust -3550 -ulations -3551 -▁shipping -3552 -▁strateg -3553 -▁Internet -3554 -▁active -3555 -▁threat -3556 -ram -3557 -▁Win -3558 -▁looked -3559 -oma -3560 -▁ten -3561 -▁occas -3562 -▁length -3563 -inated -3564 -▁served -3565 -▁conference -3566 -ico -3567 -iny -3568 -▁IS -3569 -▁guys -3570 -▁rock -3571 -▁button -3572 -▁garden -3573 -▁Florida -3574 -▁acqu -3575 -▁Police -3576 -▁easier -3577 -▁Angel -3578 -yd -3579 -order -3580 -undred -3581 -▁Island -3582 -▁father -3583 -oly -3584 -▁bath -3585 -▁speak -3586 -▁attract -3587 -If -3588 -▁normal -3589 -▁thanks -3590 -dom -3591 -umn -3592 -▁Love -3593 -▁thank -3594 -▁bill -3595 -▁People -3596 -▁background -3597 -illa -3598 -rial -3599 -▁born -3600 -arily -3601 -▁girls -3602 -rig -3603 -▁Ev -3604 -▁Det -3605 -▁wedding -3606 -care -3607 -▁lots -3608 -▁damage -3609 -roid -3610 -▁Big -3611 -▁fat -3612 -▁pet -3613 -bl -3614 -ses -3615 -▁Ty -3616 -▁culture -3617 -▁replace -3618 -▁creative -3619 -▁internet -3620 -▁completed -3621 -▁assess -3622 -OL -3623 -▁Call -3624 -▁prec -3625 -aduate -3626 -atever -3627 -mod -3628 -que -3629 -▁Life -3630 -▁Team -3631 -▁wine -3632 -▁Company -3633 -▁husband -3634 -ij -3635 -▁coach -3636 -▁beyond -3637 -aith -3638 -▁cards -3639 -ipp -3640 -▁cash -3641 -▁Child -3642 -▁haven -3643 -▁altern -3644 -ota -3645 -▁Matt -3646 -▁guy -3647 -phone -3648 -▁depend -3649 -▁setting -3650 -leg -3651 -▁bul -3652 -▁Back -3653 -▁Show -3654 -▁miles -3655 -▁er -3656 -antly -3657 -force -3658 -▁transport -3659 -▁Management -3660 -ustain -3661 -body -3662 -ston -3663 -wise -3664 -▁emot -3665 -▁behav -3666 -▁driving -3667 -▁cream -3668 -▁response -3669 -iling -3670 -▁pred -3671 -▁estate -3672 -ously -3673 -het -3674 -▁USA -3675 -oving -3676 -isions -3677 -▁owner -3678 -▁Australia -3679 -friend -3680 -▁Pet -3681 -▁Sun -3682 -▁cho -3683 -error -3684 -▁Contact -3685 -izz -3686 -▁excited -3687 -▁selection -3688 -▁Ir -3689 -ales -3690 -anging -3691 -▁Ret -3692 -▁middle -3693 -▁efforts -3694 -▁particularly -3695 -▁Plan -3696 -▁Pal -3697 -itect -3698 -icks -3699 -▁Dri -3700 -▁helped -3701 -door -3702 -ustr -3703 -▁Lake -3704 -▁doub -3705 -▁colors -3706 -▁inform -3707 -▁Ve -3708 -aper -3709 -▁files -3710 -▁allowed -3711 -▁lines -3712 -▁existing -3713 -▁Bank -3714 -▁satis -3715 -▁patient -3716 -▁comfortable -3717 -istered -3718 -▁welcome -3719 -▁considered -3720 -▁responsible -3721 -▁clot -3722 -▁drop -3723 -▁truly -3724 -▁coffee -3725 -▁understanding -3726 -DA -3727 -▁plus -3728 -▁Govern -3729 -▁Thom -3730 -▁measure -3731 -set -3732 -▁economic -3733 -▁Yes -3734 -oming -3735 -▁frame -3736 -▁slight -3737 -▁journey -3738 -isl -3739 -▁Dec -3740 -▁indic -3741 -▁degree -3742 -▁ingred -3743 -▁himself -3744 -bon -3745 -▁purpose -3746 -▁tom -3747 -▁surv -3748 -▁changed -3749 -▁liter -3750 -▁mission -3751 -free -3752 -nown -3753 -ences -3754 -onstr -3755 -ona -3756 -▁Although -3757 -EM -3758 -▁pen -3759 -ologies -3760 -▁models -3761 -reed -3762 -▁train -3763 -▁winter -3764 -▁prot -3765 -▁stream -3766 -▁highest -3767 -ads -3768 -see -3769 -encies -3770 -▁prefer -3771 -▁seeing -3772 -▁strugg -3773 -▁evening -3774 -press -3775 -▁Take -3776 -▁artist -3777 -▁talking -3778 -OW -3779 -▁Camp -3780 -▁Phil -3781 -▁afford -3782 -▁Information -3783 -▁Str -3784 -▁sty -3785 -▁Smith -3786 -▁fashion -3787 -▁Republic -3788 -▁gun -3789 -▁disease -3790 -▁pool -3791 -▁absolute -3792 -OV -3793 -▁Sen -3794 -▁shopping -3795 -raw -3796 -oman -3797 -apter -3798 -▁River -3799 -▁Church -3800 -met -3801 -soft -3802 -▁Mart -3803 -▁lack -3804 -▁appoint -3805 -▁heavy -3806 -▁letter -3807 -rem -3808 -▁Color -3809 -▁British -3810 -▁daughter -3811 -▁fem -3812 -▁Rock -3813 -▁cast -3814 -▁brother -3815 -rey -3816 -▁Sing -3817 -▁flav -3818 -porary -3819 -▁occur -3820 -▁smooth -3821 -▁opin -3822 -▁increased -3823 -▁Jes -3824 -▁Music -3825 -▁moved -3826 -▁proud -3827 -▁couldn -3828 -▁launch -3829 -▁analysis -3830 -▁organizations -3831 -dd -3832 -▁PC -3833 -tion -3834 -▁mer -3835 -fit -3836 -▁links -3837 -gery -3838 -▁obt -3839 -▁Water -3840 -▁craft -3841 -▁church -3842 -▁compon -3843 -▁Blue -3844 -▁fill -3845 -▁rules -3846 -▁shared -3847 -▁spring -3848 -eria -3849 -uled -3850 -▁mail -3851 -▁Under -3852 -▁sched -3853 -▁Because -3854 -ronic -3855 -chan -3856 -▁Special -3857 -▁reviews -3858 -▁senior -3859 -▁hundred -3860 -IM -3861 -▁onto -3862 -▁whose -3863 -bed -3864 -▁Brown -3865 -net -3866 -▁fan -3867 -icing -3868 -▁Power -3869 -▁decor -3870 -▁secure -3871 -▁machine -3872 -imal -3873 -▁spread -3874 -▁u -3875 -▁frequ -3876 -▁score -3877 -ocolate -3878 -▁spirit -3879 -▁residents -3880 -amic -3881 -▁Hum -3882 -▁trade -3883 -▁science -3884 -vant -3885 -▁fra -3886 -▁Wood -3887 -▁appropri -3888 -▁officials -3889 -▁Sam -3890 -▁unit -3891 -▁died -3892 -hone -3893 -▁gone -3894 -▁manager -3895 -▁pressure -3896 -▁Like -3897 -▁challenges -3898 -TS -3899 -ady -3900 -▁clin -3901 -▁extend -3902 -▁instruct -3903 -▁dedicated -3904 -▁competition -3905 -▁Mount -3906 -▁Char -3907 -▁session -3908 -▁fant -3909 -▁Follow -3910 -▁happened -3911 -rian -3912 -▁Food -3913 -▁Mary -3914 -▁sort -3915 -ulated -3916 -▁initial -3917 -▁Fire -3918 -▁trou -3919 -▁Media -3920 -▁District -3921 -BA -3922 -icon -3923 -▁characters -3924 -▁basic -3925 -▁camera -3926 -▁holiday -3927 -azon -3928 -ategy -3929 -▁Enter -3930 -▁powerful -3931 -▁Institute -3932 -▁produce -3933 -▁beg -3934 -istics -3935 -▁Press -3936 -osition -3937 -▁dating -3938 -ette -3939 -asp -3940 -▁Hist -3941 -▁reasons -3942 -▁increasing -3943 -icken -3944 -▁shown -3945 -▁sugar -3946 -▁incred -3947 -▁extremely -3948 -▁rob -3949 -▁chem -3950 -▁Education -3951 -oos -3952 -▁AC -3953 -inese -3954 -▁volunte -3955 -▁disp -3956 -▁package -3957 -▁payment -3958 -RA -3959 -▁eval -3960 -▁guests -3961 -▁aren -3962 -▁snow -3963 -▁leader -3964 -▁biggest -3965 -▁TO -3966 -▁alone -3967 -▁object -3968 -▁proced -3969 -▁Sa -3970 -rowd -3971 -▁basis -3972 -▁disapp -3973 -▁supply -3974 -▁General -3975 -orney -3976 -▁Star -3977 -ifying -3978 -olic -3979 -▁laws -3980 -▁breat -3981 -▁graph -3982 -▁solid -3983 -▁forget -3984 -▁continues -3985 -LC -3986 -▁cars -3987 -▁guid -3988 -▁voice -3989 -▁experienced -3990 -▁Lou -3991 -▁mis -3992 -▁brows -3993 -rapy -3994 -▁arrest -3995 -▁passed -3996 -▁schedule -3997 -ken -3998 -omb -3999 -uing -4000 -▁egg -4001 -▁passion -4002 -▁dang -4003 -▁fear -4004 -▁guess -4005 -▁scene -4006 -esterday -4007 -BS -4008 -▁bur -4009 -▁steps -4010 -cel -4011 -▁Mal -4012 -▁beat -4013 -▁military -4014 -Sh -4015 -▁PR -4016 -▁Miss -4017 -gal -4018 -▁gra -4019 -▁names -4020 -▁approx -4021 -▁update -4022 -▁subst -4023 -▁During -4024 -▁protection -4025 -▁Att -4026 -▁Franc -4027 -▁French -4028 -annel -4029 -▁peace -4030 -▁conven -4031 -term -4032 -▁Who -4033 -▁ton -4034 -▁advantage -4035 -state -4036 -▁placed -4037 -▁Commission -4038 -▁pair -4039 -▁notice -4040 -▁strength -4041 -ero -4042 -What -4043 -incip -4044 -using -4045 -▁academ -4046 -▁Arch -4047 -▁epis -4048 -▁adding -4049 -▁waiting -4050 -▁although -4051 -ags -4052 -ideo -4053 -▁League -4054 -IV -4055 -▁Ben -4056 -clusive -4057 -▁Mot -4058 -▁reb -4059 -▁Alex -4060 -▁beauty -4061 -▁scient -4062 -ula -4063 -▁Dig -4064 -▁calls -4065 -▁relax -4066 -▁demonstr -4067 -▁regarding -4068 -amin -4069 -mark -4070 -ovel -4071 -▁income -4072 -▁covered -4073 -▁effects -4074 -ari -4075 -ixt -4076 -▁Sign -4077 -▁Online -4078 -uty -4079 -imin -4080 -▁copy -4081 -iverse -4082 -▁initi -4083 -▁experts -4084 -▁standards -4085 -▁technical -4086 -ros -4087 -okes -4088 -▁Atl -4089 -▁Vol -4090 -ading -4091 -▁manage -4092 -▁Chic -4093 -▁knows -4094 -▁winning -4095 -▁hospital -4096 -▁certainly -4097 -▁Real -4098 -▁batter -4099 -▁workers -4100 -▁connection -4101 -osh -4102 -▁compared -4103 -As -4104 -oe -4105 -▁RE -4106 -▁hom -4107 -ga -4108 -oop -4109 -▁Ins -4110 -▁Form -4111 -▁Development -4112 -▁wild -4113 -▁dinner -4114 -▁fabric -4115 -▁associated -4116 -▁experiences -4117 -▁Pay -4118 -▁doctor -4119 -▁master -4120 -▁cit -4121 -▁cru -4122 -▁wat -4123 -ograp -4124 -▁vote -4125 -▁posts -4126 -▁finding -4127 -▁Foundation -4128 -▁opened -4129 -▁Profess -4130 -▁reflect -4131 -IG -4132 -▁Carol -4133 -amm -4134 -▁audience -4135 -▁friendly -4136 -cell -4137 -unning -4138 -atically -4139 -mail -4140 -ctors -4141 -▁surface -4142 -▁den -4143 -▁Science -4144 -▁pm -4145 -▁Cap -4146 -itude -4147 -▁trail -4148 -▁artists -4149 -▁traffic -4150 -▁critical -4151 -▁communities -4152 -AA -4153 -uce -4154 -▁NY -4155 -▁Valley -4156 -works -4157 -▁remind -4158 -▁victim -4159 -▁Step -4160 -▁salt -4161 -▁followed -4162 -la -4163 -well -4164 -▁Rad -4165 -iques -4166 -▁Elect -4167 -▁football -4168 -tr -4169 -aming -4170 -▁electric -4171 -aven -4172 -▁Beach -4173 -▁facility -4174 -▁cry -4175 -gency -4176 -▁Disc -4177 -▁keeping -4178 -▁meaning -4179 -▁luck -4180 -▁pros -4181 -▁figure -4182 -▁learned -4183 -yer -4184 -ander -4185 -ulate -4186 -▁tickets -4187 -▁professionals -4188 -antic -4189 -▁laun -4190 -▁taste -4191 -▁instit -4192 -gen -4193 -▁bright -4194 -ech -4195 -arge -4196 -▁produced -4197 -▁watching -4198 -▁flex -4199 -▁catch -4200 -▁monitor -4201 -▁contains -4202 -lor -4203 -▁ter -4204 -There -4205 -ooper -4206 -▁entry -4207 -▁Project -4208 -▁Society -4209 -▁classic -4210 -▁department -4211 -edy -4212 -itar -4213 -▁diagn -4214 -▁lock -4215 -▁classes -4216 -rees -4217 -▁closed -4218 -▁starts -4219 -▁continued -4220 -▁dire -4221 -▁jump -4222 -▁awesome -4223 -▁kept -4224 -▁bought -4225 -▁listed -4226 -▁Christian -4227 -▁Wil -4228 -osure -4229 -▁Whether -4230 -▁neighbor -4231 -▁selected -4232 -▁Town -4233 -▁explore -4234 -▁testing -4235 -▁harm -4236 -▁Date -4237 -▁larger -4238 -▁videos -4239 -▁Another -4240 -▁presented -4241 -fast -4242 -▁Ber -4243 -▁ice -4244 -▁Times -4245 -▁transfer -4246 -▁thousands -4247 -▁developing -4248 -fin -4249 -▁capital -4250 -▁OF -4251 -iller -4252 -▁teaching -4253 -▁Mel -4254 -▁Nov -4255 -▁Long -4256 -▁force -4257 -▁grant -4258 -▁minute -4259 -▁talent -4260 -▁established -4261 -▁fol -4262 -▁Hill -4263 -▁desk -4264 -standing -4265 -▁England -4266 -▁AP -4267 -enses -4268 -▁announce -4269 -▁exciting -4270 -end -4271 -▁Vir -4272 -acity -4273 -▁Family -4274 -▁street -4275 -▁furn -4276 -▁facilities -4277 -▁Jim -4278 -▁brings -4279 -▁Tim -4280 -▁buying -4281 -▁records -4282 -▁articles -4283 -gn -4284 -▁sto -4285 -▁drug -4286 -▁ideal -4287 -▁library -4288 -▁requires -4289 -noon -4290 -itors -4291 -enance -4292 -▁Scott -4293 -▁micro -4294 -▁Chicago -4295 -win -4296 -rief -4297 -▁sup -4298 -▁rich -4299 -▁virt -4300 -▁novel -4301 -▁Chinese -4302 -▁sharing -4303 -▁updated -4304 -▁mo -4305 -part -4306 -sequ -4307 -▁Start -4308 -▁butter -4309 -▁driver -4310 -▁greater -4311 -riage -4312 -▁Sand -4313 -▁ship -4314 -▁crowd -4315 -▁wouldn -4316 -▁restaurant -4317 -imb -4318 -▁ir -4319 -lands -4320 -▁vision -4321 -▁Note -4322 -▁Exper -4323 -▁ingredients -4324 -ray -4325 -unately -4326 -▁List -4327 -▁poor -4328 -▁Stand -4329 -▁studies -4330 -▁Cup -4331 -overy -4332 -▁loan -4333 -▁Build -4334 -▁Grand -4335 -▁handle -4336 -▁plenty -4337 -▁resident -4338 -outs -4339 -▁bird -4340 -illage -4341 -ka -4342 -▁tree -4343 -▁economy -4344 -▁Central -4345 -▁leaving -4346 -▁serving -4347 -▁Div -4348 -▁sem -4349 -▁Support -4350 -SP -4351 -word -4352 -▁Mex -4353 -iture -4354 -▁beach -4355 -▁famous -4356 -ini -4357 -inn -4358 -▁Mil -4359 -lastname -4360 -▁manufacturer -4361 -▁faith -4362 -▁rooms -4363 -▁shall -4364 -▁recipe -4365 -▁Congress -4366 -CH -4367 -▁station -4368 -UR -4369 -▁react -4370 -▁shape -4371 -pective -4372 -▁origin -4373 -night -4374 -▁Amazon -4375 -▁injury -4376 -▁missing -4377 -reek -4378 -semb -4379 -▁Sil -4380 -▁upgr -4381 -▁Social -4382 -do -4383 -▁Pub -4384 -isher -4385 -▁motor -4386 -▁claims -4387 -▁medium -4388 -▁Bill -4389 -▁Posted -4390 -▁orders -4391 -▁maintain -4392 -rd -4393 -▁Fun -4394 -asure -4395 -▁brain -4396 -▁notes -4397 -▁views -4398 -▁Download -4399 -▁appropriate -4400 -▁boo -4401 -ishes -4402 -point -4403 -▁Offic -4404 -▁meant -4405 -▁older -4406 -▁spons -4407 -▁window -4408 -▁sustain -4409 -atab -4410 -▁Jesus -4411 -▁signed -4412 -berg -4413 -▁remove -4414 -cks -4415 -▁ended -4416 -▁changing -4417 -▁strategy -4418 -fr -4419 -cles -4420 -look -4421 -▁map -4422 -▁Union -4423 -outhern -4424 -▁happens -4425 -▁efficient -4426 -▁uns -4427 -going -4428 -▁advance -4429 -▁journal -4430 -ervation -4431 -▁plastic -4432 -▁Fore -4433 -▁stores -4434 -▁independent -4435 -▁iPhone -4436 -iest -4437 -▁useful -4438 -top -4439 -▁CD -4440 -umber -4441 -▁Organ -4442 -▁forms -4443 -▁leaves -4444 -▁Jul -4445 -craft -4446 -▁Light -4447 -▁Academ -4448 -acks -4449 -▁Award -4450 -▁advent -4451 -no -4452 -▁sand -4453 -▁shut -4454 -rehens -4455 -▁agency -4456 -▁repair -4457 -▁evidence -4458 -▁spending -4459 -▁afternoon -4460 -▁tim -4461 -apers -4462 -odes -4463 -rooms -4464 -▁throw -4465 -▁AND -4466 -▁menu -4467 -essions -4468 -▁secret -4469 -▁whatever -4470 -▁Fil -4471 -▁fee -4472 -estic -4473 -iliar -4474 -▁core -4475 -▁pray -4476 -▁sport -4477 -▁operations -4478 -▁combination -4479 -allery -4480 -▁Chris -4481 -▁Before -4482 -▁helpful -4483 -▁reality -4484 -atively -4485 -▁Where -4486 -▁multi -4487 -▁district -4488 -▁prepared -4489 -men -4490 -oyal -4491 -eless -4492 -icted -4493 -▁Week -4494 -▁cris -4495 -▁cab -4496 -ption -4497 -▁adop -4498 -▁tend -4499 -▁Democr -4500 -▁Series -4501 -▁status -4502 -▁balance -4503 -▁Mad -4504 -▁YOU -4505 -▁scen -4506 -▁estim -4507 -alls -4508 -▁flu -4509 -▁Both -4510 -▁flat -4511 -▁Author -4512 -▁joined -4513 -▁designs -4514 -▁remains -4515 -▁ID -4516 -▁Los -4517 -▁ride -4518 -▁corner -4519 -▁rank -4520 -▁eating -4521 -▁memory -4522 -Cl -4523 -mp -4524 -itz -4525 -▁Bet -4526 -▁Mont -4527 -▁caused -4528 -▁operating -4529 -▁Ma -4530 -aser -4531 -▁mist -4532 -▁George -4533 -▁discount -4534 -▁slightly -4535 -▁teachers -4536 -eed -4537 -▁IP -4538 -▁Women -4539 -▁esc -4540 -▁perhaps -4541 -▁primary -4542 -▁numerous -4543 -hem -4544 -▁funds -4545 -▁worry -4546 -▁survey -4547 -▁winner -4548 -▁enjoyed -4549 -▁showing -4550 -▁exercise -4551 -een -4552 -▁unc -4553 -▁Card -4554 -▁fourth -4555 -▁showed -4556 -▁spl -4557 -uries -4558 -▁anti -4559 -▁Francis -4560 -▁surgery -4561 -▁becoming -4562 -▁properties -4563 -pan -4564 -▁gain -4565 -▁recip -4566 -▁veget -4567 -▁Engine -4568 -▁markets -4569 -▁obvious -4570 -▁committed -4571 -▁suff -4572 -▁theme -4573 -▁focused -4574 -vere -4575 -▁plants -4576 -▁direction -4577 -ius -4578 -▁Tor -4579 -▁listen -4580 -▁managed -4581 -▁kick -4582 -iences -4583 -▁forum -4584 -▁chocolate -4585 -▁shel -4586 -▁limit -4587 -gers -4588 -lets -4589 -iency -4590 -▁legisl -4591 -aked -4592 -▁Its -4593 -▁Jun -4594 -▁busy -4595 -▁rain -4596 -issions -4597 -▁mechan -4598 -▁movement -4599 -▁encourage -4600 -▁rap -4601 -▁cloud -4602 -▁resist -4603 -▁putting -4604 -▁communication -4605 -OP -4606 -cher -4607 -▁bon -4608 -▁Their -4609 -▁raised -4610 -▁animals -4611 -▁assistance -4612 -?? -4613 -obe -4614 -oles -4615 -▁Bob -4616 -▁CEO -4617 -▁Full -4618 -▁Frank -4619 -▁lunch -4620 -▁defense -4621 -ita -4622 -▁analy -4623 -▁relig -4624 -life -4625 -rael -4626 -▁poll -4627 -▁corporate -4628 -▁practices -4629 -▁Technology -4630 -”, -4631 -itness -4632 -▁discover -4633 -▁Microsoft -4634 -", -4635 -gl -4636 -!!! -4637 -▁Mike -4638 -▁civil -4639 -▁reached -4640 -▁sources -4641 -bert -4642 -▁util -4643 -igation -4644 -vention -4645 -▁society -4646 -▁yesterday -4647 -orter -4648 -▁mill -4649 -▁chair -4650 -▁Wr -4651 -▁scr -4652 -▁youth -4653 -▁central -4654 -abilities -4655 -▁advanced -4656 -▁Ham -4657 -▁cart -4658 -▁architect -4659 -▁determine -4660 -REE -4661 -▁Fort -4662 -arrant -4663 -▁cleaning -4664 -▁vehicles -4665 -▁firstname -4666 -ena -4667 -ror -4668 -west -4669 -▁Tri -4670 -▁tea -4671 -▁dete -4672 -▁rare -4673 -▁AS -4674 -▁NOT -4675 -▁Mass -4676 -▁actual -4677 -yan -4678 -▁psych -4679 -▁Robert -4680 -▁tables -4681 -▁worksh -4682 -▁methods -4683 -▁leadership -4684 -▁Bur -4685 -▁ath -4686 -▁structure -4687 -kin -4688 -▁vs -4689 -▁pock -4690 -aturing -4691 -▁Commit -4692 -CC -4693 -MS -4694 -iled -4695 -▁Log -4696 -▁Set -4697 -▁fell -4698 -▁register -4699 -?” -4700 -▁repe -4701 -▁battle -4702 -▁format -4703 -▁becomes -4704 -▁willing -4705 -bre -4706 -ifts -4707 -▁colle -4708 -▁charges -4709 -▁funding -4710 -▁updates -4711 -▁thoughts -4712 -▁ju -4713 -▁Tre -4714 -ordin -4715 -▁toward -4716 -▁appears -4717 -▁visitors -4718 -▁fees -4719 -▁incor -4720 -▁sector -4721 -▁Copyright -4722 -▁absolutely -4723 -▁temperature -4724 -▁lose -4725 -▁locations -4726 -▁Keep -4727 -▁Next -4728 -▁colour -4729 -▁filled -4730 -▁songs -4731 -▁Network -4732 -▁Old -4733 -▁instru -4734 -levision -4735 -▁Wall -4736 -▁Trump -4737 -▁brown -4738 -▁Spring -4739 -▁century -4740 -▁extensive -4741 -▁Conference -4742 -kins -4743 -▁Land -4744 -▁Learn -4745 -▁Louis -4746 -▁asking -4747 -▁environmental -4748 -ola -4749 -ship -4750 -▁Way -4751 -▁topic -4752 -▁favour -4753 -▁transl -4754 -▁courses -4755 -▁profile -4756 -▁AL -4757 -▁Ol -4758 -while -4759 -▁Test -4760 -▁south -4761 -▁dur -4762 -▁Medic -4763 -▁Report -4764 -▁documents -4765 -▁previously -4766 -coh -4767 -▁Dou -4768 -▁Oper -4769 -▁adapt -4770 -▁north -4771 -ception -4772 -ipl -4773 -▁Plus -4774 -▁bowl -4775 -▁swim -4776 -ivered -4777 -▁guest -4778 -▁refer -4779 -▁visual -4780 -▁readers -4781 -▁anywhere -4782 -▁kid -4783 -▁registered -4784 -otton -4785 -▁Jeff -4786 -▁France -4787 -For -4788 -▁Cre -4789 -▁Lim -4790 -▁lux -4791 -▁sch -4792 -▁polic -4793 -▁charged -4794 -▁expertise -4795 -New -4796 -water -4797 -▁task -4798 -iration -4799 -▁upcoming -4800 -▁UN -4801 -▁wire -4802 -▁allowing -4803 -FL -4804 -▁Ok -4805 -▁selling -4806 -po -4807 -bour -4808 -▁bask -4809 -▁recommended -4810 -▁stre -4811 -▁Hotel -4812 -▁plays -4813 -▁Android -4814 -▁coverage -4815 -icip -4816 -▁Lat -4817 -▁fuel -4818 -▁neck -4819 -▁audio -4820 -▁sounds -4821 -▁Library -4822 -▁population -4823 -list -4824 -umin -4825 -▁Only -4826 -▁Conne -4827 -▁featured -4828 -▁Saf -4829 -▁pal -4830 -▁joint -4831 -▁Medical -4832 -▁princip -4833 -▁smaller -4834 -▁walking -4835 -▁ur -4836 -ulty -4837 -▁thr -4838 -▁Prov -4839 -▁seat -4840 -▁mental -4841 -▁establish -4842 -▁discussion -4843 -▁Jew -4844 -▁tun -4845 -▁apart -4846 -▁trial -4847 -▁parties -4848 -▁NE -4849 -istan -4850 -▁dance -4851 -ferences -4852 -IA -4853 -azz -4854 -ora -4855 -osis -4856 -▁Somet -4857 -▁Watch -4858 -igan -4859 -prise -4860 -▁Main -4861 -▁dogs -4862 -▁radio -4863 -▁despite -4864 -On -4865 -▁Lord -4866 -▁Walk -4867 -▁fold -4868 -▁truck -4869 -▁Africa -4870 -▁Virgin -4871 -▁scheduled -4872 -▁maintenance -4873 -▁Head -4874 -▁inspired -4875 -▁ON -4876 -▁diet -4877 -▁nine -4878 -▁restr -4879 -SA -4880 -▁writer -4881 -▁outdoor -4882 -▁Security -4883 -▁accommod -4884 -▁combined -4885 -▁van -4886 -ki -4887 -▁CA -4888 -▁har -4889 -▁citiz -4890 -▁scored -4891 -aks -4892 -alog -4893 -▁Western -4894 -rehensive -4895 -▁techniques -4896 -OO -4897 -▁Game -4898 -▁Admin -4899 -▁decide -4900 -▁seconds -4901 -▁Soft -4902 -▁Museum -4903 -▁values -4904 -▁removed -4905 -▁provider -4906 -▁sav -4907 -▁earth -4908 -▁raise -4909 -▁accompl -4910 -ownt -4911 -▁metal -4912 -▁stret -4913 -▁researc -4914 -eal -4915 -▁Place -4916 -▁spect -4917 -▁elements -4918 -▁purchased -4919 -▁joy -4920 -▁calc -4921 -▁purs -4922 -▁trees -4923 -▁launched -4924 -zen -4925 -▁Hy -4926 -▁Mer -4927 -▁sea -4928 -▁honest -4929 -▁movies -4930 -▁innovative -4931 -An -4932 -IF -4933 -▁panel -4934 -idering -4935 -▁counter -4936 -▁shooting -4937 -▁delicious -4938 -▁approximately -4939 -▁sitting -4940 -gment -4941 -▁killed -4942 -▁separate -4943 -▁edge -4944 -▁Video -4945 -▁Digital -4946 -▁teacher -4947 -▁relevant -4948 -ano -4949 -▁matt -4950 -▁approved -4951 -gage -4952 -▁lovely -4953 -▁parking -4954 -▁consumers -4955 -▁executive -4956 -My -4957 -nel -4958 -van -4959 -▁steel -4960 -▁Israel -4961 -▁Angeles -4962 -▁Manager -4963 -▁magazine -4964 -rs -4965 -ye -4966 -orry -4967 -▁hearing -4968 -▁concerns -4969 -bu -4970 -appy -4971 -igned -4972 -ushed -4973 -▁Charl -4974 -▁Person -4975 -pet -4976 -ellig -4977 -known -4978 -▁chat -4979 -▁conv -4980 -▁Georg -4981 -▁Peter -4982 -ensions -4983 -▁mostly -4984 -▁agreement -4985 -ears -4986 -▁eth -4987 -▁milk -4988 -▁rise -4989 -▁occasion -4990 -ups -4991 -▁Aud -4992 -▁tow -4993 -olars -4994 -▁Cook -4995 -▁Data -4996 -▁Join -4997 -isation -4998 -▁cheese -4999 -▁highlight -5000 -▁generation -5001 -VD -5002 -▁Ext -5003 -▁Ill -5004 -▁Penn -5005 -▁Word -5006 -▁Const -5007 -osit -5008 -▁mur -5009 -▁rid -5010 -▁Room -5011 -▁Thomas -5012 -▁identify -5013 -▁Gal -5014 -▁Pac -5015 -▁Centre -5016 -▁connected -5017 -▁intended -5018 -▁appearance -5019 -TV -5020 -fol -5021 -ring -5022 -orthern -5023 -▁controll -5024 -PA -5025 -ris -5026 -apes -5027 -▁sets -5028 -▁Prote -5029 -▁feels -5030 -▁waste -5031 -▁described -5032 -▁operation -5033 -▁commitment -5034 -▁Mo -5035 -▁Ver -5036 -irmed -5037 -▁truth -5038 -▁Master -5039 -▁academic -5040 -▁delivered -5041 -▁participate -5042 -cm -5043 -▁sympt -5044 -▁Through -5045 -ournament -5046 -!) -5047 -ENT -5048 -▁Men -5049 -oston -5050 -▁Lead -5051 -▁push -5052 -▁stars -5053 -▁Indust -5054 -▁Invest -5055 -▁server -5056 -▁Children -5057 -▁familiar -5058 -▁marriage -5059 -osen -5060 -▁Bas -5061 -▁nom -5062 -▁Arts -5063 -▁tough -5064 -▁enhance -5065 -▁capacity -5066 -▁relationships -5067 -UT -5068 -ycl -5069 -▁Upd -5070 -reens -5071 -▁cooking -5072 -▁promote -5073 -den -5074 -elines -5075 -▁landsc -5076 -ker -5077 -alend -5078 -nergy -5079 -▁cells -5080 -▁campus -5081 -▁editor -5082 -mond -5083 -▁mort -5084 -▁optim -5085 -▁cities -5086 -▁Journal -5087 -▁decisions -5088 -▁generally -5089 -▁Fair -5090 -▁signs -5091 -▁Access -5092 -▁wearing -5093 -▁therefore -5094 -▁introduced -5095 -arsh -5096 -berry -5097 -▁Vict -5098 -▁breast -5099 -▁accident -5100 -▁properly -5101 -▁processes -5102 -▁Er -5103 -prene -5104 -▁educational -5105 -▁Ul -5106 -▁Cam -5107 -cohol -5108 -eline -5109 -▁situ -5110 -▁majority -5111 -▁investigation -5112 -anda -5113 -inch -5114 -▁jew -5115 -▁minor -5116 -ya -5117 -burg -5118 -▁arm -5119 -ishing -5120 -▁opinion -5121 -▁detailed -5122 -▁Government -5123 -▁Dev -5124 -▁fly -5125 -▁Hand -5126 -▁Rest -5127 -reprene -5128 -▁technologies -5129 -▁teen -5130 -▁Chief -5131 -▁Earth -5132 -atabase -5133 -▁Global -5134 -▁minimum -5135 -▁category -5136 -▁presence -5137 -IR -5138 -▁Lab -5139 -▁ban -5140 -▁Live -5141 -▁label -5142 -▁calling -5143 -▁returned -5144 -▁emergency -5145 -▁expensive -5146 -▁mentioned -5147 -ef -5148 -▁Tur -5149 -▁feedback -5150 -fortunately -5151 -▁responsibility -5152 -▁Ari -5153 -▁Fund -5154 -▁Ohio -5155 -▁Wild -5156 -ression -5157 -▁Committee -5158 -▁installed -5159 -DF -5160 -▁Mur -5161 -▁ring -5162 -▁square -5163 -▁Johnson -5164 -▁foreign -5165 -▁bringing -5166 -▁hundreds -5167 -▁websites -5168 -▁Americans -5169 -▁installation -5170 -col -5171 -▁Que -5172 -▁plug -5173 -▁female -5174 -▁ourselves -5175 -rag -5176 -razy -5177 -▁Boston -5178 -▁entertainment -5179 -otten -5180 -ternal -5181 -▁invent -5182 -▁arrange -5183 -▁behavior -5184 -▁exchange -5185 -▁performed -5186 -▁episode -5187 -▁factors -5188 -▁consumer -5189 -▁advertising -5190 -ien -5191 -▁Pack -5192 -▁sizes -5193 -▁begins -5194 -▁satisf -5195 -hab -5196 -text -5197 -▁appeared -5198 -▁Di -5199 -▁Kn -5200 -aded -5201 -▁brief -5202 -▁sides -5203 -▁veter -5204 -▁Squ -5205 -▁flo -5206 -▁teach -5207 -▁units -5208 -▁studio -5209 -uts -5210 -▁Den -5211 -▁coast -5212 -ictions -5213 -emporary -5214 -▁MP -5215 -rist -5216 -▁Adv -5217 -▁Sup -5218 -▁Human -5219 -▁Federal -5220 -AY -5221 -▁elig -5222 -▁icon -5223 -▁tight -5224 -▁caught -5225 -▁transform -5226 -▁confidence -5227 -icians -5228 -▁chief -5229 -▁sauce -5230 -▁thick -5231 -ae -5232 -When -5233 -iser -5234 -▁Tour -5235 -▁fruit -5236 -▁Colorado -5237 -▁honor -5238 -▁holding -5239 -▁reserved -5240 -lock -5241 -▁Wal -5242 -▁Those -5243 -▁adults -5244 -▁topics -5245 -▁policies -5246 -▁supporting -5247 -spe -5248 -uke -5249 -▁https -5250 -▁Contin -5251 -▁ven -5252 -OC -5253 -hew -5254 -cean -5255 -▁alle -5256 -▁meat -5257 -▁ment -5258 -▁achie -5259 -▁chicken -5260 -▁windows -5261 -▁confident -5262 -▁HD -5263 -acle -5264 -▁vary -5265 -▁Price -5266 -rastructure -5267 -▁administration -5268 -▁Pan -5269 -▁motiv -5270 -▁animal -5271 -ifications -5272 -▁supported -5273 -with -5274 -▁Jud -5275 -▁cro -5276 -▁fantastic -5277 -ushing -5278 -▁mouth -5279 -▁sexual -5280 -▁seeking -5281 -SS -5282 -▁meal -5283 -▁Creat -5284 -▁alternative -5285 -arp -5286 -iat -5287 -arks -5288 -oted -5289 -▁Maybe -5290 -▁victory -5291 -ait -5292 -how -5293 -▁Bi -5294 -▁Search -5295 -▁Carolina -5296 -▁Australian -5297 -kes -5298 -ancer -5299 -▁Germany -5300 -▁components -5301 -▁importance -5302 -▁competitive -5303 -vy -5304 -▁sy -5305 -▁Prem -5306 -▁quiet -5307 -▁basket -5308 -▁edition -5309 -paper -5310 -▁tele -5311 -▁sister -5312 -▁dollars -5313 -rier -5314 -▁cheap -5315 -▁leads -5316 -▁thread -5317 -▁apparent -5318 -ste -5319 -▁Jon -5320 -▁rom -5321 -▁rub -5322 -unting -5323 -▁Canad -5324 -▁Sports -5325 -▁switch -5326 -▁guarantee -5327 -▁Academy -5328 -▁conduct -5329 -▁confirm -5330 -▁transact -5331 -▁conversation -5332 -inct -5333 -▁Lin -5334 -ighter -5335 -▁distance -5336 -▁Tit -5337 -▁Young -5338 -▁recru -5339 -▁centre -5340 -▁measures -5341 -▁worldwide -5342 -Com -5343 -▁Gar -5344 -▁Gen -5345 -▁info -5346 -▁Festival -5347 -▁Students -5348 -.| -5349 -etic -5350 -▁Bal -5351 -▁fif -5352 -▁picked -5353 -iability -5354 -▁remaining -5355 -▁photograph -5356 -weet -5357 -▁Jose -5358 -weight -5359 -▁bread -5360 -▁license -5361 -away -5362 -ucks -5363 -▁impl -5364 -▁flight -5365 -▁totally -5366 -▁Nor -5367 -▁rat -5368 -▁Meet -5369 -▁doubt -5370 -▁prison -5371 -▁unless -5372 -▁tack -5373 -▁Martin -5374 -inations -5375 -NA -5376 -atre -5377 -▁Sar -5378 -▁ang -5379 -▁vir -5380 -achel -5381 -uable -5382 -▁species -5383 -How -5384 -elly -5385 -ersey -5386 -▁restaurants -5387 -▁comprehensive -5388 -asks -5389 -▁seek -5390 -▁doors -5391 -▁contest -5392 -▁agencies -5393 -ailability -5394 -▁Champions -5395 -iano -5396 -verse -5397 -▁Quest -5398 -▁tests -5399 -▁faster -5400 -▁delight -5401 -▁maximum -5402 -▁celebrate -5403 -uzz -5404 -eries -5405 -▁league -5406 -▁clearly -5407 -▁musical -5408 -▁visiting -5409 -▁photograp -5410 -RC -5411 -TH -5412 -Our -5413 -▁Type -5414 -▁forg -5415 -itable -5416 -▁depart -5417 -▁painting -5418 -▁eventually -5419 -pass -5420 -▁Did -5421 -▁dyn -5422 -▁wel -5423 -estyle -5424 -▁noted -5425 -▁planned -5426 -▁election -5427 -▁revealed -5428 -▁considering -5429 -TC -5430 -otic -5431 -▁Inte -5432 -▁propos -5433 -▁prepare -5434 -▁depending -5435 -▁Cred -5436 -▁Using -5437 -▁Energy -5438 -▁arrived -5439 -▁housing -5440 -▁married -5441 -▁university -5442 -igr -5443 -▁Ro -5444 -usion -5445 -▁burn -5446 -▁lived -5447 -▁ticket -5448 -▁Hospital -5449 -▁bike -5450 -▁mine -5451 -▁Jackson -5452 -▁sessions -5453 -erg -5454 -▁Ce -5455 -▁inn -5456 -iminal -5457 -ixture -5458 -orough -5459 -▁scale -5460 -▁Assist -5461 -▁SP -5462 -wing -5463 -▁McC -5464 -▁ign -5465 -▁ris -5466 -ulous -5467 -▁FREE -5468 -▁apps -5469 -▁otherwise -5470 -▁discovered -5471 -▁Mid -5472 -▁Cost -5473 -▁compar -5474 -▁gather -5475 -▁officer -5476 -mes -5477 -▁Secret -5478 -▁climate -5479 -▁monthly -5480 -▁Japanese -5481 -▁chemical -5482 -▁neighborhood -5483 -▁boys -5484 -▁ends -5485 -▁liqu -5486 -▁evalu -5487 -▁turns -5488 -▁inches -5489 -▁spokes -5490 -▁struct -5491 -▁commission -5492 -▁Kore -5493 -▁weap -5494 -▁symptoms -5495 -ht -5496 -▁Bul -5497 -▁Cat -5498 -agram -5499 -▁freed -5500 -▁missed -5501 -▁cutting -5502 -▁accounts -5503 -▁internal -5504 -▁reliable -5505 -ias -5506 -▁ran -5507 -tered -5508 -▁pump -5509 -▁surf -5510 -related -5511 -▁brands -5512 -▁lights -5513 -▁seemed -5514 -▁appreciate -5515 -▁participants -5516 -otes -5517 -alian -5518 -▁Know -5519 -▁battery -5520 -▁organic -5521 -▁affordable -5522 -edia -5523 -▁hyd -5524 -▁Cert -5525 -▁corn -5526 -▁twice -5527 -▁Applic -5528 -▁Columb -5529 -▁Georgia -5530 -▁cultural -5531 -▁resource -5532 -▁featuring -5533 -hi -5534 -▁Second -5535 -▁automatically -5536 -They -5537 -ician -5538 -▁valid -5539 -▁athlet -5540 -▁paying -5541 -▁submit -5542 -▁African -5543 -▁meetings -5544 -iors -5545 -▁Code -5546 -▁Jones -5547 -▁Andrew -5548 -EE -5549 -▁emp -5550 -▁Share -5551 -▁bigger -5552 -▁regularly -5553 -); -5554 -Ex -5555 -but -5556 -▁Hard -5557 -▁Qual -5558 -▁debt -5559 -▁Middle -5560 -▁failed -5561 -▁supposed -5562 -▁Ep -5563 -▁Help -5564 -▁Steve -5565 -▁storm -5566 -▁accurate -5567 -▁possibly -5568 -GB -5569 -ua -5570 -ban -5571 -▁mel -5572 -▁pod -5573 -▁boost -5574 -▁deals -5575 -▁labor -5576 -▁volume -5577 -▁television -5578 -▁presentation -5579 -cont -5580 -▁fro -5581 -▁draft -5582 -▁fellow -5583 -▁realize -5584 -▁manufacturing -5585 -Pro -5586 -▁Ut -5587 -▁fle -5588 -▁Daniel -5589 -▁concent -5590 -▁Virginia -5591 -▁messages -5592 -?" -5593 -▁SH -5594 -ennis -5595 -idden -5596 -pected -5597 -▁fields -5598 -▁revenue -5599 -▁affected -5600 -▁recovery -5601 -EST -5602 -rupt -5603 -▁Boy -5604 -▁Blog -5605 -▁German -5606 -▁covers -5607 -▁shares -5608 -▁proposed -5609 -▁researchers -5610 -No -5611 -roy -5612 -eper -5613 -mosp -5614 -▁die -5615 -rical -5616 -▁Page -5617 -iamond -5618 -alendar -5619 -oration -5620 -▁Rights -5621 -ployment -5622 -▁returns -5623 -▁engineering -5624 -▁Lee -5625 -▁Tem -5626 -▁Farm -5627 -▁Travel -5628 -▁birthday -5629 -▁AD -5630 -case -5631 -▁Rom -5632 -▁aid -5633 -▁ages -5634 -▁Little -5635 -▁confirmed -5636 -▁instructions -5637 -▁amb -5638 -cious -5639 -▁Cast -5640 -▁Trust -5641 -▁dates -5642 -▁tells -5643 -▁answers -5644 -▁creation -5645 -▁interior -5646 -▁protected -5647 -ca -5648 -ters -5649 -▁Tech -5650 -▁breakfast -5651 -▁sad -5652 -▁wal -5653 -▁dish -5654 -▁chart -5655 -▁warrant -5656 -▁industrial -5657 -▁infrastructure -5658 -iner -5659 -▁nor -5660 -which -5661 -▁Orig -5662 -▁Games -5663 -▁Visit -5664 -▁loves -5665 -▁Mexico -5666 -▁county -5667 -▁applied -5668 -▁browser -5669 -▁employee -5670 -ario -5671 -▁nurs -5672 -▁agent -5673 -▁pregn -5674 -▁specifically -5675 -▁Opt -5676 -▁mir -5677 -▁poly -5678 -▁route -5679 -▁desire -5680 -▁issued -5681 -▁choices -5682 -▁decades -5683 -▁drivers -5684 -▁NC -5685 -▁Hen -5686 -▁hook -5687 -▁rapid -5688 -▁furniture -5689 -▁chain -5690 -▁foods -5691 -fection -5692 -▁flowers -5693 -▁reference -5694 -▁twe -5695 -▁hero -5696 -▁jack -5697 -▁affili -5698 -▁element -5699 -▁perfectly -5700 -▁WH -5701 -gend -5702 -▁Joe -5703 -erves -5704 -▁thus -5705 -lights -5706 -▁attorney -5707 -▁standing -5708 -▁exclusive -5709 -ansas -5710 -▁tail -5711 -▁plate -5712 -▁chosen -5713 -▁earned -5714 -▁supports -5715 -upp -5716 -▁CH -5717 -▁anc -5718 -▁yes -5719 -anger -5720 -odies -5721 -▁Made -5722 -▁bond -5723 -▁Broad -5724 -▁talks -5725 -▁Control -5726 -▁Francisco -5727 -▁employment -5728 -hand -5729 -rick -5730 -▁Ken -5731 -hetic -5732 -oking -5733 -▁mode -5734 -▁vent -5735 -▁Brand -5736 -▁remote -5737 -ibilities -5738 -▁Executive -5739 -anna -5740 -irms -5741 -▁Dom -5742 -▁End -5743 -ospit -5744 -▁Enjoy -5745 -▁agreed -5746 -▁purposes -5747 -▁apartment -5748 -▁incredible -5749 -Al -5750 -▁AT -5751 -▁Lo -5752 -lymp -5753 -▁Bon -5754 -▁wid -5755 -▁Expl -5756 -▁broken -5757 -▁improved -5758 -▁strategies -5759 -UN -5760 -can -5761 -▁DVD -5762 -▁nav -5763 -▁Does -5764 -▁logo -5765 -▁Store -5766 -▁Williams -5767 -▁processing -5768 -▁Hope -5769 -▁Pass -5770 -▁Sher -5771 -▁Current -5772 -▁illustr -5773 -▁hardware -5774 -▁surrounding -5775 -▁Sy -5776 -anges -5777 -▁cake -5778 -▁cute -5779 -▁whom -5780 -▁advis -5781 -▁Product -5782 -▁recorded -5783 -▁disappoint -5784 -BI -5785 -MA -5786 -▁Id -5787 -ench -5788 -hent -5789 -▁Equ -5790 -▁Haw -5791 -▁lit -5792 -▁Coast -5793 -▁quant -5794 -▁reput -5795 -▁rough -5796 -▁premium -5797 -aped -5798 -▁Mic -5799 -adium -5800 -▁golf -5801 -ampion -5802 -▁holds -5803 -▁judge -5804 -▁pleased -5805 -▁accepted -5806 -▁suitable -5807 -umes -5808 -idays -5809 -▁boat -5810 -▁Point -5811 -▁downt -5812 -▁losing -5813 -▁Instead -5814 -▁male -5815 -▁pure -5816 -▁grade -5817 -▁trouble -5818 -uous -5819 -▁rule -5820 -▁Three -5821 -▁wheel -5822 -▁administr -5823 -▁buildings -5824 -lyn -5825 -oga -5826 -uits -5827 -▁usual -5828 -▁History -5829 -▁explain -5830 -▁domestic -5831 -▁concerned -5832 -!” -5833 -xy -5834 -itage -5835 -▁telling -5836 -▁Minister -5837 -▁violence -5838 -▁candidates -5839 -gas -5840 -ums -5841 -▁moist -5842 -▁licens -5843 -▁aspects -5844 -▁Communic -5845 -▁injuries -5846 -▁favourite -5847 -tra -5848 -▁ok -5849 -what -5850 -▁Girl -5851 -person -5852 -▁moments -5853 -▁typically -5854 -otal -5855 -▁pun -5856 -▁tur -5857 -▁Party -5858 -▁error -5859 -▁causes -5860 -▁styles -5861 -▁Italian -5862 -▁awareness -5863 -▁registration -5864 -▁vit -5865 -▁arts -5866 -▁phil -5867 -▁Night -5868 -▁Print -5869 -▁Perform -5870 -rim -5871 -road -5872 -lines -5873 -▁oven -5874 -▁grown -5875 -▁enable -5876 -▁island -5877 -▁greatest -5878 -vell -5879 -▁Harr -5880 -▁rand -5881 -orable -5882 -▁abuse -5883 -▁shoes -5884 -▁forces -5885 -▁stated -5886 -fficient -5887 -▁surprise -5888 -va -5889 -▁FOR -5890 -▁Key -5891 -▁tag -5892 -▁taxes -5893 -▁photography -5894 -ERS -5895 -hors -5896 -▁jun -5897 -anish -5898 -cluding -5899 -▁closer -5900 -▁citizens -5901 -▁negative -5902 -▁influence -5903 -CA -5904 -bur -5905 -writ -5906 -▁Four -5907 -▁circum -5908 -▁actions -5909 -ria -5910 -▁Def -5911 -▁Dog -5912 -tters -5913 -ulture -5914 -▁retire -5915 -▁script -5916 -▁stopped -5917 -▁stretch -5918 -▁broadcast -5919 -▁Wi -5920 -pond -5921 -▁Drive -5922 -▁Local -5923 -▁gradu -5924 -▁resol -5925 -▁Division -5926 -▁wet -5927 -▁crew -5928 -▁powder -5929 -▁database -5930 -▁tomorrow -5931 -▁sam -5932 -astern -5933 -▁Olymp -5934 -▁leather -5935 -▁practical -5936 -ribe -5937 -▁Bra -5938 -▁Ell -5939 -▁Max -5940 -▁adm -5941 -▁argu -5942 -Un -5943 -▁serves -5944 -▁weekly -5945 -▁alleged -5946 -iami -5947 -udden -5948 -▁shock -5949 -▁Pacific -5950 -▁payments -5951 -▁functions -5952 -▁inspiration -5953 -DS -5954 -▁Gra -5955 -stone -5956 -▁acid -5957 -▁bound -5958 -▁faculty -5959 -And -5960 -yers -5961 -▁tro -5962 -alled -5963 -▁mini -5964 -▁funny -5965 -▁Awards -5966 -▁speech -5967 -▁receiving -5968 -▁authorities -5969 -ava -5970 -hus -5971 -▁Mat -5972 -merce -5973 -▁Ryan -5974 -▁sequ -5975 -▁thin -5976 -lywood -5977 -▁column -5978 -▁designer -5979 -ucle -5980 -▁hits -5981 -▁cable -5982 -forcement -5983 -▁supplies -5984 -▁Available -5985 -▁electronic -5986 -TA -5987 -ERE -5988 -▁rot -5989 -atholic -5990 -▁config -5991 -▁pepper -5992 -▁village -5993 -▁identified -5994 -▁tut -5995 -▁gear -5996 -▁Cross -5997 -▁random -5998 -poration -5999 -▁everyday -6000 -▁committee -6001 -GE -6002 -bol -6003 -oup -6004 -irty -6005 -▁Hor -6006 -▁Oil -6007 -under -6008 -profit -6009 -▁Econom -6010 -▁perman -6011 -▁recognized -6012 -ache -6013 -▁Aff -6014 -itate -6015 -never -6016 -right -6017 -▁Coll -6018 -▁Need -6019 -▁grab -6020 -▁atmosp -6021 -▁degrees -6022 -▁printed -6023 -▁convenient -6024 -▁healthcare -6025 -▁impressive -6026 -PM -6027 -mar -6028 -inet -6029 -▁crime -6030 -▁keeps -6031 -▁lessons -6032 -▁Michigan -6033 -Pl -6034 -So -6035 -rip -6036 -▁tab -6037 -▁Bell -6038 -▁Cond -6039 -isters -6040 -▁essay -6041 -▁flour -6042 -▁crisis -6043 -▁height -6044 -▁emotional -6045 -▁determined -6046 -▁Cas -6047 -▁Ref -6048 -▁Tay -6049 -▁voc -6050 -atoes -6051 -etime -6052 -▁Ariz -6053 -▁films -6054 -▁imagine -6055 -▁treated -6056 -▁Sometimes -6057 -▁dangerous -6058 -▁happening -6059 -▁Lt -6060 -▁PS -6061 -aren -6062 -phas -6063 -▁Dun -6064 -▁Try -6065 -▁Small -6066 -▁crazy -6067 -▁Comple -6068 -▁ongoing -6069 -▁champions -6070 -▁explained -6071 -iate -6072 -hered -6073 -inter -6074 -▁Jenn -6075 -▁Mean -6076 -uction -6077 -▁Santa -6078 -▁fixed -6079 -▁sheet -6080 -▁entreprene -6081 -Ar -6082 -▁Run -6083 -▁Sus -6084 -urban -6085 -▁Safety -6086 -▁dropped -6087 -▁Marketing -6088 -cue -6089 -rum -6090 -▁Fed -6091 -▁patterns -6092 -▁resolution -6093 -▁du -6094 -pret -6095 -▁Mach -6096 -▁Canadian -6097 -▁investors -6098 -LS -6099 -All -6100 -aid -6101 -eler -6102 -made -6103 -▁row -6104 -▁worse -6105 -▁Victor -6106 -▁dining -6107 -iversary -6108 -▁subscrib -6109 -▁gro -6110 -anged -6111 -arian -6112 -▁Writ -6113 -▁rear -6114 -▁Guide -6115 -▁command -6116 -▁trading -6117 -▁conducted -6118 -▁tradition -6119 -LA -6120 -mary -6121 -anche -6122 -osoph -6123 -▁Rose -6124 -▁soul -6125 -▁taught -6126 -▁arrested -6127 -▁attended -6128 -▁officers -6129 -▁appointment -6130 -▁collaboration -6131 -Bl -6132 -Con -6133 -▁GM -6134 -▁Kh -6135 -enced -6136 -▁lift -6137 -▁simpl -6138 -▁extended -6139 -lete -6140 -▁der -6141 -▁Priv -6142 -▁cock -6143 -▁grad -6144 -▁roof -6145 -▁Chair -6146 -▁hoping -6147 -▁alcohol -6148 -▁positions -6149 -▁Environment -6150 -▁successfully -6151 -ppers -6152 -oosing -6153 -▁native -6154 -▁tournament -6155 -Don -6156 -inson -6157 -▁grew -6158 -▁wash -6159 -▁depth -6160 -▁flood -6161 -▁Account -6162 -▁freedom -6163 -▁ordered -6164 -▁eligible -6165 -▁incident -6166 -▁sick -6167 -▁folks -6168 -▁Senate -6169 -▁versions -6170 -iana -6171 -▁Inf -6172 -▁kne -6173 -▁Mult -6174 -▁spin -6175 -▁Richard -6176 -ello -6177 -rate -6178 -▁obtain -6179 -▁severe -6180 -▁Sat -6181 -aints -6182 -▁Turn -6183 -▁Photo -6184 -▁cycle -6185 -▁guard -6186 -▁teeth -6187 -▁noticed -6188 -iki -6189 -▁bat -6190 -▁Area -6191 -▁Paris -6192 -▁advoc -6193 -▁belong -6194 -▁forced -6195 -▁massive -6196 -▁graduate -6197 -▁construct -6198 -Be -6199 -ala -6200 -cers -6201 -essed -6202 -racts -6203 -▁adds -6204 -▁dram -6205 -▁none -6206 -▁houses -6207 -▁improvement -6208 -hire -6209 -real -6210 -rics -6211 -▁Daily -6212 -▁trend -6213 -iveness -6214 -▁Summer -6215 -▁tested -6216 -▁failure -6217 -▁Building -6218 -▁valuable -6219 -▁innovation -6220 -tle -6221 -▁ol -6222 -▁Kent -6223 -▁Which -6224 -▁mixed -6225 -▁shots -6226 -▁yards -6227 -▁cotton -6228 -▁regional -6229 -ayer -6230 -utch -6231 -▁Ash -6232 -▁Die -6233 -rease -6234 -▁Carl -6235 -▁Clean -6236 -▁Right -6237 -▁council -6238 -Is -6239 -▁MS -6240 -▁Box -6241 -▁Rev -6242 -▁thorough -6243 -▁integrated -6244 -▁DC -6245 -▁syn -6246 -▁Size -6247 -▁tiny -6248 -hentic -6249 -▁output -6250 -za -6251 -▁ec -6252 -inem -6253 -▁tank -6254 -▁owned -6255 -▁concert -6256 -▁knowing -6257 -▁routine -6258 -▁turning -6259 -▁efficiency -6260 -erse -6261 -▁drugs -6262 -▁Avenue -6263 -▁facing -6264 -▁guitar -6265 -▁diverse -6266 -▁therapy -6267 -▁clothing -6268 -▁providers -6269 -▁MO -6270 -▁Sn -6271 -▁Ent -6272 -▁Tool -6273 -acking -6274 -▁Select -6275 -▁publish -6276 -▁reduced -6277 -▁interface -6278 -CE -6279 -▁fo -6280 -▁Hon -6281 -osite -6282 -secut -6283 -▁Asia -6284 -▁Though -6285 -▁yellow -6286 -▁follows -6287 -▁description -6288 -▁distribution -6289 -illy -6290 -▁LLC -6291 -▁ped -6292 -abled -6293 -ansion -6294 -▁Training -6295 -▁settings -6296 -▁surprised -6297 -▁effectively -6298 -▁EU -6299 -print -6300 -▁auto -6301 -▁dial -6302 -sembly -6303 -▁Miami -6304 -▁silver -6305 -▁mixture -6306 -▁contemporary -6307 -▁expectations -6308 -▁:) -6309 -abet -6310 -▁Ball -6311 -intage -6312 -▁baking -6313 -▁enthus -6314 -▁unable -6315 -▁carried -6316 -▁circumst -6317 -▁intellig -6318 -▁accessible -6319 -▁challenging -6320 -▁perspective -6321 -▁Ira -6322 -▁Low -6323 -▁Want -6324 -letter -6325 -▁bonus -6326 -▁risks -6327 -▁upper -6328 -quality -6329 -▁nearby -6330 -▁pulled -6331 -▁protein -6332 -▁stunning -6333 -▁candidate -6334 -CT -6335 -PR -6336 -▁af -6337 -iece -6338 -ATION -6339 -▁Phys -6340 -▁Italy -6341 -▁stands -6342 -ev -6343 -aze -6344 -claim -6345 -▁Lind -6346 -ington -6347 -▁Beaut -6348 -▁matters -6349 -▁tonight -6350 -▁significantly -6351 -rowse -6352 -▁Nick -6353 -▁laugh -6354 -▁Proper -6355 -▁excess -6356 -▁garlic -6357 -▁univers -6358 -▁witness -6359 -▁approval -6360 -▁medicine -6361 -▁carefully -6362 -sm -6363 -zy -6364 -▁hur -6365 -▁Shop -6366 -▁chapter -6367 -▁complic -6368 -▁joining -6369 -obs -6370 -flow -6371 -oral -6372 -▁Cir -6373 -oured -6374 -▁fulf -6375 -▁equal -6376 -▁kinds -6377 -▁awarded -6378 -▁bedroom -6379 -▁channel -6380 -▁hosting -6381 -▁guidance -6382 -▁vacation -6383 -▁adventure -6384 -▁increases -6385 -▁recording -6386 -▁availability -6387 -▁SU -6388 -▁Dub -6389 -▁Requ -6390 -▁sole -6391 -▁Never -6392 -▁Works -6393 -▁likes -6394 -▁emphas -6395 -▁festival -6396 -▁accessories -6397 -bal -6398 -zer -6399 -▁glad -6400 -▁iron -6401 -▁tall -6402 -▁Heart -6403 -▁loans -6404 -▁Spanish -6405 -UL -6406 -rete -6407 -▁ease -6408 -riends -6409 -▁filed -6410 -▁renew -6411 -clusion -6412 -▁cooper -6413 -▁Republican -6414 -▁exhibition -6415 -▁partnership -6416 -stal -6417 -▁hopes -6418 -▁Credit -6419 -▁Mobile -6420 -▁SE -6421 -▁Rub -6422 -acked -6423 -ether -6424 -folio -6425 -▁bags -6426 -nesota -6427 -orgeous -6428 -▁creates -6429 -▁speaking -6430 -▁lifestyle -6431 -HA -6432 -sen -6433 -you -6434 -▁diss -6435 -▁hang -6436 -▁vend -6437 -▁Connect -6438 -▁Student -6439 -To -6440 -▁) -6441 -▁AR -6442 -adow -6443 -▁unf -6444 -▁legs -6445 -▁occup -6446 -▁Disney -6447 -▁appeal -6448 -▁assets -6449 -▁motion -6450 -▁trends -6451 -▁clothes -6452 -▁context -6453 -▁reporting -6454 -▁replacement -6455 -FC -6456 -yth -6457 -onto -6458 -yard -6459 -agues -6460 -▁Email -6461 -▁spaces -6462 -▁entirely -6463 -▁scholars -6464 -▁constantly -6465 -!" -6466 -anny -6467 -ican -6468 -long -6469 -▁arms -6470 -orders -6471 -▁shift -6472 -▁stamp -6473 -▁forest -6474 -▁Members -6475 -▁certific -6476 -▁searching -6477 -▁sustainable -6478 -▁OS -6479 -irts -6480 -onym -6481 -rition -6482 -▁spark -6483 -▁Number -6484 -▁Taylor -6485 -▁engage -6486 -▁manner -6487 -▁conflic -6488 -▁believes -6489 -▁submitted -6490 -II -6491 -bi -6492 -▁LED -6493 -comes -6494 -eding -6495 -▁kill -6496 -▁luxury -6497 -▁Studies -6498 -▁streets -6499 -▁procedures -6500 -ml -6501 -▁pil -6502 -▁fort -6503 -▁Still -6504 -▁sudden -6505 -▁outstanding -6506 -rid -6507 -▁Rh -6508 -foot -6509 -▁odd -6510 -▁cuts -6511 -▁Field -6512 -▁goods -6513 -▁negot -6514 -▁awards -6515 -▁criminal -6516 -▁monitoring -6517 -▁originally -6518 -▁SC -6519 -▁Kim -6520 -ially -6521 -▁Russian -6522 -▁invited -6523 -▁trained -6524 -▁Southern -6525 -▁millions -6526 -▁seriously -6527 -▁performing -6528 -▁transition -6529 -erts -6530 -ikes -6531 -▁Pot -6532 -▁eleg -6533 -▁weak -6534 -▁walls -6535 -▁recycl -6536 -▁refund -6537 -▁unlike -6538 -▁Arizona -6539 -▁capture -6540 -osc -6541 -asts -6542 -emic -6543 -izer -6544 -▁Pop -6545 -▁dim -6546 -▁rac -6547 -athan -6548 -ented -6549 -▁ille -6550 -▁zone -6551 -▁factor -6552 -▁prompt -6553 -▁reward -6554 -friendly -6555 -PC -6556 -ih -6557 -pat -6558 -bing -6559 -▁mal -6560 -▁Very -6561 -▁entr -6562 -▁horse -6563 -▁quote -6564 -▁museum -6565 -▁Mountain -6566 -Le -6567 -Ph -6568 -ba -6569 -▁Ra -6570 -▁Far -6571 -▁anx -6572 -▁vul -6573 -▁Jersey -6574 -▁conver -6575 -▁relief -6576 -▁illness -6577 -▁fighting -6578 -ATE -6579 -icket -6580 -▁blow -6581 -▁remov -6582 -▁Despite -6583 -▁Seattle -6584 -▁Standard -6585 -▁interests -6586 -▁foundation -6587 -▁cm -6588 -izza -6589 -front -6590 -▁Braz -6591 -▁Kenn -6592 -▁Pract -6593 -▁Should -6594 -▁herself -6595 -▁virtual -6596 -▁younger -6597 -HS -6598 -born -6599 -elry -6600 -▁tip -6601 -▁Easy -6602 -▁Ford -6603 -▁Iraq -6604 -▁moves -6605 -▁pocket -6606 -▁involve -6607 -▁examples -6608 -ani -6609 -rell -6610 -▁rose -6611 -▁smile -6612 -▁pounds -6613 -▁wealth -6614 -▁offices -6615 -▁flexible -6616 -▁Minnesota -6617 -▁transportation -6618 -▁Fre -6619 -▁Ire -6620 -▁Fall -6621 -▁gifts -6622 -▁input -6623 -▁Senior -6624 -▁upload -6625 -▁bathroom -6626 -▁assessment -6627 -▁capabilities -6628 -▁Jr -6629 -▁Ray -6630 -▁Rod -6631 -▁Stat -6632 -▁eggs -6633 -▁hole -6634 -▁pink -6635 -▁directed -6636 -▁identity -6637 -anes -6638 -ifer -6639 -iler -6640 -uter -6641 -▁Luc -6642 -▁Sav -6643 -▁beer -6644 -▁rein -6645 -▁bottle -6646 -▁Finally -6647 -▁airport -6648 -▁founded -6649 -▁clinical -6650 -▁ultimate -6651 -RS -6652 -sey -6653 -▁Army -6654 -▁debut -6655 -aturally -6656 -▁scientific -6657 -At -6658 -▁Ha -6659 -aron -6660 -▁Ask -6661 -▁Jac -6662 -▁sac -6663 -▁Bible -6664 -▁Royal -6665 -▁worst -6666 -illiant -6667 -▁distinct -6668 -▁improving -6669 -car -6670 -ilst -6671 -quir -6672 -▁Est -6673 -▁Kat -6674 -▁Vers -6675 -▁Event -6676 -▁elimin -6677 -▁figures -6678 -▁fishing -6679 -▁forever -6680 -▁copyright -6681 -da -6682 -▁Put -6683 -▁bab -6684 -ashed -6685 -▁Supp -6686 -▁faces -6687 -▁hospit -6688 -▁Country -6689 -▁Software -6690 -▁? -6691 -▁Non -6692 -ingly -6693 -▁garage -6694 -▁Instagram -6695 -▁tie -6696 -arrow -6697 -icate -6698 -▁Come -6699 -▁Site -6700 -▁Again -6701 -▁spoke -6702 -▁rating -6703 -▁Charles -6704 -▁visited -6705 -▁residential -6706 -▁Cab -6707 -ylvan -6708 -▁Arab -6709 -▁Fact -6710 -▁hasn -6711 -▁blank -6712 -▁stone -6713 -aration -6714 -▁entered -6715 -▁objects -6716 -▁rig -6717 -▁split -6718 -▁contribute -6719 -▁Unfortunately -6720 -RI -6721 -awn -6722 -uine -6723 -▁Bed -6724 -▁Dist -6725 -season -6726 -▁liked -6727 -▁spots -6728 -▁murder -6729 -▁Atlanta -6730 -▁developers -6731 -▁implementation -6732 -eah -6733 -With -6734 -▁coc -6735 -▁san -6736 -▁sky -6737 -▁Term -6738 -▁pitc -6739 -cluded -6740 -▁Radio -6741 -▁shower -6742 -▁Looking -6743 -▁Systems -6744 -▁baseball -6745 -▁calendar -6746 -▁Professor -6747 -▁procedure -6748 -oes -6749 -▁Ms -6750 -That -6751 -▁Save -6752 -▁cups -6753 -▁vital -6754 -resents -6755 -▁Member -6756 -▁linked -6757 -▁historical -6758 -▁possibility -6759 -Se -6760 -omy -6761 -umps -6762 -▁Mom -6763 -▁Foot -6764 -▁vibr -6765 -▁pitch -6766 -▁flavor -6767 -▁liquid -6768 -▁drawing -6769 -▁fitness -6770 -▁password -6771 -▁household -6772 -▁programme -6773 -▁atmosphere -6774 -▁reputation -6775 -andy -6776 -hell -6777 -ossible -6778 -▁enroll -6779 -▁papers -6780 -▁recipes -6781 -▁attached -6782 -▁mountain -6783 -▁organized -6784 -▁LA -6785 -▁Pow -6786 -▁hall -6787 -▁soph -6788 -▁tiss -6789 -asters -6790 -▁liber -6791 -▁Having -6792 -▁critic -6793 -▁muscle -6794 -▁talked -6795 -▁Administration -6796 -LY -6797 -One -6798 -host -6799 -▁Sem -6800 -▁Van -6801 -▁empt -6802 -▁seed -6803 -Americ -6804 -▁Brazil -6805 -▁Russia -6806 -▁carbon -6807 -▁passing -6808 -▁privacy -6809 -▁seasons -6810 -▁victims -6811 -▁frequently -6812 -▁institutions -6813 -.' -6814 -MP -6815 -But -6816 -rad -6817 -▁CO -6818 -▁PA -6819 -▁Space -6820 -▁chose -6821 -▁Living -6822 -▁theory -6823 -▁Shipping -6824 -▁MA -6825 -Read -6826 -▁ads -6827 -enger -6828 -ordan -6829 -▁rail -6830 -▁tech -6831 -▁regul -6832 -▁profit -6833 -▁managing -6834 -▁circumstances -6835 -ras -6836 -adel -6837 -tain -6838 -▁Son -6839 -▁Barb -6840 -▁hurt -6841 -▁proven -6842 -▁Justice -6843 -▁historic -6844 -▁networks -6845 -▁permission -6846 -▁legislation -6847 -▁publication -6848 -phy -6849 -▁Ba -6850 -bury -6851 -▁Cru -6852 -▁Cut -6853 -rible -6854 -▁butt -6855 -▁inch -6856 -▁Image -6857 -▁Express -6858 -▁regulations -6859 -dy -6860 -neys -6861 -ucky -6862 -▁err -6863 -uling -6864 -▁counsel -6865 -ta -6866 -ura -6867 -▁BE -6868 -▁Ur -6869 -olis -6870 -▁Fac -6871 -worth -6872 -▁Prom -6873 -▁skill -6874 -unction -6875 -▁Source -6876 -▁debate -6877 -▁Further -6878 -▁exposure -6879 -ubs -6880 -▁($ -6881 -▁Mir -6882 -▁Nic -6883 -▁Tax -6884 -▁cos -6885 -▁west -6886 -▁Garden -6887 -▁tracks -6888 -▁operate -6889 -RL -6890 -nders -6891 -▁Link -6892 -▁Name -6893 -▁lets -6894 -ffered -6895 -▁breath -6896 -▁qualified -6897 -▁represents -6898 -▁Leg -6899 -▁Oak -6900 -▁Brad -6901 -▁delay -6902 -▁finds -6903 -▁Season -6904 -▁walked -6905 -▁technique -6906 -▁NAS -6907 -▁bow -6908 -▁obl -6909 -▁tou -6910 -▁Anth -6911 -uclear -6912 -▁Choose -6913 -▁saving -6914 -▁authors -6915 -▁Learning -6916 -▁contrast -6917 -ella -6918 -ione -6919 -pons -6920 -▁Ltd -6921 -▁lad -6922 -icial -6923 -▁Scot -6924 -▁Brian -6925 -▁normally -6926 -▁realized -6927 -▁authentic -6928 -zes -6929 -urse -6930 -▁Rog -6931 -eller -6932 -▁fifth -6933 -▁merch -6934 -▁sight -6935 -▁tasks -6936 -▁hosted -6937 -▁reader -6938 -▁causing -6939 -▁savings -6940 -▁downtown -6941 -▁instance -6942 -By -6943 -odd -6944 -▁OR -6945 -▁Tony -6946 -▁mold -6947 -▁casual -6948 -▁execut -6949 -igration -6950 -ographic -6951 -▁anticip -6952 -▁justice -6953 -▁promise -6954 -▁somewhere -6955 -▁Professional -6956 -▁architecture -6957 -ingu -6958 -stra -6959 -entle -6960 -▁coat -6961 -▁smell -6962 -▁templ -6963 -ultural -6964 -▁sample -6965 -▁consequ -6966 -▁portion -6967 -▁estimated -6968 -Sc -6969 -idi -6970 -▁Pict -6971 -▁trib -6972 -remony -6973 -▁Labor -6974 -▁agric -6975 -▁trick -6976 -▁coordin -6977 -▁default -6978 -▁sending -6979 -▁upgrade -6980 -▁priority -6981 -▁interpret -6982 -▁surprising -6983 -▁volunteers -6984 -ults -6985 -cknow -6986 -▁batt -6987 -▁soil -6988 -▁mainly -6989 -▁manual -6990 -▁matches -6991 -▁gorgeous -6992 -▁shoulder -6993 -▁certified -6994 -▁apparently -6995 -▁continuing -6996 -▁situations -6997 -law -6998 -▁Es -6999 -▁exec -7000 -▁warn -7001 -arters -7002 -▁Stock -7003 -▁banks -7004 -▁bench -7005 -▁facil -7006 -▁lucky -7007 -ylvania -7008 -▁Golden -7009 -▁planet -7010 -▁posting -7011 -▁immediate -7012 -▁guidelines -7013 -bel -7014 -▁PH -7015 -star -7016 -▁Buy -7017 -▁Hou -7018 -words -7019 -▁Wilson -7020 -▁blocks -7021 -▁Financial -7022 -▁discussed -7023 -owa -7024 -ulf -7025 -ulpt -7026 -▁Mix -7027 -▁Mrs -7028 -▁USB -7029 -class -7030 -▁bear -7031 -▁hate -7032 -earing -7033 -▁firms -7034 -▁shops -7035 -▁Policy -7036 -▁Spirit -7037 -▁drinks -7038 -▁scheme -7039 -▁Customer -7040 -▁Medicine -7041 -▁Lar -7042 -anned -7043 -▁fasc -7044 -ealand -7045 -▁charm -7046 -ogether -7047 -respond -7048 -▁ending -7049 -▁terror -7050 -▁attacks -7051 -▁singles -7052 -▁workshop -7053 -▁Engineering -7054 -▁FA -7055 -iger -7056 -▁Ron -7057 -uster -7058 -▁Stay -7059 -▁magn -7060 -▁Sales -7061 -▁layer -7062 -▁prove -7063 -▁teasp -7064 -▁fairly -7065 -▁vulner -7066 -▁Ireland -7067 -▁external -7068 -nam -7069 -▁Yet -7070 -▁hat -7071 -▁vice -7072 -ingers -7073 -▁aspect -7074 -▁capable -7075 -▁Catholic -7076 -▁retirement -7077 -from -7078 -icit -7079 -unes -7080 -▁Cro -7081 -inder -7082 -▁scan -7083 -bridge -7084 -▁Motor -7085 -▁Order -7086 -▁Phone -7087 -▁stuck -7088 -eration -7089 -▁loving -7090 -▁Toronto -7091 -▁closely -7092 -▁injured -7093 -▁listing -7094 -▁Memorial -7095 -▁clicking -7096 -▁programming -7097 -aping -7098 -▁bare -7099 -▁Linux -7100 -▁climb -7101 -▁saved -7102 -▁orange -7103 -▁Zealand -7104 -▁proceed -7105 -▁believed -7106 -▁listening -7107 -▁industries -7108 -▁destination -7109 -▁Cy -7110 -▁EV -7111 -rich -7112 -▁Exp -7113 -▁wra -7114 -uting -7115 -▁Conf -7116 -▁Eric -7117 -▁juice -7118 -▁casino -7119 -▁breaking -7120 -▁memories -7121 -▁collected -7122 -▁landscape -7123 -SE -7124 -lo -7125 -▁Ca -7126 -▁FL -7127 -alle -7128 -aska -7129 -▁Ram -7130 -otted -7131 -▁Band -7132 -▁Tenn -7133 -▁terr -7134 -angers -7135 -▁reform -7136 -▁strike -7137 -▁Welcome -7138 -▁doctors -7139 -▁Material -7140 -▁enjoying -7141 -▁religious -7142 -▁spiritual -7143 -▁suggested -7144 -ati -7145 -▁MD -7146 -▁OK -7147 -Tube -7148 -aste -7149 -odge -7150 -▁hell -7151 -▁Roman -7152 -▁blend -7153 -▁forth -7154 -▁meets -7155 -▁assign -7156 -▁winners -7157 -▁machines -7158 -▁alongside -7159 -▁relatively -7160 -equ -7161 -ghan -7162 -▁Fox -7163 -▁Ide -7164 -oster -7165 -cludes -7166 -▁index -7167 -faction -7168 -▁riding -7169 -▁choosing -7170 -▁pleasure -7171 -▁strategic -7172 -▁anniversary -7173 -Ad -7174 -gypt -7175 -▁Dur -7176 -▁gym -7177 -child -7178 -imize -7179 -▁Line -7180 -▁yard -7181 -▁Smart -7182 -▁Think -7183 -▁aside -7184 -▁boxes -7185 -▁newly -7186 -▁prize -7187 -▁treatments -7188 -▁celebration -7189 -▁Subsc -7190 -▁bodies -7191 -▁writers -7192 -▁requests -7193 -▁designers -7194 -▁engagement -7195 -bro -7196 -inte -7197 -amber -7198 -▁Dave -7199 -▁east -7200 -▁Davis -7201 -▁Happy -7202 -▁bunch -7203 -▁pharm -7204 -▁belief -7205 -▁covering -7206 -▁extension -7207 -▁performances -7208 -▁WW -7209 -days -7210 -▁Sky -7211 -▁arg -7212 -▁Bang -7213 -▁elev -7214 -▁Camer -7215 -▁buyers -7216 -▁Meanwhile -7217 -▁brilliant -7218 -De -7219 -ls -7220 -agon -7221 -obby -7222 -▁Dar -7223 -▁NFL -7224 -▁Sep -7225 -ormal -7226 -▁enem -7227 -ensity -7228 -giving -7229 -▁birds -7230 -▁broke -7231 -▁giant -7232 -▁proof -7233 -▁franch -7234 -▁division -7235 -nic -7236 -inos -7237 -▁Pak -7238 -ashes -7239 -osophy -7240 -▁Asian -7241 -▁Kevin -7242 -lements -7243 -▁acknow -7244 -▁symbol -7245 -▁titles -7246 -sylvania -7247 -▁packaging -7248 -▁platforms -7249 -▁instrument -7250 -▁differences -7251 -oty -7252 -▁raw -7253 -▁unw -7254 -iders -7255 -ureau -7256 -▁Adam -7257 -▁iPad -7258 -esides -7259 -▁meals -7260 -▁river -7261 -▁compat -7262 -▁enables -7263 -▁drinking -7264 -▁volunteer -7265 -’. -7266 -▁PDF -7267 -inton -7268 -▁mile -7269 -▁slic -7270 -▁solo -7271 -▁superv -7272 -▁letters -7273 -▁authority -7274 -.’ -7275 -wan -7276 -▁PL -7277 -alse -7278 -rage -7279 -wart -7280 -▁pip -7281 -▁Bush -7282 -▁Iran -7283 -lisher -7284 -parent -7285 -▁Story -7286 -▁urban -7287 -ainless -7288 -▁consistent -7289 -pes -7290 -▁Uk -7291 -▁|| -7292 -bles -7293 -wich -7294 -▁kit -7295 -ronics -7296 -▁Chall -7297 -▁Model -7298 -▁centers -7299 -▁charity -7300 -▁typical -7301 -▁explains -7302 -▁replaced -7303 -▁newspaper -7304 -▁communications -7305 -GA -7306 -OVID -7307 -▁rug -7308 -▁acts -7309 -▁lapt -7310 -▁vacc -7311 -▁vast -7312 -ateful -7313 -jection -7314 -▁infect -7315 -▁YouTube -7316 -▁mortgage -7317 -▁CN -7318 -leep -7319 -oker -7320 -▁Jay -7321 -▁stim -7322 -▁tape -7323 -▁trim -7324 -▁tooth -7325 -▁dreams -7326 -▁falling -7327 -▁handling -7328 -▁holidays -7329 -▁swimming -7330 -cons -7331 -iley -7332 -page -7333 -▁stir -7334 -▁Return -7335 -▁decade -7336 -▁domain -7337 -▁singer -7338 -▁Perhaps -7339 -▁destroy -7340 -▁dynamic -7341 -▁lighting -7342 -▁proposal -7343 -▁categories -7344 -▁encouraged -7345 -▁membership -7346 -▁personally -7347 -Fi -7348 -acious -7349 -▁Jason -7350 -▁Jordan -7351 -▁Columbia -7352 -▁forecast -7353 -▁informed -7354 -▁wireless -7355 -▁classroom -7356 -▁accomplish -7357 -▁initiative -7358 -▁suggestions -7359 -▁Po -7360 -▁mut -7361 -erman -7362 -▁Bird -7363 -▁Mill -7364 -▁Swed -7365 -▁slee -7366 -▁susp -7367 -▁Egypt -7368 -▁Staff -7369 -▁Treat -7370 -▁recre -7371 -▁solve -7372 -▁agents -7373 -▁combine -7374 -▁founder -7375 -▁percentage -7376 -▁Advis -7377 -▁Cancer -7378 -▁arrive -7379 -▁headed -7380 -▁expansion -7381 -▁sensitive -7382 -▁manufacturers -7383 -TER -7384 -uis -7385 -athy -7386 -▁Bad -7387 -▁Ess -7388 -▁magic -7389 -▁penal -7390 -▁Agency -7391 -▁Miller -7392 -▁Gallery -7393 -ounce -7394 -▁bars -7395 -▁embr -7396 -▁tied -7397 -▁Being -7398 -▁crash -7399 -▁flash -7400 -▁filter -7401 -▁Classic -7402 -▁Houston -7403 -▁shouldn -7404 -▁Remember -7405 -▁Transport -7406 -▁participating -7407 -▁ast -7408 -▁Talk -7409 -▁dust -7410 -▁Annual -7411 -▁Recent -7412 -▁slowly -7413 -▁Airport -7414 -▁Kingdom -7415 -▁pricing -7416 -▁travell -7417 -▁Northern -7418 -▁enterprise -7419 -ko -7420 -▁Josh -7421 -▁evol -7422 -▁mood -7423 -▁unus -7424 -▁facts -7425 -▁phones -7426 -▁Consult -7427 -▁ancient -7428 -▁presents -7429 -▁printing -7430 -▁Secretary -7431 -▁permanent -7432 -wis -7433 -onna -7434 -level -7435 -▁hire -7436 -amsung -7437 -rovers -7438 -▁Brook -7439 -▁venue -7440 -▁Joseph -7441 -▁gender -7442 -▁extract -7443 -▁intense -7444 -ervations -7445 -▁Pennsylvania -7446 -▁DI -7447 -..... -7448 -abeth -7449 -▁Base -7450 -▁assum -7451 -▁dealing -7452 -▁gallery -7453 -▁genuine -7454 -▁portfolio -7455 -▁enforcement -7456 -FA -7457 -esy -7458 -site -7459 -▁suc -7460 -igate -7461 -uties -7462 -▁Film -7463 -▁gall -7464 -ership -7465 -▁Level -7466 -▁roles -7467 -ologist -7468 -▁Create -7469 -▁watched -7470 -▁producing -7471 -▁IC -7472 -lers -7473 -wear -7474 -▁Dam -7475 -asted -7476 -mates -7477 -▁fest -7478 -making -7479 -▁scenes -7480 -▁constit -7481 -▁carrying -7482 -▁suffered -7483 -▁traveling -7484 -▁attractive -7485 -OD -7486 -Tr -7487 -▁Own -7488 -▁Sea -7489 -iking -7490 -oices -7491 -▁Webs -7492 -▁vari -7493 -ardens -7494 -▁Grant -7495 -ulating -7496 -▁Silver -7497 -▁border -7498 -▁assault -7499 -▁Continue -7500 -▁generate -7501 -▁assistant -7502 -▁Collection -7503 -▁guaranteed -7504 -▁recommendations -7505 -Do -7506 -axy -7507 -bar -7508 -pir -7509 -Book -7510 -▁Sym -7511 -▁Stan -7512 -▁trig -7513 -▁wins -7514 -▁Books -7515 -▁absor -7516 -▁stake -7517 -▁Studio -7518 -▁Quality -7519 -▁chances -7520 -▁Personal -7521 -▁equipped -7522 -▁Ter -7523 -Press -7524 -books -7525 -active -7526 -▁grass -7527 -▁opens -7528 -▁solar -7529 -inating -7530 -▁compens -7531 -▁heading -7532 -▁Everyone -7533 -▁diseases -7534 -▁reducing -7535 -▁Hollywood -7536 -▁languages -7537 -▁professor -7538 -▁incredibly -7539 -boy -7540 -▁rh -7541 -aine -7542 -ilty -7543 -raid -7544 -burgh -7545 -▁Fred -7546 -▁actor -7547 -▁formed -7548 -▁Eastern -7549 -▁booking -7550 -▁podcast -7551 -▁speaker -7552 -▁Experience -7553 -▁interactive -7554 -SC -7555 -Te -7556 -rm -7557 -amel -7558 -▁hel -7559 -▁anyway -7560 -▁lawyer -7561 -▁neighb -7562 -▁cookies -7563 -▁Magazine -7564 -▁Therefore -7565 -acc -7566 -ila -7567 -▁CL -7568 -▁Deb -7569 -asant -7570 -ctive -7571 -▁Bern -7572 -▁lect -7573 -▁Force -7574 -▁Henry -7575 -▁Would -7576 -▁formal -7577 -▁string -7578 -▁filling -7579 -▁Products -7580 -▁purchasing -7581 -▁connections -7582 -alo -7583 -run -7584 -▁Gi -7585 -etch -7586 -game -7587 -phia -7588 -shire -7589 -▁narr -7590 -▁alive -7591 -▁pride -7592 -graduate -7593 -▁preferred -7594 -▁Hi -7595 -ials -7596 -▁Ath -7597 -▁Hun -7598 -▁Mov -7599 -stein -7600 -▁Clin -7601 -▁Emer -7602 -▁Guard -7603 -▁Major -7604 -▁phase -7605 -▁limits -7606 -▁marked -7607 -▁writes -7608 -▁defined -7609 -▁deposit -7610 -▁visible -7611 -▁suggests -7612 -oto -7613 -swe -7614 -roke -7615 -▁Tel -7616 -▁Kids -7617 -▁seats -7618 -▁shell -7619 -▁accused -7620 -▁aggress -7621 -▁expressed -7622 -▁basketball -7623 -Fr -7624 -▁EN -7625 -onic -7626 -allas -7627 -▁bact -7628 -lessly -7629 -▁empty -7630 -▁Estate -7631 -▁hotels -7632 -▁nights -7633 -▁racing -7634 -▁Comment -7635 -▁jewelry -7636 -▁substant -7637 -▁primarily -7638 -esh -7639 -imp -7640 -▁CP -7641 -bell -7642 -▁bid -7643 -▁gay -7644 -utter -7645 -▁Past -7646 -▁aims -7647 -▁lady -7648 -▁habit -7649 -▁Father -7650 -▁Histor -7651 -▁Mother -7652 -▁Things -7653 -▁rental -7654 -▁shapes -7655 -▁weapons -7656 -itionally -7657 -▁accuracy -7658 -▁resulting -7659 -▁creativity -7660 -▁specialist -7661 -▁vegetables -7662 -AV -7663 -▁oz -7664 -ogue -7665 -▁Has -7666 -▁lie -7667 -ifies -7668 -inity -7669 -▁cycl -7670 -intend -7671 -▁Based -7672 -▁bills -7673 -limited -7674 -▁remark -7675 -▁rising -7676 -▁engaged -7677 -▁instant -7678 -▁organis -7679 -▁politics -7680 -▁Published -7681 -▁recognition -7682 -ns -7683 -hour -7684 -▁Las -7685 -inois -7686 -uters -7687 -▁Give -7688 -▁Iowa -7689 -▁Marc -7690 -▁Tele -7691 -abetes -7692 -▁Vegas -7693 -▁criteria -7694 -▁suffering -7695 -▁compliance -7696 -essee -7697 -▁rice -7698 -▁marks -7699 -adelphia -7700 -▁Officer -7701 -▁compare -7702 -▁desired -7703 -▁component -7704 -▁highlights -7705 -▁TR -7706 -uana -7707 -▁tub -7708 -oween -7709 -▁dism -7710 -▁Prime -7711 -▁brush -7712 -▁Kansas -7713 -▁dollar -7714 -▁Britain -7715 -▁crucial -7716 -▁graphic -7717 -▁recover -7718 -▁achieved -7719 -▁literally -7720 -▁interviews -7721 -jo -7722 -igs -7723 -lee -7724 -▁Ap -7725 -greg -7726 -▁Map -7727 -▁tap -7728 -▁Fast -7729 -▁HERE -7730 -▁duty -7731 -makers -7732 -▁Among -7733 -▁Steel -7734 -▁knock -7735 -▁healing -7736 -▁illegal -7737 -▁admitted -7738 -▁describe -7739 -▁entering -7740 -▁releases -7741 -▁speakers -7742 -▁Solutions -7743 -▁functional -7744 -des -7745 -▁pra -7746 -▁Roll -7747 -▁Cover -7748 -▁Kelly -7749 -athered -7750 -▁intent -7751 -▁Edition -7752 -▁massage -7753 -▁packages -7754 -▁Following -7755 -▁attending -7756 -▁obviously -7757 -li -7758 -uan -7759 -▁EX -7760 -mers -7761 -▁Meth -7762 -▁keys -7763 -▁heads -7764 -holders -7765 -▁Change -7766 -▁Orange -7767 -▁matching -7768 -▁displayed -7769 -▁recognize -7770 -▁wondering -7771 -▁correspond -7772 -isa -7773 -▁CC -7774 -▁IM -7775 -Cont -7776 -orous -7777 -▁Diego -7778 -▁dough -7779 -▁trips -7780 -▁signal -7781 -▁developer -7782 -▁exceptional -7783 -▁increasingly -7784 -%. -7785 -ja -7786 -htt -7787 -▁Ros -7788 -athon -7789 -heast -7790 -▁Dead -7791 -▁puts -7792 -▁till -7793 -▁Nation -7794 -▁alumin -7795 -▁struck -7796 -novation -7797 -▁claimed -7798 -▁farmers -7799 -▁hitting -7800 -▁whenever -7801 -▁officially -7802 -▁introduction -7803 -pson -7804 -▁Isl -7805 -found -7806 -▁Auto -7807 -▁Body -7808 -▁king -7809 -▁mand -7810 -inding -7811 -▁Table -7812 -▁Forest -7813 -▁Valent -7814 -▁narrow -7815 -▁colours -7816 -▁Attorney -7817 -▁networking -7818 -▁necessarily -7819 -▁improvements -7820 -tail -7821 -▁bug -7822 -▁clar -7823 -▁Civil -7824 -utional -7825 -▁hidden -7826 -▁Theatre -7827 -▁texture -7828 -▁checking -7829 -▁constant -7830 -▁licensed -7831 -▁Cry -7832 -▁cust -7833 -▁root -7834 -ickets -7835 -terior -7836 -▁Youth -7837 -▁loose -7838 -▁setup -7839 -▁acting -7840 -▁Chapter -7841 -▁Reading -7842 -▁occurred -7843 -▁struggling -7844 -TP -7845 -tw -7846 -AND -7847 -▁ -7848 -e -7849 -t -7850 -a -7851 -o -7852 -i -7853 -n -7854 -s -7855 -r -7856 -h -7857 -l -7858 -d -7859 -c -7860 -u -7861 -m -7862 -p -7863 -g -7864 -f -7865 -y -7866 -w -7867 -b -7868 -. -7869 -v -7870 -, -7871 -k -7872 -T -7873 -I -7874 -S -7875 -A -7876 -- -7877 -C -7878 -0 -7879 -1 -7880 -M -7881 -P -7882 -B -7883 -x -7884 -2 -7885 -W -7886 -D -7887 -R -7888 -E -7889 -H -7890 -F -7891 -L -7892 -O -7893 -N -7894 -’ -7895 -' -7896 -: -7897 -G -7898 -j -7899 -) -7900 -3 -7901 -( -7902 -z -7903 -5 -7904 -q -7905 -" -7906 -U -7907 -4 -7908 -J -7909 -9 -7910 -6 -7911 -8 -7912 -V -7913 -Y -7914 -K -7915 -7 -7916 -! -7917 -| -7918 -/ -7919 -? -7920 -“ -7921 -” -7922 -; -7923 -– -7924 -& -7925 -$ -7926 -— -7927 -Q -7928 -X -7929 -% -7930 -Z -7931 diff --git a/records/track_10min_16mb/2026-03-24_74M_Ternary_UNet_FP8_10L_8192BPE_YaRN_NeoMuon/requirements.txt b/records/track_10min_16mb/2026-03-24_74M_Ternary_UNet_FP8_10L_8192BPE_YaRN_NeoMuon/requirements.txt deleted file mode 100644 index 0c5eedce7b..0000000000 --- a/records/track_10min_16mb/2026-03-24_74M_Ternary_UNet_FP8_10L_8192BPE_YaRN_NeoMuon/requirements.txt +++ /dev/null @@ -1,10 +0,0 @@ -numpy -tqdm -torch==2.10 -huggingface-hub -kernels -setuptools -typing-extensions==4.15.0 -datasets -tiktoken -sentencepiece \ No newline at end of file diff --git a/records/track_10min_16mb/2026-03-24_74M_Ternary_UNet_FP8_10L_8192BPE_YaRN_NeoMuon/run_cuda_ternary.sh b/records/track_10min_16mb/2026-03-24_74M_Ternary_UNet_FP8_10L_8192BPE_YaRN_NeoMuon/run_cuda_ternary.sh deleted file mode 100644 index a79ce9476e..0000000000 --- a/records/track_10min_16mb/2026-03-24_74M_Ternary_UNet_FP8_10L_8192BPE_YaRN_NeoMuon/run_cuda_ternary.sh +++ /dev/null @@ -1,69 +0,0 @@ -RUN_ID=pushing_run_ternary_4 \ -DATA_PATH=./data/datasets/fineweb10B_sp8192 \ -TOKENIZER_PATH=./data/tokenizers/fineweb_8192_bpe.model \ -ATTN_PROJ_TYPE=standard \ -LOGIT_HEAD_TYPE=standard \ -TVERSKY_MEMBERSHIP=sigmoid \ -TVERSKY_NUM_FEATURES=0 \ -TVERSKY_FEATURE_POOLS=0 \ -VOCAB_SIZE=8192 \ -BITNET_GROUP_SIZE=128 \ -BIGRAM_HASH=0 \ -EMBED_DIM=254 \ -TRAINING_DEPTH_RECURRENCE=0 \ -EVAL_DEPTH_RECURRENCE=0 \ -NUM_LAYERS=10 \ -MODEL_DIM=768 \ -NUM_KV_HEADS=4 \ -NUM_HEADS=8 \ -DIFF_ATTN=0 \ -MLP_MULT=4 \ -MLP_GROUPS=0 \ -MATRIX_OPTIMIZER=muon \ -ADAM_LR=0.05 \ -ADAM_WD=0.05 \ -MUON_BACKEND_STEPS=3 \ -MUON_MOMENTUM=0.95 \ -MUON_MOMENTUM_WARMUP_START=0.85 \ -MUON_MOMENTUM_WARMUP_STEPS=500 \ -MUON_WD=0.0 \ -MATRIX_LR=0.04 \ -SCALAR_LR=0.02 \ -TIED_EMBED_LR=0.02 \ -WARMDOWN_FRACTION=0.2 \ -LOGIT_SOFTCAP=10 \ -QK_GAIN_INIT=2.25 \ -ROPE_TYPE=yarn \ -YARN_MAX_LEN=2048 \ -ROPE_BASE=5000 \ -BATCH_TOKENS_START=0 \ -BATCH_SCHEDULE_FRACTION=0.33 \ -TRAIN_BATCH_TOKENS=524288 \ -SEQ_LEN_START=0 \ -SEQ_SCHEDULE_FRACTION=0.0 \ -TRAIN_SEQ_LEN=1024 \ -SMEAR=0 \ -ITERATIONS=10000 \ -WARMUP_STEPS=5 \ -MAX_WALLCLOCK_SECONDS=599 \ -VAL_LOSS_EVERY=0 \ -TRAIN_LOG_EVERY=1000 \ -CHURN_LOG_EVERY=0 \ -VAL_MAX_TOKENS=0 \ -TIE_EMBEDDINGS=1 \ -UNTIE_AT_FRACTION=0.00 \ -HEAD_LR=0.02 \ -CORR_WEIGHT_LR=0.02 \ -ACTIVATION=relu2 \ -SOFTCAP_TYPE=poly \ -MTP_HEADS=0 \ -REFINER=0 \ -REFINER_KERNEL=3 \ -SLIDING_EVAL=1 \ -SLIDING_EVAL_STRIDE=16 \ -SLIDING_BATCH_SIZE=256 \ -TEMP_SCALING=1 \ -FP_STORAGE=FP8 \ -SEED=42 \ -COMPILE_MODE=default \ -OMP_NUM_THREADS=1 torchrun --standalone --nproc_per_node=8 train_gpt_cuda_ternary.py \ No newline at end of file diff --git a/records/track_10min_16mb/2026-03-24_74M_Ternary_UNet_FP8_10L_8192BPE_YaRN_NeoMuon/setup.sh b/records/track_10min_16mb/2026-03-24_74M_Ternary_UNet_FP8_10L_8192BPE_YaRN_NeoMuon/setup.sh deleted file mode 100644 index 93f1c41fea..0000000000 --- a/records/track_10min_16mb/2026-03-24_74M_Ternary_UNet_FP8_10L_8192BPE_YaRN_NeoMuon/setup.sh +++ /dev/null @@ -1,143 +0,0 @@ -#!/bin/bash -# ------------------------------------------------------------------------------- -# Parameter Golf -- Complete Environment Setup Script -# Drop this into the project root and run: bash setup.sh -# ------------------------------------------------------------------------------- - -set -e - -echo "----------------------------------------------" -echo " Parameter Golf -- Environment Setup" -echo "----------------------------------------------" - -# ------------------------------------------------------------------------------- -# 1. Miniconda -# ------------------------------------------------------------------------------- -echo "" -echo "[1/5] Miniconda..." - -if [ -d "$HOME/miniconda3" ]; then - echo " Already installed -- skipping." -else - wget -q https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O /tmp/miniconda.sh - bash /tmp/miniconda.sh -b - rm /tmp/miniconda.sh - ~/miniconda3/bin/conda init bash - echo " Installed." -fi - -export PATH="$HOME/miniconda3/bin:$PATH" -source ~/miniconda3/etc/profile.d/conda.sh - -echo " Accepting conda TOS..." -~/miniconda3/bin/conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/main -~/miniconda3/bin/conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/r -echo " TOS accepted." - -# ------------------------------------------------------------------------------- -# 2. Python Environment -# ------------------------------------------------------------------------------- -echo "" -echo "[2/5] Python 3.13 environment..." - -if conda env list | grep -q "^golf "; then - echo " Environment 'golf' already exists -- skipping." -else - conda create -n golf python=3.13 -y - echo " Created." -fi - -conda activate golf -echo " Activated." - -# ------------------------------------------------------------------------------- -# 3. Requirements -# ------------------------------------------------------------------------------- -echo "" -echo "[3/5] Requirements..." - -if python3 -c "import torch, sentencepiece, numpy" 2>/dev/null; then - echo " Core packages already installed -- skipping." -else - pip install --upgrade pip -q - pip install -r requirements.txt -q - echo " Installed." -fi - -# ------------------------------------------------------------------------------- -# 4. FlashAttention-3 -# ------------------------------------------------------------------------------- -echo "" -echo "[4/5] FlashAttention-3..." - -if python3 -c "import flash_attn" 2>/dev/null || python3 -c "import flash_attn_interface" 2>/dev/null; then - echo " Already installed -- skipping." -else - # abi3 wheel -- Python 3.9+ compatible, installs in seconds, no compilation - pip install --no-cache-dir "https://download.pytorch.org/whl/cu128/flash_attn_3-3.0.0-cp39-abi3-manylinux_2_28_x86_64.whl" - echo " Installed." -fi - -# ------------------------------------------------------------------------------- -# 5. Dataset -# ------------------------------------------------------------------------------- -echo "" -echo "[5/5] FineWeb dataset (sp8192, 10 shards)..." - -echo " Downloading... ($TRAIN_COUNT/10 train shards found)" -hf download sproos/parameter-golf-tokenizers --include "datasets/fineweb10B_sp8192/*" --local-dir ./data -echo " Downloaded." - -# ------------------------------------------------------------------------------- -# Verification -# ------------------------------------------------------------------------------- -echo "" -echo "----------------------------------------------" -echo " Verification" -echo "----------------------------------------------" - -python3 - << 'EOF' -import sys -import torch -import numpy as np -import glob - -print(f"Python : {sys.version.split()[0]}") -print(f"PyTorch : {torch.__version__}") -print(f"CUDA : {torch.cuda.is_available()}") -print(f"GPUs : {torch.cuda.device_count()}") - -if torch.cuda.is_available(): - for i in range(torch.cuda.device_count()): - props = torch.cuda.get_device_properties(i) - print(f" GPU {i} : {props.name} ({props.total_memory // 1024**3}GB)") - -try: - import flash_attn - print(f"FlashAttn : {flash_attn.__version__}") -except ImportError: - try: - import flash_attn_interface - print(f"FlashAttn3 : available") - except ImportError: - print(f"FlashAttn : NOT found") - -train_files = sorted(glob.glob("./data/datasets/fineweb10B_sp8192/fineweb_train_*.bin")) -val_files = sorted(glob.glob("./data/datasets/fineweb10B_sp8192/fineweb_val_*.bin")) -print(f"Train shards : {len(train_files)}") -print(f"Val shards : {len(val_files)}") - -if val_files: - total = sum( - int(np.fromfile(f, dtype=' Tensor: - v = np.frombuffer(data, dtype=np.uint8).astype(np.int16) - t = np.zeros((len(v), 5), dtype=np.int8) - for i in range(5): t[:,i] = v % 3; v //= 3 - return torch.from_numpy(t.reshape(-1)[:n].astype(np.int8) - 1) - -def pack_ternary_bitmask(q: Tensor): - f = q.reshape(-1).to(torch.int8).numpy(); n = len(f) - nz = (f != 0) - return np.packbits(nz).tobytes() + np.packbits(f[nz] > 0).tobytes(), n - -def unpack_ternary_bitmask(data: bytes, n: int) -> Tensor: - ms = (n + 7) // 8 - nz = np.unpackbits(np.frombuffer(data[:ms], dtype=np.uint8))[:n].astype(bool) - s = np.unpackbits(np.frombuffer(data[ms:], dtype=np.uint8))[:int(nz.sum())].astype(bool) - w = np.zeros(n, dtype=np.int8); w[nz] = np.where(s, 1, -1) - return torch.from_numpy(w) - -# --------------------------------------------------------------------------- -# FP4 quantization (per-row absmax, 2 values packed per byte) -# --------------------------------------------------------------------------- -def quantize_to_int4(t: Tensor) -> tuple[Tensor, Tensor, list]: - t32 = t.float() - orig_shape = t32.shape - if t32.ndim < 2: - t32 = t32.unsqueeze(0) - absmax = t32.abs().amax(dim=-1, keepdim=True).clamp(min=1e-8) - scale = absmax / 7.0 - q = torch.clamp(torch.round(t32 / scale), -7, 7).to(torch.int8) - flat = q.reshape(-1) - if flat.numel() % 2 != 0: - flat = F.pad(flat, (0, 1)) - low = (flat[0::2] + 8).to(torch.uint8) - high = (flat[1::2] + 8).to(torch.uint8) - return low | (high << 4), scale.half().squeeze(-1), list(orig_shape) - -def dequantize_from_int4(packed: Tensor, scale: Tensor, shape: list) -> Tensor: - low = (packed & 0x0F).to(torch.int8) - 8 - high = ((packed >> 4) & 0x0F).to(torch.int8) - 8 - flat = torch.zeros(packed.numel() * 2, dtype=torch.int8) - flat[0::2] = low - flat[1::2] = high - numel = 1 - for s in shape: - numel *= s - flat = flat[:numel].float() - if len(shape) <= 1: - return (flat * scale.float().squeeze()).reshape(shape) - return (flat.reshape(-1, shape[-1]) * scale.float().unsqueeze(-1)).reshape(shape) - -# --------------------------------------------------------------------------- -# State dict serialization (ternary + fp16/fp8/fp4) -# --------------------------------------------------------------------------- -def q_sd(state_dict: dict, group_size: int = 64, fp_storage=False, ternary_method="standard", ternary_override_names: set | None = None) -> tuple[dict, dict]: - "Ternary for large 2D weight matrices, fp16/fp8/fp4 for everything else." - quantized = {} - stats = {"ternary_params": 0, "ternary_bytes": 0, "fp_params": 0, "fp_bytes": 0} - for name, tensor in state_dict.items(): - if "mtp_heads" in name: - continue - t = tensor.detach().cpu().float().contiguous() - t_orig_shape = list(t.shape) - if t.ndim == 3: - t = t.reshape(t.shape[0], -1) - is_ternary_candidate = ( - t.ndim == 2 and t.numel() > 65_536 - and "tok_emb" not in name and "lm_head" not in name and "embed_proj" not in name and "bigram_emb" not in name and "lm_head_correction" not in name and "lm_head_U" not in name and "lm_head_V" not in name - and "prototypes" not in name and "tversky" not in name - ) or (ternary_override_names is not None and name in ternary_override_names) - if is_ternary_candidate: - pad = (group_size - t.shape[1] % group_size) % group_size - t_padded = F.pad(t, (0, pad)) if pad > 0 else t - t_grouped = t_padded.reshape(-1, group_size) - scale = t_grouped.abs().mean(-1, keepdim=True).clamp(min=1e-8).half().float() - q = (t_grouped / scale).round().clamp(-1, 1).to(torch.int8) - - if ternary_method == "standard": - packed_bytes, n_trits = pack_ternary(q) - entry_type = "ternary" - else: - packed_bytes, n_trits = pack_ternary_bitmask(q) - entry_type = "ternary_bitmask" - - quantized[name] = { - "type": entry_type, "packed": packed_bytes, - "scale": scale.half().squeeze(-1), - "shape": list(t.shape), "padded_cols": t_padded.shape[1], - "group_size": group_size, "n_trits": n_trits, - "orig_shape": t_orig_shape, - } - stats["ternary_params"] += t.numel() - stats["ternary_bytes"] += len(packed_bytes) + scale.numel() * 2 - elif fp_storage == "fp4" and t.ndim == 2: - packed, scale, orig_shape = quantize_to_int4(t) - quantized[name] = {"type": "fp4", "packed": packed, "scale": scale, "shape": orig_shape} - stats["fp_params"] += t.numel() - stats["fp_bytes"] += packed.numel() + scale.numel() * 2 - elif fp_storage and t.ndim == 2: - quantized[name] = {"type": "fp8", "data": t.to(torch.float8_e4m3fn)} - stats["fp_params"] += t.numel() - stats["fp_bytes"] += t.numel() - else: - quantized[name] = {"type": "fp16", "data": t.half()} - stats["fp_params"] += t.numel() - stats["fp_bytes"] += t.numel() * 2 - return quantized, stats - -def deq_sd(quantized: dict, target_dtype=torch.bfloat16): - "Reconstruct full-precision state dict from quantized representation." - out = {} - for name, entry in quantized.items(): - if entry["type"] in ("ternary", "ternary_bitmask"): - if entry["type"] == "ternary": - q = unpack_ternary(entry["packed"], entry["n_trits"]) - else: - q = unpack_ternary_bitmask(entry["packed"], entry["n_trits"]) - - q = q.float().reshape(-1, entry["group_size"]) - scale = entry["scale"].float().unsqueeze(-1) - q_absmean = q.abs().mean(-1, keepdim=True).clamp(min=1e-8) - t = (q * (scale / q_absmean)).reshape(-1, entry["padded_cols"]) - shape = entry["shape"] - result = t[:shape[0], :shape[1]].to(target_dtype) - orig = entry.get("orig_shape") - out[name] = result.reshape(orig).contiguous() if orig and orig != shape else result.contiguous() - elif entry["type"] == "fp8": - out[name] = entry["data"].to(torch.float32).to(target_dtype).contiguous() - elif entry["type"] == "fp4": - out[name] = dequantize_from_int4(entry["packed"], entry["scale"], entry["shape"]).to(target_dtype).contiguous() - else: - out[name] = entry["data"].to(target_dtype).contiguous() - return out - -# --------------------------------------------------------------------------- -# Ternary diagnostics (logged during training) -# --------------------------------------------------------------------------- -def tern_stats(model: nn.Module, group_size: int = 64): - total = zeros = 0 - with torch.no_grad(): - for name, p in model.named_parameters(): - if p.ndim == 2 and ("weight" in name or "prototypes" in name) and p.shape[0] > 1: - w = p.detach().float().reshape(-1, group_size) - scale = w.abs().mean(-1, keepdim=True).clamp(min=1e-8).half().float() - q = (w / scale).round().clamp(-1, 1) - zeros += int((q == 0).sum().item()) - total += int(q.numel()) - return {"zero_frac": zeros / max(total, 1), "total_weights": total} - -_prev_committed: dict = {} - -def churn_fn(model: nn.Module, group_size: int = 64): - global _prev_committed - total = flipped = 0 - with torch.no_grad(): - for name, p in model.named_parameters(): - if p.ndim == 2 and ("weight" in name or "prototypes" in name) and p.shape[0] > 1: - w = p.detach().float().reshape(-1, group_size) - scale = w.abs().mean(-1, keepdim=True).clamp(min=1e-8).half().float() - q = (w / scale).round().clamp(-1, 1).cpu().numpy() - if name in _prev_committed: - flipped += int(np.sum(q != _prev_committed[name])) - total += q.size - _prev_committed[name] = q - return flipped / max(total, 1) - -# --------------------------------------------------------------------------- -# Muon optimizer (Newton-Schulz orthogonalized momentum) -# --------------------------------------------------------------------------- -def ns_orth(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: - a, b, c = (3.4445, -4.7750, 2.0315) - X = G.bfloat16() - X /= X.norm() + eps - transposed = G.size(0) > G.size(1) - if transposed: - X = X.T - for _ in range(steps): - A = X @ X.T - B = b * A + c * A @ A - X = a * X + B @ X - return X.T if transposed else X - -class Muon(torch.optim.Optimizer): - def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True, wd: float = 0.0): - super().__init__(params, dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov, wd=wd)) - - @torch.no_grad() - def step(self, closure=None): - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - distributed = dist.is_available() and dist.is_initialized() - world_size = dist.get_world_size() if distributed else 1 - rank = dist.get_rank() if distributed else 0 - for group in self.param_groups: - params = group["params"] - if not params: - continue - lr, momentum = group["lr"], group["momentum"] - backend_steps, nesterov = group["backend_steps"], group["nesterov"] - total_params = sum(int(p.numel()) for p in params) - updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) - curr = 0 - for i, p in enumerate(params): - if i % world_size == rank and p.grad is not None: - g = p.grad - state = self.state[p] - if "momentum_buffer" not in state: - state["momentum_buffer"] = torch.zeros_like(g) - buf = state["momentum_buffer"] - buf.mul_(momentum).add_(g) - if nesterov: - g = g.add(buf, alpha=momentum) - g = F.rms_norm(g.float(), (g.size(-1),)).bfloat16() - g = ns_orth(g, steps=backend_steps) - g *= max(1, g.size(0) / g.size(1)) ** 0.5 - updates_flat[curr:curr + p.numel()] = g.reshape(-1) - curr += p.numel() - if distributed: - dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) - wd = group.get("wd", 0.0) - curr = 0 - for p in params: - g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) - if wd > 0: - p.mul_(1 - lr * wd) - p.add_(g, alpha=-lr) - curr += p.numel() - return loss - -# --------------------------------------------------------------------------- -# Data loading -# --------------------------------------------------------------------------- -def ld_shard(file: Path) -> Tensor: - header_bytes = 256 * np.dtype(" Tensor: - chunks = [] - remaining = n - while remaining > 0: - avail = self.tokens.numel() - self.pos - if avail <= 0: - self._advance_file() - continue - k = min(remaining, avail) - chunks.append(self.tokens[self.pos:self.pos + k]) - self.pos += k - remaining -= k - return chunks[0] if len(chunks) == 1 else torch.cat(chunks) - -class DistributedTokenLoader: - def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): - self.rank, self.world_size, self.device = rank, world_size, device - self.stream = TokenStream(pattern) - - def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: - local_tokens = global_tokens // (self.world_size * grad_accum_steps) - per_rank_span = local_tokens + 1 - chunk = self.stream.take(per_rank_span * self.world_size) - start = self.rank * per_rank_span - local = chunk[start:start + per_rank_span].pin_memory().to(self.device, non_blocking=True).to(torch.int64) - x = local[:-1].reshape(-1, seq_len) - y = local[1:].reshape(-1, seq_len) - return x, y - -# --------------------------------------------------------------------------- -# Model -# --------------------------------------------------------------------------- -class RMSNorm(nn.Module): - def __init__(self, eps: float | None = None): - super().__init__() - self.eps = eps - - def forward(self, x: Tensor) -> Tensor: - return F.rms_norm(x, (x.size(-1),), eps=self.eps) - -def apply_qat_ste(w: Tensor, fp_storage: str | bool) -> Tensor: - """Applies Straight-Through Estimator (STE) for FP4 or FP8 simulated quantization.""" - if not fp_storage: - return w - if fp_storage == "fp4": - absmax = w.abs().amax(dim=-1, keepdim=True).clamp(min=1e-8) - scale = absmax / 7.0 - q = torch.clamp(torch.round(w / scale), -7.0, 7.0) - w_sim = q * scale - return (w_sim - w).detach() + w - elif fp_storage is True or fp_storage == "fp8": - w_sim = w.to(torch.float8_e4m3fn).to(w.dtype) - return (w_sim - w).detach() + w - return w - -class QATLinear(nn.Linear): - def __init__(self, in_features: int, out_features: int, bias: bool = False, fp_storage: str | bool = False): - super().__init__(in_features, out_features, bias=bias) - self.fp_storage = fp_storage - - def forward(self, x: Tensor) -> Tensor: - w_qat = apply_qat_ste(self.weight, self.fp_storage) - return F.linear(x, w_qat.to(x.dtype), self.bias.to(x.dtype) if self.bias is not None else None) - -class QATEmbedding(nn.Embedding): - def __init__(self, num_embeddings: int, embedding_dim: int, fp_storage: str | bool = False): - super().__init__(num_embeddings, embedding_dim) - self.fp_storage = fp_storage - - def forward(self, input: Tensor) -> Tensor: - w_qat = apply_qat_ste(self.weight, self.fp_storage) - return F.embedding(input, w_qat, self.padding_idx, self.max_norm, - self.norm_type, self.scale_grad_by_freq, self.sparse) - -class TernaryLinear(nn.Linear): - def __init__(self, in_features, out_features, bias=False, group_size=64): - super().__init__(in_features, out_features, bias=bias) - self.group_size = group_size - - def forward(self, x: Tensor) -> Tensor: - w = self.weight.bfloat16() - g = self.group_size - w_g = w.reshape(-1, g) - scale = w_g.abs().mean(-1, keepdim=True).clamp(min=1e-8) - q = (w_g / scale).round().clamp(-1, 1) - w_ternary = w + ((q * scale).reshape(w.shape) - w).detach() - return F.linear(x, w_ternary, - self.bias.to(x.dtype) if self.bias is not None else None) - - -class NormedTernaryLinear(TernaryLinear): - "Ternary linear with RMSNorm on input — for output projections receiving un-normalized activations." - def forward(self, x: Tensor) -> Tensor: - return super().forward(F.rms_norm(x, (x.size(-1),))) - -class GroupedTernaryLinear(nn.Module): - "Grouped linear with ternary STE. Weight stored as 2D [groups*group_out, group_in] for ternary quantization compatibility." - def __init__(self, in_features, out_features, groups=4, group_size=64, normed=False): - super().__init__() - assert in_features % groups == 0 and out_features % groups == 0 - self.groups = groups - self.group_in = in_features // groups - self.group_out = out_features // groups - self.group_size = group_size - self.normed = normed - self.weight = nn.Parameter(torch.randn(groups * self.group_out, self.group_in) * 0.02) - - def forward(self, x: Tensor) -> Tensor: - if self.normed: - x = F.rms_norm(x, (x.size(-1),)) - w = self.weight.bfloat16() - g = self.group_size - w_g = w.reshape(-1, g) - scale = w_g.abs().mean(-1, keepdim=True).clamp(min=1e-8) - q = (w_g / scale).round().clamp(-1, 1) - w_ternary = w + ((q * scale).reshape(w.shape) - w).detach() - w_grouped = w_ternary.reshape(self.groups, self.group_out, self.group_in) - bsz = x.shape[:-1] - x_g = x.reshape(*bsz, self.groups, self.group_in) - out = torch.einsum('...gi,goi->...go', x_g, w_grouped) - return out.reshape(*bsz, self.groups * self.group_out) - -class TverskyProjection(nn.Module): - "Tversky similarity: S = θ·f(A∩B) - α·f(A\\B) - β·f(B\\A). Three modes." - def __init__(self, in_features: int, out_features: int, num_features: int = 16, - group_size: int = 64, use_shared_features: bool = False, - membership: str = "sigmoid"): - super().__init__() - self.group_size = group_size - self.num_features = num_features - self.membership_type = membership - self.no_features_mode = (num_features == 0) - - if not self.no_features_mode and not use_shared_features: - self.features = nn.Parameter(torch.empty(num_features, in_features).uniform_(-0.02, 0.02)) - else: - self.register_parameter('features', None) - - self.prototypes = nn.Parameter(torch.empty(out_features, in_features).uniform_(-0.02, 0.02)) - self.theta = nn.Parameter(torch.tensor(1.0)) - self.alpha = nn.Parameter(torch.tensor(0.5)) - self.beta = nn.Parameter(torch.tensor(0.5)) - - def _ternary_ste(self, w: Tensor) -> Tensor: - w_bf16 = w.bfloat16() - g = self.group_size - w_grouped = w_bf16.reshape(-1, g) - scale = w_grouped.abs().mean(-1, keepdim=True).clamp(min=1e-8) - q = (w_grouped / scale).round().clamp(-1, 1) - w_ternary = w_bf16 + ((q * scale).reshape(w_bf16.shape) - w_bf16).detach() - return w_ternary.reshape(w.shape) - - def _membership(self, t: Tensor) -> Tensor: - if self.membership_type == "poly": - return torch.clamp(t * 5.0 / 4.0 + 0.5, 0.0, 1.0) - elif self.membership_type == "tanh": - return (torch.tanh(t * 5.0) + 1.0) * 0.5 - else: - return torch.sigmoid(t * 5.0) - - def forward(self, x: Tensor, shared_features: Tensor | None = None) -> Tensor: - proto = self._ternary_ste(self.prototypes) - - if self.no_features_mode: - # NoFeatures: prototypes are their own feature universe - x_f = x @ proto.t() # [B, S, out] - p_norm = F.normalize(proto, dim=-1) - p_f = p_norm @ p_norm.t() # [out, out] - else: - feat = (shared_features if shared_features is not None else self.features).float() - x_f = x @ feat.t() # [B, S, nf] - p_f = proto @ feat.t() # [out, nf] - - x_s = self._membership(x_f) - p_s = self._membership(p_f) - x_a = x_f * x_s - p_a = p_f * p_s - - t, a, b = self.theta.abs(), self.alpha.abs(), self.beta.abs() - return t * (x_a @ p_a.t()) - a * (x_a @ (1 - p_s).t()) - b * ((1 - x_s) @ p_a.t()) - -def restore_low_dim_params_to_fp32(module: nn.Module) -> None: - with torch.no_grad(): - for name, param in module.named_parameters(): - if (param.ndim < 2 or any(p in name for p in CTP)) and param.dtype != torch.float32: - param.data = param.data.float() - -class Rotary(nn.Module): - def __init__(self, dim: int, base: float = 10000.0, no_cache: bool = False, - rope_type: str = "rope", yarn_max_len: int = 4096, train_seq_len: int = 1024): - super().__init__() - self.no_cache = no_cache - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) - if rope_type == "yarn": - scale = train_seq_len / yarn_max_len - freq_idx = torch.arange(0, dim, 2, dtype=torch.float32) - ramp = torch.clamp((freq_idx / dim - 0.25) / 0.75, 0.0, 1.0) - inv_freq = inv_freq / (ramp * (1.0 / scale - 1.0) + 1.0) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self._seq_len_cached = 0 - self._cos_cached: Tensor | None = None - self._sin_cached: Tensor | None = None - - def forward(self, seq_len, device, dtype): - if self.no_cache: - t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) - freqs = torch.outer(t, self.inv_freq.to(device)) - return freqs.cos()[None, :, None, :].to(dtype=dtype), freqs.sin()[None, :, None, :].to(dtype=dtype) - if ( - self._cos_cached is None - or self._sin_cached is None - or self._seq_len_cached != seq_len - or self._cos_cached.device != device - ): - t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) - freqs = torch.outer(t, self.inv_freq.to(device)) - self._cos_cached = freqs.cos()[None, :, None, :] - self._sin_cached = freqs.sin()[None, :, None, :] - self._seq_len_cached = seq_len - return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) - -def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: - half = x.size(-1) // 2 - x1, x2 = x[..., :half], x[..., half:] - return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) - -class CausalSelfAttention(nn.Module): - def __init__(self, dim, num_heads, num_kv_heads, rope_base, qk_gain_init, - group_size=64, attn_proj_type="standard", tversky_num_features=16, - tversky_feature_pools=0, no_cache=False, rope_type="rope", - yarn_max_len=4096, train_seq_len=1024, tversky_membership="sigmoid", - diff_attn=False): - super().__init__() - self.num_heads, self.num_kv_heads = num_heads, num_kv_heads - self.head_dim = dim // num_heads - self.diff_attn = diff_attn - self.q_size = self.num_heads * self.head_dim - self.kv_size = self.num_kv_heads * self.head_dim - - self.c_qkv = TernaryLinear(dim, self.q_size + 2 * self.kv_size, bias=False, group_size=group_size) - self.proj = NormedTernaryLinear(dim, dim, bias=False, group_size=group_size) if attn_proj_type != "tversky" else None - if self.proj is not None: - self.proj._zero_init = True - self.tversky_proj = TverskyProjection( - dim, dim, num_features=tversky_num_features, group_size=group_size, - use_shared_features=(tversky_feature_pools > 0), - membership=tversky_membership, - ) if attn_proj_type == "tversky" else None - self.shared_features = None - self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) - if diff_attn: - self.diff_lambda = nn.Parameter(torch.full((num_heads,), 0.5, dtype=torch.float32)) - self.rotary = Rotary(self.head_dim, base=rope_base, no_cache=no_cache, - rope_type=rope_type, yarn_max_len=yarn_max_len, - train_seq_len=train_seq_len) - - def forward(self, x: Tensor) -> Tensor: - bsz, seqlen, dim = x.shape - qkv_out = self.c_qkv(x) - q_out, k_out, v_out = qkv_out.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - q = q_out.reshape(bsz, seqlen, self.num_heads, self.head_dim) - k = k_out.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) - v = v_out.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) - q, k = F.rms_norm(q, (q.size(-1),)), F.rms_norm(k, (k.size(-1),)) - cos, sin = self.rotary(seqlen, x.device, q.dtype) - q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin) - q = q * self.q_gain.to(dtype=q.dtype)[None, None, :, None] - if self.diff_attn: - half = self.head_dim // 2 - q1, q2 = q[..., :half], q[..., half:] - k1, k2 = k[..., :half], k[..., half:] - v1, v2 = v[..., :half], v[..., half:] - y1 = flash_attn_func(q1.contiguous(), k1.contiguous(), v1.contiguous(), causal=True) - y2 = flash_attn_func(q2.contiguous(), k2.contiguous(), v2.contiguous(), causal=True) - lam = self.diff_lambda.to(dtype=y1.dtype)[None, None, :, None] - y = torch.cat([y1 - lam * y2, y1 + lam * y2], dim=-1) - else: - y = flash_attn_func( - q.contiguous(), - k.contiguous(), - v.contiguous(), - causal=True - ) - y = y.reshape(bsz, seqlen, dim) - return self.tversky_proj(y, self.shared_features) if self.tversky_proj is not None else self.proj(y) - -class MLP(nn.Module): - def __init__(self, dim, mlp_mult, group_size=64, activation="swiglu", mlp_groups=0): - super().__init__() - hidden = mlp_mult * dim - self.activation = activation - if mlp_groups > 0: - if activation == "swiglu": - self.gate_up = GroupedTernaryLinear(dim, hidden * 2, groups=mlp_groups, group_size=group_size) - else: - self.fc = GroupedTernaryLinear(dim, hidden, groups=mlp_groups, group_size=group_size) - self.proj = GroupedTernaryLinear(hidden, dim, groups=mlp_groups, group_size=group_size, normed=True) - else: - if activation == "swiglu": - self.gate_up = TernaryLinear(dim, hidden * 2, bias=False, group_size=group_size) - else: - self.fc = TernaryLinear(dim, hidden, bias=False, group_size=group_size) - self.proj = NormedTernaryLinear(hidden, dim, bias=False, group_size=group_size) - self.proj._zero_init = True - - def forward(self, x: Tensor) -> Tensor: - if self.activation == "swiglu": - gu = self.gate_up(x) - gate, up = gu.chunk(2, dim=-1) - return self.proj(F.silu(gate) * up) - elif self.activation == "relu": - return self.proj(torch.relu(self.fc(x))) - elif self.activation == "leaky_relu": - return self.proj(F.leaky_relu(self.fc(x), negative_slope=0.01)) - else: # relu2 - return self.proj(torch.relu(self.fc(x)).square()) - -class SmearModule(nn.Module): - def __init__(self, dim: int): - super().__init__() - self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) - - def forward(self, x: Tensor) -> Tensor: - cumsum = x.cumsum(dim=1) - counts = torch.arange(1, x.size(1) + 1, device=x.device, dtype=x.dtype).view(1, -1, 1) - smeared = cumsum / counts - gate = torch.tanh(self.gate.to(dtype=x.dtype)) - return x + gate * (smeared - x) - - -class CausalConvRefiner(nn.Module): - "Causal Conv1d that refines hidden states using local n-gram context." - def __init__(self, dim: int, kernel_size: int = 3): - super().__init__() - self.kernel_size = kernel_size - self.conv = nn.Conv1d(dim, dim, kernel_size, padding=0, bias=False) - self.gate = nn.Parameter(torch.zeros(1, dtype=torch.float32)) - - def forward(self, x: Tensor) -> Tensor: - h = x.permute(0, 2, 1) # [B, D, S] - h = F.pad(h, (self.kernel_size - 1, 0)) # causal pad - h = self.conv(h) - h = h.permute(0, 2, 1) # [B, S, D] - return x + torch.tanh(self.gate.to(dtype=x.dtype)) * F.rms_norm(h, (h.size(-1),)) - -class Block(nn.Module): - def __init__(self, dim: int, num_heads: int, num_kv_heads: int, mlp_mult: int, - rope_base: float, qk_gain_init: float, group_size: int=64, - activation: str="swiglu", attn_proj_type: str="standard", - tversky_num_features: int=16, tversky_feature_pools: int=0, no_cache: bool=False, - smear: bool=False, rope_type: str="rope", yarn_max_len: int=4096, - train_seq_len: int=1024, tversky_membership: str="sigmoid", - diff_attn: bool=False, mlp_groups: int=0): - super().__init__() - self.attn_norm = RMSNorm() - self.mlp_norm = RMSNorm() - self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init, - group_size, attn_proj_type, tversky_num_features, - tversky_feature_pools, no_cache, rope_type, yarn_max_len, - train_seq_len, tversky_membership, diff_attn) - self.mlp = MLP(dim, mlp_mult, group_size, activation, mlp_groups) - self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) - self.smear = SmearModule(dim) if smear else None - - def forward(self, x: Tensor, x0: Tensor) -> Tensor: - mix = self.resid_mix.to(dtype=x.dtype) - x = mix[0] * x + mix[1] * x0 - n = self.attn_norm(x) - x = x + self.attn_scale.to(dtype=x.dtype) * self.attn(n) - x = x + self.mlp_scale.to(dtype=x.dtype) * self.mlp(self.mlp_norm(x)) - if self.smear is not None: - x = self.smear(x) - return x - -class GPT(nn.Module): - def __init__(self, vocab_size, num_layers, model_dim, num_heads, num_kv_heads, mlp_mult, - tie_embeddings, tied_embed_init_std, logit_softcap, rope_base, qk_gain_init, - group_size: int = 64, activation: str = "swiglu", mtp_heads_count: int = 0, - embed_dim: int = 0, attn_proj_type: str = "standard", logit_head_type: str = "standard", - tversky_num_features: int = 16, tversky_feature_pools: int = 0, - training_depth_recurrence: int=1, fp_storage=False, bigram_hash: bool=False, - softcap_type: str="poly", no_cache: bool=False, - smear: bool=False, rope_type: str="rope", yarn_max_len: int=4096, - train_seq_len: int=1024, tversky_membership: str="sigmoid", - diff_attn=False, mlp_groups=0, refiner=False, refiner_kernel=3): - super().__init__() - self.training_depth_recurrence = training_depth_recurrence - self.fp_storage = fp_storage - self.tie_embeddings = tie_embeddings - self.logit_softcap = logit_softcap - self.softcap_type = softcap_type - self.embed_dim = embed_dim if embed_dim > 0 else model_dim - self.tok_emb = QATEmbedding(vocab_size, self.embed_dim, fp_storage=fp_storage) - self.bigram_emb = QATEmbedding(vocab_size, self.embed_dim, fp_storage=fp_storage) if bigram_hash else None - if self.bigram_emb is not None: - nn.init.zeros_(self.bigram_emb.weight) - self.lm_head_correction = nn.Parameter( - torch.zeros(vocab_size, self.embed_dim)) if tie_embeddings == 2 else None - self.embed_proj = QATLinear(self.embed_dim, model_dim, bias=False, fp_storage=fp_storage) if self.embed_dim != model_dim else None - self.embed_proj_rev = QATLinear(model_dim, self.embed_dim, bias=False, fp_storage=fp_storage) if ( - self.embed_dim != model_dim and logit_head_type != "tversky") else None - self.num_encoder_layers = num_layers // 2 - self.num_decoder_layers = num_layers - self.num_encoder_layers - self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) - self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) - - # Shared Tversky feature pools (if enabled and num_features > 0) - if attn_proj_type == "tversky" and tversky_feature_pools > 0 and tversky_num_features > 0: - self.tversky_feature_pools_list = nn.ParameterList([ - nn.Parameter(torch.empty(tversky_num_features, model_dim).uniform_(-0.02, 0.02)) - for _ in range(tversky_feature_pools) - ]) - else: - self.tversky_feature_pools_list = None - - self.blocks = nn.ModuleList([ - Block(model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init, - group_size, activation, attn_proj_type, tversky_num_features, tversky_feature_pools, - no_cache, smear, rope_type, yarn_max_len, train_seq_len, tversky_membership, - diff_attn, mlp_groups) - for _ in range(num_layers) - ]) - - # Inject shared feature pool references into attention layers - if self.tversky_feature_pools_list is not None: - for i, block in enumerate(self.blocks): - pool_idx = (i * tversky_feature_pools) // num_layers - block.attn.shared_features = self.tversky_feature_pools_list[pool_idx] - - self.final_norm = RMSNorm() - self.refiner = CausalConvRefiner(model_dim, kernel_size=refiner_kernel) if refiner else None - self.mtp_heads = nn.ModuleList([ - nn.Linear(model_dim, vocab_size, bias=False) for _ in range(mtp_heads_count) - ]) - for h in self.mtp_heads: - nn.init.zeros_(h.weight) - self.logit_head_type = logit_head_type - if logit_head_type == "tversky" and tversky_num_features == 0 and vocab_size > 1024: - raise ValueError( - f"Tversky logit head with no-features mode creates O(V^2) = {vocab_size}x{vocab_size} " - f"matrix per forward pass. Use tversky_num_features > 0 or a smaller vocab." - ) - self.tversky_head = TverskyProjection( - model_dim, vocab_size, num_features=tversky_num_features, - membership=tversky_membership, - ) if logit_head_type == "tversky" else None - self.lm_head = QATLinear(model_dim, vocab_size, bias=False, fp_storage=fp_storage) - self.lm_head._zero_init = True - if self.lm_head is not None and (tie_embeddings or logit_head_type == "tversky"): - self.lm_head.weight.requires_grad_(False) - - self.vocab_bias = nn.Parameter(torch.zeros(vocab_size, dtype=torch.float32)) - self._init_weights(tied_embed_init_std) - - def _init_weights(self, tied_embed_init_std: float) -> None: - if self.tie_embeddings: - nn.init.normal_(self.tok_emb.weight, mean=0.0, std=tied_embed_init_std) - for module in self.modules(): - if isinstance(module, TernaryLinear) and not getattr(module, "_zero_init", False): - nn.init.normal_(module.weight, mean=0.0, std=0.02) - elif isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): - nn.init.zeros_(module.weight) - - def _compute_logits(self, x: Tensor) -> Tensor: - if self.tversky_head is not None: - logits_raw = self.tversky_head(x) - elif self.tie_embeddings: - if self.embed_proj_rev is not None: - proj = self.embed_proj_rev(x) - else: - proj = x - weight = self.tok_emb.weight - if self.lm_head_correction is not None: - weight = weight + self.lm_head_correction - logits_raw = F.linear(proj, weight.to(x.dtype)) - else: - logits_raw = self.lm_head(x) - return logits_raw + self.vocab_bias.to(x.dtype) - - def _softcap(self, logits: Tensor) -> Tensor: - s = self.logit_softcap - if self.softcap_type == "tanh": - return s * torch.tanh(logits / s) - x_sc = torch.clamp(logits / s, -2.0, 2.0) - x2 = x_sc * x_sc - return s * torch.clamp(x_sc * (1.0 - x2 / 3.0 + x2 * x2 / 15.0), -1.0, 1.0) - - def forward(self, input_ids: Tensor, target_ids: Tensor, reduction: str = "mean", temperature: float = 1.0) -> Tensor: - x = self.tok_emb(input_ids).float() - if self.bigram_emb is not None: - prev = F.pad(input_ids[:, :-1], (1, 0), value=0) - x = x + self.bigram_emb(prev).float() - if self.embed_proj is not None: - x = self.embed_proj(x) - x = F.rms_norm(x, (x.size(-1),)) - x0 = x - - # U-Net style encoder/decoder with skip connections - skips = [] - for i in range(self.num_encoder_layers): - for _ in range(max(1, self.training_depth_recurrence)): - x = self.blocks[i](x, x0) - skips.append(x) - for i in range(self.num_decoder_layers): - bi = self.num_encoder_layers + i - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype) * skips.pop() - for _ in range(max(1, self.training_depth_recurrence)): - x = self.blocks[bi](x, x0) - - x_normed = self.final_norm(x) - if self.refiner is not None: - x_normed = self.refiner(x_normed) - - # Standard training/eval path - x_flat = x_normed.reshape(-1, x_normed.size(-1)) - targets = target_ids.reshape(-1) - logits = self._softcap(self._compute_logits(x_flat)) - - if reduction == "none": - return F.cross_entropy(logits.float(), targets, reduction="none").reshape(input_ids.shape) - - # Fused CE + Z-loss: single logsumexp computation - logits_f = logits.float() - lse = torch.logsumexp(logits_f, dim=-1) - target_logits = logits_f.gather(1, targets.unsqueeze(1)).squeeze(1) - main_loss = (lse - target_logits).mean() + 1e-4 * (lse ** 2).mean() - - # Multi-token prediction auxiliary loss (training only) - if self.training and len(self.mtp_heads) > 0: - mtp_loss = torch.zeros((), device=main_loss.device) - for k, head in enumerate(self.mtp_heads): - shift = k + 2 - if target_ids.shape[1] > shift: - mtp_tgt = target_ids[:, shift:].reshape(-1) - mtp_in = x_normed[:, :target_ids.shape[1] - shift, :].reshape(-1, x_normed.shape[-1]) - mtp_loss = mtp_loss + F.cross_entropy(head(mtp_in).float(), mtp_tgt, reduction="mean") - main_loss = main_loss + 0.1 * mtp_loss / len(self.mtp_heads) - return main_loss - -# --------------------------------------------------------------------------- -# Validation -# --------------------------------------------------------------------------- - -def build_luts(sp, vocab_size: int, device: torch.device): - sp_vocab_size = int(sp.vocab_size()) - table_size = max(sp_vocab_size, vocab_size) - base_bytes_np = np.zeros((table_size,), dtype=np.int16) - has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) - is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) - for token_id in range(sp_vocab_size): - if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): - continue - is_boundary_token_np[token_id] = False - if sp.is_byte(token_id): - base_bytes_np[token_id] = 1 - continue - piece = sp.id_to_piece(token_id) - if piece.startswith("\u2581"): - has_leading_space_np[token_id] = True - piece = piece[1:] - base_bytes_np[token_id] = len(piece.encode("utf-8")) - return ( - torch.tensor(base_bytes_np, dtype=torch.int16, device=device), - torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), - torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), - ) - -def ld_val(pattern, seq_len, max_tok=int(os.environ.get("VAL_MAX_TOKENS", 500000))): - files = sorted(glob.glob(pattern)) - assert files, f"No files: {pattern}" - tok = torch.cat([ld_shard(Path(p)) for p in files]).contiguous() - if max_tok > 0: tok = tok[:max_tok + 1] - u = ((tok.numel() - 1) // seq_len) * seq_len - return tok[:u + 1] - -def eval_val(args, model, rank, world_size, device, grad_accum_steps, val_tokens, - base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, temperature: float = 1.0): - local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) - local_batch_seqs = max(1, local_batch_tokens // args.train_seq_len) - total_seqs = (val_tokens.numel() - 1) // args.train_seq_len - seq_start = (total_seqs * rank) // world_size - seq_end = (total_seqs * (rank + 1)) // world_size - loss_sum = torch.zeros((), device=device, dtype=torch.float64) - token_count = torch.zeros((), device=device, dtype=torch.float64) - byte_count = torch.zeros((), device=device, dtype=torch.float64) - model.eval() - with torch.inference_mode(): - for batch_start in range(seq_start, seq_end, local_batch_seqs): - batch_end = min(batch_start + local_batch_seqs, seq_end) - raw_start = batch_start * args.train_seq_len - raw_end = batch_end * args.train_seq_len + 1 - local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64) - x, y = local[:-1].reshape(-1, args.train_seq_len), local[1:].reshape(-1, args.train_seq_len) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16): - batch_loss = model(x, y, temperature=temperature).detach() - n = float(y.numel()) - loss_sum += batch_loss.to(torch.float64) * n - token_count += n - prev_ids, tgt_ids = x.reshape(-1), y.reshape(-1) - tok_bytes = base_bytes_lut[tgt_ids].to(torch.int16) - tok_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(torch.int16) - byte_count += tok_bytes.to(torch.float64).sum() - if dist.is_available() and dist.is_initialized(): - for t in (loss_sum, token_count, byte_count): - dist.all_reduce(t, op=dist.ReduceOp.SUM) - val_loss = loss_sum / token_count - bpb = (val_loss.item() / math.log(2.0)) * (token_count.item() / byte_count.item()) - model.train() - return float(val_loss.item()), float(bpb) - -def eval_val_sliding(args, model, rank, world_size, device, grad_accum_steps, val_tokens, - base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - stride: int = 64, temperature: float = 1.0): - seq_len = args.train_seq_len - batch_size = args.sliding_batch_size - total_tokens = val_tokens.numel() - 1 - loss_sum = torch.zeros((), device=device, dtype=torch.float64) - token_count = torch.zeros((), device=device, dtype=torch.float64) - byte_count = torch.zeros((), device=device, dtype=torch.float64) - all_starts = list(range(0, total_tokens - seq_len, stride)) - my_starts = all_starts[rank::world_size] - - model.eval() - with torch.inference_mode(): - for i in range(0, len(my_starts), batch_size): - batch_starts = my_starts[i:i + batch_size] - - starts_t = torch.tensor(batch_starts, dtype=torch.int64) - offsets = torch.arange(seq_len + 1, dtype=torch.int64) - indices = starts_t.unsqueeze(1) + offsets.unsqueeze(0) - - local_batch = val_tokens[indices].to(device=device, dtype=torch.int64, non_blocking=True) - x = local_batch[:, :-1] - y = local_batch[:, 1:] - - with torch.autocast(device_type="cuda", dtype=torch.bfloat16): - per_token_loss = model(x, y, reduction="none", temperature=temperature).detach() - - for b, start in enumerate(batch_starts): - score_from = 0 if start == 0 else seq_len - stride - scored = per_token_loss[b, score_from:] - sx, sy = x[b, score_from:], y[b, score_from:] - - loss_sum += scored.to(torch.float64).sum() - token_count += scored.numel() - - tok_bytes = base_bytes_lut[sy].to(torch.int16) - tok_bytes += (has_leading_space_lut[sy] & ~is_boundary_token_lut[sx]).to(torch.int16) - byte_count += tok_bytes.to(torch.float64).sum() - - if dist.is_available() and dist.is_initialized(): - for t in (loss_sum, token_count, byte_count): - dist.all_reduce(t, op=dist.ReduceOp.SUM) - - val_loss = loss_sum / token_count - bpb = (val_loss.item() / math.log(2.0)) * (token_count.item() / byte_count.item()) - model.train() - return float(val_loss.item()), float(bpb) - -# --------------------------------------------------------------------------- -# Temperature scaling -# --------------------------------------------------------------------------- -def find_temp(args, base_model, rank, world_size, device, grad_accum_steps, - calibration_tokens, base_bytes_lut, has_leading_space_lut, - is_boundary_token_lut): - best_t, best_loss = 1.0, float("inf") - for t in [0.90, 0.95, 1.00, 1.05, 1.10]: - loss, _ = eval_val(args, base_model, rank, world_size, device, grad_accum_steps, - calibration_tokens, base_bytes_lut, has_leading_space_lut, - is_boundary_token_lut, temperature=t) - if loss < best_loss: - best_loss = loss - best_t = t - return best_t - -# --------------------------------------------------------------------------- -# Training -# --------------------------------------------------------------------------- -def main() -> None: - args = Hyperparameters() - code = Path(__file__).read_text(encoding="utf-8") - - if args.matrix_optimizer != "adamw": - global ns_orth - ns_orth = torch.compile(ns_orth) - - distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ - rank = int(os.environ.get("RANK", "0")) - world_size = int(os.environ.get("WORLD_SIZE", "1")) - local_rank = int(os.environ.get("LOCAL_RANK", "0")) - grad_accum_steps = max(1, 8 // world_size) - grad_scale = 1.0 / grad_accum_steps - - if not torch.cuda.is_available(): - raise RuntimeError("CUDA is required") - device = torch.device("cuda", local_rank) - torch.cuda.set_device(device) - if distributed: - dist.init_process_group(backend="nccl", device_id=device) - dist.barrier() - master_process = rank == 0 - torch.backends.cuda.matmul.allow_tf32 = True - torch.backends.cudnn.allow_tf32 = True - - os.makedirs("logs/cuda/", exist_ok=True) - logfile = f"logs/cuda/{args.run_id}.txt" if master_process else None - if master_process: - print(logfile) - - def log0(msg: str, console: bool = True) -> None: - if not master_process: - return - if console: - print(msg) - if logfile: - with open(logfile, "a", encoding="utf-8") as f: - print(msg, file=f) - - log0(code, console=False) - log0("=" * 100, console=False) - - log0(f"Python {sys.version}", console=False) - log0(f"PyTorch {torch.__version__}", console=False) - - random.seed(args.seed) - np.random.seed(args.seed) - torch.manual_seed(args.seed) - torch.cuda.manual_seed_all(args.seed) - - sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) - val_tokens = ld_val(args.val_files, args.train_seq_len) - base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_luts( - sp, args.vocab_size, device) - - # --- Model --- - base_model = GPT( - vocab_size=args.vocab_size, num_layers=args.num_layers, model_dim=args.model_dim, - num_heads=args.num_heads, num_kv_heads=args.num_kv_heads, mlp_mult=args.mlp_mult, - tie_embeddings=args.tie_embeddings, tied_embed_init_std=args.tied_embed_init_std, - logit_softcap=args.logit_softcap, rope_base=args.rope_base, qk_gain_init=args.qk_gain_init, - group_size=args.bitnet_group_size, activation=args.activation_type, mtp_heads_count=args.mtp_heads_count, - embed_dim=args.embed_dim, attn_proj_type=args.attn_proj_type, logit_head_type=args.logit_head_type, - tversky_num_features=args.tversky_num_features, tversky_feature_pools=args.tversky_feature_pools, - training_depth_recurrence=args.training_depth_recurrence, fp_storage=args.fp_storage, - bigram_hash=args.bigram_hash, softcap_type=args.softcap_type, no_cache=(args.compile_mode == "reduce-overhead"), - smear=args.smear, rope_type=args.rope_type, yarn_max_len=args.yarn_max_len, train_seq_len=args.train_seq_len, - tversky_membership=args.tversky_membership, diff_attn=args.diff_attn, - refiner=args.refiner, refiner_kernel=args.refiner_kernel, mlp_groups=args.mlp_groups, - ).to(device).bfloat16() - - for module in base_model.modules(): - if isinstance(module, nn.Linear): - module.float() - restore_low_dim_params_to_fp32(base_model) - if base_model.lm_head is not None and (args.tie_embeddings or args.logit_head_type == "tversky"): - base_model.lm_head.weight.requires_grad_(False) - - torch._dynamo.config.optimize_ddp = False - - compiled_model = torch.compile(base_model, mode=args.compile_mode if args.compile_mode != "default" else None) - use_find_unused = args.untie_at_fraction > 0 or args.mtp_heads_count > 0 or not args.tie_embeddings - model = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False, - find_unused_parameters=use_find_unused, - static_graph=not use_find_unused, - gradient_as_bucket_view=True) if distributed else compiled_model - - # --- Optimizers --- - _excl = {"tok_emb.weight", "lm_head.weight", "lm_head_correction"} - all_other_params = [(n, p) for n, p in base_model.named_parameters() - if not any(eh in n for eh in _excl)] - matrix_params = [p for n, p in all_other_params - if p.ndim == 2 and not any(pat in n for pat in CTP)] - scalar_params = [p for n, p in all_other_params - if p.ndim < 2 or any(pat in n for pat in CTP)] - - token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr - opt_tok = torch.optim.Adam( - [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], - betas=(args.beta1, args.beta2), eps=args.adam_eps, fused=True) - if args.matrix_optimizer == "adamw": - opt_muon = torch.optim.AdamW( - [{"params": matrix_params, "lr": args.adam_lr, "base_lr": args.adam_lr}], - betas=(args.beta1, args.beta2), eps=args.adam_eps, weight_decay=args.adam_wd, fused=True) - else: - opt_muon = Muon(matrix_params, lr=args.matrix_lr, momentum=args.muon_momentum, - backend_steps=args.muon_backend_steps, wd=args.muon_wd) - for g in opt_muon.param_groups: - g["base_lr"] = args.matrix_lr - opt_scalar = torch.optim.Adam( - [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], - betas=(args.beta1, args.beta2), eps=args.adam_eps, fused=True) - opt_head = torch.optim.Adam( - [{"params": [base_model.lm_head.weight], "lr": 0.0, "base_lr": 0.0}], - betas=(args.beta1, args.beta2), eps=args.adam_eps, fused=True) - - optimizers = [opt for opt in [opt_tok, opt_muon, opt_scalar, opt_head] if opt is not None] - - if base_model.lm_head_correction is not None: - opt_corr = torch.optim.Adam( - [{"params": [base_model.lm_head_correction], - "lr": args.corr_weight_lr, "base_lr": args.corr_weight_lr}], - betas=(args.beta1, args.beta2), eps=args.adam_eps, fused=True) - optimizers.append(opt_corr) - - # --- Log all hyperparameters --- - log0("--- Hyperparameters ---", console=False) - log0(" ".join(f"{a}={getattr(args,a)}" for a in sorted(dir(args)) if not a.startswith("_") and a not in ("train_files","val_files") and not callable(getattr(args,a))), console=False) - n_params = sum(p.numel() for p in base_model.parameters()) - log0(f"params:{n_params} L:{args.num_layers} d:{args.model_dim} h:{args.num_heads} kv:{args.num_kv_heads} ws:{world_size} ga:{grad_accum_steps} s:{args.seed}") - - # --- Data loader & helpers --- - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - def zero_grad_all(): - for opt in optimizers: - opt.zero_grad(set_to_none=True) - - max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None - - def lr_mul(step: int, elapsed_ms: float): - if args.warmdown_fraction <= 0: - return 1.0 - if max_wallclock_ms is None: - warmdown_start = int(args.iterations * (1.0 - args.warmdown_fraction)) - return max((args.iterations - step) / max(args.iterations * args.warmdown_fraction, 1), 0.0) if step >= warmdown_start else 1.0 - warmdown_ms = max_wallclock_ms * args.warmdown_fraction - remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) - return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 - - _seq_switched = False - _batch_switched = False - active_seq_len = args.seq_len_start if args.seq_len_start > 0 else args.train_seq_len - active_batch_tokens = args.batch_tokens_start if args.batch_tokens_start > 0 else args.train_batch_tokens - - # --- Compiler warmup --- - if args.warmup_steps > 0: - _ms = {n: t.detach().cpu().clone() for n, t in base_model.state_dict().items()} - _os = [copy.deepcopy(o.state_dict()) for o in optimizers] - model.train() - for ws in range(args.warmup_steps): - zero_grad_all() - for mi in range(grad_accum_steps): - if distributed: model.require_backward_grad_sync = mi == grad_accum_steps - 1 - x, y = train_loader.next_batch(active_batch_tokens, active_seq_len, grad_accum_steps) - torch.compiler.cudagraph_mark_step_begin() - with torch.autocast(device_type="cuda", dtype=torch.bfloat16): loss = model(x, y) - (loss * grad_scale).backward() - for o in optimizers: o.step() - zero_grad_all() - log0(f"warmup:{ws+1}/{args.warmup_steps}") - base_model.load_state_dict(_ms, strict=True) - for o, s in zip(optimizers, _os): o.load_state_dict(s) - zero_grad_all() - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - # --- Main training loop --- - training_time_ms = 0.0 - stop_after_step: int | None = None - _untied = False - train_loss = torch.zeros((), device=device) - torch.cuda.synchronize() - t0 = time.perf_counter() - step = 0 - - while True: - last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) - - if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): - torch.cuda.synchronize() - training_time_ms += 1000.0 * (time.perf_counter() - t0) - val_loss, val_bpb = eval_val(args, model, rank, world_size, device, grad_accum_steps, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut) - tstats = tern_stats(base_model, group_size=args.bitnet_group_size) - log0(f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " - f"train_time:{training_time_ms:.0f}ms zero_frac:{tstats['zero_frac']:.3f}") - torch.cuda.synchronize() - t0 = time.perf_counter() - - if last_step: - if stop_after_step is not None and step < args.iterations: - log0(f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms step:{step}/{args.iterations}") - break - - elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - scale = lr_mul(step, elapsed_ms) - - # Sequence length schedule - if args.seq_len_start > 0 and not _seq_switched: - if max_wallclock_ms is not None: - should_switch_seq = elapsed_ms >= args.seq_schedule_fraction * max_wallclock_ms - else: - should_switch_seq = step >= int(args.iterations * args.seq_schedule_fraction) - if should_switch_seq: - active_seq_len = args.train_seq_len - _seq_switched = True - torch._dynamo.reset() - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - log0(f"step:{step} seq_len_switch:{args.seq_len_start}->{active_seq_len}") - - # Batch size schedule - if args.batch_tokens_start > 0 and not _batch_switched: - if max_wallclock_ms is not None: - should_switch_batch = elapsed_ms >= args.batch_schedule_fraction * max_wallclock_ms - else: - should_switch_batch = step >= int(args.iterations * args.batch_schedule_fraction) - if should_switch_batch: - active_batch_tokens = args.train_batch_tokens - _batch_switched = True - log0(f"step:{step} batch_switch:{args.batch_tokens_start}->{active_batch_tokens}") - - zero_grad_all() - train_loss.zero_() - - for micro in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro == grad_accum_steps - 1 - x, y = train_loader.next_batch(active_batch_tokens, active_seq_len, grad_accum_steps) - torch.compiler.cudagraph_mark_step_begin() - with torch.autocast(device_type="cuda", dtype=torch.bfloat16): - loss = model(x, y) - train_loss.add_(loss.detach()) - (loss * grad_scale).backward() - train_loss /= grad_accum_steps - - # Untie lm_head at configured fraction of training - if args.untie_at_fraction > 0: - if max_wallclock_ms is not None: - should_untie = not _untied and elapsed_ms >= args.untie_at_fraction * max_wallclock_ms - else: - should_untie = not _untied and step >= int(args.iterations * args.untie_at_fraction) - if should_untie and base_model.tie_embeddings: - with torch.no_grad(): - base_weight = base_model.tok_emb.weight.float() - if base_model.lm_head_correction is not None: - base_weight = base_weight + base_model.lm_head_correction.float() - if base_model.embed_proj_rev is not None: - full_weight = base_weight @ base_model.embed_proj_rev.weight.float() - else: - full_weight = base_weight - base_model.lm_head.weight.copy_(full_weight) - base_model.tie_embeddings = False - base_model.lm_head.weight.requires_grad_(True) - for g in opt_head.param_groups: - g["lr"] = g["base_lr"] = args.head_lr - _untied = True - torch._dynamo.reset() - log0(f"step:{step} untied lm_head (head_lr={args.head_lr})") - - # Muon momentum warmup - if args.matrix_optimizer != "adam": - frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 - for g in opt_muon.param_groups: - g["momentum"] = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum - - # LR scheduling - for opt in optimizers: - for g in opt.param_groups: - g["lr"] = g["base_lr"] * scale - opt.step() - zero_grad_all() - step += 1 - approx_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - - if args.train_log_every > 0 and step % args.train_log_every == 0: - log0(f"step:{step}/{args.iterations} loss:{train_loss.item():.4f} t:{approx_ms:.0f}ms avg:{approx_ms/step:.1f}ms") - if args.churn_log_every > 0 and step % args.churn_log_every == 0: - log0(f"step:{step} churn:{churn_fn(base_model, args.bitnet_group_size):.4f} zero:{tern_stats(base_model, args.bitnet_group_size)['zero_frac']:.3f}") - - # Wallclock cap sync - if stop_after_step is None and max_wallclock_ms is not None and step % 10 == 0: - reached_cap = approx_ms >= max_wallclock_ms - if distributed: - cap_t = torch.tensor(int(reached_cap), device=device) - dist.all_reduce(cap_t, op=dist.ReduceOp.MAX) - reached_cap = bool(cap_t.item()) - if reached_cap: - stop_after_step = step - - # --- Serialization --- - if master_process: - sd = base_model.state_dict() - if base_model.tie_embeddings or args.logit_head_type == "tversky": - sd.pop("lm_head.weight", None) - - # Compute ternary overrides for no-features Tversky prototypes - ternary_overrides = set() - for n, m in base_model.named_modules(): - if isinstance(m, TverskyProjection) and m.no_features_mode: - ternary_overrides.add(n + ".prototypes") - ternary_overrides = ternary_overrides or None - - # Two methods: Standard Base-3 vs Bitmask Mapping - methods = {} - for method in ("standard", "bitmask"): - q_obj, stats = q_sd(sd, group_size=args.bitnet_group_size, fp_storage=args.fp_storage, ternary_method=method, ternary_override_names=ternary_overrides) - buf = io.BytesIO() - torch.save(q_obj, buf) - methods[method] = {"blob": lzma.compress(buf.getvalue(), preset=9), "stats": stats} - best = min(methods, key=lambda m: len(methods[m]["blob"])) - final_blob, q_stats = methods[best]["blob"], methods[best]["stats"] - with open("final_model.ternary.ptz", "wb") as f: - f.write(final_blob) - - artifact_bytes = len(final_blob) - code_bytes = len(code.encode("utf-8")) - - total = artifact_bytes + code_bytes - log0(f"artifact:{artifact_bytes/1e6:.2f}MB ternary:{q_stats['ternary_params']}({q_stats['ternary_bytes']}B) fp:{q_stats['fp_params']}({q_stats['fp_bytes']}B) code:{code_bytes}") - log0(f"budget:{total}/{16000000} ({total/1e6:.2f}/{16.00:.2f}MB) {'FITS' if total <= 16000000 else 'OVER'}") - - if args.eval_depth_recurrence > 0: - base_model.training_depth_recurrence = args.eval_depth_recurrence - log0(f"eval_depth_recurrence:{args.eval_depth_recurrence}") - - # --- All ranks load roundtrip weights and evaluate --- - if distributed: - dist.barrier() - - with open("final_model.ternary.ptz", "rb") as f: - loaded = torch.load(io.BytesIO(lzma.decompress(f.read())), map_location="cpu", weights_only=False) - base_model.load_state_dict(deq_sd(loaded), strict=False) - torch._dynamo.reset() - - q_val_loss, q_val_bpb = eval_val(args, model, rank, world_size, device, grad_accum_steps, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut) - log0(f"final_ternary_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f}") - - opt_temp = 1.0 - if args.temp_scaling: - torch.cuda.synchronize() - t_temp = time.perf_counter() - calibration_tokens = train_loader.stream.take(65536).to(device) - opt_temp = find_temp(args, base_model, rank, world_size, device, grad_accum_steps, - calibration_tokens, base_bytes_lut, has_leading_space_lut, - is_boundary_token_lut) - torch.cuda.synchronize() - temp_time_ms = 1000.0 * (time.perf_counter() - t_temp) - log0(f"temp_scaling optimal_T:{opt_temp:.2f} eval_time:{temp_time_ms:.0f}ms") - - if args.sliding_eval: - torch.cuda.synchronize() - t_sliding = time.perf_counter() - sw_loss, sw_bpb = eval_val_sliding(args, base_model, rank, world_size, device, grad_accum_steps, - val_tokens, base_bytes_lut, has_leading_space_lut, - is_boundary_token_lut, stride=args.sliding_eval_stride, - temperature=opt_temp) - torch.cuda.synchronize() - sliding_time_ms = 1000.0 * (time.perf_counter() - t_sliding) - log0(f"final_sliding val_loss:{sw_loss:.4f} val_bpb:{sw_bpb:.4f} " - f"(stride={args.sliding_eval_stride}, T={opt_temp:.2f}) eval_time:{sliding_time_ms:.0f}ms") - - if distributed: - dist.destroy_process_group() - -if __name__ == "__main__": - main() -==================================================================================================== -Python 3.13.12 | packaged by Anaconda, Inc. | (main, Feb 24 2026, 16:13:31) [GCC 14.3.0] -PyTorch 2.10.0+cu128 ---- Hyperparameters --- -activation_type=relu2 adam_eps=1e-08 adam_lr=0.05 adam_wd=0.05 attn_proj_type=standard batch_schedule_fraction=0.33 batch_tokens_start=0 beta1=0.9 beta2=0.95 bigram_hash=False bitnet_group_size=128 churn_log_every=0 compile_mode=default corr_weight_lr=0.02 data_path=./data/datasets/fineweb10B_sp8192 diff_attn=False embed_dim=254 embed_lr=0.6 eval_depth_recurrence=0 fp_storage=True grad_clip_norm=0.0 head_lr=0.02 iterations=10000 logit_head_type=standard logit_softcap=10.0 matrix_lr=0.04 matrix_optimizer=muon max_wallclock_seconds=599.0 mlp_groups=0 mlp_mult=4 model_dim=768 mtp_heads_count=0 muon_backend_steps=3 muon_momentum=0.95 muon_momentum_warmup_start=0.85 muon_momentum_warmup_steps=500 muon_wd=0.0 num_heads=8 num_kv_heads=4 num_layers=10 qk_gain_init=2.25 refiner=False refiner_kernel=3 rope_base=5000.0 rope_type=yarn run_id=pushing_run_ternary_1 scalar_lr=0.02 seed=1337 seq_len_start=0 seq_schedule_fraction=0.0 sliding_batch_size=256 sliding_eval=True sliding_eval_stride=16 smear=False softcap_type=poly temp_scaling=True tie_embeddings=1 tied_embed_init_std=0.005 tied_embed_lr=0.02 tokenizer_path=./data/tokenizers/fineweb_8192_bpe.model train_batch_tokens=524288 train_log_every=1000 train_seq_len=1024 training_depth_recurrence=0 tversky_feature_pools=1 tversky_membership=sigmoid tversky_num_features=128 untie_at_fraction=0.0 val_batch_size=524288 val_loss_every=0 vocab_size=8192 warmdown_fraction=0.2 warmup_steps=5 yarn_max_len=2048 -params:73685840 L:10 d:768 h:8 kv:4 ws:8 ga:1 s:1337 -warmup:1/5 -warmup:2/5 -warmup:3/5 -warmup:4/5 -warmup:5/5 -step:1000/10000 loss:3.3139 t:91682ms avg:91.7ms -step:2000/10000 loss:3.2956 t:183563ms avg:91.8ms -step:3000/10000 loss:3.1547 t:275555ms avg:91.9ms -step:4000/10000 loss:3.3160 t:367464ms avg:91.9ms -step:5000/10000 loss:3.1719 t:459393ms avg:91.9ms -step:6000/10000 loss:3.0246 t:551299ms avg:91.9ms -step:6520/10000 val_loss:3.0541 val_bpb:1.1825 train_time:599128ms zero_frac:0.335 -stopping_early: wallclock_cap train_time:599128ms step:6520/10000 -artifact:15.92MB ternary:64880640(14245920B) fp:2513744(2537376B) code:70853 -budget:15995705/16000000 (16.00/16.00MB) FITS -final_ternary_roundtrip val_loss:3.0577 val_bpb:1.1839 -temp_scaling optimal_T:0.90 eval_time:151ms -final_sliding val_loss:2.9877 val_bpb:1.1568 (stride=16, T=0.90) eval_time:428712ms diff --git a/records/track_10min_16mb/2026-03-24_74M_Ternary_UNet_FP8_10L_8192BPE_YaRN_NeoMuon/ternary_log_42.txt b/records/track_10min_16mb/2026-03-24_74M_Ternary_UNet_FP8_10L_8192BPE_YaRN_NeoMuon/ternary_log_42.txt deleted file mode 100644 index e224ed1f9b..0000000000 --- a/records/track_10min_16mb/2026-03-24_74M_Ternary_UNet_FP8_10L_8192BPE_YaRN_NeoMuon/ternary_log_42.txt +++ /dev/null @@ -1,1460 +0,0 @@ -"Ternary training script for OpenAI's Parameter Golf Challenge. Ciprian-Florin Ifrim - 24 March 2026" - -import copy -import glob -import io -import math -import os -import random -import sys -import time -import lzma -from pathlib import Path -import numpy as np -import sentencepiece as spm -import torch -import torch.distributed as dist -import torch.nn.functional as F -from torch import Tensor, nn -from torch.nn.parallel import DistributedDataParallel as DDP -from flash_attn_interface import flash_attn_func - -# --------------------------------------------------------------------------- -# Hyperparameters (all configurable via environment variables) -# --------------------------------------------------------------------------- -def _e(k, d, t=str): - v = os.environ.get(k, str(d)) - if t == bool: return bool(int(v)) - return t(v) - -class Hyperparameters: - data_path = _e("DATA_PATH", "./data/datasets/fineweb10B_sp1024") - train_files = os.path.join(data_path, "fineweb_train_*.bin") - val_files = os.path.join(data_path, "fineweb_val_*.bin") - tokenizer_path = _e("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") - run_id = os.environ.get("RUN_ID", f"run_{int(time.time())}") - seed = _e("SEED", 1337, int) - compile_mode = _e("COMPILE_MODE", "default") - val_batch_size = _e("VAL_BATCH_SIZE", 524288, int) - val_loss_every = _e("VAL_LOSS_EVERY", 500, int) - train_log_every = _e("TRAIN_LOG_EVERY", 10, int) - iterations = _e("ITERATIONS", 2000, int) - warmdown_fraction = _e("WARMDOWN_FRACTION", 0.2, float) - warmup_steps = _e("WARMUP_STEPS", 20, int) - train_batch_tokens = _e("TRAIN_BATCH_TOKENS", 524288, int) - train_seq_len = _e("TRAIN_SEQ_LEN", 1024, int) - max_wallclock_seconds = _e("MAX_WALLCLOCK_SECONDS", 0.0, float) - vocab_size = _e("VOCAB_SIZE", 1024, int) - num_layers = _e("NUM_LAYERS", 16, int) - num_kv_heads = _e("NUM_KV_HEADS", 4, int) - model_dim = _e("MODEL_DIM", 512, int) - num_heads = _e("NUM_HEADS", 8, int) - mlp_mult = _e("MLP_MULT", 2, int) - tie_embeddings = _e("TIE_EMBEDDINGS", 1, int) - rope_base = _e("ROPE_BASE", 10000.0, float) - rope_type = _e("ROPE_TYPE", "rope") - yarn_max_len = _e("YARN_MAX_LEN", 4096, int) - logit_softcap = _e("LOGIT_SOFTCAP", 30.0, float) - softcap_type = _e("SOFTCAP_TYPE", "poly") - tied_embed_init_std = _e("TIED_EMBED_INIT_STD", 0.005, float) - qk_gain_init = _e("QK_GAIN_INIT", 1.5, float) - activation_type = _e("ACTIVATION", "swiglu") - embed_dim = _e("EMBED_DIM", 0, int) - bigram_hash = _e("BIGRAM_HASH", 0, bool) - mtp_heads_count = _e("MTP_HEADS", 0, int) - training_depth_recurrence = _e("TRAINING_DEPTH_RECURRENCE", 1, int) - eval_depth_recurrence = _e("EVAL_DEPTH_RECURRENCE", 1, int) - attn_proj_type = _e("ATTN_PROJ_TYPE", "standard") - logit_head_type = _e("LOGIT_HEAD_TYPE", "standard") - tversky_num_features = _e("TVERSKY_NUM_FEATURES", 16, int) - tversky_feature_pools = _e("TVERSKY_FEATURE_POOLS", 0, int) - tversky_membership = _e("TVERSKY_MEMBERSHIP", "sigmoid") - diff_attn = _e("DIFF_ATTN", 0, bool) - refiner = _e("REFINER", 0, bool) - refiner_kernel = _e("REFINER_KERNEL", 3, int) - mlp_groups = _e("MLP_GROUPS", 0, int) - embed_lr = _e("EMBED_LR", 0.6, float) - head_lr = _e("HEAD_LR", 0.008, float) - adam_lr = _e("ADAM_LR", 1e-3, float) - adam_wd = _e("ADAM_WD", 0.05, float) - untie_at_fraction = _e("UNTIE_AT_FRACTION", 0.0, float) - tied_embed_lr = _e("TIED_EMBED_LR", 0.05, float) - corr_weight_lr = _e("CORR_WEIGHT_LR", 0.05, float) - smear = _e("SMEAR", 0, bool) - seq_len_start = _e("SEQ_LEN_START", 0, int) - seq_schedule_fraction = _e("SEQ_SCHEDULE_FRACTION", 0.33, float) - batch_tokens_start = _e("BATCH_TOKENS_START", 0, int) - batch_schedule_fraction = _e("BATCH_SCHEDULE_FRACTION", 0.33, float) - churn_log_every = _e("CHURN_LOG_EVERY", 500, int) - matrix_lr = _e("MATRIX_LR", 0.04, float) - scalar_lr = _e("SCALAR_LR", 0.04, float) - muon_momentum = _e("MUON_MOMENTUM", 0.95, float) - muon_backend_steps = _e("MUON_BACKEND_STEPS", 5, int) - muon_wd = _e("MUON_WD", 0.0, float) - matrix_optimizer = _e("MATRIX_OPTIMIZER", "muon") - muon_momentum_warmup_start = _e("MUON_MOMENTUM_WARMUP_START", 0.85, float) - muon_momentum_warmup_steps = _e("MUON_MOMENTUM_WARMUP_STEPS", 500, int) - beta1 = _e("BETA1", 0.9, float) - beta2 = _e("BETA2", 0.95, float) - adam_eps = _e("ADAM_EPS", 1e-8, float) - grad_clip_norm = _e("GRAD_CLIP_NORM", 0.0, float) - bitnet_group_size = _e("BITNET_GROUP_SIZE", 64, int) - sliding_eval = _e("SLIDING_EVAL", 0, bool) - sliding_eval_stride = _e("SLIDING_EVAL_STRIDE", 64, int) - sliding_batch_size = _e("SLIDING_BATCH_SIZE", 64, int) - temp_scaling = _e("TEMP_SCALING", 0, bool) - _fp_raw = os.environ.get("FP_STORAGE", "0") - fp_storage = True if _fp_raw == "FP8" else ("fp4" if _fp_raw == "FP4" else False) - -CTP = ("attn_scale","attn_scales","mlp_scale","mlp_scales","resid_mix","resid_mixes","q_gain","diff_lambda","skip_weight","skip_weights","vocab_bias","refiner.gate") - -# --------------------------------------------------------------------------- -# Ternary packing — base-3 encoding (5 trits/byte) -# --------------------------------------------------------------------------- -def pack_ternary(q: Tensor): - f = (q.reshape(-1).to(torch.int8) + 1).numpy() - n = len(f) - p = (5 - n % 5) % 5 - if p: f = np.concatenate([f, np.zeros(p, dtype=np.int8)]) - g = f.reshape(-1, 5).astype(np.uint8) - return (g[:,0] + g[:,1]*3 + g[:,2]*9 + g[:,3]*27 + g[:,4]*81).tobytes(), n - -def unpack_ternary(data: bytes, n: int) -> Tensor: - v = np.frombuffer(data, dtype=np.uint8).astype(np.int16) - t = np.zeros((len(v), 5), dtype=np.int8) - for i in range(5): t[:,i] = v % 3; v //= 3 - return torch.from_numpy(t.reshape(-1)[:n].astype(np.int8) - 1) - -def pack_ternary_bitmask(q: Tensor): - f = q.reshape(-1).to(torch.int8).numpy(); n = len(f) - nz = (f != 0) - return np.packbits(nz).tobytes() + np.packbits(f[nz] > 0).tobytes(), n - -def unpack_ternary_bitmask(data: bytes, n: int) -> Tensor: - ms = (n + 7) // 8 - nz = np.unpackbits(np.frombuffer(data[:ms], dtype=np.uint8))[:n].astype(bool) - s = np.unpackbits(np.frombuffer(data[ms:], dtype=np.uint8))[:int(nz.sum())].astype(bool) - w = np.zeros(n, dtype=np.int8); w[nz] = np.where(s, 1, -1) - return torch.from_numpy(w) - -# --------------------------------------------------------------------------- -# FP4 quantization (per-row absmax, 2 values packed per byte) -# --------------------------------------------------------------------------- -def quantize_to_int4(t: Tensor) -> tuple[Tensor, Tensor, list]: - t32 = t.float() - orig_shape = t32.shape - if t32.ndim < 2: - t32 = t32.unsqueeze(0) - absmax = t32.abs().amax(dim=-1, keepdim=True).clamp(min=1e-8) - scale = absmax / 7.0 - q = torch.clamp(torch.round(t32 / scale), -7, 7).to(torch.int8) - flat = q.reshape(-1) - if flat.numel() % 2 != 0: - flat = F.pad(flat, (0, 1)) - low = (flat[0::2] + 8).to(torch.uint8) - high = (flat[1::2] + 8).to(torch.uint8) - return low | (high << 4), scale.half().squeeze(-1), list(orig_shape) - -def dequantize_from_int4(packed: Tensor, scale: Tensor, shape: list) -> Tensor: - low = (packed & 0x0F).to(torch.int8) - 8 - high = ((packed >> 4) & 0x0F).to(torch.int8) - 8 - flat = torch.zeros(packed.numel() * 2, dtype=torch.int8) - flat[0::2] = low - flat[1::2] = high - numel = 1 - for s in shape: - numel *= s - flat = flat[:numel].float() - if len(shape) <= 1: - return (flat * scale.float().squeeze()).reshape(shape) - return (flat.reshape(-1, shape[-1]) * scale.float().unsqueeze(-1)).reshape(shape) - -# --------------------------------------------------------------------------- -# State dict serialization (ternary + fp16/fp8/fp4) -# --------------------------------------------------------------------------- -def q_sd(state_dict: dict, group_size: int = 64, fp_storage=False, ternary_method="standard", ternary_override_names: set | None = None) -> tuple[dict, dict]: - "Ternary for large 2D weight matrices, fp16/fp8/fp4 for everything else." - quantized = {} - stats = {"ternary_params": 0, "ternary_bytes": 0, "fp_params": 0, "fp_bytes": 0} - for name, tensor in state_dict.items(): - if "mtp_heads" in name: - continue - t = tensor.detach().cpu().float().contiguous() - t_orig_shape = list(t.shape) - if t.ndim == 3: - t = t.reshape(t.shape[0], -1) - is_ternary_candidate = ( - t.ndim == 2 and t.numel() > 65_536 - and "tok_emb" not in name and "lm_head" not in name and "embed_proj" not in name and "bigram_emb" not in name and "lm_head_correction" not in name and "lm_head_U" not in name and "lm_head_V" not in name - and "prototypes" not in name and "tversky" not in name - ) or (ternary_override_names is not None and name in ternary_override_names) - if is_ternary_candidate: - pad = (group_size - t.shape[1] % group_size) % group_size - t_padded = F.pad(t, (0, pad)) if pad > 0 else t - t_grouped = t_padded.reshape(-1, group_size) - scale = t_grouped.abs().mean(-1, keepdim=True).clamp(min=1e-8).half().float() - q = (t_grouped / scale).round().clamp(-1, 1).to(torch.int8) - - if ternary_method == "standard": - packed_bytes, n_trits = pack_ternary(q) - entry_type = "ternary" - else: - packed_bytes, n_trits = pack_ternary_bitmask(q) - entry_type = "ternary_bitmask" - - quantized[name] = { - "type": entry_type, "packed": packed_bytes, - "scale": scale.half().squeeze(-1), - "shape": list(t.shape), "padded_cols": t_padded.shape[1], - "group_size": group_size, "n_trits": n_trits, - "orig_shape": t_orig_shape, - } - stats["ternary_params"] += t.numel() - stats["ternary_bytes"] += len(packed_bytes) + scale.numel() * 2 - elif fp_storage == "fp4" and t.ndim == 2: - packed, scale, orig_shape = quantize_to_int4(t) - quantized[name] = {"type": "fp4", "packed": packed, "scale": scale, "shape": orig_shape} - stats["fp_params"] += t.numel() - stats["fp_bytes"] += packed.numel() + scale.numel() * 2 - elif fp_storage and t.ndim == 2: - quantized[name] = {"type": "fp8", "data": t.to(torch.float8_e4m3fn)} - stats["fp_params"] += t.numel() - stats["fp_bytes"] += t.numel() - else: - quantized[name] = {"type": "fp16", "data": t.half()} - stats["fp_params"] += t.numel() - stats["fp_bytes"] += t.numel() * 2 - return quantized, stats - -def deq_sd(quantized: dict, target_dtype=torch.bfloat16): - "Reconstruct full-precision state dict from quantized representation." - out = {} - for name, entry in quantized.items(): - if entry["type"] in ("ternary", "ternary_bitmask"): - if entry["type"] == "ternary": - q = unpack_ternary(entry["packed"], entry["n_trits"]) - else: - q = unpack_ternary_bitmask(entry["packed"], entry["n_trits"]) - - q = q.float().reshape(-1, entry["group_size"]) - scale = entry["scale"].float().unsqueeze(-1) - q_absmean = q.abs().mean(-1, keepdim=True).clamp(min=1e-8) - t = (q * (scale / q_absmean)).reshape(-1, entry["padded_cols"]) - shape = entry["shape"] - result = t[:shape[0], :shape[1]].to(target_dtype) - orig = entry.get("orig_shape") - out[name] = result.reshape(orig).contiguous() if orig and orig != shape else result.contiguous() - elif entry["type"] == "fp8": - out[name] = entry["data"].to(torch.float32).to(target_dtype).contiguous() - elif entry["type"] == "fp4": - out[name] = dequantize_from_int4(entry["packed"], entry["scale"], entry["shape"]).to(target_dtype).contiguous() - else: - out[name] = entry["data"].to(target_dtype).contiguous() - return out - -# --------------------------------------------------------------------------- -# Ternary diagnostics (logged during training) -# --------------------------------------------------------------------------- -def tern_stats(model: nn.Module, group_size: int = 64): - total = zeros = 0 - with torch.no_grad(): - for name, p in model.named_parameters(): - if p.ndim == 2 and ("weight" in name or "prototypes" in name) and p.shape[0] > 1: - w = p.detach().float().reshape(-1, group_size) - scale = w.abs().mean(-1, keepdim=True).clamp(min=1e-8).half().float() - q = (w / scale).round().clamp(-1, 1) - zeros += int((q == 0).sum().item()) - total += int(q.numel()) - return {"zero_frac": zeros / max(total, 1), "total_weights": total} - -_prev_committed: dict = {} - -def churn_fn(model: nn.Module, group_size: int = 64): - global _prev_committed - total = flipped = 0 - with torch.no_grad(): - for name, p in model.named_parameters(): - if p.ndim == 2 and ("weight" in name or "prototypes" in name) and p.shape[0] > 1: - w = p.detach().float().reshape(-1, group_size) - scale = w.abs().mean(-1, keepdim=True).clamp(min=1e-8).half().float() - q = (w / scale).round().clamp(-1, 1).cpu().numpy() - if name in _prev_committed: - flipped += int(np.sum(q != _prev_committed[name])) - total += q.size - _prev_committed[name] = q - return flipped / max(total, 1) - -# --------------------------------------------------------------------------- -# Muon optimizer (Newton-Schulz orthogonalized momentum) -# --------------------------------------------------------------------------- -def ns_orth(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: - a, b, c = (3.4445, -4.7750, 2.0315) - X = G.bfloat16() - X /= X.norm() + eps - transposed = G.size(0) > G.size(1) - if transposed: - X = X.T - for _ in range(steps): - A = X @ X.T - B = b * A + c * A @ A - X = a * X + B @ X - return X.T if transposed else X - -class Muon(torch.optim.Optimizer): - def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True, wd: float = 0.0): - super().__init__(params, dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov, wd=wd)) - - @torch.no_grad() - def step(self, closure=None): - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - distributed = dist.is_available() and dist.is_initialized() - world_size = dist.get_world_size() if distributed else 1 - rank = dist.get_rank() if distributed else 0 - for group in self.param_groups: - params = group["params"] - if not params: - continue - lr, momentum = group["lr"], group["momentum"] - backend_steps, nesterov = group["backend_steps"], group["nesterov"] - total_params = sum(int(p.numel()) for p in params) - updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) - curr = 0 - for i, p in enumerate(params): - if i % world_size == rank and p.grad is not None: - g = p.grad - state = self.state[p] - if "momentum_buffer" not in state: - state["momentum_buffer"] = torch.zeros_like(g) - buf = state["momentum_buffer"] - buf.mul_(momentum).add_(g) - if nesterov: - g = g.add(buf, alpha=momentum) - g = F.rms_norm(g.float(), (g.size(-1),)).bfloat16() - g = ns_orth(g, steps=backend_steps) - g *= max(1, g.size(0) / g.size(1)) ** 0.5 - updates_flat[curr:curr + p.numel()] = g.reshape(-1) - curr += p.numel() - if distributed: - dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) - wd = group.get("wd", 0.0) - curr = 0 - for p in params: - g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) - if wd > 0: - p.mul_(1 - lr * wd) - p.add_(g, alpha=-lr) - curr += p.numel() - return loss - -# --------------------------------------------------------------------------- -# Data loading -# --------------------------------------------------------------------------- -def ld_shard(file: Path) -> Tensor: - header_bytes = 256 * np.dtype(" Tensor: - chunks = [] - remaining = n - while remaining > 0: - avail = self.tokens.numel() - self.pos - if avail <= 0: - self._advance_file() - continue - k = min(remaining, avail) - chunks.append(self.tokens[self.pos:self.pos + k]) - self.pos += k - remaining -= k - return chunks[0] if len(chunks) == 1 else torch.cat(chunks) - -class DistributedTokenLoader: - def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): - self.rank, self.world_size, self.device = rank, world_size, device - self.stream = TokenStream(pattern) - - def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: - local_tokens = global_tokens // (self.world_size * grad_accum_steps) - per_rank_span = local_tokens + 1 - chunk = self.stream.take(per_rank_span * self.world_size) - start = self.rank * per_rank_span - local = chunk[start:start + per_rank_span].pin_memory().to(self.device, non_blocking=True).to(torch.int64) - x = local[:-1].reshape(-1, seq_len) - y = local[1:].reshape(-1, seq_len) - return x, y - -# --------------------------------------------------------------------------- -# Model -# --------------------------------------------------------------------------- -class RMSNorm(nn.Module): - def __init__(self, eps: float | None = None): - super().__init__() - self.eps = eps - - def forward(self, x: Tensor) -> Tensor: - return F.rms_norm(x, (x.size(-1),), eps=self.eps) - -def apply_qat_ste(w: Tensor, fp_storage: str | bool) -> Tensor: - """Applies Straight-Through Estimator (STE) for FP4 or FP8 simulated quantization.""" - if not fp_storage: - return w - if fp_storage == "fp4": - absmax = w.abs().amax(dim=-1, keepdim=True).clamp(min=1e-8) - scale = absmax / 7.0 - q = torch.clamp(torch.round(w / scale), -7.0, 7.0) - w_sim = q * scale - return (w_sim - w).detach() + w - elif fp_storage is True or fp_storage == "fp8": - w_sim = w.to(torch.float8_e4m3fn).to(w.dtype) - return (w_sim - w).detach() + w - return w - -class QATLinear(nn.Linear): - def __init__(self, in_features: int, out_features: int, bias: bool = False, fp_storage: str | bool = False): - super().__init__(in_features, out_features, bias=bias) - self.fp_storage = fp_storage - - def forward(self, x: Tensor) -> Tensor: - w_qat = apply_qat_ste(self.weight, self.fp_storage) - return F.linear(x, w_qat.to(x.dtype), self.bias.to(x.dtype) if self.bias is not None else None) - -class QATEmbedding(nn.Embedding): - def __init__(self, num_embeddings: int, embedding_dim: int, fp_storage: str | bool = False): - super().__init__(num_embeddings, embedding_dim) - self.fp_storage = fp_storage - - def forward(self, input: Tensor) -> Tensor: - w_qat = apply_qat_ste(self.weight, self.fp_storage) - return F.embedding(input, w_qat, self.padding_idx, self.max_norm, - self.norm_type, self.scale_grad_by_freq, self.sparse) - -class TernaryLinear(nn.Linear): - def __init__(self, in_features, out_features, bias=False, group_size=64): - super().__init__(in_features, out_features, bias=bias) - self.group_size = group_size - - def forward(self, x: Tensor) -> Tensor: - w = self.weight.bfloat16() - g = self.group_size - w_g = w.reshape(-1, g) - scale = w_g.abs().mean(-1, keepdim=True).clamp(min=1e-8) - q = (w_g / scale).round().clamp(-1, 1) - w_ternary = w + ((q * scale).reshape(w.shape) - w).detach() - return F.linear(x, w_ternary, - self.bias.to(x.dtype) if self.bias is not None else None) - - -class NormedTernaryLinear(TernaryLinear): - "Ternary linear with RMSNorm on input — for output projections receiving un-normalized activations." - def forward(self, x: Tensor) -> Tensor: - return super().forward(F.rms_norm(x, (x.size(-1),))) - -class GroupedTernaryLinear(nn.Module): - "Grouped linear with ternary STE. Weight stored as 2D [groups*group_out, group_in] for ternary quantization compatibility." - def __init__(self, in_features, out_features, groups=4, group_size=64, normed=False): - super().__init__() - assert in_features % groups == 0 and out_features % groups == 0 - self.groups = groups - self.group_in = in_features // groups - self.group_out = out_features // groups - self.group_size = group_size - self.normed = normed - self.weight = nn.Parameter(torch.randn(groups * self.group_out, self.group_in) * 0.02) - - def forward(self, x: Tensor) -> Tensor: - if self.normed: - x = F.rms_norm(x, (x.size(-1),)) - w = self.weight.bfloat16() - g = self.group_size - w_g = w.reshape(-1, g) - scale = w_g.abs().mean(-1, keepdim=True).clamp(min=1e-8) - q = (w_g / scale).round().clamp(-1, 1) - w_ternary = w + ((q * scale).reshape(w.shape) - w).detach() - w_grouped = w_ternary.reshape(self.groups, self.group_out, self.group_in) - bsz = x.shape[:-1] - x_g = x.reshape(*bsz, self.groups, self.group_in) - out = torch.einsum('...gi,goi->...go', x_g, w_grouped) - return out.reshape(*bsz, self.groups * self.group_out) - -class TverskyProjection(nn.Module): - "Tversky similarity: S = θ·f(A∩B) - α·f(A\\B) - β·f(B\\A). Three modes." - def __init__(self, in_features: int, out_features: int, num_features: int = 16, - group_size: int = 64, use_shared_features: bool = False, - membership: str = "sigmoid"): - super().__init__() - self.group_size = group_size - self.num_features = num_features - self.membership_type = membership - self.no_features_mode = (num_features == 0) - - if not self.no_features_mode and not use_shared_features: - self.features = nn.Parameter(torch.empty(num_features, in_features).uniform_(-0.02, 0.02)) - else: - self.register_parameter('features', None) - - self.prototypes = nn.Parameter(torch.empty(out_features, in_features).uniform_(-0.02, 0.02)) - self.theta = nn.Parameter(torch.tensor(1.0)) - self.alpha = nn.Parameter(torch.tensor(0.5)) - self.beta = nn.Parameter(torch.tensor(0.5)) - - def _ternary_ste(self, w: Tensor) -> Tensor: - w_bf16 = w.bfloat16() - g = self.group_size - w_grouped = w_bf16.reshape(-1, g) - scale = w_grouped.abs().mean(-1, keepdim=True).clamp(min=1e-8) - q = (w_grouped / scale).round().clamp(-1, 1) - w_ternary = w_bf16 + ((q * scale).reshape(w_bf16.shape) - w_bf16).detach() - return w_ternary.reshape(w.shape) - - def _membership(self, t: Tensor) -> Tensor: - if self.membership_type == "poly": - return torch.clamp(t * 5.0 / 4.0 + 0.5, 0.0, 1.0) - elif self.membership_type == "tanh": - return (torch.tanh(t * 5.0) + 1.0) * 0.5 - else: - return torch.sigmoid(t * 5.0) - - def forward(self, x: Tensor, shared_features: Tensor | None = None) -> Tensor: - proto = self._ternary_ste(self.prototypes) - - if self.no_features_mode: - # NoFeatures: prototypes are their own feature universe - x_f = x @ proto.t() # [B, S, out] - p_norm = F.normalize(proto, dim=-1) - p_f = p_norm @ p_norm.t() # [out, out] - else: - feat = (shared_features if shared_features is not None else self.features).float() - x_f = x @ feat.t() # [B, S, nf] - p_f = proto @ feat.t() # [out, nf] - - x_s = self._membership(x_f) - p_s = self._membership(p_f) - x_a = x_f * x_s - p_a = p_f * p_s - - t, a, b = self.theta.abs(), self.alpha.abs(), self.beta.abs() - return t * (x_a @ p_a.t()) - a * (x_a @ (1 - p_s).t()) - b * ((1 - x_s) @ p_a.t()) - -def restore_low_dim_params_to_fp32(module: nn.Module) -> None: - with torch.no_grad(): - for name, param in module.named_parameters(): - if (param.ndim < 2 or any(p in name for p in CTP)) and param.dtype != torch.float32: - param.data = param.data.float() - -class Rotary(nn.Module): - def __init__(self, dim: int, base: float = 10000.0, no_cache: bool = False, - rope_type: str = "rope", yarn_max_len: int = 4096, train_seq_len: int = 1024): - super().__init__() - self.no_cache = no_cache - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) - if rope_type == "yarn": - scale = train_seq_len / yarn_max_len - freq_idx = torch.arange(0, dim, 2, dtype=torch.float32) - ramp = torch.clamp((freq_idx / dim - 0.25) / 0.75, 0.0, 1.0) - inv_freq = inv_freq / (ramp * (1.0 / scale - 1.0) + 1.0) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self._seq_len_cached = 0 - self._cos_cached: Tensor | None = None - self._sin_cached: Tensor | None = None - - def forward(self, seq_len, device, dtype): - if self.no_cache: - t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) - freqs = torch.outer(t, self.inv_freq.to(device)) - return freqs.cos()[None, :, None, :].to(dtype=dtype), freqs.sin()[None, :, None, :].to(dtype=dtype) - if ( - self._cos_cached is None - or self._sin_cached is None - or self._seq_len_cached != seq_len - or self._cos_cached.device != device - ): - t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) - freqs = torch.outer(t, self.inv_freq.to(device)) - self._cos_cached = freqs.cos()[None, :, None, :] - self._sin_cached = freqs.sin()[None, :, None, :] - self._seq_len_cached = seq_len - return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) - -def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: - half = x.size(-1) // 2 - x1, x2 = x[..., :half], x[..., half:] - return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) - -class CausalSelfAttention(nn.Module): - def __init__(self, dim, num_heads, num_kv_heads, rope_base, qk_gain_init, - group_size=64, attn_proj_type="standard", tversky_num_features=16, - tversky_feature_pools=0, no_cache=False, rope_type="rope", - yarn_max_len=4096, train_seq_len=1024, tversky_membership="sigmoid", - diff_attn=False): - super().__init__() - self.num_heads, self.num_kv_heads = num_heads, num_kv_heads - self.head_dim = dim // num_heads - self.diff_attn = diff_attn - self.q_size = self.num_heads * self.head_dim - self.kv_size = self.num_kv_heads * self.head_dim - - self.c_qkv = TernaryLinear(dim, self.q_size + 2 * self.kv_size, bias=False, group_size=group_size) - self.proj = NormedTernaryLinear(dim, dim, bias=False, group_size=group_size) if attn_proj_type != "tversky" else None - if self.proj is not None: - self.proj._zero_init = True - self.tversky_proj = TverskyProjection( - dim, dim, num_features=tversky_num_features, group_size=group_size, - use_shared_features=(tversky_feature_pools > 0), - membership=tversky_membership, - ) if attn_proj_type == "tversky" else None - self.shared_features = None - self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) - if diff_attn: - self.diff_lambda = nn.Parameter(torch.full((num_heads,), 0.5, dtype=torch.float32)) - self.rotary = Rotary(self.head_dim, base=rope_base, no_cache=no_cache, - rope_type=rope_type, yarn_max_len=yarn_max_len, - train_seq_len=train_seq_len) - - def forward(self, x: Tensor) -> Tensor: - bsz, seqlen, dim = x.shape - qkv_out = self.c_qkv(x) - q_out, k_out, v_out = qkv_out.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - q = q_out.reshape(bsz, seqlen, self.num_heads, self.head_dim) - k = k_out.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) - v = v_out.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) - q, k = F.rms_norm(q, (q.size(-1),)), F.rms_norm(k, (k.size(-1),)) - cos, sin = self.rotary(seqlen, x.device, q.dtype) - q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin) - q = q * self.q_gain.to(dtype=q.dtype)[None, None, :, None] - if self.diff_attn: - half = self.head_dim // 2 - q1, q2 = q[..., :half], q[..., half:] - k1, k2 = k[..., :half], k[..., half:] - v1, v2 = v[..., :half], v[..., half:] - y1 = flash_attn_func(q1.contiguous(), k1.contiguous(), v1.contiguous(), causal=True) - y2 = flash_attn_func(q2.contiguous(), k2.contiguous(), v2.contiguous(), causal=True) - lam = self.diff_lambda.to(dtype=y1.dtype)[None, None, :, None] - y = torch.cat([y1 - lam * y2, y1 + lam * y2], dim=-1) - else: - y = flash_attn_func( - q.contiguous(), - k.contiguous(), - v.contiguous(), - causal=True - ) - y = y.reshape(bsz, seqlen, dim) - return self.tversky_proj(y, self.shared_features) if self.tversky_proj is not None else self.proj(y) - -class MLP(nn.Module): - def __init__(self, dim, mlp_mult, group_size=64, activation="swiglu", mlp_groups=0): - super().__init__() - hidden = mlp_mult * dim - self.activation = activation - if mlp_groups > 0: - if activation == "swiglu": - self.gate_up = GroupedTernaryLinear(dim, hidden * 2, groups=mlp_groups, group_size=group_size) - else: - self.fc = GroupedTernaryLinear(dim, hidden, groups=mlp_groups, group_size=group_size) - self.proj = GroupedTernaryLinear(hidden, dim, groups=mlp_groups, group_size=group_size, normed=True) - else: - if activation == "swiglu": - self.gate_up = TernaryLinear(dim, hidden * 2, bias=False, group_size=group_size) - else: - self.fc = TernaryLinear(dim, hidden, bias=False, group_size=group_size) - self.proj = NormedTernaryLinear(hidden, dim, bias=False, group_size=group_size) - self.proj._zero_init = True - - def forward(self, x: Tensor) -> Tensor: - if self.activation == "swiglu": - gu = self.gate_up(x) - gate, up = gu.chunk(2, dim=-1) - return self.proj(F.silu(gate) * up) - elif self.activation == "relu": - return self.proj(torch.relu(self.fc(x))) - elif self.activation == "leaky_relu": - return self.proj(F.leaky_relu(self.fc(x), negative_slope=0.01)) - else: # relu2 - return self.proj(torch.relu(self.fc(x)).square()) - -class SmearModule(nn.Module): - def __init__(self, dim: int): - super().__init__() - self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) - - def forward(self, x: Tensor) -> Tensor: - cumsum = x.cumsum(dim=1) - counts = torch.arange(1, x.size(1) + 1, device=x.device, dtype=x.dtype).view(1, -1, 1) - smeared = cumsum / counts - gate = torch.tanh(self.gate.to(dtype=x.dtype)) - return x + gate * (smeared - x) - - -class CausalConvRefiner(nn.Module): - "Causal Conv1d that refines hidden states using local n-gram context." - def __init__(self, dim: int, kernel_size: int = 3): - super().__init__() - self.kernel_size = kernel_size - self.conv = nn.Conv1d(dim, dim, kernel_size, padding=0, bias=False) - self.gate = nn.Parameter(torch.zeros(1, dtype=torch.float32)) - - def forward(self, x: Tensor) -> Tensor: - h = x.permute(0, 2, 1) # [B, D, S] - h = F.pad(h, (self.kernel_size - 1, 0)) # causal pad - h = self.conv(h) - h = h.permute(0, 2, 1) # [B, S, D] - return x + torch.tanh(self.gate.to(dtype=x.dtype)) * F.rms_norm(h, (h.size(-1),)) - -class Block(nn.Module): - def __init__(self, dim: int, num_heads: int, num_kv_heads: int, mlp_mult: int, - rope_base: float, qk_gain_init: float, group_size: int=64, - activation: str="swiglu", attn_proj_type: str="standard", - tversky_num_features: int=16, tversky_feature_pools: int=0, no_cache: bool=False, - smear: bool=False, rope_type: str="rope", yarn_max_len: int=4096, - train_seq_len: int=1024, tversky_membership: str="sigmoid", - diff_attn: bool=False, mlp_groups: int=0): - super().__init__() - self.attn_norm = RMSNorm() - self.mlp_norm = RMSNorm() - self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init, - group_size, attn_proj_type, tversky_num_features, - tversky_feature_pools, no_cache, rope_type, yarn_max_len, - train_seq_len, tversky_membership, diff_attn) - self.mlp = MLP(dim, mlp_mult, group_size, activation, mlp_groups) - self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) - self.smear = SmearModule(dim) if smear else None - - def forward(self, x: Tensor, x0: Tensor) -> Tensor: - mix = self.resid_mix.to(dtype=x.dtype) - x = mix[0] * x + mix[1] * x0 - n = self.attn_norm(x) - x = x + self.attn_scale.to(dtype=x.dtype) * self.attn(n) - x = x + self.mlp_scale.to(dtype=x.dtype) * self.mlp(self.mlp_norm(x)) - if self.smear is not None: - x = self.smear(x) - return x - -class GPT(nn.Module): - def __init__(self, vocab_size, num_layers, model_dim, num_heads, num_kv_heads, mlp_mult, - tie_embeddings, tied_embed_init_std, logit_softcap, rope_base, qk_gain_init, - group_size: int = 64, activation: str = "swiglu", mtp_heads_count: int = 0, - embed_dim: int = 0, attn_proj_type: str = "standard", logit_head_type: str = "standard", - tversky_num_features: int = 16, tversky_feature_pools: int = 0, - training_depth_recurrence: int=1, fp_storage=False, bigram_hash: bool=False, - softcap_type: str="poly", no_cache: bool=False, - smear: bool=False, rope_type: str="rope", yarn_max_len: int=4096, - train_seq_len: int=1024, tversky_membership: str="sigmoid", - diff_attn=False, mlp_groups=0, refiner=False, refiner_kernel=3): - super().__init__() - self.training_depth_recurrence = training_depth_recurrence - self.fp_storage = fp_storage - self.tie_embeddings = tie_embeddings - self.logit_softcap = logit_softcap - self.softcap_type = softcap_type - self.embed_dim = embed_dim if embed_dim > 0 else model_dim - self.tok_emb = QATEmbedding(vocab_size, self.embed_dim, fp_storage=fp_storage) - self.bigram_emb = QATEmbedding(vocab_size, self.embed_dim, fp_storage=fp_storage) if bigram_hash else None - if self.bigram_emb is not None: - nn.init.zeros_(self.bigram_emb.weight) - self.lm_head_correction = nn.Parameter( - torch.zeros(vocab_size, self.embed_dim)) if tie_embeddings == 2 else None - self.embed_proj = QATLinear(self.embed_dim, model_dim, bias=False, fp_storage=fp_storage) if self.embed_dim != model_dim else None - self.embed_proj_rev = QATLinear(model_dim, self.embed_dim, bias=False, fp_storage=fp_storage) if ( - self.embed_dim != model_dim and logit_head_type != "tversky") else None - self.num_encoder_layers = num_layers // 2 - self.num_decoder_layers = num_layers - self.num_encoder_layers - self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) - self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) - - # Shared Tversky feature pools (if enabled and num_features > 0) - if attn_proj_type == "tversky" and tversky_feature_pools > 0 and tversky_num_features > 0: - self.tversky_feature_pools_list = nn.ParameterList([ - nn.Parameter(torch.empty(tversky_num_features, model_dim).uniform_(-0.02, 0.02)) - for _ in range(tversky_feature_pools) - ]) - else: - self.tversky_feature_pools_list = None - - self.blocks = nn.ModuleList([ - Block(model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init, - group_size, activation, attn_proj_type, tversky_num_features, tversky_feature_pools, - no_cache, smear, rope_type, yarn_max_len, train_seq_len, tversky_membership, - diff_attn, mlp_groups) - for _ in range(num_layers) - ]) - - # Inject shared feature pool references into attention layers - if self.tversky_feature_pools_list is not None: - for i, block in enumerate(self.blocks): - pool_idx = (i * tversky_feature_pools) // num_layers - block.attn.shared_features = self.tversky_feature_pools_list[pool_idx] - - self.final_norm = RMSNorm() - self.refiner = CausalConvRefiner(model_dim, kernel_size=refiner_kernel) if refiner else None - self.mtp_heads = nn.ModuleList([ - nn.Linear(model_dim, vocab_size, bias=False) for _ in range(mtp_heads_count) - ]) - for h in self.mtp_heads: - nn.init.zeros_(h.weight) - self.logit_head_type = logit_head_type - if logit_head_type == "tversky" and tversky_num_features == 0 and vocab_size > 1024: - raise ValueError( - f"Tversky logit head with no-features mode creates O(V^2) = {vocab_size}x{vocab_size} " - f"matrix per forward pass. Use tversky_num_features > 0 or a smaller vocab." - ) - self.tversky_head = TverskyProjection( - model_dim, vocab_size, num_features=tversky_num_features, - membership=tversky_membership, - ) if logit_head_type == "tversky" else None - self.lm_head = QATLinear(model_dim, vocab_size, bias=False, fp_storage=fp_storage) - self.lm_head._zero_init = True - if self.lm_head is not None and (tie_embeddings or logit_head_type == "tversky"): - self.lm_head.weight.requires_grad_(False) - - self.vocab_bias = nn.Parameter(torch.zeros(vocab_size, dtype=torch.float32)) - self._init_weights(tied_embed_init_std) - - def _init_weights(self, tied_embed_init_std: float) -> None: - if self.tie_embeddings: - nn.init.normal_(self.tok_emb.weight, mean=0.0, std=tied_embed_init_std) - for module in self.modules(): - if isinstance(module, TernaryLinear) and not getattr(module, "_zero_init", False): - nn.init.normal_(module.weight, mean=0.0, std=0.02) - elif isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): - nn.init.zeros_(module.weight) - - def _compute_logits(self, x: Tensor) -> Tensor: - if self.tversky_head is not None: - logits_raw = self.tversky_head(x) - elif self.tie_embeddings: - if self.embed_proj_rev is not None: - proj = self.embed_proj_rev(x) - else: - proj = x - weight = self.tok_emb.weight - if self.lm_head_correction is not None: - weight = weight + self.lm_head_correction - logits_raw = F.linear(proj, weight.to(x.dtype)) - else: - logits_raw = self.lm_head(x) - return logits_raw + self.vocab_bias.to(x.dtype) - - def _softcap(self, logits: Tensor) -> Tensor: - s = self.logit_softcap - if self.softcap_type == "tanh": - return s * torch.tanh(logits / s) - x_sc = torch.clamp(logits / s, -2.0, 2.0) - x2 = x_sc * x_sc - return s * torch.clamp(x_sc * (1.0 - x2 / 3.0 + x2 * x2 / 15.0), -1.0, 1.0) - - def forward(self, input_ids: Tensor, target_ids: Tensor, reduction: str = "mean", temperature: float = 1.0) -> Tensor: - x = self.tok_emb(input_ids).float() - if self.bigram_emb is not None: - prev = F.pad(input_ids[:, :-1], (1, 0), value=0) - x = x + self.bigram_emb(prev).float() - if self.embed_proj is not None: - x = self.embed_proj(x) - x = F.rms_norm(x, (x.size(-1),)) - x0 = x - - # U-Net style encoder/decoder with skip connections - skips = [] - for i in range(self.num_encoder_layers): - for _ in range(max(1, self.training_depth_recurrence)): - x = self.blocks[i](x, x0) - skips.append(x) - for i in range(self.num_decoder_layers): - bi = self.num_encoder_layers + i - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype) * skips.pop() - for _ in range(max(1, self.training_depth_recurrence)): - x = self.blocks[bi](x, x0) - - x_normed = self.final_norm(x) - if self.refiner is not None: - x_normed = self.refiner(x_normed) - - # Standard training/eval path - x_flat = x_normed.reshape(-1, x_normed.size(-1)) - targets = target_ids.reshape(-1) - logits = self._softcap(self._compute_logits(x_flat)) - - if reduction == "none": - return F.cross_entropy(logits.float(), targets, reduction="none").reshape(input_ids.shape) - - # Fused CE + Z-loss: single logsumexp computation - logits_f = logits.float() - lse = torch.logsumexp(logits_f, dim=-1) - target_logits = logits_f.gather(1, targets.unsqueeze(1)).squeeze(1) - main_loss = (lse - target_logits).mean() + 1e-4 * (lse ** 2).mean() - - # Multi-token prediction auxiliary loss (training only) - if self.training and len(self.mtp_heads) > 0: - mtp_loss = torch.zeros((), device=main_loss.device) - for k, head in enumerate(self.mtp_heads): - shift = k + 2 - if target_ids.shape[1] > shift: - mtp_tgt = target_ids[:, shift:].reshape(-1) - mtp_in = x_normed[:, :target_ids.shape[1] - shift, :].reshape(-1, x_normed.shape[-1]) - mtp_loss = mtp_loss + F.cross_entropy(head(mtp_in).float(), mtp_tgt, reduction="mean") - main_loss = main_loss + 0.1 * mtp_loss / len(self.mtp_heads) - return main_loss - -# --------------------------------------------------------------------------- -# Validation -# --------------------------------------------------------------------------- - -def build_luts(sp, vocab_size: int, device: torch.device): - sp_vocab_size = int(sp.vocab_size()) - table_size = max(sp_vocab_size, vocab_size) - base_bytes_np = np.zeros((table_size,), dtype=np.int16) - has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) - is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) - for token_id in range(sp_vocab_size): - if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): - continue - is_boundary_token_np[token_id] = False - if sp.is_byte(token_id): - base_bytes_np[token_id] = 1 - continue - piece = sp.id_to_piece(token_id) - if piece.startswith("\u2581"): - has_leading_space_np[token_id] = True - piece = piece[1:] - base_bytes_np[token_id] = len(piece.encode("utf-8")) - return ( - torch.tensor(base_bytes_np, dtype=torch.int16, device=device), - torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), - torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), - ) - -def ld_val(pattern, seq_len, max_tok=int(os.environ.get("VAL_MAX_TOKENS", 500000))): - files = sorted(glob.glob(pattern)) - assert files, f"No files: {pattern}" - tok = torch.cat([ld_shard(Path(p)) for p in files]).contiguous() - if max_tok > 0: tok = tok[:max_tok + 1] - u = ((tok.numel() - 1) // seq_len) * seq_len - return tok[:u + 1] - -def eval_val(args, model, rank, world_size, device, grad_accum_steps, val_tokens, - base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, temperature: float = 1.0): - local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) - local_batch_seqs = max(1, local_batch_tokens // args.train_seq_len) - total_seqs = (val_tokens.numel() - 1) // args.train_seq_len - seq_start = (total_seqs * rank) // world_size - seq_end = (total_seqs * (rank + 1)) // world_size - loss_sum = torch.zeros((), device=device, dtype=torch.float64) - token_count = torch.zeros((), device=device, dtype=torch.float64) - byte_count = torch.zeros((), device=device, dtype=torch.float64) - model.eval() - with torch.inference_mode(): - for batch_start in range(seq_start, seq_end, local_batch_seqs): - batch_end = min(batch_start + local_batch_seqs, seq_end) - raw_start = batch_start * args.train_seq_len - raw_end = batch_end * args.train_seq_len + 1 - local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64) - x, y = local[:-1].reshape(-1, args.train_seq_len), local[1:].reshape(-1, args.train_seq_len) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16): - batch_loss = model(x, y, temperature=temperature).detach() - n = float(y.numel()) - loss_sum += batch_loss.to(torch.float64) * n - token_count += n - prev_ids, tgt_ids = x.reshape(-1), y.reshape(-1) - tok_bytes = base_bytes_lut[tgt_ids].to(torch.int16) - tok_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(torch.int16) - byte_count += tok_bytes.to(torch.float64).sum() - if dist.is_available() and dist.is_initialized(): - for t in (loss_sum, token_count, byte_count): - dist.all_reduce(t, op=dist.ReduceOp.SUM) - val_loss = loss_sum / token_count - bpb = (val_loss.item() / math.log(2.0)) * (token_count.item() / byte_count.item()) - model.train() - return float(val_loss.item()), float(bpb) - -def eval_val_sliding(args, model, rank, world_size, device, grad_accum_steps, val_tokens, - base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - stride: int = 64, temperature: float = 1.0): - seq_len = args.train_seq_len - batch_size = args.sliding_batch_size - total_tokens = val_tokens.numel() - 1 - loss_sum = torch.zeros((), device=device, dtype=torch.float64) - token_count = torch.zeros((), device=device, dtype=torch.float64) - byte_count = torch.zeros((), device=device, dtype=torch.float64) - all_starts = list(range(0, total_tokens - seq_len, stride)) - my_starts = all_starts[rank::world_size] - - model.eval() - with torch.inference_mode(): - for i in range(0, len(my_starts), batch_size): - batch_starts = my_starts[i:i + batch_size] - - starts_t = torch.tensor(batch_starts, dtype=torch.int64) - offsets = torch.arange(seq_len + 1, dtype=torch.int64) - indices = starts_t.unsqueeze(1) + offsets.unsqueeze(0) - - local_batch = val_tokens[indices].to(device=device, dtype=torch.int64, non_blocking=True) - x = local_batch[:, :-1] - y = local_batch[:, 1:] - - with torch.autocast(device_type="cuda", dtype=torch.bfloat16): - per_token_loss = model(x, y, reduction="none", temperature=temperature).detach() - - for b, start in enumerate(batch_starts): - score_from = 0 if start == 0 else seq_len - stride - scored = per_token_loss[b, score_from:] - sx, sy = x[b, score_from:], y[b, score_from:] - - loss_sum += scored.to(torch.float64).sum() - token_count += scored.numel() - - tok_bytes = base_bytes_lut[sy].to(torch.int16) - tok_bytes += (has_leading_space_lut[sy] & ~is_boundary_token_lut[sx]).to(torch.int16) - byte_count += tok_bytes.to(torch.float64).sum() - - if dist.is_available() and dist.is_initialized(): - for t in (loss_sum, token_count, byte_count): - dist.all_reduce(t, op=dist.ReduceOp.SUM) - - val_loss = loss_sum / token_count - bpb = (val_loss.item() / math.log(2.0)) * (token_count.item() / byte_count.item()) - model.train() - return float(val_loss.item()), float(bpb) - -# --------------------------------------------------------------------------- -# Temperature scaling -# --------------------------------------------------------------------------- -def find_temp(args, base_model, rank, world_size, device, grad_accum_steps, - calibration_tokens, base_bytes_lut, has_leading_space_lut, - is_boundary_token_lut): - best_t, best_loss = 1.0, float("inf") - for t in [0.90, 0.95, 1.00, 1.05, 1.10]: - loss, _ = eval_val(args, base_model, rank, world_size, device, grad_accum_steps, - calibration_tokens, base_bytes_lut, has_leading_space_lut, - is_boundary_token_lut, temperature=t) - if loss < best_loss: - best_loss = loss - best_t = t - return best_t - -# --------------------------------------------------------------------------- -# Training -# --------------------------------------------------------------------------- -def main() -> None: - args = Hyperparameters() - code = Path(__file__).read_text(encoding="utf-8") - - if args.matrix_optimizer != "adamw": - global ns_orth - ns_orth = torch.compile(ns_orth) - - distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ - rank = int(os.environ.get("RANK", "0")) - world_size = int(os.environ.get("WORLD_SIZE", "1")) - local_rank = int(os.environ.get("LOCAL_RANK", "0")) - grad_accum_steps = max(1, 8 // world_size) - grad_scale = 1.0 / grad_accum_steps - - if not torch.cuda.is_available(): - raise RuntimeError("CUDA is required") - device = torch.device("cuda", local_rank) - torch.cuda.set_device(device) - if distributed: - dist.init_process_group(backend="nccl", device_id=device) - dist.barrier() - master_process = rank == 0 - torch.backends.cuda.matmul.allow_tf32 = True - torch.backends.cudnn.allow_tf32 = True - - os.makedirs("logs/cuda/", exist_ok=True) - logfile = f"logs/cuda/{args.run_id}.txt" if master_process else None - if master_process: - print(logfile) - - def log0(msg: str, console: bool = True) -> None: - if not master_process: - return - if console: - print(msg) - if logfile: - with open(logfile, "a", encoding="utf-8") as f: - print(msg, file=f) - - log0(code, console=False) - log0("=" * 100, console=False) - - log0(f"Python {sys.version}", console=False) - log0(f"PyTorch {torch.__version__}", console=False) - - random.seed(args.seed) - np.random.seed(args.seed) - torch.manual_seed(args.seed) - torch.cuda.manual_seed_all(args.seed) - - sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) - val_tokens = ld_val(args.val_files, args.train_seq_len) - base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_luts( - sp, args.vocab_size, device) - - # --- Model --- - base_model = GPT( - vocab_size=args.vocab_size, num_layers=args.num_layers, model_dim=args.model_dim, - num_heads=args.num_heads, num_kv_heads=args.num_kv_heads, mlp_mult=args.mlp_mult, - tie_embeddings=args.tie_embeddings, tied_embed_init_std=args.tied_embed_init_std, - logit_softcap=args.logit_softcap, rope_base=args.rope_base, qk_gain_init=args.qk_gain_init, - group_size=args.bitnet_group_size, activation=args.activation_type, mtp_heads_count=args.mtp_heads_count, - embed_dim=args.embed_dim, attn_proj_type=args.attn_proj_type, logit_head_type=args.logit_head_type, - tversky_num_features=args.tversky_num_features, tversky_feature_pools=args.tversky_feature_pools, - training_depth_recurrence=args.training_depth_recurrence, fp_storage=args.fp_storage, - bigram_hash=args.bigram_hash, softcap_type=args.softcap_type, no_cache=(args.compile_mode == "reduce-overhead"), - smear=args.smear, rope_type=args.rope_type, yarn_max_len=args.yarn_max_len, train_seq_len=args.train_seq_len, - tversky_membership=args.tversky_membership, diff_attn=args.diff_attn, - refiner=args.refiner, refiner_kernel=args.refiner_kernel, mlp_groups=args.mlp_groups, - ).to(device).bfloat16() - - for module in base_model.modules(): - if isinstance(module, nn.Linear): - module.float() - restore_low_dim_params_to_fp32(base_model) - if base_model.lm_head is not None and (args.tie_embeddings or args.logit_head_type == "tversky"): - base_model.lm_head.weight.requires_grad_(False) - - torch._dynamo.config.optimize_ddp = False - - compiled_model = torch.compile(base_model, mode=args.compile_mode if args.compile_mode != "default" else None) - use_find_unused = args.untie_at_fraction > 0 or args.mtp_heads_count > 0 or not args.tie_embeddings - model = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False, - find_unused_parameters=use_find_unused, - static_graph=not use_find_unused, - gradient_as_bucket_view=True) if distributed else compiled_model - - # --- Optimizers --- - _excl = {"tok_emb.weight", "lm_head.weight", "lm_head_correction"} - all_other_params = [(n, p) for n, p in base_model.named_parameters() - if not any(eh in n for eh in _excl)] - matrix_params = [p for n, p in all_other_params - if p.ndim == 2 and not any(pat in n for pat in CTP)] - scalar_params = [p for n, p in all_other_params - if p.ndim < 2 or any(pat in n for pat in CTP)] - - token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr - opt_tok = torch.optim.Adam( - [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], - betas=(args.beta1, args.beta2), eps=args.adam_eps, fused=True) - if args.matrix_optimizer == "adamw": - opt_muon = torch.optim.AdamW( - [{"params": matrix_params, "lr": args.adam_lr, "base_lr": args.adam_lr}], - betas=(args.beta1, args.beta2), eps=args.adam_eps, weight_decay=args.adam_wd, fused=True) - else: - opt_muon = Muon(matrix_params, lr=args.matrix_lr, momentum=args.muon_momentum, - backend_steps=args.muon_backend_steps, wd=args.muon_wd) - for g in opt_muon.param_groups: - g["base_lr"] = args.matrix_lr - opt_scalar = torch.optim.Adam( - [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], - betas=(args.beta1, args.beta2), eps=args.adam_eps, fused=True) - opt_head = torch.optim.Adam( - [{"params": [base_model.lm_head.weight], "lr": 0.0, "base_lr": 0.0}], - betas=(args.beta1, args.beta2), eps=args.adam_eps, fused=True) - - optimizers = [opt for opt in [opt_tok, opt_muon, opt_scalar, opt_head] if opt is not None] - - if base_model.lm_head_correction is not None: - opt_corr = torch.optim.Adam( - [{"params": [base_model.lm_head_correction], - "lr": args.corr_weight_lr, "base_lr": args.corr_weight_lr}], - betas=(args.beta1, args.beta2), eps=args.adam_eps, fused=True) - optimizers.append(opt_corr) - - # --- Log all hyperparameters --- - log0("--- Hyperparameters ---", console=False) - log0(" ".join(f"{a}={getattr(args,a)}" for a in sorted(dir(args)) if not a.startswith("_") and a not in ("train_files","val_files") and not callable(getattr(args,a))), console=False) - n_params = sum(p.numel() for p in base_model.parameters()) - log0(f"params:{n_params} L:{args.num_layers} d:{args.model_dim} h:{args.num_heads} kv:{args.num_kv_heads} ws:{world_size} ga:{grad_accum_steps} s:{args.seed}") - - # --- Data loader & helpers --- - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - def zero_grad_all(): - for opt in optimizers: - opt.zero_grad(set_to_none=True) - - max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None - - def lr_mul(step: int, elapsed_ms: float): - if args.warmdown_fraction <= 0: - return 1.0 - if max_wallclock_ms is None: - warmdown_start = int(args.iterations * (1.0 - args.warmdown_fraction)) - return max((args.iterations - step) / max(args.iterations * args.warmdown_fraction, 1), 0.0) if step >= warmdown_start else 1.0 - warmdown_ms = max_wallclock_ms * args.warmdown_fraction - remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) - return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 - - _seq_switched = False - _batch_switched = False - active_seq_len = args.seq_len_start if args.seq_len_start > 0 else args.train_seq_len - active_batch_tokens = args.batch_tokens_start if args.batch_tokens_start > 0 else args.train_batch_tokens - - # --- Compiler warmup --- - if args.warmup_steps > 0: - _ms = {n: t.detach().cpu().clone() for n, t in base_model.state_dict().items()} - _os = [copy.deepcopy(o.state_dict()) for o in optimizers] - model.train() - for ws in range(args.warmup_steps): - zero_grad_all() - for mi in range(grad_accum_steps): - if distributed: model.require_backward_grad_sync = mi == grad_accum_steps - 1 - x, y = train_loader.next_batch(active_batch_tokens, active_seq_len, grad_accum_steps) - torch.compiler.cudagraph_mark_step_begin() - with torch.autocast(device_type="cuda", dtype=torch.bfloat16): loss = model(x, y) - (loss * grad_scale).backward() - for o in optimizers: o.step() - zero_grad_all() - log0(f"warmup:{ws+1}/{args.warmup_steps}") - base_model.load_state_dict(_ms, strict=True) - for o, s in zip(optimizers, _os): o.load_state_dict(s) - zero_grad_all() - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - # --- Main training loop --- - training_time_ms = 0.0 - stop_after_step: int | None = None - _untied = False - train_loss = torch.zeros((), device=device) - torch.cuda.synchronize() - t0 = time.perf_counter() - step = 0 - - while True: - last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) - - if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): - torch.cuda.synchronize() - training_time_ms += 1000.0 * (time.perf_counter() - t0) - val_loss, val_bpb = eval_val(args, model, rank, world_size, device, grad_accum_steps, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut) - tstats = tern_stats(base_model, group_size=args.bitnet_group_size) - log0(f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " - f"train_time:{training_time_ms:.0f}ms zero_frac:{tstats['zero_frac']:.3f}") - torch.cuda.synchronize() - t0 = time.perf_counter() - - if last_step: - if stop_after_step is not None and step < args.iterations: - log0(f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms step:{step}/{args.iterations}") - break - - elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - scale = lr_mul(step, elapsed_ms) - - # Sequence length schedule - if args.seq_len_start > 0 and not _seq_switched: - if max_wallclock_ms is not None: - should_switch_seq = elapsed_ms >= args.seq_schedule_fraction * max_wallclock_ms - else: - should_switch_seq = step >= int(args.iterations * args.seq_schedule_fraction) - if should_switch_seq: - active_seq_len = args.train_seq_len - _seq_switched = True - torch._dynamo.reset() - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - log0(f"step:{step} seq_len_switch:{args.seq_len_start}->{active_seq_len}") - - # Batch size schedule - if args.batch_tokens_start > 0 and not _batch_switched: - if max_wallclock_ms is not None: - should_switch_batch = elapsed_ms >= args.batch_schedule_fraction * max_wallclock_ms - else: - should_switch_batch = step >= int(args.iterations * args.batch_schedule_fraction) - if should_switch_batch: - active_batch_tokens = args.train_batch_tokens - _batch_switched = True - log0(f"step:{step} batch_switch:{args.batch_tokens_start}->{active_batch_tokens}") - - zero_grad_all() - train_loss.zero_() - - for micro in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro == grad_accum_steps - 1 - x, y = train_loader.next_batch(active_batch_tokens, active_seq_len, grad_accum_steps) - torch.compiler.cudagraph_mark_step_begin() - with torch.autocast(device_type="cuda", dtype=torch.bfloat16): - loss = model(x, y) - train_loss.add_(loss.detach()) - (loss * grad_scale).backward() - train_loss /= grad_accum_steps - - # Untie lm_head at configured fraction of training - if args.untie_at_fraction > 0: - if max_wallclock_ms is not None: - should_untie = not _untied and elapsed_ms >= args.untie_at_fraction * max_wallclock_ms - else: - should_untie = not _untied and step >= int(args.iterations * args.untie_at_fraction) - if should_untie and base_model.tie_embeddings: - with torch.no_grad(): - base_weight = base_model.tok_emb.weight.float() - if base_model.lm_head_correction is not None: - base_weight = base_weight + base_model.lm_head_correction.float() - if base_model.embed_proj_rev is not None: - full_weight = base_weight @ base_model.embed_proj_rev.weight.float() - else: - full_weight = base_weight - base_model.lm_head.weight.copy_(full_weight) - base_model.tie_embeddings = False - base_model.lm_head.weight.requires_grad_(True) - for g in opt_head.param_groups: - g["lr"] = g["base_lr"] = args.head_lr - _untied = True - torch._dynamo.reset() - log0(f"step:{step} untied lm_head (head_lr={args.head_lr})") - - # Muon momentum warmup - if args.matrix_optimizer != "adam": - frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 - for g in opt_muon.param_groups: - g["momentum"] = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum - - # LR scheduling - for opt in optimizers: - for g in opt.param_groups: - g["lr"] = g["base_lr"] * scale - opt.step() - zero_grad_all() - step += 1 - approx_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - - if args.train_log_every > 0 and step % args.train_log_every == 0: - log0(f"step:{step}/{args.iterations} loss:{train_loss.item():.4f} t:{approx_ms:.0f}ms avg:{approx_ms/step:.1f}ms") - if args.churn_log_every > 0 and step % args.churn_log_every == 0: - log0(f"step:{step} churn:{churn_fn(base_model, args.bitnet_group_size):.4f} zero:{tern_stats(base_model, args.bitnet_group_size)['zero_frac']:.3f}") - - # Wallclock cap sync - if stop_after_step is None and max_wallclock_ms is not None and step % 10 == 0: - reached_cap = approx_ms >= max_wallclock_ms - if distributed: - cap_t = torch.tensor(int(reached_cap), device=device) - dist.all_reduce(cap_t, op=dist.ReduceOp.MAX) - reached_cap = bool(cap_t.item()) - if reached_cap: - stop_after_step = step - - # --- Serialization --- - if master_process: - sd = base_model.state_dict() - if base_model.tie_embeddings or args.logit_head_type == "tversky": - sd.pop("lm_head.weight", None) - - # Compute ternary overrides for no-features Tversky prototypes - ternary_overrides = set() - for n, m in base_model.named_modules(): - if isinstance(m, TverskyProjection) and m.no_features_mode: - ternary_overrides.add(n + ".prototypes") - ternary_overrides = ternary_overrides or None - - # Two methods: Standard Base-3 vs Bitmask Mapping - methods = {} - for method in ("standard", "bitmask"): - q_obj, stats = q_sd(sd, group_size=args.bitnet_group_size, fp_storage=args.fp_storage, ternary_method=method, ternary_override_names=ternary_overrides) - buf = io.BytesIO() - torch.save(q_obj, buf) - methods[method] = {"blob": lzma.compress(buf.getvalue(), preset=9), "stats": stats} - best = min(methods, key=lambda m: len(methods[m]["blob"])) - final_blob, q_stats = methods[best]["blob"], methods[best]["stats"] - with open("final_model.ternary.ptz", "wb") as f: - f.write(final_blob) - - artifact_bytes = len(final_blob) - code_bytes = len(code.encode("utf-8")) - - total = artifact_bytes + code_bytes - log0(f"artifact:{artifact_bytes/1e6:.2f}MB ternary:{q_stats['ternary_params']}({q_stats['ternary_bytes']}B) fp:{q_stats['fp_params']}({q_stats['fp_bytes']}B) code:{code_bytes}") - log0(f"budget:{total}/{16000000} ({total/1e6:.2f}/{16.00:.2f}MB) {'FITS' if total <= 16000000 else 'OVER'}") - - if args.eval_depth_recurrence > 0: - base_model.training_depth_recurrence = args.eval_depth_recurrence - log0(f"eval_depth_recurrence:{args.eval_depth_recurrence}") - - # --- All ranks load roundtrip weights and evaluate --- - if distributed: - dist.barrier() - - with open("final_model.ternary.ptz", "rb") as f: - loaded = torch.load(io.BytesIO(lzma.decompress(f.read())), map_location="cpu", weights_only=False) - base_model.load_state_dict(deq_sd(loaded), strict=False) - torch._dynamo.reset() - - q_val_loss, q_val_bpb = eval_val(args, model, rank, world_size, device, grad_accum_steps, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut) - log0(f"final_ternary_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f}") - - opt_temp = 1.0 - if args.temp_scaling: - torch.cuda.synchronize() - t_temp = time.perf_counter() - calibration_tokens = train_loader.stream.take(65536).to(device) - opt_temp = find_temp(args, base_model, rank, world_size, device, grad_accum_steps, - calibration_tokens, base_bytes_lut, has_leading_space_lut, - is_boundary_token_lut) - torch.cuda.synchronize() - temp_time_ms = 1000.0 * (time.perf_counter() - t_temp) - log0(f"temp_scaling optimal_T:{opt_temp:.2f} eval_time:{temp_time_ms:.0f}ms") - - if args.sliding_eval: - torch.cuda.synchronize() - t_sliding = time.perf_counter() - sw_loss, sw_bpb = eval_val_sliding(args, base_model, rank, world_size, device, grad_accum_steps, - val_tokens, base_bytes_lut, has_leading_space_lut, - is_boundary_token_lut, stride=args.sliding_eval_stride, - temperature=opt_temp) - torch.cuda.synchronize() - sliding_time_ms = 1000.0 * (time.perf_counter() - t_sliding) - log0(f"final_sliding val_loss:{sw_loss:.4f} val_bpb:{sw_bpb:.4f} " - f"(stride={args.sliding_eval_stride}, T={opt_temp:.2f}) eval_time:{sliding_time_ms:.0f}ms") - - if distributed: - dist.destroy_process_group() - -if __name__ == "__main__": - main() -==================================================================================================== -Python 3.13.12 | packaged by Anaconda, Inc. | (main, Feb 24 2026, 16:13:31) [GCC 14.3.0] -PyTorch 2.10.0+cu128 ---- Hyperparameters --- -activation_type=relu2 adam_eps=1e-08 adam_lr=0.05 adam_wd=0.05 attn_proj_type=standard batch_schedule_fraction=0.33 batch_tokens_start=0 beta1=0.9 beta2=0.95 bigram_hash=False bitnet_group_size=128 churn_log_every=0 compile_mode=default corr_weight_lr=0.02 data_path=./data/datasets/fineweb10B_sp8192 diff_attn=False embed_dim=254 embed_lr=0.6 eval_depth_recurrence=0 fp_storage=True grad_clip_norm=0.0 head_lr=0.02 iterations=10000 logit_head_type=standard logit_softcap=10.0 matrix_lr=0.04 matrix_optimizer=muon max_wallclock_seconds=599.0 mlp_groups=0 mlp_mult=4 model_dim=768 mtp_heads_count=0 muon_backend_steps=3 muon_momentum=0.95 muon_momentum_warmup_start=0.85 muon_momentum_warmup_steps=500 muon_wd=0.0 num_heads=8 num_kv_heads=4 num_layers=10 qk_gain_init=2.25 refiner=False refiner_kernel=3 rope_base=5000.0 rope_type=yarn run_id=pushing_run_ternary_2 scalar_lr=0.02 seed=42 seq_len_start=0 seq_schedule_fraction=0.0 sliding_batch_size=256 sliding_eval=True sliding_eval_stride=16 smear=False softcap_type=poly temp_scaling=True tie_embeddings=1 tied_embed_init_std=0.005 tied_embed_lr=0.02 tokenizer_path=./data/tokenizers/fineweb_8192_bpe.model train_batch_tokens=524288 train_log_every=1000 train_seq_len=1024 training_depth_recurrence=0 tversky_feature_pools=1 tversky_membership=sigmoid tversky_num_features=128 untie_at_fraction=0.0 val_batch_size=524288 val_loss_every=0 vocab_size=8192 warmdown_fraction=0.2 warmup_steps=5 yarn_max_len=2048 -params:73685840 L:10 d:768 h:8 kv:4 ws:8 ga:1 s:42 -warmup:1/5 -warmup:2/5 -warmup:3/5 -warmup:4/5 -warmup:5/5 -step:1000/10000 loss:3.3120 t:91478ms avg:91.5ms -step:2000/10000 loss:3.2883 t:183345ms avg:91.7ms -step:3000/10000 loss:3.1489 t:275315ms avg:91.8ms -step:4000/10000 loss:3.3138 t:367178ms avg:91.8ms -step:5000/10000 loss:3.1796 t:459028ms avg:91.8ms -step:6000/10000 loss:3.0211 t:550871ms avg:91.8ms -step:6530/10000 val_loss:3.0517 val_bpb:1.1816 train_time:599652ms zero_frac:0.336 -stopping_early: wallclock_cap train_time:599652ms step:6530/10000 -artifact:15.92MB ternary:64880640(14239486B) fp:2513744(2537376B) code:70853 -budget:15993853/16000000 (15.99/16.00MB) FITS -final_ternary_roundtrip val_loss:3.0570 val_bpb:1.1837 -temp_scaling optimal_T:0.90 eval_time:146ms -final_sliding val_loss:2.9869 val_bpb:1.1565 (stride=16, T=0.90) eval_time:429311ms diff --git a/records/track_10min_16mb/2026-03-24_74M_Ternary_UNet_FP8_10L_8192BPE_YaRN_NeoMuon/ternary_log_7.txt b/records/track_10min_16mb/2026-03-24_74M_Ternary_UNet_FP8_10L_8192BPE_YaRN_NeoMuon/ternary_log_7.txt deleted file mode 100644 index 643c4111d7..0000000000 --- a/records/track_10min_16mb/2026-03-24_74M_Ternary_UNet_FP8_10L_8192BPE_YaRN_NeoMuon/ternary_log_7.txt +++ /dev/null @@ -1,1460 +0,0 @@ -"Ternary training script for OpenAI's Parameter Golf Challenge. Ciprian-Florin Ifrim - 24 March 2026" - -import copy -import glob -import io -import math -import os -import random -import sys -import time -import lzma -from pathlib import Path -import numpy as np -import sentencepiece as spm -import torch -import torch.distributed as dist -import torch.nn.functional as F -from torch import Tensor, nn -from torch.nn.parallel import DistributedDataParallel as DDP -from flash_attn_interface import flash_attn_func - -# --------------------------------------------------------------------------- -# Hyperparameters (all configurable via environment variables) -# --------------------------------------------------------------------------- -def _e(k, d, t=str): - v = os.environ.get(k, str(d)) - if t == bool: return bool(int(v)) - return t(v) - -class Hyperparameters: - data_path = _e("DATA_PATH", "./data/datasets/fineweb10B_sp1024") - train_files = os.path.join(data_path, "fineweb_train_*.bin") - val_files = os.path.join(data_path, "fineweb_val_*.bin") - tokenizer_path = _e("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") - run_id = os.environ.get("RUN_ID", f"run_{int(time.time())}") - seed = _e("SEED", 1337, int) - compile_mode = _e("COMPILE_MODE", "default") - val_batch_size = _e("VAL_BATCH_SIZE", 524288, int) - val_loss_every = _e("VAL_LOSS_EVERY", 500, int) - train_log_every = _e("TRAIN_LOG_EVERY", 10, int) - iterations = _e("ITERATIONS", 2000, int) - warmdown_fraction = _e("WARMDOWN_FRACTION", 0.2, float) - warmup_steps = _e("WARMUP_STEPS", 20, int) - train_batch_tokens = _e("TRAIN_BATCH_TOKENS", 524288, int) - train_seq_len = _e("TRAIN_SEQ_LEN", 1024, int) - max_wallclock_seconds = _e("MAX_WALLCLOCK_SECONDS", 0.0, float) - vocab_size = _e("VOCAB_SIZE", 1024, int) - num_layers = _e("NUM_LAYERS", 16, int) - num_kv_heads = _e("NUM_KV_HEADS", 4, int) - model_dim = _e("MODEL_DIM", 512, int) - num_heads = _e("NUM_HEADS", 8, int) - mlp_mult = _e("MLP_MULT", 2, int) - tie_embeddings = _e("TIE_EMBEDDINGS", 1, int) - rope_base = _e("ROPE_BASE", 10000.0, float) - rope_type = _e("ROPE_TYPE", "rope") - yarn_max_len = _e("YARN_MAX_LEN", 4096, int) - logit_softcap = _e("LOGIT_SOFTCAP", 30.0, float) - softcap_type = _e("SOFTCAP_TYPE", "poly") - tied_embed_init_std = _e("TIED_EMBED_INIT_STD", 0.005, float) - qk_gain_init = _e("QK_GAIN_INIT", 1.5, float) - activation_type = _e("ACTIVATION", "swiglu") - embed_dim = _e("EMBED_DIM", 0, int) - bigram_hash = _e("BIGRAM_HASH", 0, bool) - mtp_heads_count = _e("MTP_HEADS", 0, int) - training_depth_recurrence = _e("TRAINING_DEPTH_RECURRENCE", 1, int) - eval_depth_recurrence = _e("EVAL_DEPTH_RECURRENCE", 1, int) - attn_proj_type = _e("ATTN_PROJ_TYPE", "standard") - logit_head_type = _e("LOGIT_HEAD_TYPE", "standard") - tversky_num_features = _e("TVERSKY_NUM_FEATURES", 16, int) - tversky_feature_pools = _e("TVERSKY_FEATURE_POOLS", 0, int) - tversky_membership = _e("TVERSKY_MEMBERSHIP", "sigmoid") - diff_attn = _e("DIFF_ATTN", 0, bool) - refiner = _e("REFINER", 0, bool) - refiner_kernel = _e("REFINER_KERNEL", 3, int) - mlp_groups = _e("MLP_GROUPS", 0, int) - embed_lr = _e("EMBED_LR", 0.6, float) - head_lr = _e("HEAD_LR", 0.008, float) - adam_lr = _e("ADAM_LR", 1e-3, float) - adam_wd = _e("ADAM_WD", 0.05, float) - untie_at_fraction = _e("UNTIE_AT_FRACTION", 0.0, float) - tied_embed_lr = _e("TIED_EMBED_LR", 0.05, float) - corr_weight_lr = _e("CORR_WEIGHT_LR", 0.05, float) - smear = _e("SMEAR", 0, bool) - seq_len_start = _e("SEQ_LEN_START", 0, int) - seq_schedule_fraction = _e("SEQ_SCHEDULE_FRACTION", 0.33, float) - batch_tokens_start = _e("BATCH_TOKENS_START", 0, int) - batch_schedule_fraction = _e("BATCH_SCHEDULE_FRACTION", 0.33, float) - churn_log_every = _e("CHURN_LOG_EVERY", 500, int) - matrix_lr = _e("MATRIX_LR", 0.04, float) - scalar_lr = _e("SCALAR_LR", 0.04, float) - muon_momentum = _e("MUON_MOMENTUM", 0.95, float) - muon_backend_steps = _e("MUON_BACKEND_STEPS", 5, int) - muon_wd = _e("MUON_WD", 0.0, float) - matrix_optimizer = _e("MATRIX_OPTIMIZER", "muon") - muon_momentum_warmup_start = _e("MUON_MOMENTUM_WARMUP_START", 0.85, float) - muon_momentum_warmup_steps = _e("MUON_MOMENTUM_WARMUP_STEPS", 500, int) - beta1 = _e("BETA1", 0.9, float) - beta2 = _e("BETA2", 0.95, float) - adam_eps = _e("ADAM_EPS", 1e-8, float) - grad_clip_norm = _e("GRAD_CLIP_NORM", 0.0, float) - bitnet_group_size = _e("BITNET_GROUP_SIZE", 64, int) - sliding_eval = _e("SLIDING_EVAL", 0, bool) - sliding_eval_stride = _e("SLIDING_EVAL_STRIDE", 64, int) - sliding_batch_size = _e("SLIDING_BATCH_SIZE", 64, int) - temp_scaling = _e("TEMP_SCALING", 0, bool) - _fp_raw = os.environ.get("FP_STORAGE", "0") - fp_storage = True if _fp_raw == "FP8" else ("fp4" if _fp_raw == "FP4" else False) - -CTP = ("attn_scale","attn_scales","mlp_scale","mlp_scales","resid_mix","resid_mixes","q_gain","diff_lambda","skip_weight","skip_weights","vocab_bias","refiner.gate") - -# --------------------------------------------------------------------------- -# Ternary packing — base-3 encoding (5 trits/byte) -# --------------------------------------------------------------------------- -def pack_ternary(q: Tensor): - f = (q.reshape(-1).to(torch.int8) + 1).numpy() - n = len(f) - p = (5 - n % 5) % 5 - if p: f = np.concatenate([f, np.zeros(p, dtype=np.int8)]) - g = f.reshape(-1, 5).astype(np.uint8) - return (g[:,0] + g[:,1]*3 + g[:,2]*9 + g[:,3]*27 + g[:,4]*81).tobytes(), n - -def unpack_ternary(data: bytes, n: int) -> Tensor: - v = np.frombuffer(data, dtype=np.uint8).astype(np.int16) - t = np.zeros((len(v), 5), dtype=np.int8) - for i in range(5): t[:,i] = v % 3; v //= 3 - return torch.from_numpy(t.reshape(-1)[:n].astype(np.int8) - 1) - -def pack_ternary_bitmask(q: Tensor): - f = q.reshape(-1).to(torch.int8).numpy(); n = len(f) - nz = (f != 0) - return np.packbits(nz).tobytes() + np.packbits(f[nz] > 0).tobytes(), n - -def unpack_ternary_bitmask(data: bytes, n: int) -> Tensor: - ms = (n + 7) // 8 - nz = np.unpackbits(np.frombuffer(data[:ms], dtype=np.uint8))[:n].astype(bool) - s = np.unpackbits(np.frombuffer(data[ms:], dtype=np.uint8))[:int(nz.sum())].astype(bool) - w = np.zeros(n, dtype=np.int8); w[nz] = np.where(s, 1, -1) - return torch.from_numpy(w) - -# --------------------------------------------------------------------------- -# FP4 quantization (per-row absmax, 2 values packed per byte) -# --------------------------------------------------------------------------- -def quantize_to_int4(t: Tensor) -> tuple[Tensor, Tensor, list]: - t32 = t.float() - orig_shape = t32.shape - if t32.ndim < 2: - t32 = t32.unsqueeze(0) - absmax = t32.abs().amax(dim=-1, keepdim=True).clamp(min=1e-8) - scale = absmax / 7.0 - q = torch.clamp(torch.round(t32 / scale), -7, 7).to(torch.int8) - flat = q.reshape(-1) - if flat.numel() % 2 != 0: - flat = F.pad(flat, (0, 1)) - low = (flat[0::2] + 8).to(torch.uint8) - high = (flat[1::2] + 8).to(torch.uint8) - return low | (high << 4), scale.half().squeeze(-1), list(orig_shape) - -def dequantize_from_int4(packed: Tensor, scale: Tensor, shape: list) -> Tensor: - low = (packed & 0x0F).to(torch.int8) - 8 - high = ((packed >> 4) & 0x0F).to(torch.int8) - 8 - flat = torch.zeros(packed.numel() * 2, dtype=torch.int8) - flat[0::2] = low - flat[1::2] = high - numel = 1 - for s in shape: - numel *= s - flat = flat[:numel].float() - if len(shape) <= 1: - return (flat * scale.float().squeeze()).reshape(shape) - return (flat.reshape(-1, shape[-1]) * scale.float().unsqueeze(-1)).reshape(shape) - -# --------------------------------------------------------------------------- -# State dict serialization (ternary + fp16/fp8/fp4) -# --------------------------------------------------------------------------- -def q_sd(state_dict: dict, group_size: int = 64, fp_storage=False, ternary_method="standard", ternary_override_names: set | None = None) -> tuple[dict, dict]: - "Ternary for large 2D weight matrices, fp16/fp8/fp4 for everything else." - quantized = {} - stats = {"ternary_params": 0, "ternary_bytes": 0, "fp_params": 0, "fp_bytes": 0} - for name, tensor in state_dict.items(): - if "mtp_heads" in name: - continue - t = tensor.detach().cpu().float().contiguous() - t_orig_shape = list(t.shape) - if t.ndim == 3: - t = t.reshape(t.shape[0], -1) - is_ternary_candidate = ( - t.ndim == 2 and t.numel() > 65_536 - and "tok_emb" not in name and "lm_head" not in name and "embed_proj" not in name and "bigram_emb" not in name and "lm_head_correction" not in name and "lm_head_U" not in name and "lm_head_V" not in name - and "prototypes" not in name and "tversky" not in name - ) or (ternary_override_names is not None and name in ternary_override_names) - if is_ternary_candidate: - pad = (group_size - t.shape[1] % group_size) % group_size - t_padded = F.pad(t, (0, pad)) if pad > 0 else t - t_grouped = t_padded.reshape(-1, group_size) - scale = t_grouped.abs().mean(-1, keepdim=True).clamp(min=1e-8).half().float() - q = (t_grouped / scale).round().clamp(-1, 1).to(torch.int8) - - if ternary_method == "standard": - packed_bytes, n_trits = pack_ternary(q) - entry_type = "ternary" - else: - packed_bytes, n_trits = pack_ternary_bitmask(q) - entry_type = "ternary_bitmask" - - quantized[name] = { - "type": entry_type, "packed": packed_bytes, - "scale": scale.half().squeeze(-1), - "shape": list(t.shape), "padded_cols": t_padded.shape[1], - "group_size": group_size, "n_trits": n_trits, - "orig_shape": t_orig_shape, - } - stats["ternary_params"] += t.numel() - stats["ternary_bytes"] += len(packed_bytes) + scale.numel() * 2 - elif fp_storage == "fp4" and t.ndim == 2: - packed, scale, orig_shape = quantize_to_int4(t) - quantized[name] = {"type": "fp4", "packed": packed, "scale": scale, "shape": orig_shape} - stats["fp_params"] += t.numel() - stats["fp_bytes"] += packed.numel() + scale.numel() * 2 - elif fp_storage and t.ndim == 2: - quantized[name] = {"type": "fp8", "data": t.to(torch.float8_e4m3fn)} - stats["fp_params"] += t.numel() - stats["fp_bytes"] += t.numel() - else: - quantized[name] = {"type": "fp16", "data": t.half()} - stats["fp_params"] += t.numel() - stats["fp_bytes"] += t.numel() * 2 - return quantized, stats - -def deq_sd(quantized: dict, target_dtype=torch.bfloat16): - "Reconstruct full-precision state dict from quantized representation." - out = {} - for name, entry in quantized.items(): - if entry["type"] in ("ternary", "ternary_bitmask"): - if entry["type"] == "ternary": - q = unpack_ternary(entry["packed"], entry["n_trits"]) - else: - q = unpack_ternary_bitmask(entry["packed"], entry["n_trits"]) - - q = q.float().reshape(-1, entry["group_size"]) - scale = entry["scale"].float().unsqueeze(-1) - q_absmean = q.abs().mean(-1, keepdim=True).clamp(min=1e-8) - t = (q * (scale / q_absmean)).reshape(-1, entry["padded_cols"]) - shape = entry["shape"] - result = t[:shape[0], :shape[1]].to(target_dtype) - orig = entry.get("orig_shape") - out[name] = result.reshape(orig).contiguous() if orig and orig != shape else result.contiguous() - elif entry["type"] == "fp8": - out[name] = entry["data"].to(torch.float32).to(target_dtype).contiguous() - elif entry["type"] == "fp4": - out[name] = dequantize_from_int4(entry["packed"], entry["scale"], entry["shape"]).to(target_dtype).contiguous() - else: - out[name] = entry["data"].to(target_dtype).contiguous() - return out - -# --------------------------------------------------------------------------- -# Ternary diagnostics (logged during training) -# --------------------------------------------------------------------------- -def tern_stats(model: nn.Module, group_size: int = 64): - total = zeros = 0 - with torch.no_grad(): - for name, p in model.named_parameters(): - if p.ndim == 2 and ("weight" in name or "prototypes" in name) and p.shape[0] > 1: - w = p.detach().float().reshape(-1, group_size) - scale = w.abs().mean(-1, keepdim=True).clamp(min=1e-8).half().float() - q = (w / scale).round().clamp(-1, 1) - zeros += int((q == 0).sum().item()) - total += int(q.numel()) - return {"zero_frac": zeros / max(total, 1), "total_weights": total} - -_prev_committed: dict = {} - -def churn_fn(model: nn.Module, group_size: int = 64): - global _prev_committed - total = flipped = 0 - with torch.no_grad(): - for name, p in model.named_parameters(): - if p.ndim == 2 and ("weight" in name or "prototypes" in name) and p.shape[0] > 1: - w = p.detach().float().reshape(-1, group_size) - scale = w.abs().mean(-1, keepdim=True).clamp(min=1e-8).half().float() - q = (w / scale).round().clamp(-1, 1).cpu().numpy() - if name in _prev_committed: - flipped += int(np.sum(q != _prev_committed[name])) - total += q.size - _prev_committed[name] = q - return flipped / max(total, 1) - -# --------------------------------------------------------------------------- -# Muon optimizer (Newton-Schulz orthogonalized momentum) -# --------------------------------------------------------------------------- -def ns_orth(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: - a, b, c = (3.4445, -4.7750, 2.0315) - X = G.bfloat16() - X /= X.norm() + eps - transposed = G.size(0) > G.size(1) - if transposed: - X = X.T - for _ in range(steps): - A = X @ X.T - B = b * A + c * A @ A - X = a * X + B @ X - return X.T if transposed else X - -class Muon(torch.optim.Optimizer): - def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True, wd: float = 0.0): - super().__init__(params, dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov, wd=wd)) - - @torch.no_grad() - def step(self, closure=None): - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - distributed = dist.is_available() and dist.is_initialized() - world_size = dist.get_world_size() if distributed else 1 - rank = dist.get_rank() if distributed else 0 - for group in self.param_groups: - params = group["params"] - if not params: - continue - lr, momentum = group["lr"], group["momentum"] - backend_steps, nesterov = group["backend_steps"], group["nesterov"] - total_params = sum(int(p.numel()) for p in params) - updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) - curr = 0 - for i, p in enumerate(params): - if i % world_size == rank and p.grad is not None: - g = p.grad - state = self.state[p] - if "momentum_buffer" not in state: - state["momentum_buffer"] = torch.zeros_like(g) - buf = state["momentum_buffer"] - buf.mul_(momentum).add_(g) - if nesterov: - g = g.add(buf, alpha=momentum) - g = F.rms_norm(g.float(), (g.size(-1),)).bfloat16() - g = ns_orth(g, steps=backend_steps) - g *= max(1, g.size(0) / g.size(1)) ** 0.5 - updates_flat[curr:curr + p.numel()] = g.reshape(-1) - curr += p.numel() - if distributed: - dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) - wd = group.get("wd", 0.0) - curr = 0 - for p in params: - g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) - if wd > 0: - p.mul_(1 - lr * wd) - p.add_(g, alpha=-lr) - curr += p.numel() - return loss - -# --------------------------------------------------------------------------- -# Data loading -# --------------------------------------------------------------------------- -def ld_shard(file: Path) -> Tensor: - header_bytes = 256 * np.dtype(" Tensor: - chunks = [] - remaining = n - while remaining > 0: - avail = self.tokens.numel() - self.pos - if avail <= 0: - self._advance_file() - continue - k = min(remaining, avail) - chunks.append(self.tokens[self.pos:self.pos + k]) - self.pos += k - remaining -= k - return chunks[0] if len(chunks) == 1 else torch.cat(chunks) - -class DistributedTokenLoader: - def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): - self.rank, self.world_size, self.device = rank, world_size, device - self.stream = TokenStream(pattern) - - def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: - local_tokens = global_tokens // (self.world_size * grad_accum_steps) - per_rank_span = local_tokens + 1 - chunk = self.stream.take(per_rank_span * self.world_size) - start = self.rank * per_rank_span - local = chunk[start:start + per_rank_span].pin_memory().to(self.device, non_blocking=True).to(torch.int64) - x = local[:-1].reshape(-1, seq_len) - y = local[1:].reshape(-1, seq_len) - return x, y - -# --------------------------------------------------------------------------- -# Model -# --------------------------------------------------------------------------- -class RMSNorm(nn.Module): - def __init__(self, eps: float | None = None): - super().__init__() - self.eps = eps - - def forward(self, x: Tensor) -> Tensor: - return F.rms_norm(x, (x.size(-1),), eps=self.eps) - -def apply_qat_ste(w: Tensor, fp_storage: str | bool) -> Tensor: - """Applies Straight-Through Estimator (STE) for FP4 or FP8 simulated quantization.""" - if not fp_storage: - return w - if fp_storage == "fp4": - absmax = w.abs().amax(dim=-1, keepdim=True).clamp(min=1e-8) - scale = absmax / 7.0 - q = torch.clamp(torch.round(w / scale), -7.0, 7.0) - w_sim = q * scale - return (w_sim - w).detach() + w - elif fp_storage is True or fp_storage == "fp8": - w_sim = w.to(torch.float8_e4m3fn).to(w.dtype) - return (w_sim - w).detach() + w - return w - -class QATLinear(nn.Linear): - def __init__(self, in_features: int, out_features: int, bias: bool = False, fp_storage: str | bool = False): - super().__init__(in_features, out_features, bias=bias) - self.fp_storage = fp_storage - - def forward(self, x: Tensor) -> Tensor: - w_qat = apply_qat_ste(self.weight, self.fp_storage) - return F.linear(x, w_qat.to(x.dtype), self.bias.to(x.dtype) if self.bias is not None else None) - -class QATEmbedding(nn.Embedding): - def __init__(self, num_embeddings: int, embedding_dim: int, fp_storage: str | bool = False): - super().__init__(num_embeddings, embedding_dim) - self.fp_storage = fp_storage - - def forward(self, input: Tensor) -> Tensor: - w_qat = apply_qat_ste(self.weight, self.fp_storage) - return F.embedding(input, w_qat, self.padding_idx, self.max_norm, - self.norm_type, self.scale_grad_by_freq, self.sparse) - -class TernaryLinear(nn.Linear): - def __init__(self, in_features, out_features, bias=False, group_size=64): - super().__init__(in_features, out_features, bias=bias) - self.group_size = group_size - - def forward(self, x: Tensor) -> Tensor: - w = self.weight.bfloat16() - g = self.group_size - w_g = w.reshape(-1, g) - scale = w_g.abs().mean(-1, keepdim=True).clamp(min=1e-8) - q = (w_g / scale).round().clamp(-1, 1) - w_ternary = w + ((q * scale).reshape(w.shape) - w).detach() - return F.linear(x, w_ternary, - self.bias.to(x.dtype) if self.bias is not None else None) - - -class NormedTernaryLinear(TernaryLinear): - "Ternary linear with RMSNorm on input — for output projections receiving un-normalized activations." - def forward(self, x: Tensor) -> Tensor: - return super().forward(F.rms_norm(x, (x.size(-1),))) - -class GroupedTernaryLinear(nn.Module): - "Grouped linear with ternary STE. Weight stored as 2D [groups*group_out, group_in] for ternary quantization compatibility." - def __init__(self, in_features, out_features, groups=4, group_size=64, normed=False): - super().__init__() - assert in_features % groups == 0 and out_features % groups == 0 - self.groups = groups - self.group_in = in_features // groups - self.group_out = out_features // groups - self.group_size = group_size - self.normed = normed - self.weight = nn.Parameter(torch.randn(groups * self.group_out, self.group_in) * 0.02) - - def forward(self, x: Tensor) -> Tensor: - if self.normed: - x = F.rms_norm(x, (x.size(-1),)) - w = self.weight.bfloat16() - g = self.group_size - w_g = w.reshape(-1, g) - scale = w_g.abs().mean(-1, keepdim=True).clamp(min=1e-8) - q = (w_g / scale).round().clamp(-1, 1) - w_ternary = w + ((q * scale).reshape(w.shape) - w).detach() - w_grouped = w_ternary.reshape(self.groups, self.group_out, self.group_in) - bsz = x.shape[:-1] - x_g = x.reshape(*bsz, self.groups, self.group_in) - out = torch.einsum('...gi,goi->...go', x_g, w_grouped) - return out.reshape(*bsz, self.groups * self.group_out) - -class TverskyProjection(nn.Module): - "Tversky similarity: S = θ·f(A∩B) - α·f(A\\B) - β·f(B\\A). Three modes." - def __init__(self, in_features: int, out_features: int, num_features: int = 16, - group_size: int = 64, use_shared_features: bool = False, - membership: str = "sigmoid"): - super().__init__() - self.group_size = group_size - self.num_features = num_features - self.membership_type = membership - self.no_features_mode = (num_features == 0) - - if not self.no_features_mode and not use_shared_features: - self.features = nn.Parameter(torch.empty(num_features, in_features).uniform_(-0.02, 0.02)) - else: - self.register_parameter('features', None) - - self.prototypes = nn.Parameter(torch.empty(out_features, in_features).uniform_(-0.02, 0.02)) - self.theta = nn.Parameter(torch.tensor(1.0)) - self.alpha = nn.Parameter(torch.tensor(0.5)) - self.beta = nn.Parameter(torch.tensor(0.5)) - - def _ternary_ste(self, w: Tensor) -> Tensor: - w_bf16 = w.bfloat16() - g = self.group_size - w_grouped = w_bf16.reshape(-1, g) - scale = w_grouped.abs().mean(-1, keepdim=True).clamp(min=1e-8) - q = (w_grouped / scale).round().clamp(-1, 1) - w_ternary = w_bf16 + ((q * scale).reshape(w_bf16.shape) - w_bf16).detach() - return w_ternary.reshape(w.shape) - - def _membership(self, t: Tensor) -> Tensor: - if self.membership_type == "poly": - return torch.clamp(t * 5.0 / 4.0 + 0.5, 0.0, 1.0) - elif self.membership_type == "tanh": - return (torch.tanh(t * 5.0) + 1.0) * 0.5 - else: - return torch.sigmoid(t * 5.0) - - def forward(self, x: Tensor, shared_features: Tensor | None = None) -> Tensor: - proto = self._ternary_ste(self.prototypes) - - if self.no_features_mode: - # NoFeatures: prototypes are their own feature universe - x_f = x @ proto.t() # [B, S, out] - p_norm = F.normalize(proto, dim=-1) - p_f = p_norm @ p_norm.t() # [out, out] - else: - feat = (shared_features if shared_features is not None else self.features).float() - x_f = x @ feat.t() # [B, S, nf] - p_f = proto @ feat.t() # [out, nf] - - x_s = self._membership(x_f) - p_s = self._membership(p_f) - x_a = x_f * x_s - p_a = p_f * p_s - - t, a, b = self.theta.abs(), self.alpha.abs(), self.beta.abs() - return t * (x_a @ p_a.t()) - a * (x_a @ (1 - p_s).t()) - b * ((1 - x_s) @ p_a.t()) - -def restore_low_dim_params_to_fp32(module: nn.Module) -> None: - with torch.no_grad(): - for name, param in module.named_parameters(): - if (param.ndim < 2 or any(p in name for p in CTP)) and param.dtype != torch.float32: - param.data = param.data.float() - -class Rotary(nn.Module): - def __init__(self, dim: int, base: float = 10000.0, no_cache: bool = False, - rope_type: str = "rope", yarn_max_len: int = 4096, train_seq_len: int = 1024): - super().__init__() - self.no_cache = no_cache - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) - if rope_type == "yarn": - scale = train_seq_len / yarn_max_len - freq_idx = torch.arange(0, dim, 2, dtype=torch.float32) - ramp = torch.clamp((freq_idx / dim - 0.25) / 0.75, 0.0, 1.0) - inv_freq = inv_freq / (ramp * (1.0 / scale - 1.0) + 1.0) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self._seq_len_cached = 0 - self._cos_cached: Tensor | None = None - self._sin_cached: Tensor | None = None - - def forward(self, seq_len, device, dtype): - if self.no_cache: - t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) - freqs = torch.outer(t, self.inv_freq.to(device)) - return freqs.cos()[None, :, None, :].to(dtype=dtype), freqs.sin()[None, :, None, :].to(dtype=dtype) - if ( - self._cos_cached is None - or self._sin_cached is None - or self._seq_len_cached != seq_len - or self._cos_cached.device != device - ): - t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) - freqs = torch.outer(t, self.inv_freq.to(device)) - self._cos_cached = freqs.cos()[None, :, None, :] - self._sin_cached = freqs.sin()[None, :, None, :] - self._seq_len_cached = seq_len - return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) - -def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: - half = x.size(-1) // 2 - x1, x2 = x[..., :half], x[..., half:] - return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) - -class CausalSelfAttention(nn.Module): - def __init__(self, dim, num_heads, num_kv_heads, rope_base, qk_gain_init, - group_size=64, attn_proj_type="standard", tversky_num_features=16, - tversky_feature_pools=0, no_cache=False, rope_type="rope", - yarn_max_len=4096, train_seq_len=1024, tversky_membership="sigmoid", - diff_attn=False): - super().__init__() - self.num_heads, self.num_kv_heads = num_heads, num_kv_heads - self.head_dim = dim // num_heads - self.diff_attn = diff_attn - self.q_size = self.num_heads * self.head_dim - self.kv_size = self.num_kv_heads * self.head_dim - - self.c_qkv = TernaryLinear(dim, self.q_size + 2 * self.kv_size, bias=False, group_size=group_size) - self.proj = NormedTernaryLinear(dim, dim, bias=False, group_size=group_size) if attn_proj_type != "tversky" else None - if self.proj is not None: - self.proj._zero_init = True - self.tversky_proj = TverskyProjection( - dim, dim, num_features=tversky_num_features, group_size=group_size, - use_shared_features=(tversky_feature_pools > 0), - membership=tversky_membership, - ) if attn_proj_type == "tversky" else None - self.shared_features = None - self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) - if diff_attn: - self.diff_lambda = nn.Parameter(torch.full((num_heads,), 0.5, dtype=torch.float32)) - self.rotary = Rotary(self.head_dim, base=rope_base, no_cache=no_cache, - rope_type=rope_type, yarn_max_len=yarn_max_len, - train_seq_len=train_seq_len) - - def forward(self, x: Tensor) -> Tensor: - bsz, seqlen, dim = x.shape - qkv_out = self.c_qkv(x) - q_out, k_out, v_out = qkv_out.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - q = q_out.reshape(bsz, seqlen, self.num_heads, self.head_dim) - k = k_out.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) - v = v_out.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) - q, k = F.rms_norm(q, (q.size(-1),)), F.rms_norm(k, (k.size(-1),)) - cos, sin = self.rotary(seqlen, x.device, q.dtype) - q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin) - q = q * self.q_gain.to(dtype=q.dtype)[None, None, :, None] - if self.diff_attn: - half = self.head_dim // 2 - q1, q2 = q[..., :half], q[..., half:] - k1, k2 = k[..., :half], k[..., half:] - v1, v2 = v[..., :half], v[..., half:] - y1 = flash_attn_func(q1.contiguous(), k1.contiguous(), v1.contiguous(), causal=True) - y2 = flash_attn_func(q2.contiguous(), k2.contiguous(), v2.contiguous(), causal=True) - lam = self.diff_lambda.to(dtype=y1.dtype)[None, None, :, None] - y = torch.cat([y1 - lam * y2, y1 + lam * y2], dim=-1) - else: - y = flash_attn_func( - q.contiguous(), - k.contiguous(), - v.contiguous(), - causal=True - ) - y = y.reshape(bsz, seqlen, dim) - return self.tversky_proj(y, self.shared_features) if self.tversky_proj is not None else self.proj(y) - -class MLP(nn.Module): - def __init__(self, dim, mlp_mult, group_size=64, activation="swiglu", mlp_groups=0): - super().__init__() - hidden = mlp_mult * dim - self.activation = activation - if mlp_groups > 0: - if activation == "swiglu": - self.gate_up = GroupedTernaryLinear(dim, hidden * 2, groups=mlp_groups, group_size=group_size) - else: - self.fc = GroupedTernaryLinear(dim, hidden, groups=mlp_groups, group_size=group_size) - self.proj = GroupedTernaryLinear(hidden, dim, groups=mlp_groups, group_size=group_size, normed=True) - else: - if activation == "swiglu": - self.gate_up = TernaryLinear(dim, hidden * 2, bias=False, group_size=group_size) - else: - self.fc = TernaryLinear(dim, hidden, bias=False, group_size=group_size) - self.proj = NormedTernaryLinear(hidden, dim, bias=False, group_size=group_size) - self.proj._zero_init = True - - def forward(self, x: Tensor) -> Tensor: - if self.activation == "swiglu": - gu = self.gate_up(x) - gate, up = gu.chunk(2, dim=-1) - return self.proj(F.silu(gate) * up) - elif self.activation == "relu": - return self.proj(torch.relu(self.fc(x))) - elif self.activation == "leaky_relu": - return self.proj(F.leaky_relu(self.fc(x), negative_slope=0.01)) - else: # relu2 - return self.proj(torch.relu(self.fc(x)).square()) - -class SmearModule(nn.Module): - def __init__(self, dim: int): - super().__init__() - self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) - - def forward(self, x: Tensor) -> Tensor: - cumsum = x.cumsum(dim=1) - counts = torch.arange(1, x.size(1) + 1, device=x.device, dtype=x.dtype).view(1, -1, 1) - smeared = cumsum / counts - gate = torch.tanh(self.gate.to(dtype=x.dtype)) - return x + gate * (smeared - x) - - -class CausalConvRefiner(nn.Module): - "Causal Conv1d that refines hidden states using local n-gram context." - def __init__(self, dim: int, kernel_size: int = 3): - super().__init__() - self.kernel_size = kernel_size - self.conv = nn.Conv1d(dim, dim, kernel_size, padding=0, bias=False) - self.gate = nn.Parameter(torch.zeros(1, dtype=torch.float32)) - - def forward(self, x: Tensor) -> Tensor: - h = x.permute(0, 2, 1) # [B, D, S] - h = F.pad(h, (self.kernel_size - 1, 0)) # causal pad - h = self.conv(h) - h = h.permute(0, 2, 1) # [B, S, D] - return x + torch.tanh(self.gate.to(dtype=x.dtype)) * F.rms_norm(h, (h.size(-1),)) - -class Block(nn.Module): - def __init__(self, dim: int, num_heads: int, num_kv_heads: int, mlp_mult: int, - rope_base: float, qk_gain_init: float, group_size: int=64, - activation: str="swiglu", attn_proj_type: str="standard", - tversky_num_features: int=16, tversky_feature_pools: int=0, no_cache: bool=False, - smear: bool=False, rope_type: str="rope", yarn_max_len: int=4096, - train_seq_len: int=1024, tversky_membership: str="sigmoid", - diff_attn: bool=False, mlp_groups: int=0): - super().__init__() - self.attn_norm = RMSNorm() - self.mlp_norm = RMSNorm() - self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init, - group_size, attn_proj_type, tversky_num_features, - tversky_feature_pools, no_cache, rope_type, yarn_max_len, - train_seq_len, tversky_membership, diff_attn) - self.mlp = MLP(dim, mlp_mult, group_size, activation, mlp_groups) - self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) - self.smear = SmearModule(dim) if smear else None - - def forward(self, x: Tensor, x0: Tensor) -> Tensor: - mix = self.resid_mix.to(dtype=x.dtype) - x = mix[0] * x + mix[1] * x0 - n = self.attn_norm(x) - x = x + self.attn_scale.to(dtype=x.dtype) * self.attn(n) - x = x + self.mlp_scale.to(dtype=x.dtype) * self.mlp(self.mlp_norm(x)) - if self.smear is not None: - x = self.smear(x) - return x - -class GPT(nn.Module): - def __init__(self, vocab_size, num_layers, model_dim, num_heads, num_kv_heads, mlp_mult, - tie_embeddings, tied_embed_init_std, logit_softcap, rope_base, qk_gain_init, - group_size: int = 64, activation: str = "swiglu", mtp_heads_count: int = 0, - embed_dim: int = 0, attn_proj_type: str = "standard", logit_head_type: str = "standard", - tversky_num_features: int = 16, tversky_feature_pools: int = 0, - training_depth_recurrence: int=1, fp_storage=False, bigram_hash: bool=False, - softcap_type: str="poly", no_cache: bool=False, - smear: bool=False, rope_type: str="rope", yarn_max_len: int=4096, - train_seq_len: int=1024, tversky_membership: str="sigmoid", - diff_attn=False, mlp_groups=0, refiner=False, refiner_kernel=3): - super().__init__() - self.training_depth_recurrence = training_depth_recurrence - self.fp_storage = fp_storage - self.tie_embeddings = tie_embeddings - self.logit_softcap = logit_softcap - self.softcap_type = softcap_type - self.embed_dim = embed_dim if embed_dim > 0 else model_dim - self.tok_emb = QATEmbedding(vocab_size, self.embed_dim, fp_storage=fp_storage) - self.bigram_emb = QATEmbedding(vocab_size, self.embed_dim, fp_storage=fp_storage) if bigram_hash else None - if self.bigram_emb is not None: - nn.init.zeros_(self.bigram_emb.weight) - self.lm_head_correction = nn.Parameter( - torch.zeros(vocab_size, self.embed_dim)) if tie_embeddings == 2 else None - self.embed_proj = QATLinear(self.embed_dim, model_dim, bias=False, fp_storage=fp_storage) if self.embed_dim != model_dim else None - self.embed_proj_rev = QATLinear(model_dim, self.embed_dim, bias=False, fp_storage=fp_storage) if ( - self.embed_dim != model_dim and logit_head_type != "tversky") else None - self.num_encoder_layers = num_layers // 2 - self.num_decoder_layers = num_layers - self.num_encoder_layers - self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) - self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) - - # Shared Tversky feature pools (if enabled and num_features > 0) - if attn_proj_type == "tversky" and tversky_feature_pools > 0 and tversky_num_features > 0: - self.tversky_feature_pools_list = nn.ParameterList([ - nn.Parameter(torch.empty(tversky_num_features, model_dim).uniform_(-0.02, 0.02)) - for _ in range(tversky_feature_pools) - ]) - else: - self.tversky_feature_pools_list = None - - self.blocks = nn.ModuleList([ - Block(model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init, - group_size, activation, attn_proj_type, tversky_num_features, tversky_feature_pools, - no_cache, smear, rope_type, yarn_max_len, train_seq_len, tversky_membership, - diff_attn, mlp_groups) - for _ in range(num_layers) - ]) - - # Inject shared feature pool references into attention layers - if self.tversky_feature_pools_list is not None: - for i, block in enumerate(self.blocks): - pool_idx = (i * tversky_feature_pools) // num_layers - block.attn.shared_features = self.tversky_feature_pools_list[pool_idx] - - self.final_norm = RMSNorm() - self.refiner = CausalConvRefiner(model_dim, kernel_size=refiner_kernel) if refiner else None - self.mtp_heads = nn.ModuleList([ - nn.Linear(model_dim, vocab_size, bias=False) for _ in range(mtp_heads_count) - ]) - for h in self.mtp_heads: - nn.init.zeros_(h.weight) - self.logit_head_type = logit_head_type - if logit_head_type == "tversky" and tversky_num_features == 0 and vocab_size > 1024: - raise ValueError( - f"Tversky logit head with no-features mode creates O(V^2) = {vocab_size}x{vocab_size} " - f"matrix per forward pass. Use tversky_num_features > 0 or a smaller vocab." - ) - self.tversky_head = TverskyProjection( - model_dim, vocab_size, num_features=tversky_num_features, - membership=tversky_membership, - ) if logit_head_type == "tversky" else None - self.lm_head = QATLinear(model_dim, vocab_size, bias=False, fp_storage=fp_storage) - self.lm_head._zero_init = True - if self.lm_head is not None and (tie_embeddings or logit_head_type == "tversky"): - self.lm_head.weight.requires_grad_(False) - - self.vocab_bias = nn.Parameter(torch.zeros(vocab_size, dtype=torch.float32)) - self._init_weights(tied_embed_init_std) - - def _init_weights(self, tied_embed_init_std: float) -> None: - if self.tie_embeddings: - nn.init.normal_(self.tok_emb.weight, mean=0.0, std=tied_embed_init_std) - for module in self.modules(): - if isinstance(module, TernaryLinear) and not getattr(module, "_zero_init", False): - nn.init.normal_(module.weight, mean=0.0, std=0.02) - elif isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): - nn.init.zeros_(module.weight) - - def _compute_logits(self, x: Tensor) -> Tensor: - if self.tversky_head is not None: - logits_raw = self.tversky_head(x) - elif self.tie_embeddings: - if self.embed_proj_rev is not None: - proj = self.embed_proj_rev(x) - else: - proj = x - weight = self.tok_emb.weight - if self.lm_head_correction is not None: - weight = weight + self.lm_head_correction - logits_raw = F.linear(proj, weight.to(x.dtype)) - else: - logits_raw = self.lm_head(x) - return logits_raw + self.vocab_bias.to(x.dtype) - - def _softcap(self, logits: Tensor) -> Tensor: - s = self.logit_softcap - if self.softcap_type == "tanh": - return s * torch.tanh(logits / s) - x_sc = torch.clamp(logits / s, -2.0, 2.0) - x2 = x_sc * x_sc - return s * torch.clamp(x_sc * (1.0 - x2 / 3.0 + x2 * x2 / 15.0), -1.0, 1.0) - - def forward(self, input_ids: Tensor, target_ids: Tensor, reduction: str = "mean", temperature: float = 1.0) -> Tensor: - x = self.tok_emb(input_ids).float() - if self.bigram_emb is not None: - prev = F.pad(input_ids[:, :-1], (1, 0), value=0) - x = x + self.bigram_emb(prev).float() - if self.embed_proj is not None: - x = self.embed_proj(x) - x = F.rms_norm(x, (x.size(-1),)) - x0 = x - - # U-Net style encoder/decoder with skip connections - skips = [] - for i in range(self.num_encoder_layers): - for _ in range(max(1, self.training_depth_recurrence)): - x = self.blocks[i](x, x0) - skips.append(x) - for i in range(self.num_decoder_layers): - bi = self.num_encoder_layers + i - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype) * skips.pop() - for _ in range(max(1, self.training_depth_recurrence)): - x = self.blocks[bi](x, x0) - - x_normed = self.final_norm(x) - if self.refiner is not None: - x_normed = self.refiner(x_normed) - - # Standard training/eval path - x_flat = x_normed.reshape(-1, x_normed.size(-1)) - targets = target_ids.reshape(-1) - logits = self._softcap(self._compute_logits(x_flat)) - - if reduction == "none": - return F.cross_entropy(logits.float(), targets, reduction="none").reshape(input_ids.shape) - - # Fused CE + Z-loss: single logsumexp computation - logits_f = logits.float() - lse = torch.logsumexp(logits_f, dim=-1) - target_logits = logits_f.gather(1, targets.unsqueeze(1)).squeeze(1) - main_loss = (lse - target_logits).mean() + 1e-4 * (lse ** 2).mean() - - # Multi-token prediction auxiliary loss (training only) - if self.training and len(self.mtp_heads) > 0: - mtp_loss = torch.zeros((), device=main_loss.device) - for k, head in enumerate(self.mtp_heads): - shift = k + 2 - if target_ids.shape[1] > shift: - mtp_tgt = target_ids[:, shift:].reshape(-1) - mtp_in = x_normed[:, :target_ids.shape[1] - shift, :].reshape(-1, x_normed.shape[-1]) - mtp_loss = mtp_loss + F.cross_entropy(head(mtp_in).float(), mtp_tgt, reduction="mean") - main_loss = main_loss + 0.1 * mtp_loss / len(self.mtp_heads) - return main_loss - -# --------------------------------------------------------------------------- -# Validation -# --------------------------------------------------------------------------- - -def build_luts(sp, vocab_size: int, device: torch.device): - sp_vocab_size = int(sp.vocab_size()) - table_size = max(sp_vocab_size, vocab_size) - base_bytes_np = np.zeros((table_size,), dtype=np.int16) - has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) - is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) - for token_id in range(sp_vocab_size): - if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): - continue - is_boundary_token_np[token_id] = False - if sp.is_byte(token_id): - base_bytes_np[token_id] = 1 - continue - piece = sp.id_to_piece(token_id) - if piece.startswith("\u2581"): - has_leading_space_np[token_id] = True - piece = piece[1:] - base_bytes_np[token_id] = len(piece.encode("utf-8")) - return ( - torch.tensor(base_bytes_np, dtype=torch.int16, device=device), - torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), - torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), - ) - -def ld_val(pattern, seq_len, max_tok=int(os.environ.get("VAL_MAX_TOKENS", 500000))): - files = sorted(glob.glob(pattern)) - assert files, f"No files: {pattern}" - tok = torch.cat([ld_shard(Path(p)) for p in files]).contiguous() - if max_tok > 0: tok = tok[:max_tok + 1] - u = ((tok.numel() - 1) // seq_len) * seq_len - return tok[:u + 1] - -def eval_val(args, model, rank, world_size, device, grad_accum_steps, val_tokens, - base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, temperature: float = 1.0): - local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) - local_batch_seqs = max(1, local_batch_tokens // args.train_seq_len) - total_seqs = (val_tokens.numel() - 1) // args.train_seq_len - seq_start = (total_seqs * rank) // world_size - seq_end = (total_seqs * (rank + 1)) // world_size - loss_sum = torch.zeros((), device=device, dtype=torch.float64) - token_count = torch.zeros((), device=device, dtype=torch.float64) - byte_count = torch.zeros((), device=device, dtype=torch.float64) - model.eval() - with torch.inference_mode(): - for batch_start in range(seq_start, seq_end, local_batch_seqs): - batch_end = min(batch_start + local_batch_seqs, seq_end) - raw_start = batch_start * args.train_seq_len - raw_end = batch_end * args.train_seq_len + 1 - local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64) - x, y = local[:-1].reshape(-1, args.train_seq_len), local[1:].reshape(-1, args.train_seq_len) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16): - batch_loss = model(x, y, temperature=temperature).detach() - n = float(y.numel()) - loss_sum += batch_loss.to(torch.float64) * n - token_count += n - prev_ids, tgt_ids = x.reshape(-1), y.reshape(-1) - tok_bytes = base_bytes_lut[tgt_ids].to(torch.int16) - tok_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(torch.int16) - byte_count += tok_bytes.to(torch.float64).sum() - if dist.is_available() and dist.is_initialized(): - for t in (loss_sum, token_count, byte_count): - dist.all_reduce(t, op=dist.ReduceOp.SUM) - val_loss = loss_sum / token_count - bpb = (val_loss.item() / math.log(2.0)) * (token_count.item() / byte_count.item()) - model.train() - return float(val_loss.item()), float(bpb) - -def eval_val_sliding(args, model, rank, world_size, device, grad_accum_steps, val_tokens, - base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - stride: int = 64, temperature: float = 1.0): - seq_len = args.train_seq_len - batch_size = args.sliding_batch_size - total_tokens = val_tokens.numel() - 1 - loss_sum = torch.zeros((), device=device, dtype=torch.float64) - token_count = torch.zeros((), device=device, dtype=torch.float64) - byte_count = torch.zeros((), device=device, dtype=torch.float64) - all_starts = list(range(0, total_tokens - seq_len, stride)) - my_starts = all_starts[rank::world_size] - - model.eval() - with torch.inference_mode(): - for i in range(0, len(my_starts), batch_size): - batch_starts = my_starts[i:i + batch_size] - - starts_t = torch.tensor(batch_starts, dtype=torch.int64) - offsets = torch.arange(seq_len + 1, dtype=torch.int64) - indices = starts_t.unsqueeze(1) + offsets.unsqueeze(0) - - local_batch = val_tokens[indices].to(device=device, dtype=torch.int64, non_blocking=True) - x = local_batch[:, :-1] - y = local_batch[:, 1:] - - with torch.autocast(device_type="cuda", dtype=torch.bfloat16): - per_token_loss = model(x, y, reduction="none", temperature=temperature).detach() - - for b, start in enumerate(batch_starts): - score_from = 0 if start == 0 else seq_len - stride - scored = per_token_loss[b, score_from:] - sx, sy = x[b, score_from:], y[b, score_from:] - - loss_sum += scored.to(torch.float64).sum() - token_count += scored.numel() - - tok_bytes = base_bytes_lut[sy].to(torch.int16) - tok_bytes += (has_leading_space_lut[sy] & ~is_boundary_token_lut[sx]).to(torch.int16) - byte_count += tok_bytes.to(torch.float64).sum() - - if dist.is_available() and dist.is_initialized(): - for t in (loss_sum, token_count, byte_count): - dist.all_reduce(t, op=dist.ReduceOp.SUM) - - val_loss = loss_sum / token_count - bpb = (val_loss.item() / math.log(2.0)) * (token_count.item() / byte_count.item()) - model.train() - return float(val_loss.item()), float(bpb) - -# --------------------------------------------------------------------------- -# Temperature scaling -# --------------------------------------------------------------------------- -def find_temp(args, base_model, rank, world_size, device, grad_accum_steps, - calibration_tokens, base_bytes_lut, has_leading_space_lut, - is_boundary_token_lut): - best_t, best_loss = 1.0, float("inf") - for t in [0.90, 0.95, 1.00, 1.05, 1.10]: - loss, _ = eval_val(args, base_model, rank, world_size, device, grad_accum_steps, - calibration_tokens, base_bytes_lut, has_leading_space_lut, - is_boundary_token_lut, temperature=t) - if loss < best_loss: - best_loss = loss - best_t = t - return best_t - -# --------------------------------------------------------------------------- -# Training -# --------------------------------------------------------------------------- -def main() -> None: - args = Hyperparameters() - code = Path(__file__).read_text(encoding="utf-8") - - if args.matrix_optimizer != "adamw": - global ns_orth - ns_orth = torch.compile(ns_orth) - - distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ - rank = int(os.environ.get("RANK", "0")) - world_size = int(os.environ.get("WORLD_SIZE", "1")) - local_rank = int(os.environ.get("LOCAL_RANK", "0")) - grad_accum_steps = max(1, 8 // world_size) - grad_scale = 1.0 / grad_accum_steps - - if not torch.cuda.is_available(): - raise RuntimeError("CUDA is required") - device = torch.device("cuda", local_rank) - torch.cuda.set_device(device) - if distributed: - dist.init_process_group(backend="nccl", device_id=device) - dist.barrier() - master_process = rank == 0 - torch.backends.cuda.matmul.allow_tf32 = True - torch.backends.cudnn.allow_tf32 = True - - os.makedirs("logs/cuda/", exist_ok=True) - logfile = f"logs/cuda/{args.run_id}.txt" if master_process else None - if master_process: - print(logfile) - - def log0(msg: str, console: bool = True) -> None: - if not master_process: - return - if console: - print(msg) - if logfile: - with open(logfile, "a", encoding="utf-8") as f: - print(msg, file=f) - - log0(code, console=False) - log0("=" * 100, console=False) - - log0(f"Python {sys.version}", console=False) - log0(f"PyTorch {torch.__version__}", console=False) - - random.seed(args.seed) - np.random.seed(args.seed) - torch.manual_seed(args.seed) - torch.cuda.manual_seed_all(args.seed) - - sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) - val_tokens = ld_val(args.val_files, args.train_seq_len) - base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_luts( - sp, args.vocab_size, device) - - # --- Model --- - base_model = GPT( - vocab_size=args.vocab_size, num_layers=args.num_layers, model_dim=args.model_dim, - num_heads=args.num_heads, num_kv_heads=args.num_kv_heads, mlp_mult=args.mlp_mult, - tie_embeddings=args.tie_embeddings, tied_embed_init_std=args.tied_embed_init_std, - logit_softcap=args.logit_softcap, rope_base=args.rope_base, qk_gain_init=args.qk_gain_init, - group_size=args.bitnet_group_size, activation=args.activation_type, mtp_heads_count=args.mtp_heads_count, - embed_dim=args.embed_dim, attn_proj_type=args.attn_proj_type, logit_head_type=args.logit_head_type, - tversky_num_features=args.tversky_num_features, tversky_feature_pools=args.tversky_feature_pools, - training_depth_recurrence=args.training_depth_recurrence, fp_storage=args.fp_storage, - bigram_hash=args.bigram_hash, softcap_type=args.softcap_type, no_cache=(args.compile_mode == "reduce-overhead"), - smear=args.smear, rope_type=args.rope_type, yarn_max_len=args.yarn_max_len, train_seq_len=args.train_seq_len, - tversky_membership=args.tversky_membership, diff_attn=args.diff_attn, - refiner=args.refiner, refiner_kernel=args.refiner_kernel, mlp_groups=args.mlp_groups, - ).to(device).bfloat16() - - for module in base_model.modules(): - if isinstance(module, nn.Linear): - module.float() - restore_low_dim_params_to_fp32(base_model) - if base_model.lm_head is not None and (args.tie_embeddings or args.logit_head_type == "tversky"): - base_model.lm_head.weight.requires_grad_(False) - - torch._dynamo.config.optimize_ddp = False - - compiled_model = torch.compile(base_model, mode=args.compile_mode if args.compile_mode != "default" else None) - use_find_unused = args.untie_at_fraction > 0 or args.mtp_heads_count > 0 or not args.tie_embeddings - model = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False, - find_unused_parameters=use_find_unused, - static_graph=not use_find_unused, - gradient_as_bucket_view=True) if distributed else compiled_model - - # --- Optimizers --- - _excl = {"tok_emb.weight", "lm_head.weight", "lm_head_correction"} - all_other_params = [(n, p) for n, p in base_model.named_parameters() - if not any(eh in n for eh in _excl)] - matrix_params = [p for n, p in all_other_params - if p.ndim == 2 and not any(pat in n for pat in CTP)] - scalar_params = [p for n, p in all_other_params - if p.ndim < 2 or any(pat in n for pat in CTP)] - - token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr - opt_tok = torch.optim.Adam( - [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], - betas=(args.beta1, args.beta2), eps=args.adam_eps, fused=True) - if args.matrix_optimizer == "adamw": - opt_muon = torch.optim.AdamW( - [{"params": matrix_params, "lr": args.adam_lr, "base_lr": args.adam_lr}], - betas=(args.beta1, args.beta2), eps=args.adam_eps, weight_decay=args.adam_wd, fused=True) - else: - opt_muon = Muon(matrix_params, lr=args.matrix_lr, momentum=args.muon_momentum, - backend_steps=args.muon_backend_steps, wd=args.muon_wd) - for g in opt_muon.param_groups: - g["base_lr"] = args.matrix_lr - opt_scalar = torch.optim.Adam( - [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], - betas=(args.beta1, args.beta2), eps=args.adam_eps, fused=True) - opt_head = torch.optim.Adam( - [{"params": [base_model.lm_head.weight], "lr": 0.0, "base_lr": 0.0}], - betas=(args.beta1, args.beta2), eps=args.adam_eps, fused=True) - - optimizers = [opt for opt in [opt_tok, opt_muon, opt_scalar, opt_head] if opt is not None] - - if base_model.lm_head_correction is not None: - opt_corr = torch.optim.Adam( - [{"params": [base_model.lm_head_correction], - "lr": args.corr_weight_lr, "base_lr": args.corr_weight_lr}], - betas=(args.beta1, args.beta2), eps=args.adam_eps, fused=True) - optimizers.append(opt_corr) - - # --- Log all hyperparameters --- - log0("--- Hyperparameters ---", console=False) - log0(" ".join(f"{a}={getattr(args,a)}" for a in sorted(dir(args)) if not a.startswith("_") and a not in ("train_files","val_files") and not callable(getattr(args,a))), console=False) - n_params = sum(p.numel() for p in base_model.parameters()) - log0(f"params:{n_params} L:{args.num_layers} d:{args.model_dim} h:{args.num_heads} kv:{args.num_kv_heads} ws:{world_size} ga:{grad_accum_steps} s:{args.seed}") - - # --- Data loader & helpers --- - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - def zero_grad_all(): - for opt in optimizers: - opt.zero_grad(set_to_none=True) - - max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None - - def lr_mul(step: int, elapsed_ms: float): - if args.warmdown_fraction <= 0: - return 1.0 - if max_wallclock_ms is None: - warmdown_start = int(args.iterations * (1.0 - args.warmdown_fraction)) - return max((args.iterations - step) / max(args.iterations * args.warmdown_fraction, 1), 0.0) if step >= warmdown_start else 1.0 - warmdown_ms = max_wallclock_ms * args.warmdown_fraction - remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) - return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 - - _seq_switched = False - _batch_switched = False - active_seq_len = args.seq_len_start if args.seq_len_start > 0 else args.train_seq_len - active_batch_tokens = args.batch_tokens_start if args.batch_tokens_start > 0 else args.train_batch_tokens - - # --- Compiler warmup --- - if args.warmup_steps > 0: - _ms = {n: t.detach().cpu().clone() for n, t in base_model.state_dict().items()} - _os = [copy.deepcopy(o.state_dict()) for o in optimizers] - model.train() - for ws in range(args.warmup_steps): - zero_grad_all() - for mi in range(grad_accum_steps): - if distributed: model.require_backward_grad_sync = mi == grad_accum_steps - 1 - x, y = train_loader.next_batch(active_batch_tokens, active_seq_len, grad_accum_steps) - torch.compiler.cudagraph_mark_step_begin() - with torch.autocast(device_type="cuda", dtype=torch.bfloat16): loss = model(x, y) - (loss * grad_scale).backward() - for o in optimizers: o.step() - zero_grad_all() - log0(f"warmup:{ws+1}/{args.warmup_steps}") - base_model.load_state_dict(_ms, strict=True) - for o, s in zip(optimizers, _os): o.load_state_dict(s) - zero_grad_all() - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - # --- Main training loop --- - training_time_ms = 0.0 - stop_after_step: int | None = None - _untied = False - train_loss = torch.zeros((), device=device) - torch.cuda.synchronize() - t0 = time.perf_counter() - step = 0 - - while True: - last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) - - if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): - torch.cuda.synchronize() - training_time_ms += 1000.0 * (time.perf_counter() - t0) - val_loss, val_bpb = eval_val(args, model, rank, world_size, device, grad_accum_steps, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut) - tstats = tern_stats(base_model, group_size=args.bitnet_group_size) - log0(f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " - f"train_time:{training_time_ms:.0f}ms zero_frac:{tstats['zero_frac']:.3f}") - torch.cuda.synchronize() - t0 = time.perf_counter() - - if last_step: - if stop_after_step is not None and step < args.iterations: - log0(f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms step:{step}/{args.iterations}") - break - - elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - scale = lr_mul(step, elapsed_ms) - - # Sequence length schedule - if args.seq_len_start > 0 and not _seq_switched: - if max_wallclock_ms is not None: - should_switch_seq = elapsed_ms >= args.seq_schedule_fraction * max_wallclock_ms - else: - should_switch_seq = step >= int(args.iterations * args.seq_schedule_fraction) - if should_switch_seq: - active_seq_len = args.train_seq_len - _seq_switched = True - torch._dynamo.reset() - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - log0(f"step:{step} seq_len_switch:{args.seq_len_start}->{active_seq_len}") - - # Batch size schedule - if args.batch_tokens_start > 0 and not _batch_switched: - if max_wallclock_ms is not None: - should_switch_batch = elapsed_ms >= args.batch_schedule_fraction * max_wallclock_ms - else: - should_switch_batch = step >= int(args.iterations * args.batch_schedule_fraction) - if should_switch_batch: - active_batch_tokens = args.train_batch_tokens - _batch_switched = True - log0(f"step:{step} batch_switch:{args.batch_tokens_start}->{active_batch_tokens}") - - zero_grad_all() - train_loss.zero_() - - for micro in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro == grad_accum_steps - 1 - x, y = train_loader.next_batch(active_batch_tokens, active_seq_len, grad_accum_steps) - torch.compiler.cudagraph_mark_step_begin() - with torch.autocast(device_type="cuda", dtype=torch.bfloat16): - loss = model(x, y) - train_loss.add_(loss.detach()) - (loss * grad_scale).backward() - train_loss /= grad_accum_steps - - # Untie lm_head at configured fraction of training - if args.untie_at_fraction > 0: - if max_wallclock_ms is not None: - should_untie = not _untied and elapsed_ms >= args.untie_at_fraction * max_wallclock_ms - else: - should_untie = not _untied and step >= int(args.iterations * args.untie_at_fraction) - if should_untie and base_model.tie_embeddings: - with torch.no_grad(): - base_weight = base_model.tok_emb.weight.float() - if base_model.lm_head_correction is not None: - base_weight = base_weight + base_model.lm_head_correction.float() - if base_model.embed_proj_rev is not None: - full_weight = base_weight @ base_model.embed_proj_rev.weight.float() - else: - full_weight = base_weight - base_model.lm_head.weight.copy_(full_weight) - base_model.tie_embeddings = False - base_model.lm_head.weight.requires_grad_(True) - for g in opt_head.param_groups: - g["lr"] = g["base_lr"] = args.head_lr - _untied = True - torch._dynamo.reset() - log0(f"step:{step} untied lm_head (head_lr={args.head_lr})") - - # Muon momentum warmup - if args.matrix_optimizer != "adam": - frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 - for g in opt_muon.param_groups: - g["momentum"] = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum - - # LR scheduling - for opt in optimizers: - for g in opt.param_groups: - g["lr"] = g["base_lr"] * scale - opt.step() - zero_grad_all() - step += 1 - approx_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - - if args.train_log_every > 0 and step % args.train_log_every == 0: - log0(f"step:{step}/{args.iterations} loss:{train_loss.item():.4f} t:{approx_ms:.0f}ms avg:{approx_ms/step:.1f}ms") - if args.churn_log_every > 0 and step % args.churn_log_every == 0: - log0(f"step:{step} churn:{churn_fn(base_model, args.bitnet_group_size):.4f} zero:{tern_stats(base_model, args.bitnet_group_size)['zero_frac']:.3f}") - - # Wallclock cap sync - if stop_after_step is None and max_wallclock_ms is not None and step % 10 == 0: - reached_cap = approx_ms >= max_wallclock_ms - if distributed: - cap_t = torch.tensor(int(reached_cap), device=device) - dist.all_reduce(cap_t, op=dist.ReduceOp.MAX) - reached_cap = bool(cap_t.item()) - if reached_cap: - stop_after_step = step - - # --- Serialization --- - if master_process: - sd = base_model.state_dict() - if base_model.tie_embeddings or args.logit_head_type == "tversky": - sd.pop("lm_head.weight", None) - - # Compute ternary overrides for no-features Tversky prototypes - ternary_overrides = set() - for n, m in base_model.named_modules(): - if isinstance(m, TverskyProjection) and m.no_features_mode: - ternary_overrides.add(n + ".prototypes") - ternary_overrides = ternary_overrides or None - - # Two methods: Standard Base-3 vs Bitmask Mapping - methods = {} - for method in ("standard", "bitmask"): - q_obj, stats = q_sd(sd, group_size=args.bitnet_group_size, fp_storage=args.fp_storage, ternary_method=method, ternary_override_names=ternary_overrides) - buf = io.BytesIO() - torch.save(q_obj, buf) - methods[method] = {"blob": lzma.compress(buf.getvalue(), preset=9), "stats": stats} - best = min(methods, key=lambda m: len(methods[m]["blob"])) - final_blob, q_stats = methods[best]["blob"], methods[best]["stats"] - with open("final_model.ternary.ptz", "wb") as f: - f.write(final_blob) - - artifact_bytes = len(final_blob) - code_bytes = len(code.encode("utf-8")) - - total = artifact_bytes + code_bytes - log0(f"artifact:{artifact_bytes/1e6:.2f}MB ternary:{q_stats['ternary_params']}({q_stats['ternary_bytes']}B) fp:{q_stats['fp_params']}({q_stats['fp_bytes']}B) code:{code_bytes}") - log0(f"budget:{total}/{16000000} ({total/1e6:.2f}/{16.00:.2f}MB) {'FITS' if total <= 16000000 else 'OVER'}") - - if args.eval_depth_recurrence > 0: - base_model.training_depth_recurrence = args.eval_depth_recurrence - log0(f"eval_depth_recurrence:{args.eval_depth_recurrence}") - - # --- All ranks load roundtrip weights and evaluate --- - if distributed: - dist.barrier() - - with open("final_model.ternary.ptz", "rb") as f: - loaded = torch.load(io.BytesIO(lzma.decompress(f.read())), map_location="cpu", weights_only=False) - base_model.load_state_dict(deq_sd(loaded), strict=False) - torch._dynamo.reset() - - q_val_loss, q_val_bpb = eval_val(args, model, rank, world_size, device, grad_accum_steps, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut) - log0(f"final_ternary_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f}") - - opt_temp = 1.0 - if args.temp_scaling: - torch.cuda.synchronize() - t_temp = time.perf_counter() - calibration_tokens = train_loader.stream.take(65536).to(device) - opt_temp = find_temp(args, base_model, rank, world_size, device, grad_accum_steps, - calibration_tokens, base_bytes_lut, has_leading_space_lut, - is_boundary_token_lut) - torch.cuda.synchronize() - temp_time_ms = 1000.0 * (time.perf_counter() - t_temp) - log0(f"temp_scaling optimal_T:{opt_temp:.2f} eval_time:{temp_time_ms:.0f}ms") - - if args.sliding_eval: - torch.cuda.synchronize() - t_sliding = time.perf_counter() - sw_loss, sw_bpb = eval_val_sliding(args, base_model, rank, world_size, device, grad_accum_steps, - val_tokens, base_bytes_lut, has_leading_space_lut, - is_boundary_token_lut, stride=args.sliding_eval_stride, - temperature=opt_temp) - torch.cuda.synchronize() - sliding_time_ms = 1000.0 * (time.perf_counter() - t_sliding) - log0(f"final_sliding val_loss:{sw_loss:.4f} val_bpb:{sw_bpb:.4f} " - f"(stride={args.sliding_eval_stride}, T={opt_temp:.2f}) eval_time:{sliding_time_ms:.0f}ms") - - if distributed: - dist.destroy_process_group() - -if __name__ == "__main__": - main() -==================================================================================================== -Python 3.13.12 | packaged by Anaconda, Inc. | (main, Feb 24 2026, 16:13:31) [GCC 14.3.0] -PyTorch 2.10.0+cu128 ---- Hyperparameters --- -activation_type=relu2 adam_eps=1e-08 adam_lr=0.05 adam_wd=0.05 attn_proj_type=standard batch_schedule_fraction=0.33 batch_tokens_start=0 beta1=0.9 beta2=0.95 bigram_hash=False bitnet_group_size=128 churn_log_every=0 compile_mode=default corr_weight_lr=0.02 data_path=./data/datasets/fineweb10B_sp8192 diff_attn=False embed_dim=254 embed_lr=0.6 eval_depth_recurrence=0 fp_storage=True grad_clip_norm=0.0 head_lr=0.02 iterations=10000 logit_head_type=standard logit_softcap=10.0 matrix_lr=0.04 matrix_optimizer=muon max_wallclock_seconds=599.0 mlp_groups=0 mlp_mult=4 model_dim=768 mtp_heads_count=0 muon_backend_steps=3 muon_momentum=0.95 muon_momentum_warmup_start=0.85 muon_momentum_warmup_steps=500 muon_wd=0.0 num_heads=8 num_kv_heads=4 num_layers=10 qk_gain_init=2.25 refiner=False refiner_kernel=3 rope_base=5000.0 rope_type=yarn run_id=pushing_run_ternary_3 scalar_lr=0.02 seed=7 seq_len_start=0 seq_schedule_fraction=0.0 sliding_batch_size=256 sliding_eval=True sliding_eval_stride=16 smear=False softcap_type=poly temp_scaling=True tie_embeddings=1 tied_embed_init_std=0.005 tied_embed_lr=0.02 tokenizer_path=./data/tokenizers/fineweb_8192_bpe.model train_batch_tokens=524288 train_log_every=1000 train_seq_len=1024 training_depth_recurrence=0 tversky_feature_pools=1 tversky_membership=sigmoid tversky_num_features=128 untie_at_fraction=0.0 val_batch_size=524288 val_loss_every=0 vocab_size=8192 warmdown_fraction=0.2 warmup_steps=5 yarn_max_len=2048 -params:73685840 L:10 d:768 h:8 kv:4 ws:8 ga:1 s:7 -warmup:1/5 -warmup:2/5 -warmup:3/5 -warmup:4/5 -warmup:5/5 -step:1000/10000 loss:3.3114 t:91476ms avg:91.5ms -step:2000/10000 loss:3.2911 t:183385ms avg:91.7ms -step:3000/10000 loss:3.1520 t:275324ms avg:91.8ms -step:4000/10000 loss:3.3148 t:367161ms avg:91.8ms -step:5000/10000 loss:3.1792 t:459015ms avg:91.8ms -step:6000/10000 loss:3.0227 t:550864ms avg:91.8ms -step:6530/10000 val_loss:3.0535 val_bpb:1.1823 train_time:599602ms zero_frac:0.336 -stopping_early: wallclock_cap train_time:599602ms step:6530/10000 -artifact:15.92MB ternary:64880640(14237274B) fp:2513744(2537376B) code:70853 -budget:15992753/16000000 (15.99/16.00MB) FITS -final_ternary_roundtrip val_loss:3.0604 val_bpb:1.1850 -temp_scaling optimal_T:0.90 eval_time:150ms -final_sliding val_loss:2.9901 val_bpb:1.1578 (stride=16, T=0.90) eval_time:429023ms diff --git a/records/track_10min_16mb/2026-03-24_74M_Ternary_UNet_FP8_10L_8192BPE_YaRN_NeoMuon/train_gpt_cuda_ternary.py b/records/track_10min_16mb/2026-03-24_74M_Ternary_UNet_FP8_10L_8192BPE_YaRN_NeoMuon/train_gpt_cuda_ternary.py deleted file mode 100644 index 750bdff79d..0000000000 --- a/records/track_10min_16mb/2026-03-24_74M_Ternary_UNet_FP8_10L_8192BPE_YaRN_NeoMuon/train_gpt_cuda_ternary.py +++ /dev/null @@ -1,1436 +0,0 @@ -"Ternary training script for OpenAI's Parameter Golf Challenge. Ciprian-Florin Ifrim - 24 March 2026" - -import copy -import glob -import io -import math -import os -import random -import sys -import time -import lzma -from pathlib import Path -import numpy as np -import sentencepiece as spm -import torch -import torch.distributed as dist -import torch.nn.functional as F -from torch import Tensor, nn -from torch.nn.parallel import DistributedDataParallel as DDP -from flash_attn_interface import flash_attn_func - -# --------------------------------------------------------------------------- -# Hyperparameters (all configurable via environment variables) -# --------------------------------------------------------------------------- -def _e(k, d, t=str): - v = os.environ.get(k, str(d)) - if t == bool: return bool(int(v)) - return t(v) - -class Hyperparameters: - data_path = _e("DATA_PATH", "./data/datasets/fineweb10B_sp1024") - train_files = os.path.join(data_path, "fineweb_train_*.bin") - val_files = os.path.join(data_path, "fineweb_val_*.bin") - tokenizer_path = _e("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") - run_id = os.environ.get("RUN_ID", f"run_{int(time.time())}") - seed = _e("SEED", 1337, int) - compile_mode = _e("COMPILE_MODE", "default") - val_batch_size = _e("VAL_BATCH_SIZE", 524288, int) - val_loss_every = _e("VAL_LOSS_EVERY", 500, int) - train_log_every = _e("TRAIN_LOG_EVERY", 10, int) - iterations = _e("ITERATIONS", 2000, int) - warmdown_fraction = _e("WARMDOWN_FRACTION", 0.2, float) - warmup_steps = _e("WARMUP_STEPS", 20, int) - train_batch_tokens = _e("TRAIN_BATCH_TOKENS", 524288, int) - train_seq_len = _e("TRAIN_SEQ_LEN", 1024, int) - max_wallclock_seconds = _e("MAX_WALLCLOCK_SECONDS", 0.0, float) - vocab_size = _e("VOCAB_SIZE", 1024, int) - num_layers = _e("NUM_LAYERS", 16, int) - num_kv_heads = _e("NUM_KV_HEADS", 4, int) - model_dim = _e("MODEL_DIM", 512, int) - num_heads = _e("NUM_HEADS", 8, int) - mlp_mult = _e("MLP_MULT", 2, int) - tie_embeddings = _e("TIE_EMBEDDINGS", 1, int) - rope_base = _e("ROPE_BASE", 10000.0, float) - rope_type = _e("ROPE_TYPE", "rope") - yarn_max_len = _e("YARN_MAX_LEN", 4096, int) - logit_softcap = _e("LOGIT_SOFTCAP", 30.0, float) - softcap_type = _e("SOFTCAP_TYPE", "poly") - tied_embed_init_std = _e("TIED_EMBED_INIT_STD", 0.005, float) - qk_gain_init = _e("QK_GAIN_INIT", 1.5, float) - activation_type = _e("ACTIVATION", "swiglu") - embed_dim = _e("EMBED_DIM", 0, int) - bigram_hash = _e("BIGRAM_HASH", 0, bool) - mtp_heads_count = _e("MTP_HEADS", 0, int) - training_depth_recurrence = _e("TRAINING_DEPTH_RECURRENCE", 1, int) - eval_depth_recurrence = _e("EVAL_DEPTH_RECURRENCE", 1, int) - attn_proj_type = _e("ATTN_PROJ_TYPE", "standard") - logit_head_type = _e("LOGIT_HEAD_TYPE", "standard") - tversky_num_features = _e("TVERSKY_NUM_FEATURES", 16, int) - tversky_feature_pools = _e("TVERSKY_FEATURE_POOLS", 0, int) - tversky_membership = _e("TVERSKY_MEMBERSHIP", "sigmoid") - diff_attn = _e("DIFF_ATTN", 0, bool) - refiner = _e("REFINER", 0, bool) - refiner_kernel = _e("REFINER_KERNEL", 3, int) - mlp_groups = _e("MLP_GROUPS", 0, int) - embed_lr = _e("EMBED_LR", 0.6, float) - head_lr = _e("HEAD_LR", 0.008, float) - adam_lr = _e("ADAM_LR", 1e-3, float) - adam_wd = _e("ADAM_WD", 0.05, float) - untie_at_fraction = _e("UNTIE_AT_FRACTION", 0.0, float) - tied_embed_lr = _e("TIED_EMBED_LR", 0.05, float) - corr_weight_lr = _e("CORR_WEIGHT_LR", 0.05, float) - smear = _e("SMEAR", 0, bool) - seq_len_start = _e("SEQ_LEN_START", 0, int) - seq_schedule_fraction = _e("SEQ_SCHEDULE_FRACTION", 0.33, float) - batch_tokens_start = _e("BATCH_TOKENS_START", 0, int) - batch_schedule_fraction = _e("BATCH_SCHEDULE_FRACTION", 0.33, float) - churn_log_every = _e("CHURN_LOG_EVERY", 500, int) - matrix_lr = _e("MATRIX_LR", 0.04, float) - scalar_lr = _e("SCALAR_LR", 0.04, float) - muon_momentum = _e("MUON_MOMENTUM", 0.95, float) - muon_backend_steps = _e("MUON_BACKEND_STEPS", 5, int) - muon_wd = _e("MUON_WD", 0.0, float) - matrix_optimizer = _e("MATRIX_OPTIMIZER", "muon") - muon_momentum_warmup_start = _e("MUON_MOMENTUM_WARMUP_START", 0.85, float) - muon_momentum_warmup_steps = _e("MUON_MOMENTUM_WARMUP_STEPS", 500, int) - beta1 = _e("BETA1", 0.9, float) - beta2 = _e("BETA2", 0.95, float) - adam_eps = _e("ADAM_EPS", 1e-8, float) - grad_clip_norm = _e("GRAD_CLIP_NORM", 0.0, float) - bitnet_group_size = _e("BITNET_GROUP_SIZE", 64, int) - sliding_eval = _e("SLIDING_EVAL", 0, bool) - sliding_eval_stride = _e("SLIDING_EVAL_STRIDE", 64, int) - sliding_batch_size = _e("SLIDING_BATCH_SIZE", 64, int) - temp_scaling = _e("TEMP_SCALING", 0, bool) - _fp_raw = os.environ.get("FP_STORAGE", "0") - fp_storage = True if _fp_raw == "FP8" else ("fp4" if _fp_raw == "FP4" else False) - -CTP = ("attn_scale","attn_scales","mlp_scale","mlp_scales","resid_mix","resid_mixes","q_gain","diff_lambda","skip_weight","skip_weights","vocab_bias","refiner.gate") - -# --------------------------------------------------------------------------- -# Ternary packing — base-3 encoding (5 trits/byte) -# --------------------------------------------------------------------------- -def pack_ternary(q: Tensor): - f = (q.reshape(-1).to(torch.int8) + 1).numpy() - n = len(f) - p = (5 - n % 5) % 5 - if p: f = np.concatenate([f, np.zeros(p, dtype=np.int8)]) - g = f.reshape(-1, 5).astype(np.uint8) - return (g[:,0] + g[:,1]*3 + g[:,2]*9 + g[:,3]*27 + g[:,4]*81).tobytes(), n - -def unpack_ternary(data: bytes, n: int) -> Tensor: - v = np.frombuffer(data, dtype=np.uint8).astype(np.int16) - t = np.zeros((len(v), 5), dtype=np.int8) - for i in range(5): t[:,i] = v % 3; v //= 3 - return torch.from_numpy(t.reshape(-1)[:n].astype(np.int8) - 1) - -def pack_ternary_bitmask(q: Tensor): - f = q.reshape(-1).to(torch.int8).numpy(); n = len(f) - nz = (f != 0) - return np.packbits(nz).tobytes() + np.packbits(f[nz] > 0).tobytes(), n - -def unpack_ternary_bitmask(data: bytes, n: int) -> Tensor: - ms = (n + 7) // 8 - nz = np.unpackbits(np.frombuffer(data[:ms], dtype=np.uint8))[:n].astype(bool) - s = np.unpackbits(np.frombuffer(data[ms:], dtype=np.uint8))[:int(nz.sum())].astype(bool) - w = np.zeros(n, dtype=np.int8); w[nz] = np.where(s, 1, -1) - return torch.from_numpy(w) - -# --------------------------------------------------------------------------- -# FP4 quantization (per-row absmax, 2 values packed per byte) -# --------------------------------------------------------------------------- -def quantize_to_int4(t: Tensor) -> tuple[Tensor, Tensor, list]: - t32 = t.float() - orig_shape = t32.shape - if t32.ndim < 2: - t32 = t32.unsqueeze(0) - absmax = t32.abs().amax(dim=-1, keepdim=True).clamp(min=1e-8) - scale = absmax / 7.0 - q = torch.clamp(torch.round(t32 / scale), -7, 7).to(torch.int8) - flat = q.reshape(-1) - if flat.numel() % 2 != 0: - flat = F.pad(flat, (0, 1)) - low = (flat[0::2] + 8).to(torch.uint8) - high = (flat[1::2] + 8).to(torch.uint8) - return low | (high << 4), scale.half().squeeze(-1), list(orig_shape) - -def dequantize_from_int4(packed: Tensor, scale: Tensor, shape: list) -> Tensor: - low = (packed & 0x0F).to(torch.int8) - 8 - high = ((packed >> 4) & 0x0F).to(torch.int8) - 8 - flat = torch.zeros(packed.numel() * 2, dtype=torch.int8) - flat[0::2] = low - flat[1::2] = high - numel = 1 - for s in shape: - numel *= s - flat = flat[:numel].float() - if len(shape) <= 1: - return (flat * scale.float().squeeze()).reshape(shape) - return (flat.reshape(-1, shape[-1]) * scale.float().unsqueeze(-1)).reshape(shape) - -# --------------------------------------------------------------------------- -# State dict serialization (ternary + fp16/fp8/fp4) -# --------------------------------------------------------------------------- -def q_sd(state_dict: dict, group_size: int = 64, fp_storage=False, ternary_method="standard", ternary_override_names: set | None = None) -> tuple[dict, dict]: - "Ternary for large 2D weight matrices, fp16/fp8/fp4 for everything else." - quantized = {} - stats = {"ternary_params": 0, "ternary_bytes": 0, "fp_params": 0, "fp_bytes": 0} - for name, tensor in state_dict.items(): - if "mtp_heads" in name: - continue - t = tensor.detach().cpu().float().contiguous() - t_orig_shape = list(t.shape) - if t.ndim == 3: - t = t.reshape(t.shape[0], -1) - is_ternary_candidate = ( - t.ndim == 2 and t.numel() > 65_536 - and "tok_emb" not in name and "lm_head" not in name and "embed_proj" not in name and "bigram_emb" not in name and "lm_head_correction" not in name and "lm_head_U" not in name and "lm_head_V" not in name - and "prototypes" not in name and "tversky" not in name - ) or (ternary_override_names is not None and name in ternary_override_names) - if is_ternary_candidate: - pad = (group_size - t.shape[1] % group_size) % group_size - t_padded = F.pad(t, (0, pad)) if pad > 0 else t - t_grouped = t_padded.reshape(-1, group_size) - scale = t_grouped.abs().mean(-1, keepdim=True).clamp(min=1e-8).half().float() - q = (t_grouped / scale).round().clamp(-1, 1).to(torch.int8) - - if ternary_method == "standard": - packed_bytes, n_trits = pack_ternary(q) - entry_type = "ternary" - else: - packed_bytes, n_trits = pack_ternary_bitmask(q) - entry_type = "ternary_bitmask" - - quantized[name] = { - "type": entry_type, "packed": packed_bytes, - "scale": scale.half().squeeze(-1), - "shape": list(t.shape), "padded_cols": t_padded.shape[1], - "group_size": group_size, "n_trits": n_trits, - "orig_shape": t_orig_shape, - } - stats["ternary_params"] += t.numel() - stats["ternary_bytes"] += len(packed_bytes) + scale.numel() * 2 - elif fp_storage == "fp4" and t.ndim == 2: - packed, scale, orig_shape = quantize_to_int4(t) - quantized[name] = {"type": "fp4", "packed": packed, "scale": scale, "shape": orig_shape} - stats["fp_params"] += t.numel() - stats["fp_bytes"] += packed.numel() + scale.numel() * 2 - elif fp_storage and t.ndim == 2: - quantized[name] = {"type": "fp8", "data": t.to(torch.float8_e4m3fn)} - stats["fp_params"] += t.numel() - stats["fp_bytes"] += t.numel() - else: - quantized[name] = {"type": "fp16", "data": t.half()} - stats["fp_params"] += t.numel() - stats["fp_bytes"] += t.numel() * 2 - return quantized, stats - -def deq_sd(quantized: dict, target_dtype=torch.bfloat16): - "Reconstruct full-precision state dict from quantized representation." - out = {} - for name, entry in quantized.items(): - if entry["type"] in ("ternary", "ternary_bitmask"): - if entry["type"] == "ternary": - q = unpack_ternary(entry["packed"], entry["n_trits"]) - else: - q = unpack_ternary_bitmask(entry["packed"], entry["n_trits"]) - - q = q.float().reshape(-1, entry["group_size"]) - scale = entry["scale"].float().unsqueeze(-1) - q_absmean = q.abs().mean(-1, keepdim=True).clamp(min=1e-8) - t = (q * (scale / q_absmean)).reshape(-1, entry["padded_cols"]) - shape = entry["shape"] - result = t[:shape[0], :shape[1]].to(target_dtype) - orig = entry.get("orig_shape") - out[name] = result.reshape(orig).contiguous() if orig and orig != shape else result.contiguous() - elif entry["type"] == "fp8": - out[name] = entry["data"].to(torch.float32).to(target_dtype).contiguous() - elif entry["type"] == "fp4": - out[name] = dequantize_from_int4(entry["packed"], entry["scale"], entry["shape"]).to(target_dtype).contiguous() - else: - out[name] = entry["data"].to(target_dtype).contiguous() - return out - -# --------------------------------------------------------------------------- -# Ternary diagnostics (logged during training) -# --------------------------------------------------------------------------- -def tern_stats(model: nn.Module, group_size: int = 64): - total = zeros = 0 - with torch.no_grad(): - for name, p in model.named_parameters(): - if p.ndim == 2 and ("weight" in name or "prototypes" in name) and p.shape[0] > 1: - w = p.detach().float().reshape(-1, group_size) - scale = w.abs().mean(-1, keepdim=True).clamp(min=1e-8).half().float() - q = (w / scale).round().clamp(-1, 1) - zeros += int((q == 0).sum().item()) - total += int(q.numel()) - return {"zero_frac": zeros / max(total, 1), "total_weights": total} - -_prev_committed: dict = {} - -def churn_fn(model: nn.Module, group_size: int = 64): - global _prev_committed - total = flipped = 0 - with torch.no_grad(): - for name, p in model.named_parameters(): - if p.ndim == 2 and ("weight" in name or "prototypes" in name) and p.shape[0] > 1: - w = p.detach().float().reshape(-1, group_size) - scale = w.abs().mean(-1, keepdim=True).clamp(min=1e-8).half().float() - q = (w / scale).round().clamp(-1, 1).cpu().numpy() - if name in _prev_committed: - flipped += int(np.sum(q != _prev_committed[name])) - total += q.size - _prev_committed[name] = q - return flipped / max(total, 1) - -# --------------------------------------------------------------------------- -# Muon optimizer (Newton-Schulz orthogonalized momentum) -# --------------------------------------------------------------------------- -def ns_orth(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: - a, b, c = (3.4445, -4.7750, 2.0315) - X = G.bfloat16() - X /= X.norm() + eps - transposed = G.size(0) > G.size(1) - if transposed: - X = X.T - for _ in range(steps): - A = X @ X.T - B = b * A + c * A @ A - X = a * X + B @ X - return X.T if transposed else X - -class Muon(torch.optim.Optimizer): - def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True, wd: float = 0.0): - super().__init__(params, dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov, wd=wd)) - - @torch.no_grad() - def step(self, closure=None): - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - distributed = dist.is_available() and dist.is_initialized() - world_size = dist.get_world_size() if distributed else 1 - rank = dist.get_rank() if distributed else 0 - for group in self.param_groups: - params = group["params"] - if not params: - continue - lr, momentum = group["lr"], group["momentum"] - backend_steps, nesterov = group["backend_steps"], group["nesterov"] - total_params = sum(int(p.numel()) for p in params) - updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) - curr = 0 - for i, p in enumerate(params): - if i % world_size == rank and p.grad is not None: - g = p.grad - state = self.state[p] - if "momentum_buffer" not in state: - state["momentum_buffer"] = torch.zeros_like(g) - buf = state["momentum_buffer"] - buf.mul_(momentum).add_(g) - if nesterov: - g = g.add(buf, alpha=momentum) - g = F.rms_norm(g.float(), (g.size(-1),)).bfloat16() - g = ns_orth(g, steps=backend_steps) - g *= max(1, g.size(0) / g.size(1)) ** 0.5 - updates_flat[curr:curr + p.numel()] = g.reshape(-1) - curr += p.numel() - if distributed: - dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) - wd = group.get("wd", 0.0) - curr = 0 - for p in params: - g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) - if wd > 0: - p.mul_(1 - lr * wd) - p.add_(g, alpha=-lr) - curr += p.numel() - return loss - -# --------------------------------------------------------------------------- -# Data loading -# --------------------------------------------------------------------------- -def ld_shard(file: Path) -> Tensor: - header_bytes = 256 * np.dtype(" Tensor: - chunks = [] - remaining = n - while remaining > 0: - avail = self.tokens.numel() - self.pos - if avail <= 0: - self._advance_file() - continue - k = min(remaining, avail) - chunks.append(self.tokens[self.pos:self.pos + k]) - self.pos += k - remaining -= k - return chunks[0] if len(chunks) == 1 else torch.cat(chunks) - -class DistributedTokenLoader: - def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): - self.rank, self.world_size, self.device = rank, world_size, device - self.stream = TokenStream(pattern) - - def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: - local_tokens = global_tokens // (self.world_size * grad_accum_steps) - per_rank_span = local_tokens + 1 - chunk = self.stream.take(per_rank_span * self.world_size) - start = self.rank * per_rank_span - local = chunk[start:start + per_rank_span].pin_memory().to(self.device, non_blocking=True).to(torch.int64) - x = local[:-1].reshape(-1, seq_len) - y = local[1:].reshape(-1, seq_len) - return x, y - -# --------------------------------------------------------------------------- -# Model -# --------------------------------------------------------------------------- -class RMSNorm(nn.Module): - def __init__(self, eps: float | None = None): - super().__init__() - self.eps = eps - - def forward(self, x: Tensor) -> Tensor: - return F.rms_norm(x, (x.size(-1),), eps=self.eps) - -def apply_qat_ste(w: Tensor, fp_storage: str | bool) -> Tensor: - """Applies Straight-Through Estimator (STE) for FP4 or FP8 simulated quantization.""" - if not fp_storage: - return w - if fp_storage == "fp4": - absmax = w.abs().amax(dim=-1, keepdim=True).clamp(min=1e-8) - scale = absmax / 7.0 - q = torch.clamp(torch.round(w / scale), -7.0, 7.0) - w_sim = q * scale - return (w_sim - w).detach() + w - elif fp_storage is True or fp_storage == "fp8": - w_sim = w.to(torch.float8_e4m3fn).to(w.dtype) - return (w_sim - w).detach() + w - return w - -class QATLinear(nn.Linear): - def __init__(self, in_features: int, out_features: int, bias: bool = False, fp_storage: str | bool = False): - super().__init__(in_features, out_features, bias=bias) - self.fp_storage = fp_storage - - def forward(self, x: Tensor) -> Tensor: - w_qat = apply_qat_ste(self.weight, self.fp_storage) - return F.linear(x, w_qat.to(x.dtype), self.bias.to(x.dtype) if self.bias is not None else None) - -class QATEmbedding(nn.Embedding): - def __init__(self, num_embeddings: int, embedding_dim: int, fp_storage: str | bool = False): - super().__init__(num_embeddings, embedding_dim) - self.fp_storage = fp_storage - - def forward(self, input: Tensor) -> Tensor: - w_qat = apply_qat_ste(self.weight, self.fp_storage) - return F.embedding(input, w_qat, self.padding_idx, self.max_norm, - self.norm_type, self.scale_grad_by_freq, self.sparse) - -class TernaryLinear(nn.Linear): - def __init__(self, in_features, out_features, bias=False, group_size=64): - super().__init__(in_features, out_features, bias=bias) - self.group_size = group_size - - def forward(self, x: Tensor) -> Tensor: - w = self.weight.bfloat16() - g = self.group_size - w_g = w.reshape(-1, g) - scale = w_g.abs().mean(-1, keepdim=True).clamp(min=1e-8) - q = (w_g / scale).round().clamp(-1, 1) - w_ternary = w + ((q * scale).reshape(w.shape) - w).detach() - return F.linear(x, w_ternary, - self.bias.to(x.dtype) if self.bias is not None else None) - - -class NormedTernaryLinear(TernaryLinear): - "Ternary linear with RMSNorm on input — for output projections receiving un-normalized activations." - def forward(self, x: Tensor) -> Tensor: - return super().forward(F.rms_norm(x, (x.size(-1),))) - -class GroupedTernaryLinear(nn.Module): - "Grouped linear with ternary STE. Weight stored as 2D [groups*group_out, group_in] for ternary quantization compatibility." - def __init__(self, in_features, out_features, groups=4, group_size=64, normed=False): - super().__init__() - assert in_features % groups == 0 and out_features % groups == 0 - self.groups = groups - self.group_in = in_features // groups - self.group_out = out_features // groups - self.group_size = group_size - self.normed = normed - self.weight = nn.Parameter(torch.randn(groups * self.group_out, self.group_in) * 0.02) - - def forward(self, x: Tensor) -> Tensor: - if self.normed: - x = F.rms_norm(x, (x.size(-1),)) - w = self.weight.bfloat16() - g = self.group_size - w_g = w.reshape(-1, g) - scale = w_g.abs().mean(-1, keepdim=True).clamp(min=1e-8) - q = (w_g / scale).round().clamp(-1, 1) - w_ternary = w + ((q * scale).reshape(w.shape) - w).detach() - w_grouped = w_ternary.reshape(self.groups, self.group_out, self.group_in) - bsz = x.shape[:-1] - x_g = x.reshape(*bsz, self.groups, self.group_in) - out = torch.einsum('...gi,goi->...go', x_g, w_grouped) - return out.reshape(*bsz, self.groups * self.group_out) - -class TverskyProjection(nn.Module): - "Tversky similarity: S = θ·f(A∩B) - α·f(A\\B) - β·f(B\\A). Three modes." - def __init__(self, in_features: int, out_features: int, num_features: int = 16, - group_size: int = 64, use_shared_features: bool = False, - membership: str = "sigmoid"): - super().__init__() - self.group_size = group_size - self.num_features = num_features - self.membership_type = membership - self.no_features_mode = (num_features == 0) - - if not self.no_features_mode and not use_shared_features: - self.features = nn.Parameter(torch.empty(num_features, in_features).uniform_(-0.02, 0.02)) - else: - self.register_parameter('features', None) - - self.prototypes = nn.Parameter(torch.empty(out_features, in_features).uniform_(-0.02, 0.02)) - self.theta = nn.Parameter(torch.tensor(1.0)) - self.alpha = nn.Parameter(torch.tensor(0.5)) - self.beta = nn.Parameter(torch.tensor(0.5)) - - def _ternary_ste(self, w: Tensor) -> Tensor: - w_bf16 = w.bfloat16() - g = self.group_size - w_grouped = w_bf16.reshape(-1, g) - scale = w_grouped.abs().mean(-1, keepdim=True).clamp(min=1e-8) - q = (w_grouped / scale).round().clamp(-1, 1) - w_ternary = w_bf16 + ((q * scale).reshape(w_bf16.shape) - w_bf16).detach() - return w_ternary.reshape(w.shape) - - def _membership(self, t: Tensor) -> Tensor: - if self.membership_type == "poly": - return torch.clamp(t * 5.0 / 4.0 + 0.5, 0.0, 1.0) - elif self.membership_type == "tanh": - return (torch.tanh(t * 5.0) + 1.0) * 0.5 - else: - return torch.sigmoid(t * 5.0) - - def forward(self, x: Tensor, shared_features: Tensor | None = None) -> Tensor: - proto = self._ternary_ste(self.prototypes) - - if self.no_features_mode: - # NoFeatures: prototypes are their own feature universe - x_f = x @ proto.t() # [B, S, out] - p_norm = F.normalize(proto, dim=-1) - p_f = p_norm @ p_norm.t() # [out, out] - else: - feat = (shared_features if shared_features is not None else self.features).float() - x_f = x @ feat.t() # [B, S, nf] - p_f = proto @ feat.t() # [out, nf] - - x_s = self._membership(x_f) - p_s = self._membership(p_f) - x_a = x_f * x_s - p_a = p_f * p_s - - t, a, b = self.theta.abs(), self.alpha.abs(), self.beta.abs() - return t * (x_a @ p_a.t()) - a * (x_a @ (1 - p_s).t()) - b * ((1 - x_s) @ p_a.t()) - -def restore_low_dim_params_to_fp32(module: nn.Module) -> None: - with torch.no_grad(): - for name, param in module.named_parameters(): - if (param.ndim < 2 or any(p in name for p in CTP)) and param.dtype != torch.float32: - param.data = param.data.float() - -class Rotary(nn.Module): - def __init__(self, dim: int, base: float = 10000.0, no_cache: bool = False, - rope_type: str = "rope", yarn_max_len: int = 4096, train_seq_len: int = 1024): - super().__init__() - self.no_cache = no_cache - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) - if rope_type == "yarn": - scale = train_seq_len / yarn_max_len - freq_idx = torch.arange(0, dim, 2, dtype=torch.float32) - ramp = torch.clamp((freq_idx / dim - 0.25) / 0.75, 0.0, 1.0) - inv_freq = inv_freq / (ramp * (1.0 / scale - 1.0) + 1.0) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self._seq_len_cached = 0 - self._cos_cached: Tensor | None = None - self._sin_cached: Tensor | None = None - - def forward(self, seq_len, device, dtype): - if self.no_cache: - t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) - freqs = torch.outer(t, self.inv_freq.to(device)) - return freqs.cos()[None, :, None, :].to(dtype=dtype), freqs.sin()[None, :, None, :].to(dtype=dtype) - if ( - self._cos_cached is None - or self._sin_cached is None - or self._seq_len_cached != seq_len - or self._cos_cached.device != device - ): - t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) - freqs = torch.outer(t, self.inv_freq.to(device)) - self._cos_cached = freqs.cos()[None, :, None, :] - self._sin_cached = freqs.sin()[None, :, None, :] - self._seq_len_cached = seq_len - return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) - -def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: - half = x.size(-1) // 2 - x1, x2 = x[..., :half], x[..., half:] - return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) - -class CausalSelfAttention(nn.Module): - def __init__(self, dim, num_heads, num_kv_heads, rope_base, qk_gain_init, - group_size=64, attn_proj_type="standard", tversky_num_features=16, - tversky_feature_pools=0, no_cache=False, rope_type="rope", - yarn_max_len=4096, train_seq_len=1024, tversky_membership="sigmoid", - diff_attn=False): - super().__init__() - self.num_heads, self.num_kv_heads = num_heads, num_kv_heads - self.head_dim = dim // num_heads - self.diff_attn = diff_attn - self.q_size = self.num_heads * self.head_dim - self.kv_size = self.num_kv_heads * self.head_dim - - self.c_qkv = TernaryLinear(dim, self.q_size + 2 * self.kv_size, bias=False, group_size=group_size) - self.proj = NormedTernaryLinear(dim, dim, bias=False, group_size=group_size) if attn_proj_type != "tversky" else None - if self.proj is not None: - self.proj._zero_init = True - self.tversky_proj = TverskyProjection( - dim, dim, num_features=tversky_num_features, group_size=group_size, - use_shared_features=(tversky_feature_pools > 0), - membership=tversky_membership, - ) if attn_proj_type == "tversky" else None - self.shared_features = None - self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) - if diff_attn: - self.diff_lambda = nn.Parameter(torch.full((num_heads,), 0.5, dtype=torch.float32)) - self.rotary = Rotary(self.head_dim, base=rope_base, no_cache=no_cache, - rope_type=rope_type, yarn_max_len=yarn_max_len, - train_seq_len=train_seq_len) - - def forward(self, x: Tensor) -> Tensor: - bsz, seqlen, dim = x.shape - qkv_out = self.c_qkv(x) - q_out, k_out, v_out = qkv_out.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - q = q_out.reshape(bsz, seqlen, self.num_heads, self.head_dim) - k = k_out.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) - v = v_out.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) - q, k = F.rms_norm(q, (q.size(-1),)), F.rms_norm(k, (k.size(-1),)) - cos, sin = self.rotary(seqlen, x.device, q.dtype) - q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin) - q = q * self.q_gain.to(dtype=q.dtype)[None, None, :, None] - if self.diff_attn: - half = self.head_dim // 2 - q1, q2 = q[..., :half], q[..., half:] - k1, k2 = k[..., :half], k[..., half:] - v1, v2 = v[..., :half], v[..., half:] - y1 = flash_attn_func(q1.contiguous(), k1.contiguous(), v1.contiguous(), causal=True) - y2 = flash_attn_func(q2.contiguous(), k2.contiguous(), v2.contiguous(), causal=True) - lam = self.diff_lambda.to(dtype=y1.dtype)[None, None, :, None] - y = torch.cat([y1 - lam * y2, y1 + lam * y2], dim=-1) - else: - y = flash_attn_func( - q.contiguous(), - k.contiguous(), - v.contiguous(), - causal=True - ) - y = y.reshape(bsz, seqlen, dim) - return self.tversky_proj(y, self.shared_features) if self.tversky_proj is not None else self.proj(y) - -class MLP(nn.Module): - def __init__(self, dim, mlp_mult, group_size=64, activation="swiglu", mlp_groups=0): - super().__init__() - hidden = mlp_mult * dim - self.activation = activation - if mlp_groups > 0: - if activation == "swiglu": - self.gate_up = GroupedTernaryLinear(dim, hidden * 2, groups=mlp_groups, group_size=group_size) - else: - self.fc = GroupedTernaryLinear(dim, hidden, groups=mlp_groups, group_size=group_size) - self.proj = GroupedTernaryLinear(hidden, dim, groups=mlp_groups, group_size=group_size, normed=True) - else: - if activation == "swiglu": - self.gate_up = TernaryLinear(dim, hidden * 2, bias=False, group_size=group_size) - else: - self.fc = TernaryLinear(dim, hidden, bias=False, group_size=group_size) - self.proj = NormedTernaryLinear(hidden, dim, bias=False, group_size=group_size) - self.proj._zero_init = True - - def forward(self, x: Tensor) -> Tensor: - if self.activation == "swiglu": - gu = self.gate_up(x) - gate, up = gu.chunk(2, dim=-1) - return self.proj(F.silu(gate) * up) - elif self.activation == "relu": - return self.proj(torch.relu(self.fc(x))) - elif self.activation == "leaky_relu": - return self.proj(F.leaky_relu(self.fc(x), negative_slope=0.01)) - else: # relu2 - return self.proj(torch.relu(self.fc(x)).square()) - -class SmearModule(nn.Module): - def __init__(self, dim: int): - super().__init__() - self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) - - def forward(self, x: Tensor) -> Tensor: - cumsum = x.cumsum(dim=1) - counts = torch.arange(1, x.size(1) + 1, device=x.device, dtype=x.dtype).view(1, -1, 1) - smeared = cumsum / counts - gate = torch.tanh(self.gate.to(dtype=x.dtype)) - return x + gate * (smeared - x) - - -class CausalConvRefiner(nn.Module): - "Causal Conv1d that refines hidden states using local n-gram context." - def __init__(self, dim: int, kernel_size: int = 3): - super().__init__() - self.kernel_size = kernel_size - self.conv = nn.Conv1d(dim, dim, kernel_size, padding=0, bias=False) - self.gate = nn.Parameter(torch.zeros(1, dtype=torch.float32)) - - def forward(self, x: Tensor) -> Tensor: - h = x.permute(0, 2, 1) # [B, D, S] - h = F.pad(h, (self.kernel_size - 1, 0)) # causal pad - h = self.conv(h) - h = h.permute(0, 2, 1) # [B, S, D] - return x + torch.tanh(self.gate.to(dtype=x.dtype)) * F.rms_norm(h, (h.size(-1),)) - -class Block(nn.Module): - def __init__(self, dim: int, num_heads: int, num_kv_heads: int, mlp_mult: int, - rope_base: float, qk_gain_init: float, group_size: int=64, - activation: str="swiglu", attn_proj_type: str="standard", - tversky_num_features: int=16, tversky_feature_pools: int=0, no_cache: bool=False, - smear: bool=False, rope_type: str="rope", yarn_max_len: int=4096, - train_seq_len: int=1024, tversky_membership: str="sigmoid", - diff_attn: bool=False, mlp_groups: int=0): - super().__init__() - self.attn_norm = RMSNorm() - self.mlp_norm = RMSNorm() - self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init, - group_size, attn_proj_type, tversky_num_features, - tversky_feature_pools, no_cache, rope_type, yarn_max_len, - train_seq_len, tversky_membership, diff_attn) - self.mlp = MLP(dim, mlp_mult, group_size, activation, mlp_groups) - self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) - self.smear = SmearModule(dim) if smear else None - - def forward(self, x: Tensor, x0: Tensor) -> Tensor: - mix = self.resid_mix.to(dtype=x.dtype) - x = mix[0] * x + mix[1] * x0 - n = self.attn_norm(x) - x = x + self.attn_scale.to(dtype=x.dtype) * self.attn(n) - x = x + self.mlp_scale.to(dtype=x.dtype) * self.mlp(self.mlp_norm(x)) - if self.smear is not None: - x = self.smear(x) - return x - -class GPT(nn.Module): - def __init__(self, vocab_size, num_layers, model_dim, num_heads, num_kv_heads, mlp_mult, - tie_embeddings, tied_embed_init_std, logit_softcap, rope_base, qk_gain_init, - group_size: int = 64, activation: str = "swiglu", mtp_heads_count: int = 0, - embed_dim: int = 0, attn_proj_type: str = "standard", logit_head_type: str = "standard", - tversky_num_features: int = 16, tversky_feature_pools: int = 0, - training_depth_recurrence: int=1, fp_storage=False, bigram_hash: bool=False, - softcap_type: str="poly", no_cache: bool=False, - smear: bool=False, rope_type: str="rope", yarn_max_len: int=4096, - train_seq_len: int=1024, tversky_membership: str="sigmoid", - diff_attn=False, mlp_groups=0, refiner=False, refiner_kernel=3): - super().__init__() - self.training_depth_recurrence = training_depth_recurrence - self.fp_storage = fp_storage - self.tie_embeddings = tie_embeddings - self.logit_softcap = logit_softcap - self.softcap_type = softcap_type - self.embed_dim = embed_dim if embed_dim > 0 else model_dim - self.tok_emb = QATEmbedding(vocab_size, self.embed_dim, fp_storage=fp_storage) - self.bigram_emb = QATEmbedding(vocab_size, self.embed_dim, fp_storage=fp_storage) if bigram_hash else None - if self.bigram_emb is not None: - nn.init.zeros_(self.bigram_emb.weight) - self.lm_head_correction = nn.Parameter( - torch.zeros(vocab_size, self.embed_dim)) if tie_embeddings == 2 else None - self.embed_proj = QATLinear(self.embed_dim, model_dim, bias=False, fp_storage=fp_storage) if self.embed_dim != model_dim else None - self.embed_proj_rev = QATLinear(model_dim, self.embed_dim, bias=False, fp_storage=fp_storage) if ( - self.embed_dim != model_dim and logit_head_type != "tversky") else None - self.num_encoder_layers = num_layers // 2 - self.num_decoder_layers = num_layers - self.num_encoder_layers - self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) - self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) - - # Shared Tversky feature pools (if enabled and num_features > 0) - if attn_proj_type == "tversky" and tversky_feature_pools > 0 and tversky_num_features > 0: - self.tversky_feature_pools_list = nn.ParameterList([ - nn.Parameter(torch.empty(tversky_num_features, model_dim).uniform_(-0.02, 0.02)) - for _ in range(tversky_feature_pools) - ]) - else: - self.tversky_feature_pools_list = None - - self.blocks = nn.ModuleList([ - Block(model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init, - group_size, activation, attn_proj_type, tversky_num_features, tversky_feature_pools, - no_cache, smear, rope_type, yarn_max_len, train_seq_len, tversky_membership, - diff_attn, mlp_groups) - for _ in range(num_layers) - ]) - - # Inject shared feature pool references into attention layers - if self.tversky_feature_pools_list is not None: - for i, block in enumerate(self.blocks): - pool_idx = (i * tversky_feature_pools) // num_layers - block.attn.shared_features = self.tversky_feature_pools_list[pool_idx] - - self.final_norm = RMSNorm() - self.refiner = CausalConvRefiner(model_dim, kernel_size=refiner_kernel) if refiner else None - self.mtp_heads = nn.ModuleList([ - nn.Linear(model_dim, vocab_size, bias=False) for _ in range(mtp_heads_count) - ]) - for h in self.mtp_heads: - nn.init.zeros_(h.weight) - self.logit_head_type = logit_head_type - if logit_head_type == "tversky" and tversky_num_features == 0 and vocab_size > 1024: - raise ValueError( - f"Tversky logit head with no-features mode creates O(V^2) = {vocab_size}x{vocab_size} " - f"matrix per forward pass. Use tversky_num_features > 0 or a smaller vocab." - ) - self.tversky_head = TverskyProjection( - model_dim, vocab_size, num_features=tversky_num_features, - membership=tversky_membership, - ) if logit_head_type == "tversky" else None - self.lm_head = QATLinear(model_dim, vocab_size, bias=False, fp_storage=fp_storage) - self.lm_head._zero_init = True - if self.lm_head is not None and (tie_embeddings or logit_head_type == "tversky"): - self.lm_head.weight.requires_grad_(False) - - self.vocab_bias = nn.Parameter(torch.zeros(vocab_size, dtype=torch.float32)) - self._init_weights(tied_embed_init_std) - - def _init_weights(self, tied_embed_init_std: float) -> None: - if self.tie_embeddings: - nn.init.normal_(self.tok_emb.weight, mean=0.0, std=tied_embed_init_std) - for module in self.modules(): - if isinstance(module, TernaryLinear) and not getattr(module, "_zero_init", False): - nn.init.normal_(module.weight, mean=0.0, std=0.02) - elif isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): - nn.init.zeros_(module.weight) - - def _compute_logits(self, x: Tensor) -> Tensor: - if self.tversky_head is not None: - logits_raw = self.tversky_head(x) - elif self.tie_embeddings: - if self.embed_proj_rev is not None: - proj = self.embed_proj_rev(x) - else: - proj = x - weight = self.tok_emb.weight - if self.lm_head_correction is not None: - weight = weight + self.lm_head_correction - logits_raw = F.linear(proj, weight.to(x.dtype)) - else: - logits_raw = self.lm_head(x) - return logits_raw + self.vocab_bias.to(x.dtype) - - def _softcap(self, logits: Tensor) -> Tensor: - s = self.logit_softcap - if self.softcap_type == "tanh": - return s * torch.tanh(logits / s) - x_sc = torch.clamp(logits / s, -2.0, 2.0) - x2 = x_sc * x_sc - return s * torch.clamp(x_sc * (1.0 - x2 / 3.0 + x2 * x2 / 15.0), -1.0, 1.0) - - def forward(self, input_ids: Tensor, target_ids: Tensor, reduction: str = "mean", temperature: float = 1.0) -> Tensor: - x = self.tok_emb(input_ids).float() - if self.bigram_emb is not None: - prev = F.pad(input_ids[:, :-1], (1, 0), value=0) - x = x + self.bigram_emb(prev).float() - if self.embed_proj is not None: - x = self.embed_proj(x) - x = F.rms_norm(x, (x.size(-1),)) - x0 = x - - # U-Net style encoder/decoder with skip connections - skips = [] - for i in range(self.num_encoder_layers): - for _ in range(max(1, self.training_depth_recurrence)): - x = self.blocks[i](x, x0) - skips.append(x) - for i in range(self.num_decoder_layers): - bi = self.num_encoder_layers + i - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype) * skips.pop() - for _ in range(max(1, self.training_depth_recurrence)): - x = self.blocks[bi](x, x0) - - x_normed = self.final_norm(x) - if self.refiner is not None: - x_normed = self.refiner(x_normed) - - # Standard training/eval path - x_flat = x_normed.reshape(-1, x_normed.size(-1)) - targets = target_ids.reshape(-1) - logits = self._softcap(self._compute_logits(x_flat)) - - if reduction == "none": - return F.cross_entropy(logits.float(), targets, reduction="none").reshape(input_ids.shape) - - # Fused CE + Z-loss: single logsumexp computation - logits_f = logits.float() - lse = torch.logsumexp(logits_f, dim=-1) - target_logits = logits_f.gather(1, targets.unsqueeze(1)).squeeze(1) - main_loss = (lse - target_logits).mean() + 1e-4 * (lse ** 2).mean() - - # Multi-token prediction auxiliary loss (training only) - if self.training and len(self.mtp_heads) > 0: - mtp_loss = torch.zeros((), device=main_loss.device) - for k, head in enumerate(self.mtp_heads): - shift = k + 2 - if target_ids.shape[1] > shift: - mtp_tgt = target_ids[:, shift:].reshape(-1) - mtp_in = x_normed[:, :target_ids.shape[1] - shift, :].reshape(-1, x_normed.shape[-1]) - mtp_loss = mtp_loss + F.cross_entropy(head(mtp_in).float(), mtp_tgt, reduction="mean") - main_loss = main_loss + 0.1 * mtp_loss / len(self.mtp_heads) - return main_loss - -# --------------------------------------------------------------------------- -# Validation -# --------------------------------------------------------------------------- - -def build_luts(sp, vocab_size: int, device: torch.device): - sp_vocab_size = int(sp.vocab_size()) - table_size = max(sp_vocab_size, vocab_size) - base_bytes_np = np.zeros((table_size,), dtype=np.int16) - has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) - is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) - for token_id in range(sp_vocab_size): - if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): - continue - is_boundary_token_np[token_id] = False - if sp.is_byte(token_id): - base_bytes_np[token_id] = 1 - continue - piece = sp.id_to_piece(token_id) - if piece.startswith("\u2581"): - has_leading_space_np[token_id] = True - piece = piece[1:] - base_bytes_np[token_id] = len(piece.encode("utf-8")) - return ( - torch.tensor(base_bytes_np, dtype=torch.int16, device=device), - torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), - torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), - ) - -def ld_val(pattern, seq_len, max_tok=int(os.environ.get("VAL_MAX_TOKENS", 500000))): - files = sorted(glob.glob(pattern)) - assert files, f"No files: {pattern}" - tok = torch.cat([ld_shard(Path(p)) for p in files]).contiguous() - if max_tok > 0: tok = tok[:max_tok + 1] - u = ((tok.numel() - 1) // seq_len) * seq_len - return tok[:u + 1] - -def eval_val(args, model, rank, world_size, device, grad_accum_steps, val_tokens, - base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, temperature: float = 1.0): - local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) - local_batch_seqs = max(1, local_batch_tokens // args.train_seq_len) - total_seqs = (val_tokens.numel() - 1) // args.train_seq_len - seq_start = (total_seqs * rank) // world_size - seq_end = (total_seqs * (rank + 1)) // world_size - loss_sum = torch.zeros((), device=device, dtype=torch.float64) - token_count = torch.zeros((), device=device, dtype=torch.float64) - byte_count = torch.zeros((), device=device, dtype=torch.float64) - model.eval() - with torch.inference_mode(): - for batch_start in range(seq_start, seq_end, local_batch_seqs): - batch_end = min(batch_start + local_batch_seqs, seq_end) - raw_start = batch_start * args.train_seq_len - raw_end = batch_end * args.train_seq_len + 1 - local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64) - x, y = local[:-1].reshape(-1, args.train_seq_len), local[1:].reshape(-1, args.train_seq_len) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16): - batch_loss = model(x, y, temperature=temperature).detach() - n = float(y.numel()) - loss_sum += batch_loss.to(torch.float64) * n - token_count += n - prev_ids, tgt_ids = x.reshape(-1), y.reshape(-1) - tok_bytes = base_bytes_lut[tgt_ids].to(torch.int16) - tok_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(torch.int16) - byte_count += tok_bytes.to(torch.float64).sum() - if dist.is_available() and dist.is_initialized(): - for t in (loss_sum, token_count, byte_count): - dist.all_reduce(t, op=dist.ReduceOp.SUM) - val_loss = loss_sum / token_count - bpb = (val_loss.item() / math.log(2.0)) * (token_count.item() / byte_count.item()) - model.train() - return float(val_loss.item()), float(bpb) - -def eval_val_sliding(args, model, rank, world_size, device, grad_accum_steps, val_tokens, - base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - stride: int = 64, temperature: float = 1.0): - seq_len = args.train_seq_len - batch_size = args.sliding_batch_size - total_tokens = val_tokens.numel() - 1 - loss_sum = torch.zeros((), device=device, dtype=torch.float64) - token_count = torch.zeros((), device=device, dtype=torch.float64) - byte_count = torch.zeros((), device=device, dtype=torch.float64) - all_starts = list(range(0, total_tokens - seq_len, stride)) - my_starts = all_starts[rank::world_size] - - model.eval() - with torch.inference_mode(): - for i in range(0, len(my_starts), batch_size): - batch_starts = my_starts[i:i + batch_size] - - starts_t = torch.tensor(batch_starts, dtype=torch.int64) - offsets = torch.arange(seq_len + 1, dtype=torch.int64) - indices = starts_t.unsqueeze(1) + offsets.unsqueeze(0) - - local_batch = val_tokens[indices].to(device=device, dtype=torch.int64, non_blocking=True) - x = local_batch[:, :-1] - y = local_batch[:, 1:] - - with torch.autocast(device_type="cuda", dtype=torch.bfloat16): - per_token_loss = model(x, y, reduction="none", temperature=temperature).detach() - - for b, start in enumerate(batch_starts): - score_from = 0 if start == 0 else seq_len - stride - scored = per_token_loss[b, score_from:] - sx, sy = x[b, score_from:], y[b, score_from:] - - loss_sum += scored.to(torch.float64).sum() - token_count += scored.numel() - - tok_bytes = base_bytes_lut[sy].to(torch.int16) - tok_bytes += (has_leading_space_lut[sy] & ~is_boundary_token_lut[sx]).to(torch.int16) - byte_count += tok_bytes.to(torch.float64).sum() - - if dist.is_available() and dist.is_initialized(): - for t in (loss_sum, token_count, byte_count): - dist.all_reduce(t, op=dist.ReduceOp.SUM) - - val_loss = loss_sum / token_count - bpb = (val_loss.item() / math.log(2.0)) * (token_count.item() / byte_count.item()) - model.train() - return float(val_loss.item()), float(bpb) - -# --------------------------------------------------------------------------- -# Temperature scaling -# --------------------------------------------------------------------------- -def find_temp(args, base_model, rank, world_size, device, grad_accum_steps, - calibration_tokens, base_bytes_lut, has_leading_space_lut, - is_boundary_token_lut): - best_t, best_loss = 1.0, float("inf") - for t in [0.90, 0.95, 1.00, 1.05, 1.10]: - loss, _ = eval_val(args, base_model, rank, world_size, device, grad_accum_steps, - calibration_tokens, base_bytes_lut, has_leading_space_lut, - is_boundary_token_lut, temperature=t) - if loss < best_loss: - best_loss = loss - best_t = t - return best_t - -# --------------------------------------------------------------------------- -# Training -# --------------------------------------------------------------------------- -def main() -> None: - args = Hyperparameters() - code = Path(__file__).read_text(encoding="utf-8") - - if args.matrix_optimizer != "adamw": - global ns_orth - ns_orth = torch.compile(ns_orth) - - distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ - rank = int(os.environ.get("RANK", "0")) - world_size = int(os.environ.get("WORLD_SIZE", "1")) - local_rank = int(os.environ.get("LOCAL_RANK", "0")) - grad_accum_steps = max(1, 8 // world_size) - grad_scale = 1.0 / grad_accum_steps - - if not torch.cuda.is_available(): - raise RuntimeError("CUDA is required") - device = torch.device("cuda", local_rank) - torch.cuda.set_device(device) - if distributed: - dist.init_process_group(backend="nccl", device_id=device) - dist.barrier() - master_process = rank == 0 - torch.backends.cuda.matmul.allow_tf32 = True - torch.backends.cudnn.allow_tf32 = True - - os.makedirs("logs/cuda/", exist_ok=True) - logfile = f"logs/cuda/{args.run_id}.txt" if master_process else None - if master_process: - print(logfile) - - def log0(msg: str, console: bool = True) -> None: - if not master_process: - return - if console: - print(msg) - if logfile: - with open(logfile, "a", encoding="utf-8") as f: - print(msg, file=f) - - log0(code, console=False) - log0("=" * 100, console=False) - - log0(f"Python {sys.version}", console=False) - log0(f"PyTorch {torch.__version__}", console=False) - - random.seed(args.seed) - np.random.seed(args.seed) - torch.manual_seed(args.seed) - torch.cuda.manual_seed_all(args.seed) - - sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) - val_tokens = ld_val(args.val_files, args.train_seq_len) - base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_luts( - sp, args.vocab_size, device) - - # --- Model --- - base_model = GPT( - vocab_size=args.vocab_size, num_layers=args.num_layers, model_dim=args.model_dim, - num_heads=args.num_heads, num_kv_heads=args.num_kv_heads, mlp_mult=args.mlp_mult, - tie_embeddings=args.tie_embeddings, tied_embed_init_std=args.tied_embed_init_std, - logit_softcap=args.logit_softcap, rope_base=args.rope_base, qk_gain_init=args.qk_gain_init, - group_size=args.bitnet_group_size, activation=args.activation_type, mtp_heads_count=args.mtp_heads_count, - embed_dim=args.embed_dim, attn_proj_type=args.attn_proj_type, logit_head_type=args.logit_head_type, - tversky_num_features=args.tversky_num_features, tversky_feature_pools=args.tversky_feature_pools, - training_depth_recurrence=args.training_depth_recurrence, fp_storage=args.fp_storage, - bigram_hash=args.bigram_hash, softcap_type=args.softcap_type, no_cache=(args.compile_mode == "reduce-overhead"), - smear=args.smear, rope_type=args.rope_type, yarn_max_len=args.yarn_max_len, train_seq_len=args.train_seq_len, - tversky_membership=args.tversky_membership, diff_attn=args.diff_attn, - refiner=args.refiner, refiner_kernel=args.refiner_kernel, mlp_groups=args.mlp_groups, - ).to(device).bfloat16() - - for module in base_model.modules(): - if isinstance(module, nn.Linear): - module.float() - restore_low_dim_params_to_fp32(base_model) - if base_model.lm_head is not None and (args.tie_embeddings or args.logit_head_type == "tversky"): - base_model.lm_head.weight.requires_grad_(False) - - torch._dynamo.config.optimize_ddp = False - - compiled_model = torch.compile(base_model, mode=args.compile_mode if args.compile_mode != "default" else None) - use_find_unused = args.untie_at_fraction > 0 or args.mtp_heads_count > 0 or not args.tie_embeddings - model = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False, - find_unused_parameters=use_find_unused, - static_graph=not use_find_unused, - gradient_as_bucket_view=True) if distributed else compiled_model - - # --- Optimizers --- - _excl = {"tok_emb.weight", "lm_head.weight", "lm_head_correction"} - all_other_params = [(n, p) for n, p in base_model.named_parameters() - if not any(eh in n for eh in _excl)] - matrix_params = [p for n, p in all_other_params - if p.ndim == 2 and not any(pat in n for pat in CTP)] - scalar_params = [p for n, p in all_other_params - if p.ndim < 2 or any(pat in n for pat in CTP)] - - token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr - opt_tok = torch.optim.Adam( - [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], - betas=(args.beta1, args.beta2), eps=args.adam_eps, fused=True) - if args.matrix_optimizer == "adamw": - opt_muon = torch.optim.AdamW( - [{"params": matrix_params, "lr": args.adam_lr, "base_lr": args.adam_lr}], - betas=(args.beta1, args.beta2), eps=args.adam_eps, weight_decay=args.adam_wd, fused=True) - else: - opt_muon = Muon(matrix_params, lr=args.matrix_lr, momentum=args.muon_momentum, - backend_steps=args.muon_backend_steps, wd=args.muon_wd) - for g in opt_muon.param_groups: - g["base_lr"] = args.matrix_lr - opt_scalar = torch.optim.Adam( - [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], - betas=(args.beta1, args.beta2), eps=args.adam_eps, fused=True) - opt_head = torch.optim.Adam( - [{"params": [base_model.lm_head.weight], "lr": 0.0, "base_lr": 0.0}], - betas=(args.beta1, args.beta2), eps=args.adam_eps, fused=True) - - optimizers = [opt for opt in [opt_tok, opt_muon, opt_scalar, opt_head] if opt is not None] - - if base_model.lm_head_correction is not None: - opt_corr = torch.optim.Adam( - [{"params": [base_model.lm_head_correction], - "lr": args.corr_weight_lr, "base_lr": args.corr_weight_lr}], - betas=(args.beta1, args.beta2), eps=args.adam_eps, fused=True) - optimizers.append(opt_corr) - - # --- Log all hyperparameters --- - log0("--- Hyperparameters ---", console=False) - log0(" ".join(f"{a}={getattr(args,a)}" for a in sorted(dir(args)) if not a.startswith("_") and a not in ("train_files","val_files") and not callable(getattr(args,a))), console=False) - n_params = sum(p.numel() for p in base_model.parameters()) - log0(f"params:{n_params} L:{args.num_layers} d:{args.model_dim} h:{args.num_heads} kv:{args.num_kv_heads} ws:{world_size} ga:{grad_accum_steps} s:{args.seed}") - - # --- Data loader & helpers --- - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - def zero_grad_all(): - for opt in optimizers: - opt.zero_grad(set_to_none=True) - - max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None - - def lr_mul(step: int, elapsed_ms: float): - if args.warmdown_fraction <= 0: - return 1.0 - if max_wallclock_ms is None: - warmdown_start = int(args.iterations * (1.0 - args.warmdown_fraction)) - return max((args.iterations - step) / max(args.iterations * args.warmdown_fraction, 1), 0.0) if step >= warmdown_start else 1.0 - warmdown_ms = max_wallclock_ms * args.warmdown_fraction - remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) - return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 - - _seq_switched = False - _batch_switched = False - active_seq_len = args.seq_len_start if args.seq_len_start > 0 else args.train_seq_len - active_batch_tokens = args.batch_tokens_start if args.batch_tokens_start > 0 else args.train_batch_tokens - - # --- Compiler warmup --- - if args.warmup_steps > 0: - _ms = {n: t.detach().cpu().clone() for n, t in base_model.state_dict().items()} - _os = [copy.deepcopy(o.state_dict()) for o in optimizers] - model.train() - for ws in range(args.warmup_steps): - zero_grad_all() - for mi in range(grad_accum_steps): - if distributed: model.require_backward_grad_sync = mi == grad_accum_steps - 1 - x, y = train_loader.next_batch(active_batch_tokens, active_seq_len, grad_accum_steps) - torch.compiler.cudagraph_mark_step_begin() - with torch.autocast(device_type="cuda", dtype=torch.bfloat16): loss = model(x, y) - (loss * grad_scale).backward() - for o in optimizers: o.step() - zero_grad_all() - log0(f"warmup:{ws+1}/{args.warmup_steps}") - base_model.load_state_dict(_ms, strict=True) - for o, s in zip(optimizers, _os): o.load_state_dict(s) - zero_grad_all() - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - # --- Main training loop --- - training_time_ms = 0.0 - stop_after_step: int | None = None - _untied = False - train_loss = torch.zeros((), device=device) - torch.cuda.synchronize() - t0 = time.perf_counter() - step = 0 - - while True: - last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) - - if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): - torch.cuda.synchronize() - training_time_ms += 1000.0 * (time.perf_counter() - t0) - val_loss, val_bpb = eval_val(args, model, rank, world_size, device, grad_accum_steps, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut) - tstats = tern_stats(base_model, group_size=args.bitnet_group_size) - log0(f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " - f"train_time:{training_time_ms:.0f}ms zero_frac:{tstats['zero_frac']:.3f}") - torch.cuda.synchronize() - t0 = time.perf_counter() - - if last_step: - if stop_after_step is not None and step < args.iterations: - log0(f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms step:{step}/{args.iterations}") - break - - elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - scale = lr_mul(step, elapsed_ms) - - # Sequence length schedule - if args.seq_len_start > 0 and not _seq_switched: - if max_wallclock_ms is not None: - should_switch_seq = elapsed_ms >= args.seq_schedule_fraction * max_wallclock_ms - else: - should_switch_seq = step >= int(args.iterations * args.seq_schedule_fraction) - if should_switch_seq: - active_seq_len = args.train_seq_len - _seq_switched = True - torch._dynamo.reset() - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - log0(f"step:{step} seq_len_switch:{args.seq_len_start}->{active_seq_len}") - - # Batch size schedule - if args.batch_tokens_start > 0 and not _batch_switched: - if max_wallclock_ms is not None: - should_switch_batch = elapsed_ms >= args.batch_schedule_fraction * max_wallclock_ms - else: - should_switch_batch = step >= int(args.iterations * args.batch_schedule_fraction) - if should_switch_batch: - active_batch_tokens = args.train_batch_tokens - _batch_switched = True - log0(f"step:{step} batch_switch:{args.batch_tokens_start}->{active_batch_tokens}") - - zero_grad_all() - train_loss.zero_() - - for micro in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro == grad_accum_steps - 1 - x, y = train_loader.next_batch(active_batch_tokens, active_seq_len, grad_accum_steps) - torch.compiler.cudagraph_mark_step_begin() - with torch.autocast(device_type="cuda", dtype=torch.bfloat16): - loss = model(x, y) - train_loss.add_(loss.detach()) - (loss * grad_scale).backward() - train_loss /= grad_accum_steps - - # Untie lm_head at configured fraction of training - if args.untie_at_fraction > 0: - if max_wallclock_ms is not None: - should_untie = not _untied and elapsed_ms >= args.untie_at_fraction * max_wallclock_ms - else: - should_untie = not _untied and step >= int(args.iterations * args.untie_at_fraction) - if should_untie and base_model.tie_embeddings: - with torch.no_grad(): - base_weight = base_model.tok_emb.weight.float() - if base_model.lm_head_correction is not None: - base_weight = base_weight + base_model.lm_head_correction.float() - if base_model.embed_proj_rev is not None: - full_weight = base_weight @ base_model.embed_proj_rev.weight.float() - else: - full_weight = base_weight - base_model.lm_head.weight.copy_(full_weight) - base_model.tie_embeddings = False - base_model.lm_head.weight.requires_grad_(True) - for g in opt_head.param_groups: - g["lr"] = g["base_lr"] = args.head_lr - _untied = True - torch._dynamo.reset() - log0(f"step:{step} untied lm_head (head_lr={args.head_lr})") - - # Muon momentum warmup - if args.matrix_optimizer != "adam": - frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 - for g in opt_muon.param_groups: - g["momentum"] = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum - - # LR scheduling - for opt in optimizers: - for g in opt.param_groups: - g["lr"] = g["base_lr"] * scale - opt.step() - zero_grad_all() - step += 1 - approx_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - - if args.train_log_every > 0 and step % args.train_log_every == 0: - log0(f"step:{step}/{args.iterations} loss:{train_loss.item():.4f} t:{approx_ms:.0f}ms avg:{approx_ms/step:.1f}ms") - if args.churn_log_every > 0 and step % args.churn_log_every == 0: - log0(f"step:{step} churn:{churn_fn(base_model, args.bitnet_group_size):.4f} zero:{tern_stats(base_model, args.bitnet_group_size)['zero_frac']:.3f}") - - # Wallclock cap sync - if stop_after_step is None and max_wallclock_ms is not None and step % 10 == 0: - reached_cap = approx_ms >= max_wallclock_ms - if distributed: - cap_t = torch.tensor(int(reached_cap), device=device) - dist.all_reduce(cap_t, op=dist.ReduceOp.MAX) - reached_cap = bool(cap_t.item()) - if reached_cap: - stop_after_step = step - - # --- Serialization --- - if master_process: - sd = base_model.state_dict() - if base_model.tie_embeddings or args.logit_head_type == "tversky": - sd.pop("lm_head.weight", None) - - # Compute ternary overrides for no-features Tversky prototypes - ternary_overrides = set() - for n, m in base_model.named_modules(): - if isinstance(m, TverskyProjection) and m.no_features_mode: - ternary_overrides.add(n + ".prototypes") - ternary_overrides = ternary_overrides or None - - # Two methods: Standard Base-3 vs Bitmask Mapping - methods = {} - for method in ("standard", "bitmask"): - q_obj, stats = q_sd(sd, group_size=args.bitnet_group_size, fp_storage=args.fp_storage, ternary_method=method, ternary_override_names=ternary_overrides) - buf = io.BytesIO() - torch.save(q_obj, buf) - methods[method] = {"blob": lzma.compress(buf.getvalue(), preset=9), "stats": stats} - best = min(methods, key=lambda m: len(methods[m]["blob"])) - final_blob, q_stats = methods[best]["blob"], methods[best]["stats"] - with open("final_model.ternary.ptz", "wb") as f: - f.write(final_blob) - - artifact_bytes = len(final_blob) - code_bytes = len(code.encode("utf-8")) - - total = artifact_bytes + code_bytes - log0(f"artifact:{artifact_bytes/1e6:.2f}MB ternary:{q_stats['ternary_params']}({q_stats['ternary_bytes']}B) fp:{q_stats['fp_params']}({q_stats['fp_bytes']}B) code:{code_bytes}") - log0(f"budget:{total}/{16000000} ({total/1e6:.2f}/{16.00:.2f}MB) {'FITS' if total <= 16000000 else 'OVER'}") - - if args.eval_depth_recurrence > 0: - base_model.training_depth_recurrence = args.eval_depth_recurrence - log0(f"eval_depth_recurrence:{args.eval_depth_recurrence}") - - # --- All ranks load roundtrip weights and evaluate --- - if distributed: - dist.barrier() - - with open("final_model.ternary.ptz", "rb") as f: - loaded = torch.load(io.BytesIO(lzma.decompress(f.read())), map_location="cpu", weights_only=False) - base_model.load_state_dict(deq_sd(loaded), strict=False) - torch._dynamo.reset() - - q_val_loss, q_val_bpb = eval_val(args, model, rank, world_size, device, grad_accum_steps, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut) - log0(f"final_ternary_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f}") - - opt_temp = 1.0 - if args.temp_scaling: - torch.cuda.synchronize() - t_temp = time.perf_counter() - calibration_tokens = train_loader.stream.take(65536).to(device) - opt_temp = find_temp(args, base_model, rank, world_size, device, grad_accum_steps, - calibration_tokens, base_bytes_lut, has_leading_space_lut, - is_boundary_token_lut) - torch.cuda.synchronize() - temp_time_ms = 1000.0 * (time.perf_counter() - t_temp) - log0(f"temp_scaling optimal_T:{opt_temp:.2f} eval_time:{temp_time_ms:.0f}ms") - - if args.sliding_eval: - torch.cuda.synchronize() - t_sliding = time.perf_counter() - sw_loss, sw_bpb = eval_val_sliding(args, base_model, rank, world_size, device, grad_accum_steps, - val_tokens, base_bytes_lut, has_leading_space_lut, - is_boundary_token_lut, stride=args.sliding_eval_stride, - temperature=opt_temp) - torch.cuda.synchronize() - sliding_time_ms = 1000.0 * (time.perf_counter() - t_sliding) - log0(f"final_sliding val_loss:{sw_loss:.4f} val_bpb:{sw_bpb:.4f} " - f"(stride={args.sliding_eval_stride}, T={opt_temp:.2f}) eval_time:{sliding_time_ms:.0f}ms") - - if distributed: - dist.destroy_process_group() - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/records/track_10min_16mb/2026-03-25_ValCalib_GPTQ_XSA_BigramHash3072/README.md b/records/track_10min_16mb/2026-03-25_ValCalib_GPTQ_XSA_BigramHash3072/README.md deleted file mode 100644 index 1b0e007981..0000000000 --- a/records/track_10min_16mb/2026-03-25_ValCalib_GPTQ_XSA_BigramHash3072/README.md +++ /dev/null @@ -1,99 +0,0 @@ -# Record: AR Self-Gen GPTQ + XSA-all + BigramHash 3072×112 - -**val_bpb: 1.1147** (3-seed mean, std 0.0004) | **~15.91 MB** | 8×H100 SXM, 600s | No TTT - -**This submission uses only AR (autoregressive) self-generated calibration data.** After training, the model autoregressively generates its own calibration tokens (64 seqs × 2048 tokens, temp=0.8). No val data and no train data are accessed during quantization. - -**Improvement over current SOTA ([PR #549](https://github.com/openai/parameter-golf/pull/549), 1.1194 BPB):** −0.0078 nats (−0.0046 BPB) - -## Results - -| Seed | Steps | ms/step | Pre-quant BPB | **Sliding BPB** | Artifact | -|------|-------|---------|---------------|-----------------|----------| -| 314 | 6,927 | 86.6 | 1.1354 | **1.1151** | 15,863,278 | -| 42 | 6,922 | 86.7 | 1.1349 | **1.1144** | 15,984,850 | -| 999 | 6,917 | 86.8 | 1.1353 | **1.1148** | 15,876,310 | -| **Mean** | | | | **1.1147** | | - -Current SOTA (PR #549, exact 3-seed mean): **1.11937967 BPB** (**1.89002068 nats**). This run's exact 3-seed mean is **1.11473509 BPB** (**1.88217853 nats**). Delta: **−0.00784215 nats** (**−0.00464458 BPB**). - -Using the exact per-seed scores from the PR #549 logs (`1.11922988`, `1.12002032`, `1.11888882`) and this run (`1.11508120`, `1.11437394`, `1.11475014`), Welch's t-test gives **t = -11.83**, **df ≈ 3.31**. - ---- - -## Main Changes - -The comparison baseline is [PR #549](https://github.com/openai/parameter-golf/pull/549), the current legal leaderboard entry at **1.1194 BPB**. The implementation lineage is closer to [PR #609](https://github.com/openai/parameter-golf/pull/609): this run keeps the XSA-all + Full GPTQ + selective-pruning stack, but uses AR self-generated GPTQ calibration (no external data), bumps BigramHash to **3072 × 112**, and uses `lzma preset=9`. - -### 1. AR Self-Generated Full Hessian GPTQ - -PR #549 used GPTQ-lite (diagonal Hessian approximation). We use Full Hessian GPTQ with Cholesky error compensation and column reordering — a strictly better quantizer. - -The calibration problem: prior Full Hessian GPTQ implementations (PRs #535, #569, #593, #609) calibrated on training data, ruled illegal after the 600s window. We solve this by having the model generate its own calibration data. After training completes, the model autoregressively generates 64 sequences of 2048 tokens (temperature=0.8, fixed seed). Hessians H = X^T X are collected from these self-generated sequences. No val data, no train data accessed during quantization. - -### 2. BigramHash 3072 × dim=112 (up from 1536) - -Lineage: [PR #549](https://github.com/openai/parameter-golf/pull/549) (1536) → [PR #609](https://github.com/openai/parameter-golf/pull/609) (2048) → this run (**3072 × dim=112**). Fits under 16MB; going wider increased artifact pressure past the break-even point. - -### 3. XSA on all 11 layers (up from last 4) - -PR #549 applied XSA to the last 4 layers. Extending to all 11 layers forces cross-position information mixing from layer 0 at zero parameter cost. Source: [PR #478](https://github.com/openai/parameter-golf/pull/478) by @gowtham0992. - -### Dropped: TTT - -PR #549 used Legal Score-First TTT for −0.0025 BPB. On this stack, TTT is neutral or negative (25 failed attempts across two stacks — see our [PR #756](https://github.com/openai/parameter-golf/pull/756)). The Full Hessian GPTQ improvement more than compensates for dropping TTT. - ---- - -## Architecture - -| Component | Setting | First introduced by | -|-----------|---------|---------------------| -| Layers | 11 (512d, 8 GQA heads, 4 KV heads) | Baseline | -| MLP | 3× (1536) with LeakyReLU(0.5)² | [#493](https://github.com/openai/parameter-golf/pull/493) @parinzee | -| Attention | XSA on all 11 layers | [#478](https://github.com/openai/parameter-golf/pull/478) @gowtham0992 | -| BigramHash | **3072 × dim=112** | **This work** (concept: [#162](https://github.com/openai/parameter-golf/pull/162) @raahilshah) | -| RoPE | Partial (16/64 dims) | [#315](https://github.com/openai/parameter-golf/pull/315) @jfprincz | -| LN Scale | 1/√(layer+1) | [#315](https://github.com/openai/parameter-golf/pull/315) @jfprincz | -| VE128 | Layers 9-10 | [#374](https://github.com/openai/parameter-golf/pull/374) @unnir | -| SmearGate | Position-mixing gate | [#65](https://github.com/openai/parameter-golf/pull/65) @aquariouseworkman | -| U-Net skips | Encoder-decoder connections | [#289](https://github.com/openai/parameter-golf/pull/289) | -| Weight avg | EMA(0.997) + Tight SWA(every 50) | [#401](https://github.com/openai/parameter-golf/pull/401) @newjordan | -| Quantization | **Full Hessian GPTQ int6 (AR self-gen calibration)** | **This work** (GPTQ: [#535](https://github.com/openai/parameter-golf/pull/535) @raahilshah) | -| Compression | LZMA preset=9 | [#160](https://github.com/openai/parameter-golf/pull/160) @ChaseWNorton | -| Warmdown | 4000 iterations | [#364](https://github.com/openai/parameter-golf/pull/364) @shikhar1729 | -| Optimizer | **Parallel Muon + Parameter Banking** | **[#399](https://github.com/openai/parameter-golf/pull/399) @abaybektursun** | -| Late QAT | STE at LR scale < 0.15 | [#286](https://github.com/openai/parameter-golf/pull/286) @chris-buckley | -| Selective pruning | ±1 values by reconstruction error | [#609](https://github.com/openai/parameter-golf/pull/609) @saml212 | -| Flash Attention 3 | Hopper warp-specialized kernels | [#122](https://github.com/openai/parameter-golf/pull/122) @mtybadger | - -## Requirements - -**Flash Attention 3 (Hopper) is required.** The script imports `flash_attn_interface` directly and was run with PyTorch 2.9.1+cu128. - -```bash -pip install --break-system-packages flash_attn_3 --find-links https://windreamer.github.io/flash-attention3-wheels/cu128_torch291 -pip install sentencepiece zstandard -python3 -c "from flash_attn_interface import flash_attn_func; import sentencepiece, zstandard; print('deps OK')" -``` - -## Run Command - -```bash -BIGRAM_VOCAB_SIZE=3072 BIGRAM_DIM=112 WARMDOWN_ITERS=4000 \ -TARGET_MB=15.9 SEED=314 \ -torchrun --standalone --nproc_per_node=8 train_gpt.py -``` - -## Lineage - -``` -PR #549 (Legal SOTA, 1.1194) — our Parallel Muon base with LeakyReLU² + legal TTT - └── This work adds: - ├── AR self-gen GPTQ calibration (no external data during quantization) - ├── BigramHash 3072 × 112 (wider setting that still fits under 16MB) - ├── XSA-all (from #478/@gowtham0992, applied via #609/@saml212) - ├── Selective ±1 pruning (from #609/@saml212) - ├── warmdown=4000, LZMA=9 (from #364/@shikhar1729, #160/@ChaseWNorton) - └── Guided by PR #670 negative results (30+ failed experiments) -``` diff --git a/records/track_10min_16mb/2026-03-25_ValCalib_GPTQ_XSA_BigramHash3072/requirements.txt b/records/track_10min_16mb/2026-03-25_ValCalib_GPTQ_XSA_BigramHash3072/requirements.txt deleted file mode 100644 index 8b0f870b9b..0000000000 --- a/records/track_10min_16mb/2026-03-25_ValCalib_GPTQ_XSA_BigramHash3072/requirements.txt +++ /dev/null @@ -1,3 +0,0 @@ -# FlashAttention 3 must be installed separately; see README.md -sentencepiece -zstandard diff --git a/records/track_10min_16mb/2026-03-25_ValCalib_GPTQ_XSA_BigramHash3072/submission.json b/records/track_10min_16mb/2026-03-25_ValCalib_GPTQ_XSA_BigramHash3072/submission.json deleted file mode 100644 index cff849aa5a..0000000000 --- a/records/track_10min_16mb/2026-03-25_ValCalib_GPTQ_XSA_BigramHash3072/submission.json +++ /dev/null @@ -1,54 +0,0 @@ -{ - "author": "abaybektursun", - "github_id": "abaybektursun", - "name": "AR Self-Gen GPTQ + XSA-all + BigramHash 3072x112", - "blurb": "11L XSA-all + Full Hessian GPTQ with autoregressive self-generated calibration (no val/train data accessed during quantization) + selective-pruning stack. BigramHash(3072,112), warmdown=4000, lzma preset=9. 3-seed exact mean: 1.11473509 BPB / 1.88217853 nats, beating PR549's exact 3-seed mean 1.11937967 BPB / 1.89002068 nats by 0.00784215 nats (Welch t=-11.83, df=3.31).", - "date": "2026-03-25", - "track": "10min_16mb", - "val_loss": 1.88217853, - "val_bpb": 1.11473509, - "val_loss_std": 0.00059750, - "val_bpb_std": 0.00035387, - "seeds": [314, 42, 999], - "seed_results": { - "314": { - "val_loss": 1.88276292, - "val_bpb": 1.11508120, - "artifact_bytes": 15863278, - "steps": 6927, - "step_avg_ms": 86.6 - }, - "42": { - "val_loss": 1.88156874, - "val_bpb": 1.11437394, - "artifact_bytes": 15984850, - "steps": 6922, - "step_avg_ms": 86.7 - }, - "999": { - "val_loss": 1.88220393, - "val_bpb": 1.11475014, - "artifact_bytes": 15876310, - "steps": 6917, - "step_avg_ms": 86.8 - } - }, - "comparison_baseline_pr": 549, - "implementation_lineage_pr": 609, - "negative_results_pr": 670, - "delta_vs_pr549_nats": -0.00784215, - "delta_vs_pr549_bpb": -0.00464458, - "t_statistic": -11.8339, - "welch_df": 3.3063, - "artifact_bytes_mean": 15908146, - "artifact_bytes_max": 15984850, - "bytes_total": 15984850, - "train_steps_mean": 6922.00, - "step_avg_ms_mean": 86.69, - "hardware": "8xH100 80GB SXM", - "pytorch_version": "2.9.1+cu128", - "cuda_version": "12.8", - "flash_attn_version": "2.8.3 (FA3 Hopper kernels)", - "calibration": "AR self-generated (64 seqs x 2048 tokens, temp=0.8, no external data)", - "technique_summary": "AR self-gen GPTQ calibration + XSA-all + BigramHash 3072x112 + Parallel Muon + LZMA9" -} diff --git a/records/track_10min_16mb/2026-03-25_ValCalib_GPTQ_XSA_BigramHash3072/train_gpt.py b/records/track_10min_16mb/2026-03-25_ValCalib_GPTQ_XSA_BigramHash3072/train_gpt.py deleted file mode 100644 index 72c213f638..0000000000 --- a/records/track_10min_16mb/2026-03-25_ValCalib_GPTQ_XSA_BigramHash3072/train_gpt.py +++ /dev/null @@ -1,2135 +0,0 @@ -from __future__ import annotations -import copy -import glob -import io -import lzma -import math -import os -import random -import subprocess -import sys -import time -import uuid -import zlib -from pathlib import Path -try: - import zstandard - _COMPRESSOR = "zstd" -except ImportError: - _COMPRESSOR = "zlib" -import numpy as np -import sentencepiece as spm -import torch -import torch.distributed as dist -import torch.nn.functional as F -from torch import Tensor, nn -from torch.nn.parallel import DistributedDataParallel as DDP -from flash_attn_interface import flash_attn_func as flash_attn_3_func -class Hyperparameters: - data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") - train_files = os.path.join(data_path, "fineweb_train_*.bin") - val_files = os.path.join(data_path, "fineweb_val_*.bin") - tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") - run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) - seed = int(os.environ.get("SEED", 1337)) - val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) - val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 4000)) - train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 500)) - iterations = int(os.environ.get("ITERATIONS", 20000)) - warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3500)) - warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) - train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 786_432)) - train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) - eval_seq_len = int(os.environ.get("EVAL_SEQ_LEN", 2048)) - max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) - qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) - vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) - num_layers = int(os.environ.get("NUM_LAYERS", 11)) - num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) - model_dim = int(os.environ.get("MODEL_DIM", 512)) - num_heads = int(os.environ.get("NUM_HEADS", 8)) - mlp_mult = float(os.environ.get("MLP_MULT", 3.0)) - tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) - rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) - logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) - embed_lr = float(os.environ.get("EMBED_LR", 0.6)) - head_lr = float(os.environ.get("HEAD_LR", 0.008)) - tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.035)) - tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) - matrix_lr = float(os.environ.get("MATRIX_LR", 0.025)) - scalar_lr = float(os.environ.get("SCALAR_LR", 0.025)) - muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) - muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) - muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) - muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) - beta1 = float(os.environ.get("BETA1", 0.9)) - beta2 = float(os.environ.get("BETA2", 0.95)) - adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) - grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) - eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) - mtp_num_heads = int(os.environ.get("MTP_NUM_HEADS", 0)) - mtp_loss_weight = float(os.environ.get("MTP_LOSS_WEIGHT", 0.2)) - muon_beta2 = float(os.environ.get("MUON_BETA2", 0.95)) - swa_enabled = bool(int(os.environ.get("SWA_ENABLED", "1"))) - swa_every = int(os.environ.get("SWA_EVERY", 50)) - lawa_enabled = bool(int(os.environ.get("LAWA_ENABLED", "0"))) - lawa_k = int(os.environ.get("LAWA_K", 10)) - lawa_freq = int(os.environ.get("LAWA_FREQ", 100)) - muon_wd = float(os.environ.get("MUON_WD", 0.04)) - adam_wd = float(os.environ.get("ADAM_WD", 0.04)) - qat_enabled = bool(int(os.environ.get("QAT_ENABLED", "0"))) - bigram_vocab_size = int(os.environ.get("BIGRAM_VOCAB_SIZE", 2048)) - bigram_dim = int(os.environ.get("BIGRAM_DIM", 128)) - trigram_enabled = bool(int(os.environ.get("TRIGRAM", "0"))) # TrigramHash (off by default, risky) - xsa_last_n = int(os.environ.get("XSA_LAST_N", 11)) # XSA on ALL layers (our novel contribution) - rope_dims = int(os.environ.get("ROPE_DIMS", 16)) - ln_scale = bool(int(os.environ.get("LN_SCALE", "1"))) - dtg_enabled = bool(int(os.environ.get("DTG_ENABLED", "0"))) - late_qat_threshold = float(os.environ.get("LATE_QAT_THRESHOLD", 0.15)) - ve_enabled = bool(int(os.environ.get("VE_ENABLED", "1"))) - ve_dim = int(os.environ.get("VE_DIM", 128)) - ve_layers = os.environ.get("VE_LAYERS", "9,10") - gated_attention = bool(int(os.environ.get("GATED_ATTENTION", "0"))) - value_residual = bool(int(os.environ.get("VALUE_RESIDUAL", "0"))) # VRL with sigmoid gates (off by default, risky) - # GPTQ calibration - gptq_calib_batches = int(os.environ.get("GPTQ_CALIB_BATCHES", 256)) - gptq_block_size = int(os.environ.get("GPTQ_BLOCK_SIZE", 128)) - -# --- Batched Newton-Schulz orthogonalization --- - -def zeropower_via_newtonschulz5(G: Tensor, steps: int = 5, eps: float = 1e-7) -> Tensor: - """Batched Newton-Schulz orthogonalization. G: (B,M,N) or (M,N).""" - a, b, c = (3.4445, -4.7750, 2.0315) - was_2d = G.ndim == 2 - if was_2d: - G = G.unsqueeze(0) - X = G.bfloat16() - transposed = X.size(-2) > X.size(-1) - if transposed: - X = X.mT - X = X / (X.norm(dim=(-2, -1), keepdim=True) + eps) - for _ in range(steps): - A = X @ X.mT - B = b * A + c * (A @ A) - X = a * X + B @ X - if transposed: - X = X.mT - if was_2d: - X = X.squeeze(0) - return X - -# --- Parallel Muon optimizer --- - -class Muon(torch.optim.Optimizer): - """Parallel Muon: post-backward reduce-scatter -> local NS5 -> all-gather. - - No DDP for bank params. After backward, this optimizer: - 1. Launches async reduce-scatter for all banks (biggest first) - 2. Returns control so Adam can step on small params while RS is in-flight - 3. Waits for each RS, runs local NS5 on the shard, launches async all-gather - 4. Each all-gather overlaps with next bank's NS5 - """ - def __init__(self, params, lr: float, momentum: float, backend_steps: int, - nesterov: bool = True, weight_decay: float = 0.0): - super().__init__( - params, - dict(lr=lr, momentum=momentum, backend_steps=backend_steps, - nesterov=nesterov, weight_decay=weight_decay), - ) - self._built = False - - def _build(self): - self._distributed = dist.is_available() and dist.is_initialized() - self._world_size = dist.get_world_size() if self._distributed else 1 - self._rank = dist.get_rank() if self._distributed else 0 - ws = self._world_size - - self._bank_meta = [] - for group in self.param_groups: - for p in group["params"]: - B = p.shape[0] - padded_B = ((B + ws - 1) // ws) * ws - shard_B = padded_B // ws - tail = p.shape[1:] - dev = p.device - self._bank_meta.append({ - 'p': p, - 'B': B, - 'padded_grad': torch.zeros(padded_B, *tail, device=dev, dtype=torch.bfloat16), - 'shard': torch.zeros(shard_B, *tail, device=dev, dtype=torch.bfloat16), - 'shard_mom': torch.zeros(shard_B, *tail, device=dev, dtype=torch.bfloat16), - 'full_update': torch.zeros(padded_B, *tail, device=dev, dtype=torch.bfloat16), - 'scale': max(1, p.shape[-2] / p.shape[-1]) ** 0.5, - }) - # Sort by size descending -- launch biggest reduce-scatters first - self._bank_meta.sort(key=lambda m: -m['p'].numel()) - self._built = True - - def launch_reduce_scatters(self): - """Phase 1: launch async reduce-scatter for all banks. Call right after backward.""" - if not self._built: - self._build() - if not self._distributed: - return - self._rs_futures = [] - for m in self._bank_meta: - p = m['p'] - if p.grad is None: - self._rs_futures.append(None) - continue - pg = m['padded_grad'] - pg[:m['B']].copy_(p.grad.bfloat16()) - if pg.shape[0] > m['B']: - pg[m['B']:].zero_() - fut = dist.reduce_scatter_tensor(m['shard'], pg, op=dist.ReduceOp.AVG, async_op=True) - self._rs_futures.append(fut) - - @torch.no_grad() - def step(self, closure=None): - """Phase 3: wait for RS, local NS5, all-gather. Call AFTER Adam steps.""" - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - if not self._built: - self._build() - - for group in self.param_groups: - lr = group["lr"] - momentum = group["momentum"] - backend_steps = group["backend_steps"] - nesterov = group["nesterov"] - wd = group.get("weight_decay", 0.0) - - prev_ag_handle = None - prev_m = None - - sharded = self._distributed and hasattr(self, '_rs_futures') - - for i, m in enumerate(self._bank_meta): - p = m['p'] - if p.grad is None: - continue - - if prev_ag_handle is not None: - prev_ag_handle.wait() - pp = prev_m['p'] - upd = prev_m['full_update'][:prev_m['B']] - if wd > 0.0: - pp.data.mul_(1.0 - lr * wd) - pp.add_(upd.to(dtype=pp.dtype), alpha=-lr * prev_m['scale']) - - if sharded and self._rs_futures[i] is not None: - self._rs_futures[i].wait() - g = m['shard'] - buf = m['shard_mom'] - else: - g = p.grad.bfloat16() - state = self.state[p] - if "momentum_buffer" not in state: - state["momentum_buffer"] = torch.zeros_like(g) - buf = state["momentum_buffer"] - - buf.mul_(momentum).add_(g) - if nesterov: - update = g.add(buf, alpha=momentum) - else: - update = buf - - update = zeropower_via_newtonschulz5(update, steps=backend_steps) - - if sharded: - prev_ag_handle = dist.all_gather_into_tensor( - m['full_update'], update, async_op=True) - prev_m = m - else: - if wd > 0.0: - p.data.mul_(1.0 - lr * wd) - p.add_(update.to(dtype=p.dtype), alpha=-lr * m['scale']) - - if prev_ag_handle is not None: - prev_ag_handle.wait() - pp = prev_m['p'] - upd = prev_m['full_update'][:prev_m['B']] - if wd > 0.0: - pp.data.mul_(1.0 - lr * wd) - pp.add_(upd.to(dtype=pp.dtype), alpha=-lr * prev_m['scale']) - - if hasattr(self, '_rs_futures'): - del self._rs_futures - - return loss - -# --- Tokenizer evaluation helpers --- - -def build_sentencepiece_luts( - sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device -) -> tuple[Tensor, Tensor, Tensor]: - sp_vocab_size = int(sp.vocab_size()) - table_size = max(sp_vocab_size, vocab_size) - base_bytes_np = np.zeros((table_size,), dtype=np.int16) - has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) - is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) - for token_id in range(sp_vocab_size): - if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): - continue - is_boundary_token_np[token_id] = False - if sp.is_byte(token_id): - base_bytes_np[token_id] = 1 - continue - piece = sp.id_to_piece(token_id) - if piece.startswith("\u2581"): - has_leading_space_np[token_id] = True - piece = piece[1:] - base_bytes_np[token_id] = len(piece.encode("utf-8")) - return ( - torch.tensor(base_bytes_np, dtype=torch.int16, device=device), - torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), - torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), - ) -def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: - files = [Path(p) for p in sorted(glob.glob(pattern))] - if not files: - raise FileNotFoundError(f"No files found for pattern: {pattern}") - tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() - usable = ((tokens.numel() - 1) // seq_len) * seq_len - if usable <= 0: - raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") - return tokens[: usable + 1] -def eval_val( - args: Hyperparameters, - model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - grad_accum_steps: int, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, - eval_seq_len: int | None = None, -) -> tuple[float, float]: - seq_len = eval_seq_len or args.train_seq_len - local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) - if local_batch_tokens < seq_len: - raise ValueError( - "VAL_BATCH_SIZE must provide at least one sequence per rank; " - f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " - f"GRAD_ACCUM_STEPS={grad_accum_steps}, seq_len={seq_len}" - ) - local_batch_seqs = local_batch_tokens // seq_len - total_seqs = (val_tokens.numel() - 1) // seq_len - seq_start = (total_seqs * rank) // world_size - seq_end = (total_seqs * (rank + 1)) // world_size - val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) - val_token_count = torch.zeros((), device=device, dtype=torch.float64) - val_byte_count = torch.zeros((), device=device, dtype=torch.float64) - model.eval() - with torch.inference_mode(): - for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): - batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) - raw_start = batch_seq_start * seq_len - raw_end = batch_seq_end * seq_len + 1 - local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) - x = local[:-1].reshape(-1, seq_len) - y = local[1:].reshape(-1, seq_len) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - batch_loss = model(x, y).detach() - batch_token_count = float(y.numel()) - val_loss_sum += batch_loss.to(torch.float64) * batch_token_count - val_token_count += batch_token_count - prev_ids = x.reshape(-1) - tgt_ids = y.reshape(-1) - token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) - token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) - val_byte_count += token_bytes.to(torch.float64).sum() - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) - val_loss = val_loss_sum / val_token_count - bits_per_token = val_loss.item() / math.log(2.0) - tokens_per_byte = val_token_count.item() / val_byte_count.item() - model.train() - return float(val_loss.item()), float(bits_per_token * tokens_per_byte) - -# --- Quantization helpers --- - -CONTROL_TENSOR_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "CONTROL_TENSOR_NAME_PATTERNS", - "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,smear,dtg_gate,ve_layer_scales,ve_shared.scale,attn_gate,vr_lambda", - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", - ",".join(CONTROL_TENSOR_NAME_PATTERNS), - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 -INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 -INT8_PER_ROW_SCALE_DTYPE = torch.float16 -INT8_CLIP_PERCENTILE = 99.99984 -INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 -def tensor_nbytes(t: Tensor) -> int: - return int(t.numel()) * int(t.element_size()) -def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: - if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): - return t.float().contiguous() - if t.dtype in {torch.float32, torch.bfloat16}: - passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") - return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() - return t -def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: - t32 = t.float() - if t32.ndim == 2: - clip_abs = ( - torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) - if t32.numel() - else torch.empty((t32.shape[0],), dtype=torch.float32) - ) - clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) - scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) - q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() - return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() - clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 - scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) - q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() - return q, scale -def quantize_state_dict_int8(state_dict: dict[str, Tensor]): - quantized: dict[str, Tensor] = {} - scales: dict[str, Tensor] = {} - dtypes: dict[str, str] = {} - passthrough: dict[str, Tensor] = {} - passthrough_orig_dtypes: dict[str, str] = {} - qmeta: dict[str, dict[str, object]] = {} - stats = dict.fromkeys( - ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), - 0, - ) - for name, tensor in state_dict.items(): - t = tensor.detach().to("cpu").contiguous() - stats["param_count"] += int(t.numel()) - stats["num_tensors"] += 1 - stats["baseline_tensor_bytes"] += tensor_nbytes(t) - if not t.is_floating_point(): - stats["num_nonfloat_tensors"] += 1 - passthrough[name] = t - stats["int8_payload_bytes"] += tensor_nbytes(t) - continue - if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: - kept = keep_float_tensor(name, t, passthrough_orig_dtypes) - passthrough[name] = kept - stats["int8_payload_bytes"] += tensor_nbytes(kept) - continue - stats["num_float_tensors"] += 1 - q, s = quantize_float_tensor(t) - if s.ndim > 0: - qmeta[name] = {"scheme": "per_row", "axis": 0} - quantized[name] = q - scales[name] = s - dtypes[name] = str(t.dtype).removeprefix("torch.") - stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) - obj: dict[str, object] = { - "__quant_format__": "int8_clean_per_row_v1", - "quantized": quantized, - "scales": scales, - "dtypes": dtypes, - "passthrough": passthrough, - } - if qmeta: - obj["qmeta"] = qmeta - if passthrough_orig_dtypes: - obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes - return obj, stats -def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: - out: dict[str, Tensor] = {} - qmeta = obj.get("qmeta", {}) - passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) - for name, q in obj["quantized"].items(): - dtype = getattr(torch, obj["dtypes"][name]) - s = obj["scales"][name] - if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: - s = s.to(dtype=torch.float32) - out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() - else: - scale = float(s.item()) - out[name] = (q.float() * scale).to(dtype=dtype).contiguous() - for name, t in obj["passthrough"].items(): - out_t = t.detach().to("cpu").contiguous() - orig_dtype = passthrough_orig_dtypes.get(name) - if isinstance(orig_dtype, str): - out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() - out[name] = out_t - return out - -# --- Data loading --- - -def load_data_shard(file: Path) -> Tensor: - header_bytes = 256 * np.dtype(" None: - self.file_idx = (self.file_idx + 1) % len(self.files) - self.tokens = load_data_shard(self.files[self.file_idx]) - self.pos = 0 - def take(self, n: int) -> Tensor: - chunks: list[Tensor] = [] - remaining = n - while remaining > 0: - avail = self.tokens.numel() - self.pos - if avail <= 0: - self._advance_file() - continue - k = min(remaining, avail) - chunks.append(self.tokens[self.pos : self.pos + k]) - self.pos += k - remaining -= k - return chunks[0] if len(chunks) == 1 else torch.cat(chunks) -class DistributedTokenLoader: - def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): - self.rank = rank - self.world_size = world_size - self.device = device - self.stream = TokenStream(pattern) - def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: - local_tokens = global_tokens // (self.world_size * grad_accum_steps) - per_rank_span = local_tokens + 1 - chunk = self.stream.take(per_rank_span * self.world_size) - start = self.rank * per_rank_span - local = chunk[start : start + per_rank_span].to(dtype=torch.int64) - x = local[:-1].reshape(-1, seq_len) - y = local[1:].reshape(-1, seq_len) - return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) - -# --- Transformer modules --- - -class RMSNorm(nn.Module): - def __init__(self, eps: float | None = None): - super().__init__() - self.eps = eps - def forward(self, x: Tensor) -> Tensor: - return F.rms_norm(x, (x.size(-1),), eps=self.eps) -class CastedLinear(nn.Linear): - _qat_enabled: bool = False - def forward(self, x: Tensor) -> Tensor: - w = self.weight.to(x.dtype) - if CastedLinear._qat_enabled and self.training and w.ndim == 2: - with torch.no_grad(): - w32 = self.weight.float() - row_max = w32.abs().amax(dim=1) - scale = (row_max / 31.0).clamp_min(1.0 / 31.0) - w_q = (torch.clamp(torch.round(w32 / scale[:, None]), -32, 31) * scale[:, None]).to(x.dtype) - w = w + (w_q - w).detach() - bias = self.bias.to(x.dtype) if self.bias is not None else None - return F.linear(x, w, bias) -def restore_low_dim_params_to_fp32(module: nn.Module) -> None: - with torch.no_grad(): - for name, param in module.named_parameters(): - if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: - param.data = param.data.float() -class Rotary(nn.Module): - def __init__(self, dim: int, base: float = 10000.0, train_seq_len: int = 1024, rope_dims: int = 0): - super().__init__() - self.dim = dim - self.base = base - self.train_seq_len = train_seq_len - self.rope_dims = rope_dims if rope_dims > 0 else dim - inv_freq = 1.0 / (base ** (torch.arange(0, self.rope_dims, 2, dtype=torch.float32) / self.rope_dims)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self._seq_len_cached = 0 - self._cos_cached: Tensor | None = None - self._sin_cached: Tensor | None = None - def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: - if ( - self._cos_cached is None - or self._sin_cached is None - or self._seq_len_cached != seq_len - or self._cos_cached.device != device - ): - rd = self.rope_dims - if seq_len > self.train_seq_len: - scale = seq_len / self.train_seq_len - new_base = self.base * (scale ** (rd / (rd - 2))) - inv_freq = 1.0 / (new_base ** (torch.arange(0, rd, 2, dtype=torch.float32, device=device) / rd)) - else: - inv_freq = self.inv_freq.to(device) - t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) - freqs = torch.outer(t, inv_freq) - self._cos_cached = freqs.cos()[None, :, None, :] - self._sin_cached = freqs.sin()[None, :, None, :] - self._seq_len_cached = seq_len - return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) -def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor, rope_dims: int = 0) -> Tensor: - if rope_dims > 0 and rope_dims < x.size(-1): - x_rope, x_pass = x[..., :rope_dims], x[..., rope_dims:] - half = rope_dims // 2 - x1, x2 = x_rope[..., :half], x_rope[..., half:] - x_rope = torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) - return torch.cat((x_rope, x_pass), dim=-1) - half = x.size(-1) // 2 - x1, x2 = x[..., :half], x[..., half:] - return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) - -class CausalSelfAttention(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - rope_base: float, - qk_gain_init: float, - gated_attention: bool = False, - value_residual: bool = False, - ): - super().__init__() - if dim % num_heads != 0: - raise ValueError("model_dim must be divisible by num_heads") - if num_heads % num_kv_heads != 0: - raise ValueError("num_heads must be divisible by num_kv_heads") - self.num_heads = num_heads - self.num_kv_heads = num_kv_heads - self.head_dim = dim // num_heads - if self.head_dim % 2 != 0: - raise ValueError("head_dim must be even for RoPE") - # No CastedLinear -- weights come from banks - self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) - self.rope_dims = 0 # set by GPT.__init__ for partial RoPE - self.rotary = Rotary(self.head_dim, base=rope_base, train_seq_len=1024) - self.use_xsa = False # set by GPT.__init__ for deep layers only - # Gated attention and value residual (non-banked small params) - self.gated_attention = gated_attention - if gated_attention: - self.attn_gate = nn.Linear(dim, num_heads, bias=True) - nn.init.zeros_(self.attn_gate.weight) - nn.init.constant_(self.attn_gate.bias, 4.0) - self.value_residual = value_residual - if value_residual: - self.vrl_alpha = nn.Parameter(torch.zeros(1, dtype=torch.float32)) # sigmoid gate (PR #569 style) - def _xsa_efficient(self, y: Tensor, v: Tensor) -> Tensor: - """Efficient XSA: subtract self-value projection via GQA-aware reshape (no repeat_interleave). - y: [B, T, H, D], v: [B, T, Hkv, D]. H must be divisible by Hkv.""" - B, T, H, D = y.shape - Hkv = v.size(-2) - group = H // Hkv - y_g = y.reshape(B, T, Hkv, group, D) # [B, T, Hkv, group, D] - vn = F.normalize(v, dim=-1).unsqueeze(-2) # [B, T, Hkv, 1, D] -- broadcast ready - proj = (y_g * vn).sum(dim=-1, keepdim=True) * vn - return (y_g - proj).reshape(B, T, H, D) - def forward(self, x: Tensor, q_w: Tensor, k_w: Tensor, v_w: Tensor, out_w: Tensor, v_embed: Tensor | None = None, v0: Tensor | None = None) -> tuple[Tensor, Tensor | None]: - bsz, seqlen, dim = x.shape - q = F.linear(x, q_w.to(x.dtype)).reshape(bsz, seqlen, self.num_heads, self.head_dim) - k = F.linear(x, k_w.to(x.dtype)).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) - v = F.linear(x, v_w.to(x.dtype)) - if v_embed is not None: - v = v + v_embed - v = v.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) - raw_v = v if self.value_residual else None - if self.value_residual and v0 is not None: - alpha = torch.sigmoid(self.vrl_alpha.to(dtype=v.dtype)) - v = v + alpha * v0 # sigmoid-gated residual (PR #569 style) - q = F.rms_norm(q, (q.size(-1),)) - k = F.rms_norm(k, (k.size(-1),)) - cos, sin = self.rotary(seqlen, x.device, q.dtype) - q = apply_rotary_emb(q, cos, sin, self.rope_dims) - k = apply_rotary_emb(k, cos, sin, self.rope_dims) - q = q * self.q_gain.to(dtype=q.dtype)[None, None, :, None] - y = flash_attn_3_func(q, k, v, causal=True) - if self.use_xsa: - y = self._xsa_efficient(y, v) - if self.gated_attention: - # gate shape: (bsz, seqlen, num_heads) -> (bsz, seqlen, num_heads, 1) for B,T,H,D layout - gate = torch.sigmoid(self.attn_gate(x)).unsqueeze(-1) - y = y * gate - y = y.reshape(bsz, seqlen, dim) - return F.linear(y, out_w.to(x.dtype)), raw_v - -class SmearGate(nn.Module): - def __init__(self, dim: int): - super().__init__() - self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) - def forward(self, x: Tensor) -> Tensor: - g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] - x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) - return (1 - g) * x + g * x_prev - -class BigramHashEmbedding(nn.Module): - def __init__(self, bigram_vocab_size: int, bigram_dim: int, model_dim: int, trigram: bool = False): - super().__init__() - self.bigram_vocab_size = bigram_vocab_size - self._trigram = trigram - self.embed = nn.Embedding(bigram_vocab_size, bigram_dim) - nn.init.zeros_(self.embed.weight) - self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None - if self.proj is not None: - nn.init.zeros_(self.proj.weight) - self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) - def bigram_hash(self, tokens: Tensor) -> Tensor: - t = tokens.to(torch.int32) - mod = self.bigram_vocab_size - 1 - out = torch.empty_like(t) - out[..., 0] = mod - out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod - return out.long() - def trigram_hash(self, tokens: Tensor) -> Tensor: - """Hash (t-2, t-1, t) trigrams into same embedding table. Zero extra params.""" - t = tokens.to(torch.int32) - mod = self.bigram_vocab_size - 1 - out = torch.empty_like(t) - out[..., :2] = mod - out[..., 2:] = (36313 * t[..., 2:] ^ 27191 * t[..., 1:-1] ^ 51497 * t[..., :-2]) % mod - return out.long() - def forward(self, token_ids: Tensor) -> Tensor: - h = self.embed(self.bigram_hash(token_ids)) - if self._trigram: - h = h + self.embed(self.trigram_hash(token_ids)) - if self.proj is not None: - h = self.proj(h) - return h * self.scale.to(dtype=h.dtype) - -class ValueEmbedding(nn.Module): - """Reinject token identity into attention values at specific layers. - Each table maps vocab tokens to a low-dim embedding, projected to model_dim.""" - def __init__(self, vocab_size: int, ve_dim: int, model_dim: int): - super().__init__() - self.embed = nn.Embedding(vocab_size, ve_dim) - nn.init.normal_(self.embed.weight, std=0.01) - self.proj = CastedLinear(ve_dim, model_dim, bias=False) if ve_dim != model_dim else None - if self.proj is not None: - nn.init.zeros_(self.proj.weight) - self.scale = nn.Parameter(torch.tensor(0.1, dtype=torch.float32)) - def forward(self, token_ids: Tensor) -> Tensor: - h = self.embed(token_ids) - if self.proj is not None: - h = self.proj(h) - return h * self.scale.to(dtype=h.dtype) - -class MLP(nn.Module): - def __init__(self, dim: int, mlp_mult: int): - super().__init__() - # No CastedLinear -- weights come from banks - def forward(self, x: Tensor, up_w: Tensor, down_w: Tensor) -> Tensor: - x = F.leaky_relu(F.linear(x, up_w.to(x.dtype)), negative_slope=0.5) - return F.linear(x.square(), down_w.to(x.dtype)) - -class Block(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: int, - rope_base: float, - qk_gain_init: float, - layer_idx: int = 0, - ln_scale: bool = False, - dtg: bool = False, - gated_attention: bool = False, - value_residual: bool = False, - ): - super().__init__() - self.attn_norm = RMSNorm() - self.mlp_norm = RMSNorm() - self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init, - gated_attention=gated_attention, value_residual=value_residual) - self.mlp = MLP(dim, mlp_mult) - self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) - self.ln_scale_factor = 1.0 / math.sqrt(layer_idx + 1) if ln_scale else 1.0 - if dtg: - self.dtg_gate = nn.Linear(dim, 1, bias=True) - nn.init.zeros_(self.dtg_gate.weight) - nn.init.constant_(self.dtg_gate.bias, 2.0) - else: - self.dtg_gate = None - def forward(self, x: Tensor, x0: Tensor, q_w: Tensor, k_w: Tensor, v_w: Tensor, out_w: Tensor, up_w: Tensor, down_w: Tensor, v_embed: Tensor | None = None, v0: Tensor | None = None) -> tuple[Tensor, Tensor | None]: - mix = self.resid_mix.to(dtype=x.dtype) - x_in = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 - attn_out, raw_v = self.attn(self.attn_norm(x_in) * self.ln_scale_factor, q_w, k_w, v_w, out_w, v_embed=v_embed, v0=v0) - x_out = x_in + self.attn_scale.to(dtype=x_in.dtype)[None, None, :] * attn_out - x_out = x_out + self.mlp_scale.to(dtype=x_out.dtype)[None, None, :] * self.mlp(self.mlp_norm(x_out) * self.ln_scale_factor, up_w, down_w) - if self.dtg_gate is not None: - gate = torch.sigmoid(self.dtg_gate(x_in.detach())) - x_out = x_in + gate * (x_out - x_in) - return x_out, raw_v - -class GPT(nn.Module): - def __init__( - self, - vocab_size: int, - num_layers: int, - model_dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: int, - tie_embeddings: bool, - tied_embed_init_std: float, - logit_softcap: float, - rope_base: float, - qk_gain_init: float, - mtp_num_heads: int = 0, - mtp_loss_weight: float = 0.1, - bigram_vocab_size: int = 0, - bigram_dim: int = 128, - xsa_last_n: int = 0, - rope_dims: int = 0, - ln_scale: bool = False, - dtg: bool = False, - ve_enabled: bool = False, - ve_dim: int = 128, - ve_layers: str = "9,10", - gated_attention: bool = False, - value_residual: bool = False, - ): - super().__init__() - self._ve_target_dim = num_kv_heads * (model_dim // num_heads) # kv_dim for value projection - if logit_softcap <= 0.0: - raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") - self.tie_embeddings = tie_embeddings - self.tied_embed_init_std = tied_embed_init_std - self.logit_softcap = logit_softcap - self.value_residual = value_residual - self.mtp_num_heads = mtp_num_heads - self.mtp_loss_weight = mtp_loss_weight - self.tok_emb = nn.Embedding(vocab_size, model_dim) - self.bigram = BigramHashEmbedding(bigram_vocab_size, bigram_dim, model_dim, trigram=bool(int(os.environ.get("TRIGRAM", "0")))) if bigram_vocab_size > 0 else None - self.smear = SmearGate(model_dim) - self.num_encoder_layers = num_layers // 2 - self.num_decoder_layers = num_layers - self.num_encoder_layers - self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) - self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) - # Parameter banks: contiguous 3D tensors for batched optimizer - head_dim = model_dim // num_heads - kv_dim = num_kv_heads * head_dim - mlp_dim = int(mlp_mult * model_dim) - self.num_layers = num_layers - self.qo_bank = nn.Parameter(torch.empty(2 * num_layers, model_dim, model_dim)) - self.kv_bank = nn.Parameter(torch.empty(2 * num_layers, kv_dim, model_dim)) - self.mlp_up_bank = nn.Parameter(torch.empty(num_layers, mlp_dim, model_dim)) - self.mlp_down_bank = nn.Parameter(torch.empty(num_layers, model_dim, mlp_dim)) - self.blocks = nn.ModuleList( - [ - Block( - model_dim, - num_heads, - num_kv_heads, - mlp_mult, - rope_base, - qk_gain_init, - layer_idx=i, - ln_scale=ln_scale, - dtg=dtg, - gated_attention=gated_attention, - value_residual=value_residual, - ) - for i in range(num_layers) - ] - ) - if rope_dims > 0: - head_dim = model_dim // num_heads - for block in self.blocks: - block.attn.rope_dims = rope_dims - block.attn.rotary = Rotary(head_dim, base=rope_base, train_seq_len=1024, rope_dims=rope_dims) - self.ve_layer_indices = [int(x) for x in ve_layers.split(",") if x.strip()] if ve_enabled else [] - kv_dim_ve = self._ve_target_dim - if self.ve_layer_indices: - self.ve_shared = ValueEmbedding(vocab_size, ve_dim, kv_dim_ve) - self.ve_layer_scales = nn.ParameterList( - [nn.Parameter(torch.ones(1, dtype=torch.float32)) for _ in self.ve_layer_indices] - ) - else: - self.ve_shared = None - self.ve_layer_scales = nn.ParameterList() - self.value_embeds = nn.ModuleList() # keep empty for compat - self.final_norm = RMSNorm() - self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) - if self.lm_head is not None: - self.lm_head._zero_init = True - self.mtp_heads = nn.ModuleList( - [CastedLinear(model_dim, vocab_size, bias=False) for _ in range(mtp_num_heads)] - ) - for head in self.mtp_heads: - head._zero_init = True - if xsa_last_n > 0: - for i in range(max(0, num_layers - xsa_last_n), num_layers): - self.blocks[i].attn.use_xsa = True - self._init_weights() - def _init_weights(self) -> None: - if self.tie_embeddings: - nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) - n = self.num_layers - proj_scale = 1.0 / math.sqrt(2 * n) - # Init banks: orthogonal, with proj layers scaled down and out/down zero-init - for i in range(n): - nn.init.orthogonal_(self.qo_bank.data[i], gain=1.0) # Q - nn.init.zeros_(self.qo_bank.data[n + i]) # Out (zero init) - nn.init.orthogonal_(self.kv_bank.data[i], gain=1.0) # K - nn.init.orthogonal_(self.kv_bank.data[n + i], gain=1.0) # V - nn.init.orthogonal_(self.mlp_up_bank.data[i], gain=1.0) # MLP up - nn.init.zeros_(self.mlp_down_bank.data[i]) # MLP down (zero init) - # Scale proj layers (out_proj and mlp_down are "proj" layers) - self.qo_bank.data[n + i].mul_(proj_scale) - self.mlp_down_bank.data[i].mul_(proj_scale) - # Init remaining nn.Linear modules (bigram proj, mtp heads, lm_head) - for name, module in self.named_modules(): - if isinstance(module, nn.Linear): - if getattr(module, "_zero_init", False): - nn.init.zeros_(module.weight) - elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: - nn.init.orthogonal_(module.weight, gain=1.0) - def _get_ve(self, layer_idx: int, input_ids: Tensor, ve_cache: dict | None = None) -> Tensor | None: - """Get value embedding for a specific layer using shared table + per-layer scale.""" - if self.ve_shared is None or layer_idx not in self.ve_layer_indices: - return None - if ve_cache is not None and 've' not in ve_cache: - ve_cache['ve'] = self.ve_shared(input_ids) - ve_base = ve_cache['ve'] if ve_cache is not None else self.ve_shared(input_ids) - ve_idx = self.ve_layer_indices.index(layer_idx) - return ve_base * self.ve_layer_scales[ve_idx].to(dtype=ve_base.dtype) - def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: - n = self.num_layers - x = self.tok_emb(input_ids) - if self.bigram is not None: - x = x + self.bigram(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - x = self.smear(x) - x0 = x - v0 = None - skips: list[Tensor] = [] - ve_cache: dict = {} - for i in range(self.num_encoder_layers): - ve = self._get_ve(i, input_ids, ve_cache) - x, raw_v = self.blocks[i](x, x0, - self.qo_bank[i], self.kv_bank[i], self.kv_bank[n + i], - self.qo_bank[n + i], self.mlp_up_bank[i], self.mlp_down_bank[i], - v_embed=ve, v0=v0) - if v0 is None and raw_v is not None: - v0 = raw_v - skips.append(x) - for i in range(self.num_decoder_layers): - bi = self.num_encoder_layers + i - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - ve = self._get_ve(bi, input_ids, ve_cache) - x, _ = self.blocks[bi](x, x0, - self.qo_bank[bi], self.kv_bank[bi], self.kv_bank[n + bi], - self.qo_bank[n + bi], self.mlp_up_bank[bi], self.mlp_down_bank[bi], - v_embed=ve, v0=v0) - x = self.final_norm(x) - x_flat = x.reshape(-1, x.size(-1)) - targets = target_ids.reshape(-1) - if self.tie_embeddings: - logits_proj = F.linear(x_flat, self.tok_emb.weight) - else: - if self.lm_head is None: - raise RuntimeError("lm_head is required when tie_embeddings=False") - logits_proj = self.lm_head(x_flat) - logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - main_loss = F.cross_entropy(logits.float(), targets, reduction="mean") - if self.training and self.mtp_num_heads > 0 and self.mtp_loss_weight > 0.0: - _, seqlen, dim = x.shape - mtp_loss_sum = x.new_zeros(()) - mtp_loss_count = 0 - for k, mtp_head in enumerate(self.mtp_heads): - valid_t = seqlen - (k + 1) - if valid_t <= 0: - continue - mtp_hidden = x[:, :valid_t, :].reshape(-1, dim) - mtp_targets = target_ids[:, k + 1 :].reshape(-1) - mtp_logits_proj = mtp_head(mtp_hidden) - mtp_logits = self.logit_softcap * torch.tanh(mtp_logits_proj / self.logit_softcap) - mtp_loss_sum = mtp_loss_sum + F.cross_entropy(mtp_logits.float(), mtp_targets, reduction="mean") - mtp_loss_count += 1 - if mtp_loss_count > 0: - main_loss = main_loss + self.mtp_loss_weight * (mtp_loss_sum / mtp_loss_count) - return main_loss - def forward_logits(self, input_ids: Tensor) -> Tensor: - """Return logits (bsz, seq_len, vocab) without computing loss.""" - n = self.num_layers - x = self.tok_emb(input_ids) - if self.bigram is not None: - x = x + self.bigram(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - x = self.smear(x) - x0 = x - v0 = None - skips: list[Tensor] = [] - ve_cache: dict = {} - for i in range(self.num_encoder_layers): - ve = self._get_ve(i, input_ids, ve_cache) - x, raw_v = self.blocks[i](x, x0, - self.qo_bank[i], self.kv_bank[i], self.kv_bank[n + i], - self.qo_bank[n + i], self.mlp_up_bank[i], self.mlp_down_bank[i], - v_embed=ve, v0=v0) - if v0 is None and raw_v is not None: - v0 = raw_v - skips.append(x) - for i in range(self.num_decoder_layers): - bi = self.num_encoder_layers + i - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - ve = self._get_ve(bi, input_ids, ve_cache) - x, _ = self.blocks[bi](x, x0, - self.qo_bank[bi], self.kv_bank[bi], self.kv_bank[n + bi], - self.qo_bank[n + bi], self.mlp_up_bank[bi], self.mlp_down_bank[bi], - v_embed=ve, v0=v0) - x = self.final_norm(x) - if self.tie_embeddings: - logits_proj = F.linear(x, self.tok_emb.weight) - else: - logits_proj = self.lm_head(x) - return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - -# --- Sliding window evaluation --- - -def eval_val_sliding( - args: Hyperparameters, - base_model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, - stride: int, - batch_seqs: int = 32, - eval_seq_len: int | None = None, -) -> tuple[float, float]: - """Sliding window evaluation: each token scored with maximum context.""" - seq_len = eval_seq_len or args.train_seq_len - total_tokens = val_tokens.numel() - 1 - window_starts = [ws for ws in range(0, total_tokens, stride) - if min(ws + seq_len, total_tokens) - ws >= 1] - total_windows = len(window_starts) - my_s = (total_windows * rank) // world_size - my_e = (total_windows * (rank + 1)) // world_size - my_windows = window_starts[my_s:my_e] - loss_sum = torch.zeros((), device=device, dtype=torch.float64) - token_count = torch.zeros((), device=device, dtype=torch.float64) - byte_count = torch.zeros((), device=device, dtype=torch.float64) - base_model.eval() - compiled_logits = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) - with torch.inference_mode(): - for bi in range(0, len(my_windows), batch_seqs): - batch_ws = my_windows[bi:bi + batch_seqs] - bsz = len(batch_ws) - x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) - y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) - wlens: list[int] = [] - for i, ws in enumerate(batch_ws): - end = min(ws + seq_len, total_tokens) - wlen = end - ws - wlens.append(wlen) - chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) - x_batch[i, :wlen] = chunk[:-1] - y_batch[i, :wlen] = chunk[1:] - with torch.autocast(device_type="cuda", dtype=torch.bfloat16): - logits = compiled_logits(x_batch) - nll = F.cross_entropy( - logits.reshape(-1, logits.size(-1)).float(), - y_batch.reshape(-1), - reduction="none", - ).reshape(bsz, seq_len) - for i, ws in enumerate(batch_ws): - wlen = wlens[i] - s = 0 if ws == 0 else max(wlen - stride, 0) - scored_nll = nll[i, s:wlen].to(torch.float64) - loss_sum += scored_nll.sum() - token_count += float(wlen - s) - tgt = y_batch[i, s:wlen] - prev = x_batch[i, s:wlen] - tb = base_bytes_lut[tgt].to(torch.float64) - tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) - byte_count += tb.sum() - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) - val_loss = (loss_sum / token_count).item() - bits_per_token = val_loss / math.log(2.0) - tokens_per_byte = token_count.item() / byte_count.item() - base_model.train() - return val_loss, bits_per_token * tokens_per_byte - - -def generate_autoregressive_calib(model, device, num_seqs=64, seq_len=2048, - vocab_size=1024, temperature=0.8, batch_size=8, seed=42): - """Generate sequences autoregressively from the model for GPTQ calibration. - No external data accessed — fully self-contained.""" - model.eval() - rng = torch.Generator(device=device) - rng.manual_seed(seed) - all_tokens = [] - with torch.inference_mode(), torch.autocast(device_type="cuda", dtype=torch.bfloat16): - for batch_start in range(0, num_seqs, batch_size): - bs = min(batch_size, num_seqs - batch_start) - tokens = torch.randint(0, vocab_size, (bs, 1), device=device, generator=rng) - for pos in range(seq_len - 1): - logits = model.forward_logits(tokens) - next_logit = logits[:, -1, :] - probs = torch.softmax(next_logit / temperature, dim=-1) - next_tok = torch.multinomial(probs, 1, generator=rng) - tokens = torch.cat([tokens, next_tok], dim=1) - for i in range(bs): - all_tokens.append(tokens[i:i+1]) - return all_tokens - - -def collect_hessians_from_tokens(hessian_model, token_seqs, device): - """Collect H = X^T X from pre-generated token sequences.""" - hessians = {} - hooks = [] - for name, module in hessian_model.named_modules(): - if isinstance(module, CastedLinear): - param_name = name + ".weight" - cols = module.weight.shape[1] - hessians[param_name] = torch.zeros(cols, cols, dtype=torch.float32, device='cpu') - def make_hook(pname): - def hook_fn(module, input, output): - x = input[0].detach().float() - if x.ndim == 3: - x = x.reshape(-1, x.shape[-1]) - hessians[pname] += (x.T @ x).cpu() - return hook_fn - h = module.register_forward_hook(make_hook(param_name)) - hooks.append(h) - hessian_model.eval() - with torch.inference_mode(), torch.autocast(device_type="cuda", dtype=torch.bfloat16): - for seq in token_seqs: - x = seq[:, :-1].to(device) - y = seq[:, 1:].to(device) - hessian_model(x, y) - for h in hooks: - h.remove() - num_batches = len(token_seqs) - for name in hessians: - H = hessians[name] - H /= num_batches - damp = 0.01 * torch.diag(H).mean().clamp_min(1e-6) - H += damp * torch.eye(H.shape[0]) - hessians[name] = H - return hessians - - -# --- GPTQ-lite int6 quantization --- - -def _classify_param(name: str) -> str: - if "tok_emb" in name or "lm_head" in name: - return "embed" - if ".mlp." in name: - return "mlp" - if ".attn." in name or (".proj." in name and ".mlp." not in name): - return "attn" - return "other" -def quantize_int6_per_row(t: Tensor, clip_range: int = 31) -> tuple[Tensor, Tensor]: - t32 = t.float() - if t32.ndim == 2: - best_q, best_s, best_err = None, None, float('inf') - for pct in [0.9990, 0.9995, 0.9999, 0.99999, 1.0]: - if pct < 1.0: - row_clip = torch.quantile(t32.abs(), pct, dim=1) - else: - row_clip = t32.abs().amax(dim=1) - s = (row_clip / clip_range).clamp_min(1.0 / clip_range).to(torch.float16) - q = torch.clamp(torch.round(t32 / s.float()[:, None]), -clip_range, clip_range).to(torch.int8) - recon = q.float() * s.float()[:, None] - err = (t32 - recon).pow(2).mean().item() - if err < best_err: - best_q, best_s, best_err = q, s, err - return best_q, best_s - amax = t32.abs().max().item() - scale = torch.tensor(amax / clip_range if amax > 0 else 1.0, dtype=torch.float16) - q = torch.clamp(torch.round(t32 / scale.float()), -clip_range, clip_range).to(torch.int8) - return q, scale - -def quantize_int6_gptq(weight, hessian=None, clip_range=31, block_size=128): - """Full GPTQ: Hessian-aware int6 quantization with Cholesky error compensation. - If hessian is None, falls back to percentile search.""" - t32 = weight.float() - if t32.ndim != 2 or hessian is None: - return _quantize_int6_percentile(t32, clip_range) - rows, cols = t32.shape - H = hessian.float().clone() - dead = torch.diag(H) == 0 - H[dead, dead] = 1 - damp = 0.01 * torch.mean(torch.diag(H)) - H[torch.arange(cols), torch.arange(cols)] += damp - perm = torch.argsort(torch.diag(H), descending=True) - inv_perm = torch.argsort(perm) - W = t32[:, perm].clone() - W[:, dead[perm]] = 0 - H = H[perm][:, perm] - Hinv = torch.linalg.cholesky(H) - Hinv = torch.cholesky_inverse(Hinv) - Hinv = torch.linalg.cholesky(Hinv, upper=True) - best_q = None; best_scale = None; best_err = float('inf') - for pct in [0.9990, 0.9995, 0.9999, 0.99999, 1.0]: - if pct < 1.0: - row_clip = torch.quantile(t32.abs(), pct, dim=1) - else: - row_clip = t32.abs().amax(dim=1) - s = (row_clip / clip_range).clamp_min(1.0 / clip_range).to(torch.float16) - sf = s.float() - Q = torch.zeros_like(W, dtype=torch.int8) - W_work = W.clone() - for i1 in range(0, cols, block_size): - i2 = min(i1 + block_size, cols) - count = i2 - i1 - W1 = W_work[:, i1:i2].clone() - Q1 = torch.zeros(rows, count, dtype=torch.int8) - Err1 = torch.zeros(rows, count) - Hinv1 = Hinv[i1:i2, i1:i2] - for i in range(count): - w = W1[:, i] - d = Hinv1[i, i] - q = torch.clamp(torch.round(w / sf), -clip_range, clip_range).to(torch.int8) - Q1[:, i] = q - err = (w - q.float() * sf) / d - W1[:, i:] -= err.unsqueeze(1) * Hinv1[i, i:].unsqueeze(0) - Err1[:, i] = err - Q[:, i1:i2] = Q1 - if i2 < cols: - W_work[:, i2:] -= Err1 @ Hinv[i1:i2, i2:] - recon = Q.float() * sf[:, None] - mse = (W - recon).pow(2).mean().item() - if mse < best_err: - best_q, best_scale, best_err = Q, s, mse - best_q = best_q[:, inv_perm] - return best_q, best_scale - -def _quantize_int6_percentile(t32, clip_range=31): - """Fallback: percentile search (for 1D or no-Hessian cases).""" - if t32.ndim == 2: - best_q, best_s, best_err = None, None, float('inf') - for pct in [0.9990, 0.9995, 0.9999, 0.99999, 1.0]: - if pct < 1.0: - row_clip = torch.quantile(t32.abs(), pct, dim=1) - else: - row_clip = t32.abs().amax(dim=1) - s = (row_clip / clip_range).clamp_min(1.0 / clip_range).to(torch.float16) - q = torch.clamp(torch.round(t32 / s.float()[:, None]), -clip_range, clip_range).to(torch.int8) - recon = q.float() * s.float()[:, None] - err = (t32 - recon).pow(2).mean().item() - if err < best_err: - best_q, best_s, best_err = q, s, err - return best_q, best_s - amax = t32.abs().max().item() - scale = torch.tensor(amax / clip_range if amax > 0 else 1.0, dtype=torch.float16) - q = torch.clamp(torch.round(t32 / scale.float()), -clip_range, clip_range).to(torch.int8) - return q, scale - -def _unbank_state_dict(sd: dict[str, Tensor], num_layers: int) -> dict[str, Tensor]: - """Convert 3D bank tensors into individual 2D tensors with standard names.""" - out: dict[str, Tensor] = {} - n = num_layers - for name, tensor in sd.items(): - if name == "qo_bank": - for i in range(n): - out[f"blocks.{i}.attn.c_q.weight"] = tensor[i] - out[f"blocks.{i}.attn.proj.weight"] = tensor[n + i] - elif name == "kv_bank": - for i in range(n): - out[f"blocks.{i}.attn.c_k.weight"] = tensor[i] - out[f"blocks.{i}.attn.c_v.weight"] = tensor[n + i] - elif name == "mlp_up_bank": - for i in range(n): - out[f"blocks.{i}.mlp.fc.weight"] = tensor[i] - elif name == "mlp_down_bank": - for i in range(n): - out[f"blocks.{i}.mlp.proj.weight"] = tensor[i] - else: - out[name] = tensor - return out - -def _rebank_state_dict(sd: dict[str, Tensor], num_layers: int, template_sd: dict[str, Tensor]) -> dict[str, Tensor]: - """Convert individual 2D tensors back into 3D bank tensors.""" - out: dict[str, Tensor] = {} - n = num_layers - # Reconstruct banks from individual weight keys - qo_slices = [None] * (2 * n) - kv_slices = [None] * (2 * n) - up_slices = [None] * n - down_slices = [None] * n - consumed = set() - for i in range(n): - qk = f"blocks.{i}.attn.c_q.weight" - if qk in sd: - qo_slices[i] = sd[qk] - consumed.add(qk) - ok = f"blocks.{i}.attn.proj.weight" - if ok in sd: - qo_slices[n + i] = sd[ok] - consumed.add(ok) - kk = f"blocks.{i}.attn.c_k.weight" - if kk in sd: - kv_slices[i] = sd[kk] - consumed.add(kk) - vk = f"blocks.{i}.attn.c_v.weight" - if vk in sd: - kv_slices[n + i] = sd[vk] - consumed.add(vk) - fk = f"blocks.{i}.mlp.fc.weight" - if fk in sd: - up_slices[i] = sd[fk] - consumed.add(fk) - dk = f"blocks.{i}.mlp.proj.weight" - if dk in sd: - down_slices[i] = sd[dk] - consumed.add(dk) - out["qo_bank"] = torch.stack(qo_slices).to(dtype=template_sd["qo_bank"].dtype) - out["kv_bank"] = torch.stack(kv_slices).to(dtype=template_sd["kv_bank"].dtype) - out["mlp_up_bank"] = torch.stack(up_slices).to(dtype=template_sd["mlp_up_bank"].dtype) - out["mlp_down_bank"] = torch.stack(down_slices).to(dtype=template_sd["mlp_down_bank"].dtype) - for name, tensor in sd.items(): - if name not in consumed: - out[name] = tensor - return out - -# --- Non-banked model for Hessian collection --- -# This mirrors the unbanked state dict keys: blocks.{i}.attn.c_q/c_k/c_v/proj, blocks.{i}.mlp.fc/proj - -class _HessianAttn(nn.Module): - """Non-banked attention with CastedLinear layers for Hessian hooks.""" - def __init__(self, dim, num_heads, num_kv_heads, rope_base, qk_gain_init): - super().__init__() - self.num_heads, self.num_kv_heads = num_heads, num_kv_heads - self.head_dim = dim // num_heads - kv_dim = num_kv_heads * self.head_dim - self.c_q = CastedLinear(dim, dim, bias=False) - self.c_k = CastedLinear(dim, kv_dim, bias=False) - self.c_v = CastedLinear(dim, kv_dim, bias=False) - self.proj = CastedLinear(dim, dim, bias=False) - self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) - self.rope_dims = 0 - self.rotary = Rotary(self.head_dim, base=rope_base, train_seq_len=1024) - self.use_xsa = False - def _xsa_efficient(self, y, v): - B, T, H, D = y.shape; Hkv = v.size(-2); group = H // Hkv - y_g = y.reshape(B, T, Hkv, group, D) - vn = F.normalize(v, dim=-1).unsqueeze(-2) - proj = (y_g * vn).sum(dim=-1, keepdim=True) * vn - return (y_g - proj).reshape(B, T, H, D) - def forward(self, x, v_embed=None): - bsz, seqlen, dim = x.shape - q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim) - k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) - v = self.c_v(x) - if v_embed is not None: - v = v + v_embed - v = v.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) - q = F.rms_norm(q, (q.size(-1),)) - k = F.rms_norm(k, (k.size(-1),)) - cos, sin = self.rotary(seqlen, x.device, q.dtype) - q = apply_rotary_emb(q, cos, sin, self.rope_dims) - k = apply_rotary_emb(k, cos, sin, self.rope_dims) - q = q * self.q_gain.to(dtype=q.dtype)[None, None, :, None] - y = flash_attn_3_func(q, k, v, causal=True) - if self.use_xsa: - y = self._xsa_efficient(y, v) - return self.proj(y.reshape(bsz, seqlen, dim)) - -class _HessianMLP(nn.Module): - """Non-banked MLP with CastedLinear layers for Hessian hooks.""" - def __init__(self, dim, mlp_mult): - super().__init__() - self.fc = CastedLinear(dim, int(mlp_mult * dim), bias=False) - self.proj = CastedLinear(int(mlp_mult * dim), dim, bias=False) - def forward(self, x): - return self.proj(F.leaky_relu(self.fc(x), negative_slope=0.5).square()) - -class _HessianBlock(nn.Module): - def __init__(self, dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init, layer_idx=0, ln_scale=False): - super().__init__() - self.attn_norm = RMSNorm() - self.mlp_norm = RMSNorm() - self.attn = _HessianAttn(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) - self.mlp = _HessianMLP(dim, mlp_mult) - self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) - self.ln_scale_factor = 1.0 / math.sqrt(layer_idx + 1) if ln_scale else 1.0 - def forward(self, x, x0, v_embed=None): - mix = self.resid_mix.to(dtype=x.dtype) - x_in = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 - attn_out = self.attn(self.attn_norm(x_in) * self.ln_scale_factor, v_embed=v_embed) - x_out = x_in + self.attn_scale.to(dtype=x_in.dtype)[None, None, :] * attn_out - x_out = x_out + self.mlp_scale.to(dtype=x_out.dtype)[None, None, :] * self.mlp(self.mlp_norm(x_out) * self.ln_scale_factor) - return x_out - -class _HessianGPT(nn.Module): - """Non-banked GPT model matching unbanked state dict keys for Hessian collection.""" - def __init__(self, vocab_size, num_layers, model_dim, num_heads, num_kv_heads, - mlp_mult, tie_embeddings, logit_softcap, rope_base, qk_gain_init, - bigram_vocab_size=0, bigram_dim=128, xsa_last_n=0, - rope_dims=0, ln_scale=False, - ve_enabled=False, ve_dim=128, ve_layers="9,10"): - super().__init__() - self.tie_embeddings = tie_embeddings - self.logit_softcap = logit_softcap - self.num_layers = num_layers - self.tok_emb = nn.Embedding(vocab_size, model_dim) - self.bigram = BigramHashEmbedding(bigram_vocab_size, bigram_dim, model_dim, trigram=bool(int(os.environ.get("TRIGRAM", "0")))) if bigram_vocab_size > 0 else None - self.smear = SmearGate(model_dim) - self.num_encoder_layers = num_layers // 2 - self.num_decoder_layers = num_layers - self.num_encoder_layers - self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) - self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) - self.blocks = nn.ModuleList([ - _HessianBlock(model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init, - layer_idx=i, ln_scale=ln_scale) - for i in range(num_layers) - ]) - if rope_dims > 0: - head_dim = model_dim // num_heads - for block in self.blocks: - block.attn.rope_dims = rope_dims - block.attn.rotary = Rotary(head_dim, base=rope_base, train_seq_len=1024, rope_dims=rope_dims) - if xsa_last_n > 0: - for i in range(max(0, num_layers - xsa_last_n), num_layers): - self.blocks[i].attn.use_xsa = True - kv_dim = num_kv_heads * (model_dim // num_heads) - self.ve_layer_indices = [int(x) for x in ve_layers.split(",") if x.strip()] if ve_enabled else [] - if self.ve_layer_indices: - self.ve_shared = ValueEmbedding(vocab_size, ve_dim, kv_dim) - self.ve_layer_scales = nn.ParameterList([nn.Parameter(torch.ones(1, dtype=torch.float32)) for _ in self.ve_layer_indices]) - else: - self.ve_shared = None - self.ve_layer_scales = nn.ParameterList() - self.final_norm = RMSNorm() - self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) - def _get_ve(self, layer_idx, input_ids, ve_cache): - if self.ve_shared is None or layer_idx not in self.ve_layer_indices: - return None - if 've' not in ve_cache: - ve_cache['ve'] = self.ve_shared(input_ids) - ve_idx = self.ve_layer_indices.index(layer_idx) - return ve_cache['ve'] * self.ve_layer_scales[ve_idx].to(dtype=ve_cache['ve'].dtype) - def forward(self, input_ids, target_ids): - x = self.tok_emb(input_ids) - if self.bigram is not None: - x = x + self.bigram(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - x = self.smear(x) - x0 = x - skips = [] - ve_cache = {} - for i in range(self.num_encoder_layers): - ve = self._get_ve(i, input_ids, ve_cache) - x = self.blocks[i](x, x0, v_embed=ve) - skips.append(x) - for i in range(self.num_decoder_layers): - bi = self.num_encoder_layers + i - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - ve = self._get_ve(bi, input_ids, ve_cache) - x = self.blocks[bi](x, x0, v_embed=ve) - x = self.final_norm(x) - x_flat = x.reshape(-1, x.size(-1)) - targets = target_ids.reshape(-1) - logits_proj = F.linear(x_flat, self.tok_emb.weight) if self.tie_embeddings else self.lm_head(x_flat) - logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - return F.cross_entropy(logits.float(), targets, reduction="mean") - -def collect_hessians(hessian_model, train_loader, args, device, grad_accum_steps, num_batches=256): - """Run calibration batches through a non-banked model, collecting H = X^T X for each CastedLinear.""" - hessians = {} - hooks = [] - for name, module in hessian_model.named_modules(): - if isinstance(module, CastedLinear): - param_name = name + ".weight" - cols = module.weight.shape[1] - hessians[param_name] = torch.zeros(cols, cols, dtype=torch.float32, device='cpu') - def make_hook(pname): - def hook_fn(module, input, output): - x = input[0].detach().float() - if x.ndim == 3: - x = x.reshape(-1, x.shape[-1]) - hessians[pname] += (x.T @ x).cpu() - return hook_fn - h = module.register_forward_hook(make_hook(param_name)) - hooks.append(h) - hessian_model.eval() - with torch.inference_mode(), torch.autocast(device_type="cuda", dtype=torch.bfloat16): - for _ in range(num_batches): - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - hessian_model(x, y) - for h in hooks: - h.remove() - for name in hessians: - H = hessians[name] - H /= num_batches - damp = 0.01 * torch.diag(H).mean().clamp_min(1e-6) - H += damp * torch.eye(H.shape[0]) - hessians[name] = H - hessian_model.train() - return hessians - -def mixed_quantize_int6(state_dict: dict[str, Tensor], int6_cats: set[str], hessians: dict[str, Tensor] | None = None): - num_layers_total = max( - (int(k.split(".")[1]) for k in state_dict if k.startswith("blocks.")), - default=0, - ) + 1 - late_k_layers = set(range(num_layers_total - 2, num_layers_total)) - result: dict[str, Tensor] = {} - meta: dict[str, object] = {} - for name, tensor in state_dict.items(): - t = tensor.detach().cpu().contiguous() - cat = _classify_param(name) - if not t.is_floating_point() or t.numel() <= 65536: - result[name] = t.to(torch.float16) if t.is_floating_point() else t - meta[name] = "passthrough" - continue - if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): - result[name] = t.float() - meta[name] = "passthrough_ctrl" - continue - if cat in int6_cats and t.ndim >= 1: - cr = 31 # int6 for all weights - H = hessians.get(name) if hessians else None - if H is not None: - q, s = quantize_int6_gptq(t, hessian=H, clip_range=cr) - else: - q, s = quantize_int6_per_row(t, clip_range=cr) - result[name + ".q"] = q - result[name + ".scale"] = s - meta[name] = {"type": "int6"} - else: - q, s = quantize_float_tensor(t) - result[name + ".q"] = q - result[name + ".scale"] = s - meta[name] = {"type": "int8"} - return result, meta -def dequantize_mixed_int6(result: dict[str, Tensor], meta: dict[str, object], - template_sd: dict[str, Tensor]) -> dict[str, Tensor]: - out: dict[str, Tensor] = {} - for name, orig in template_sd.items(): - info = meta.get(name) - if info is None: - continue - orig_dtype = orig.dtype - if info in ("passthrough", "passthrough_ctrl", "passthrough_fp16"): - t = result[name] - if t.dtype == torch.float16 and orig_dtype in (torch.float32, torch.bfloat16): - t = t.to(orig_dtype) - out[name] = t - continue - q, s = result[name + ".q"], result[name + ".scale"] - if s.ndim > 0: - out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig_dtype) - else: - out[name] = (q.float() * float(s.item())).to(orig_dtype) - return out - -# --- Training --- - -def main() -> None: - code = Path(__file__).read_text(encoding="utf-8") - args = Hyperparameters() - # zeropower_via_newtonschulz5 runs eagerly with bmm -- do NOT compile - distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ - rank = int(os.environ.get("RANK", "0")) - world_size = int(os.environ.get("WORLD_SIZE", "1")) - local_rank = int(os.environ.get("LOCAL_RANK", "0")) - if world_size <= 0: - raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") - if 8 % world_size != 0: - raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") - grad_accum_steps = 8 // world_size - grad_scale = 1.0 / grad_accum_steps - if not torch.cuda.is_available(): - raise RuntimeError("CUDA is required") - device = torch.device("cuda", local_rank) - torch.cuda.set_device(device) - if distributed: - dist.init_process_group(backend="nccl", device_id=device) - dist.barrier() - master_process = rank == 0 - torch.backends.cuda.matmul.allow_tf32 = True - torch.backends.cudnn.allow_tf32 = True - from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp - enable_cudnn_sdp(False) - enable_flash_sdp(True) - enable_mem_efficient_sdp(False) - enable_math_sdp(False) - logfile = None - if master_process: - os.makedirs("logs", exist_ok=True) - logfile = f"logs/{args.run_id}.txt" - print(logfile) - def log0(msg: str, console: bool = True) -> None: - if not master_process: - return - if console: - print(msg) - if logfile is not None: - with open(logfile, "a", encoding="utf-8") as f: - print(msg, file=f) - log0(code, console=False) - log0("=" * 100, console=False) - log0(f"Running Python {sys.version}", console=False) - log0(f"Running PyTorch {torch.__version__}", console=False) - log0( - subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, - console=False, - ) - log0("=" * 100, console=False) - random.seed(args.seed) - np.random.seed(args.seed) - torch.manual_seed(args.seed) - torch.cuda.manual_seed_all(args.seed) - if not args.tokenizer_path.endswith(".model"): - raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") - sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) - if int(sp.vocab_size()) != args.vocab_size: - raise ValueError( - f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" - ) - dataset_dir = Path(args.data_path).resolve() - actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) - effective_eval_seq_len = args.eval_seq_len if args.eval_seq_len > 0 else args.train_seq_len - val_seq_len = max(args.train_seq_len, effective_eval_seq_len) - val_tokens = load_validation_tokens(args.val_files, val_seq_len) - base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( - sp, args.vocab_size, device - ) - log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") - log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") - log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") - CastedLinear._qat_enabled = args.qat_enabled - base_model = GPT( - vocab_size=args.vocab_size, - num_layers=args.num_layers, - model_dim=args.model_dim, - num_heads=args.num_heads, - num_kv_heads=args.num_kv_heads, - mlp_mult=args.mlp_mult, - tie_embeddings=args.tie_embeddings, - tied_embed_init_std=args.tied_embed_init_std, - logit_softcap=args.logit_softcap, - rope_base=args.rope_base, - qk_gain_init=args.qk_gain_init, - mtp_num_heads=args.mtp_num_heads, - mtp_loss_weight=args.mtp_loss_weight, - bigram_vocab_size=args.bigram_vocab_size, - bigram_dim=args.bigram_dim, - xsa_last_n=args.xsa_last_n, - rope_dims=args.rope_dims, - ln_scale=args.ln_scale, - dtg=args.dtg_enabled, - ve_enabled=args.ve_enabled, - ve_dim=args.ve_dim, - ve_layers=args.ve_layers, - gated_attention=args.gated_attention, - value_residual=args.value_residual, - ).to(device).bfloat16() - # Banks stay FP32 (like CastedLinear weights), cast to BF16 in forward - base_model.qo_bank.data = base_model.qo_bank.data.float() - base_model.kv_bank.data = base_model.kv_bank.data.float() - base_model.mlp_up_bank.data = base_model.mlp_up_bank.data.float() - base_model.mlp_down_bank.data = base_model.mlp_down_bank.data.float() - for module in base_model.modules(): - if isinstance(module, CastedLinear): - module.float() - restore_low_dim_params_to_fp32(base_model) - # No DDP -- Parallel Muon handles bank grad communication via reduce-scatter, - # and non-bank grads are manually all-reduced before Adam steps. - compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) - model = compiled_model - - # Optimizer split: - # - 4 parameter banks -> Muon (batched Newton-Schulz) - # - token embedding -> Adam - # - scalars/control tensors -> Adam - # - bigram proj, mtp heads, VE proj -> Adam (small matrix params not worth banking) - matrix_params = [ - base_model.qo_bank, base_model.kv_bank, - base_model.mlp_up_bank, base_model.mlp_down_bank, - ] - block_named_params = list(base_model.blocks.named_parameters()) - scalar_params = [ - p - for name, p in block_named_params - if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - if base_model.skip_weights.numel() > 0: - scalar_params.append(base_model.skip_weights) - scalar_params.append(base_model.smear.gate) - if base_model.bigram is not None: - scalar_params.append(base_model.bigram.scale) - token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr - tok_params = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] - if base_model.bigram is not None: - tok_params.append({"params": [base_model.bigram.embed.weight], "lr": token_lr, "base_lr": token_lr}) - if base_model.bigram.proj is not None: - scalar_params.append(base_model.bigram.proj.weight) - if base_model.ve_shared is not None: - tok_params.append({"params": [base_model.ve_shared.embed.weight], "lr": token_lr, "base_lr": token_lr}) - if base_model.ve_shared.proj is not None: - scalar_params.append(base_model.ve_shared.proj.weight) - scalar_params.append(base_model.ve_shared.scale) - for s in base_model.ve_layer_scales: - scalar_params.append(s) - optimizer_tok = torch.optim.AdamW( - tok_params, - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - weight_decay=args.adam_wd, - fused=True, - ) - optimizer_muon = Muon( - matrix_params, - lr=args.matrix_lr, - momentum=args.muon_momentum, - backend_steps=args.muon_backend_steps, - weight_decay=args.muon_wd, - ) - for group in optimizer_muon.param_groups: - group["base_lr"] = args.matrix_lr - optimizer_scalar = torch.optim.AdamW( - [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - weight_decay=args.adam_wd, - fused=True, - ) - # Non-bank params that need manual all-reduce (replicated across GPUs) - replicated_params = list(optimizer_tok.param_groups[0]["params"]) - for pg in optimizer_tok.param_groups[1:]: - replicated_params.extend(pg["params"]) - replicated_params.extend(scalar_params) - - optimizer_head = None - if base_model.lm_head is not None: - optimizer_head = torch.optim.Adam( - [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - replicated_params.append(base_model.lm_head.weight) - optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] - if optimizer_head is not None: - optimizers.append(optimizer_head) - n_params = sum(p.numel() for p in base_model.parameters()) - mtp_params = sum(p.numel() for p in base_model.mtp_heads.parameters()) - log0(f"model_params:{n_params}") - log0(f"mtp_num_heads:{args.mtp_num_heads} mtp_loss_weight:{args.mtp_loss_weight} mtp_params:{mtp_params}") - xsa_layers = [i for i, b in enumerate(base_model.blocks) if b.attn.use_xsa] - log0(f"XSA:last_{args.xsa_last_n} active_layers:{xsa_layers}") - log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") - log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") - log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") - log0( - f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " - f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " - f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" - ) - log0( - f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " - f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " - f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" - ) - log0(f"seed:{args.seed}") - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - def zero_grad_all() -> None: - for opt in optimizers: - opt.zero_grad(set_to_none=True) - max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None - def lr_mul(step: int, elapsed_ms: float) -> float: - if args.warmdown_iters <= 0: - return 1.0 - if max_wallclock_ms is None: - warmdown_start = max(args.iterations - args.warmdown_iters, 0) - return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 - step_ms = elapsed_ms / max(step, 1) - warmdown_ms = args.warmdown_iters * step_ms - remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) - return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 - if args.warmup_steps > 0: - initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} - initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] - model.train() - for warmup_step in range(args.warmup_steps): - zero_grad_all() - for micro_step in range(grad_accum_steps): - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - warmup_loss = model(x, y) - (warmup_loss * grad_scale).backward() - # All-reduce all grads for warmup (simple, not optimized) - if distributed: - for p in base_model.parameters(): - if p.grad is not None: - dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) - for opt in optimizers: - opt.step() - zero_grad_all() - if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: - log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") - base_model.load_state_dict(initial_model_state, strict=True) - for opt, state in zip(optimizers, initial_optimizer_states, strict=True): - opt.load_state_dict(state) - zero_grad_all() - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - swa_state: dict[str, Tensor] | None = None - swa_count = 0 - from collections import deque - lawa_queue: deque[dict[str, Tensor]] = deque(maxlen=args.lawa_k) - ema_state = {name: t.detach().float().clone() for name, t in base_model.state_dict().items()} - ema_decay = 0.997 - training_time_ms = 0.0 - stop_after_step: int | None = None - torch.cuda.synchronize() - t0 = time.perf_counter() - step = 0 - while True: - last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) - should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) - if should_validate: - torch.cuda.synchronize() - training_time_ms += 1000.0 * (time.perf_counter() - t0) - val_loss, val_bpb = eval_val( - args, - model, - rank, - world_size, - device, - grad_accum_steps, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, - ) - log0( - f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " - f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" - ) - torch.cuda.synchronize() - t0 = time.perf_counter() - if last_step: - if stop_after_step is not None and step < args.iterations: - log0( - f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " - f"step:{step}/{args.iterations}" - ) - break - elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - scale = lr_mul(step, elapsed_ms) - if args.late_qat_threshold > 0 and scale < args.late_qat_threshold and not CastedLinear._qat_enabled: - CastedLinear._qat_enabled = True - log0(f"late_qat:enabled step:{step} scale:{scale:.4f}") - zero_grad_all() - train_loss = torch.zeros((), device=device) - for micro_step in range(grad_accum_steps): - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - loss = model(x, y) - train_loss += loss.detach() - (loss * grad_scale).backward() - train_loss /= grad_accum_steps - frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 - muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum - for group in optimizer_muon.param_groups: - group["momentum"] = muon_momentum - for opt in optimizers: - for group in opt.param_groups: - group["lr"] = group["base_lr"] * scale - if args.grad_clip_norm > 0: - torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) - # === 3-phase overlapped optimizer step === - # Phase 1: Launch async reduce-scatter for banks (biggest first) - optimizer_muon.launch_reduce_scatters() - # Phase 2: All-reduce non-bank grads + step Adam (while bank RS is in-flight) - if distributed: - for p in replicated_params: - if p.grad is not None: - dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) - optimizer_tok.step() - optimizer_scalar.step() - if optimizer_head is not None: - optimizer_head.step() - # Phase 3: Wait for RS, local NS5, all-gather (banks processed last) - optimizer_muon.step() - zero_grad_all() - # EMA update - with torch.no_grad(): - for name, t in base_model.state_dict().items(): - ema_state[name].mul_(ema_decay).add_(t.detach().float(), alpha=1.0 - ema_decay) - step += 1 - approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - if args.swa_enabled and scale < 0.2 and step % args.swa_every == 0: - if swa_state is None: - swa_state = {name: t.detach().cpu().clone() for name, t in base_model.state_dict().items()} - swa_count = 1 - log0(f"swa:start step:{step}") - else: - for name, t in base_model.state_dict().items(): - swa_state[name] += t.detach().cpu() - swa_count += 1 - if args.lawa_enabled and step % args.lawa_freq == 0: - lawa_queue.append({name: t.detach().cpu().clone() for name, t in base_model.state_dict().items()}) - should_log_train = ( - args.train_log_every > 0 - and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) - ) - if should_log_train: - log0( - f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " - f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" - ) - reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms - if distributed and max_wallclock_ms is not None: - reached_cap_tensor = torch.tensor(int(reached_cap), device=device) - dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) - reached_cap = bool(reached_cap_tensor.item()) - if stop_after_step is None and reached_cap: - stop_after_step = step - log0( - f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " - f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" - ) - # Apply weight averaging - if args.lawa_enabled and len(lawa_queue) > 1: - log0(f"lawa:applying LAWA averaging k={len(lawa_queue)}") - current_state = base_model.state_dict() - avg_state = {name: torch.zeros(t.shape, dtype=torch.float32, device='cpu') for name, t in current_state.items()} - for snap in lawa_queue: - for name in avg_state: - avg_state[name] += snap[name].float() - for name in avg_state: - avg_state[name] /= len(lawa_queue) - avg_state[name] = avg_state[name].to(dtype=current_state[name].dtype) - base_model.load_state_dict(avg_state, strict=True) - else: - log0("ema:applying EMA weights") - current_state = base_model.state_dict() - avg_state = {name: t.to(dtype=current_state[name].dtype) for name, t in ema_state.items()} - base_model.load_state_dict(avg_state, strict=True) - torch.cuda.synchronize() - t_diag = time.perf_counter() - diag_val_loss, diag_val_bpb = eval_val( - args, compiled_model, rank, world_size, device, grad_accum_steps, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - ) - torch.cuda.synchronize() - log0( - f"DIAGNOSTIC post_ema val_loss:{diag_val_loss:.4f} val_bpb:{diag_val_bpb:.4f} " - f"eval_time:{1000.0 * (time.perf_counter() - t_diag):.0f}ms" - ) - full_state_dict = base_model.state_dict() - export_sd = {k: v for k, v in full_state_dict.items() if "mtp_heads" not in k} - excluded_mtp = sum(int(t.numel()) for k, t in full_state_dict.items() if "mtp_heads" in k) - if excluded_mtp > 0: - log0(f"export_excluding_mtp_params:{excluded_mtp}") - if master_process: - torch.save(export_sd, "final_model.pt") - model_bytes = os.path.getsize("final_model.pt") - code_bytes = len(code.encode("utf-8")) - log0(f"Serialized model: {model_bytes} bytes") - log0(f"Code size: {code_bytes} bytes") - # Unbank 3D tensors into individual 2D tensors for quantization - sd_cpu = {k: v.detach().cpu() for k, v in export_sd.items()} - unbanked_sd = _unbank_state_dict(sd_cpu, args.num_layers) - # Full GPTQ: collect Hessians via a temporary non-banked model - log0(f"gptq:building non-banked model for Hessian collection...") - hessian_model = _HessianGPT( - vocab_size=args.vocab_size, num_layers=args.num_layers, model_dim=args.model_dim, - num_heads=args.num_heads, num_kv_heads=args.num_kv_heads, mlp_mult=args.mlp_mult, - tie_embeddings=args.tie_embeddings, logit_softcap=args.logit_softcap, - rope_base=args.rope_base, qk_gain_init=args.qk_gain_init, - bigram_vocab_size=args.bigram_vocab_size, bigram_dim=args.bigram_dim, - xsa_last_n=args.xsa_last_n, rope_dims=args.rope_dims, ln_scale=args.ln_scale, - ve_enabled=args.ve_enabled, ve_dim=args.ve_dim, ve_layers=args.ve_layers, - ).to(device).bfloat16() - for m in hessian_model.modules(): - if isinstance(m, CastedLinear): - m.float() - restore_low_dim_params_to_fp32(hessian_model) - # Load unbanked weights into the non-banked model - hessian_model.load_state_dict( - {k: v.to(device) for k, v in unbanked_sd.items() if k in hessian_model.state_dict()}, - strict=False, - ) - # Autoregressive self-generated calibration (no external data) - log0("gptq:generating autoregressive calibration data (64 seqs x 2048 tokens, temp=0.8)...") - base_model.load_state_dict(export_sd, strict=False) - t_gen = time.perf_counter() - ar_tokens = generate_autoregressive_calib( - base_model, device, num_seqs=64, seq_len=args.train_seq_len, - vocab_size=args.vocab_size, temperature=0.8, batch_size=8, seed=args.seed, - ) - log0(f"gptq:generated {len(ar_tokens)} sequences in {time.perf_counter()-t_gen:.1f}s") - log0("gptq:collecting hessians from autoregressive data...") - hessians = collect_hessians_from_tokens(hessian_model, ar_tokens, device) - log0(f"gptq:collected hessians for {len(hessians)} layers (AR self-gen)") - del ar_tokens - del hessian_model - torch.cuda.empty_cache() - quant_result, quant_meta = mixed_quantize_int6(unbanked_sd, {"mlp", "attn"}, hessians=hessians) - # NOVEL: Selective ±1 pruning by reconstruction error - # Sort ±1 quantized values by their reconstruction error (scale²), - # prune least-impactful first until artifact fits target size. - target_mb = float(os.environ.get("TARGET_MB", "15.9")) - code_bytes_est = len(code.encode("utf-8")) - ones_info = [] # (tensor_key, flat_idx, error) - for name, info in quant_meta.items(): - if not (isinstance(info, dict) and info.get("type") == "int6"): continue - qk, sk = name + ".q", name + ".scale" - if qk not in quant_result or sk not in quant_result: continue - q, s = quant_result[qk], quant_result[sk] - if s.ndim > 0: - ones_mask = (q.abs() == 1) - if ones_mask.any(): - row_idx = torch.arange(q.shape[0]).unsqueeze(1).expand_as(q)[ones_mask] - flat_idx = torch.arange(q.numel()).reshape(q.shape)[ones_mask] - errors = s.float()[row_idx].pow(2) - for fi, err in zip(flat_idx.tolist(), errors.tolist()): - ones_info.append((qk, fi, err)) - if ones_info: - ones_info.sort(key=lambda x: x[2]) - def _try_prune(n): - tmp = {k: v.clone() for k, v in quant_result.items()} - for i in range(min(n, len(ones_info))): - tmp[ones_info[i][0]].view(-1)[ones_info[i][1]] = 0 - buf = io.BytesIO(); torch.save({"w": tmp, "m": quant_meta}, buf) - return len(lzma.compress(buf.getvalue(), preset=9)) + code_bytes_est, tmp - no_sz, _ = _try_prune(0) - target_bytes = int(target_mb * 1024 * 1024) - log0(f"selective_prune: {len(ones_info)} ±1 candidates, unpruned={no_sz/(1024*1024):.2f}MB target={target_mb}MB") - if no_sz <= target_bytes: - log0("selective_prune: already fits, no pruning needed") - else: - full_sz, _ = _try_prune(len(ones_info)) - log0(f"selective_prune: full ±1 prune={full_sz/(1024*1024):.2f}MB") - if full_sz > target_bytes: - log0("selective_prune: even full prune not enough, applying all") - _, quant_result = _try_prune(len(ones_info)) - else: - lo, hi = 0, len(ones_info) - while lo < hi: - mid = (lo + hi) // 2 - sz, _ = _try_prune(mid) - if sz <= target_bytes: hi = mid - else: lo = mid + 1 - log0(f"selective_prune: pruning {lo}/{len(ones_info)} ±1 values ({100*lo/len(ones_info):.1f}%) to fit {target_mb}MB") - _, quant_result = _try_prune(lo) - quant_buf = io.BytesIO() - torch.save({"w": quant_result, "m": quant_meta}, quant_buf) - quant_raw = quant_buf.getvalue() - quant_blob = lzma.compress(quant_raw, preset=9) - if master_process: - with open("final_model.int6.ptz", "wb") as f: - f.write(quant_blob) - quant_file_bytes = len(quant_blob) - code_bytes = len(code.encode("utf-8")) - log0(f"Serialized model int6+lzma: {quant_file_bytes} bytes") - log0(f"Total submission size int6+lzma: {quant_file_bytes + code_bytes} bytes") - if distributed: - dist.barrier() - with open("final_model.int6.ptz", "rb") as f: - quant_blob_disk = f.read() - quant_state = torch.load( - io.BytesIO(lzma.decompress(quant_blob_disk)), - map_location="cpu", - ) - deq_unbanked = dequantize_mixed_int6(quant_state["w"], quant_state["m"], unbanked_sd) - # Re-bank the dequantized tensors - deq_state = _rebank_state_dict(deq_unbanked, args.num_layers, sd_cpu) - eval_model = GPT( - vocab_size=args.vocab_size, num_layers=args.num_layers, model_dim=args.model_dim, - num_heads=args.num_heads, num_kv_heads=args.num_kv_heads, mlp_mult=args.mlp_mult, - tie_embeddings=args.tie_embeddings, tied_embed_init_std=args.tied_embed_init_std, - logit_softcap=args.logit_softcap, rope_base=args.rope_base, qk_gain_init=args.qk_gain_init, - mtp_num_heads=0, mtp_loss_weight=0.0, - bigram_vocab_size=args.bigram_vocab_size, bigram_dim=args.bigram_dim, - xsa_last_n=args.xsa_last_n, - rope_dims=args.rope_dims, ln_scale=args.ln_scale, dtg=args.dtg_enabled, - ve_enabled=args.ve_enabled, ve_dim=args.ve_dim, ve_layers=args.ve_layers, - gated_attention=args.gated_attention, value_residual=args.value_residual, - ).to(device).bfloat16() - eval_model.qo_bank.data = eval_model.qo_bank.data.float() - eval_model.kv_bank.data = eval_model.kv_bank.data.float() - eval_model.mlp_up_bank.data = eval_model.mlp_up_bank.data.float() - eval_model.mlp_down_bank.data = eval_model.mlp_down_bank.data.float() - for m in eval_model.modules(): - if isinstance(m, CastedLinear): - m.float() - restore_low_dim_params_to_fp32(eval_model) - eval_model.load_state_dict(deq_state, strict=True) - compiled_eval = torch.compile(eval_model, dynamic=False, fullgraph=True) - torch.cuda.synchronize() - t_qeval = time.perf_counter() - q_val_loss, q_val_bpb = eval_val( - args, compiled_eval, rank, world_size, device, grad_accum_steps, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - eval_seq_len=effective_eval_seq_len, - ) - torch.cuda.synchronize() - log0( - f"final_int6_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " - f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" - ) - log0(f"final_int6_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") - sw_seq_len = effective_eval_seq_len - if args.eval_stride > 0 and args.eval_stride < sw_seq_len: - torch.cuda.synchronize() - t_slide = time.perf_counter() - sw_val_loss, sw_val_bpb = eval_val_sliding( - args, eval_model, rank, world_size, device, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - stride=args.eval_stride, - eval_seq_len=sw_seq_len, - ) - torch.cuda.synchronize() - log0( - f"final_int6_sliding_window val_loss:{sw_val_loss:.4f} val_bpb:{sw_val_bpb:.4f} " - f"stride:{args.eval_stride} eval_time:{1000.0 * (time.perf_counter() - t_slide):.0f}ms" - ) - log0(f"final_int6_sliding_window_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") - log0(f"final_int8_zlib_roundtrip_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") - if args.eval_stride != 64 and 64 < sw_seq_len: - torch.cuda.synchronize() - t_slide64 = time.perf_counter() - sw64_val_loss, sw64_val_bpb = eval_val_sliding( - args, eval_model, rank, world_size, device, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - stride=64, - eval_seq_len=sw_seq_len, - ) - torch.cuda.synchronize() - log0( - f"final_int6_sliding_window_s64 val_loss:{sw64_val_loss:.4f} val_bpb:{sw64_val_bpb:.4f} " - f"stride:64 eval_time:{1000.0 * (time.perf_counter() - t_slide64):.0f}ms" - ) - log0(f"final_int6_sliding_window_s64_exact val_loss:{sw64_val_loss:.8f} val_bpb:{sw64_val_bpb:.8f}") - log0(f"final_int8_zlib_roundtrip_exact val_loss:{sw64_val_loss:.8f} val_bpb:{sw64_val_bpb:.8f}") - if distributed: - dist.destroy_process_group() -if __name__ == "__main__": - main() diff --git a/records/track_10min_16mb/2026-03-25_ValCalib_GPTQ_XSA_BigramHash3072/train_seed314.log b/records/track_10min_16mb/2026-03-25_ValCalib_GPTQ_XSA_BigramHash3072/train_seed314.log deleted file mode 100644 index 8375b35f35..0000000000 --- a/records/track_10min_16mb/2026-03-25_ValCalib_GPTQ_XSA_BigramHash3072/train_seed314.log +++ /dev/null @@ -1,85 +0,0 @@ -W0326 20:30:26.730000 8512 torch/distributed/run.py:803] -W0326 20:30:26.730000 8512 torch/distributed/run.py:803] ***************************************** -W0326 20:30:26.730000 8512 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. -W0326 20:30:26.730000 8512 torch/distributed/run.py:803] ***************************************** -logs/5434c191-7955-4256-b8bf-1dc361d0d86f.txt -val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model -train_loader:dataset:fineweb10B_sp1024 train_shards:80 -val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 -model_params:27067484 -mtp_num_heads:0 mtp_loss_weight:0.2 mtp_params:0 -XSA:last_11 active_layers:[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] -world_size:8 grad_accum_steps:1 -sdp_backends:cudnn=False flash=True mem_efficient=False math=False -attention_mode:gqa num_heads:8 num_kv_heads:4 -tie_embeddings:True embed_lr:0.035 head_lr:0.0 matrix_lr:0.025 scalar_lr:0.025 -train_batch_tokens:786432 train_seq_len:2048 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 -seed:314 -warmup_step:1/20 -warmup_step:2/20 -warmup_step:3/20 -warmup_step:4/20 -warmup_step:5/20 -warmup_step:6/20 -warmup_step:7/20 -warmup_step:8/20 -warmup_step:9/20 -warmup_step:10/20 -warmup_step:11/20 -warmup_step:12/20 -warmup_step:13/20 -warmup_step:14/20 -warmup_step:15/20 -warmup_step:16/20 -warmup_step:17/20 -warmup_step:18/20 -warmup_step:19/20 -warmup_step:20/20 -step:0/20000 val_loss:6.9271 val_bpb:4.1026 train_time:0ms step_avg:0.01ms -step:1/20000 train_loss:6.9298 train_time:135ms step_avg:134.76ms -step:2/20000 train_loss:8.6135 train_time:165ms step_avg:82.66ms -step:3/20000 train_loss:7.6124 train_time:249ms step_avg:82.96ms -step:4/20000 train_loss:7.3645 train_time:333ms step_avg:83.23ms -step:5/20000 train_loss:7.1467 train_time:417ms step_avg:83.41ms -step:6/20000 train_loss:7.0060 train_time:501ms step_avg:83.55ms -step:7/20000 train_loss:6.9248 train_time:586ms step_avg:83.76ms -step:8/20000 train_loss:6.7919 train_time:671ms step_avg:83.85ms -step:9/20000 train_loss:6.4482 train_time:755ms step_avg:83.91ms -step:10/20000 train_loss:6.0553 train_time:839ms step_avg:83.95ms -step:500/20000 train_loss:2.3787 train_time:42942ms step_avg:85.88ms -step:1000/20000 train_loss:2.2509 train_time:86053ms step_avg:86.05ms -step:1500/20000 train_loss:2.1982 train_time:129210ms step_avg:86.14ms -step:2000/20000 train_loss:2.0412 train_time:172475ms step_avg:86.24ms -step:2500/20000 train_loss:2.1464 train_time:215777ms step_avg:86.31ms -step:3000/20000 train_loss:2.1423 train_time:259072ms step_avg:86.36ms -step:3500/20000 train_loss:2.1495 train_time:302369ms step_avg:86.39ms -step:4000/20000 train_loss:1.9433 train_time:345683ms step_avg:86.42ms -step:4000/20000 val_loss:2.0348 val_bpb:1.2051 train_time:345740ms step_avg:86.43ms -step:4500/20000 train_loss:2.0982 train_time:388997ms step_avg:86.44ms -step:5000/20000 train_loss:2.0805 train_time:432313ms step_avg:86.46ms -step:5500/20000 train_loss:1.9939 train_time:475594ms step_avg:86.47ms -step:6000/20000 train_loss:1.9209 train_time:518844ms step_avg:86.47ms -swa:start step:6150 -late_qat:enabled step:6335 scale:0.1498 -step:6500/20000 train_loss:2.0612 train_time:562554ms step_avg:86.55ms -step:6927/20000 val_loss:1.9171 val_bpb:1.1354 train_time:600109ms step_avg:86.63ms -stopping_early: wallclock_cap train_time:600109ms step:6927/20000 -peak memory allocated: 22858 MiB reserved: 22924 MiB -ema:applying EMA weights -DIAGNOSTIC post_ema val_loss:1.9155 val_bpb:1.1344 eval_time:2059ms -Serialized model: 106289590 bytes -Code size: 101850 bytes -gptq:building non-banked model for Hessian collection... -gptq:generating autoregressive calibration data (64 seqs x 2048 tokens, temp=0.8)... -gptq:generated 64 sequences in 196.7s -gptq:collecting hessians from autoregressive data... -gptq:collected hessians for 68 layers (AR self-gen) -selective_prune: 4207533 ±1 candidates, unpruned=15.13MB target=15.9MB -selective_prune: already fits, no pruning needed -Serialized model int6+lzma: 15761428 bytes -Total submission size int6+lzma: 15863278 bytes -final_int6_roundtrip val_loss:1.9225 val_bpb:1.1386 eval_time:23007ms -final_int6_roundtrip_exact val_loss:1.92248956 val_bpb:1.13860661 -final_int6_sliding_window val_loss:1.8828 val_bpb:1.1151 stride:64 eval_time:105090ms -final_int6_sliding_window_exact val_loss:1.88276292 val_bpb:1.11508120 -final_int8_zlib_roundtrip_exact val_loss:1.88276292 val_bpb:1.11508120 diff --git a/records/track_10min_16mb/2026-03-25_ValCalib_GPTQ_XSA_BigramHash3072/train_seed42.log b/records/track_10min_16mb/2026-03-25_ValCalib_GPTQ_XSA_BigramHash3072/train_seed42.log deleted file mode 100644 index ca8b176ae7..0000000000 --- a/records/track_10min_16mb/2026-03-25_ValCalib_GPTQ_XSA_BigramHash3072/train_seed42.log +++ /dev/null @@ -1,85 +0,0 @@ -W0326 20:50:06.519000 66486 torch/distributed/run.py:803] -W0326 20:50:06.519000 66486 torch/distributed/run.py:803] ***************************************** -W0326 20:50:06.519000 66486 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. -W0326 20:50:06.519000 66486 torch/distributed/run.py:803] ***************************************** -logs/d1e51d8b-edcf-4543-9c30-8d0636896131.txt -val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model -train_loader:dataset:fineweb10B_sp1024 train_shards:80 -val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 -model_params:27067484 -mtp_num_heads:0 mtp_loss_weight:0.2 mtp_params:0 -XSA:last_11 active_layers:[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] -world_size:8 grad_accum_steps:1 -sdp_backends:cudnn=False flash=True mem_efficient=False math=False -attention_mode:gqa num_heads:8 num_kv_heads:4 -tie_embeddings:True embed_lr:0.035 head_lr:0.0 matrix_lr:0.025 scalar_lr:0.025 -train_batch_tokens:786432 train_seq_len:2048 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 -seed:42 -warmup_step:1/20 -warmup_step:2/20 -warmup_step:3/20 -warmup_step:4/20 -warmup_step:5/20 -warmup_step:6/20 -warmup_step:7/20 -warmup_step:8/20 -warmup_step:9/20 -warmup_step:10/20 -warmup_step:11/20 -warmup_step:12/20 -warmup_step:13/20 -warmup_step:14/20 -warmup_step:15/20 -warmup_step:16/20 -warmup_step:17/20 -warmup_step:18/20 -warmup_step:19/20 -warmup_step:20/20 -step:0/20000 val_loss:6.9307 val_bpb:4.1048 train_time:0ms step_avg:0.01ms -step:1/20000 train_loss:6.9316 train_time:136ms step_avg:135.73ms -step:2/20000 train_loss:8.7430 train_time:166ms step_avg:82.95ms -step:3/20000 train_loss:7.6321 train_time:250ms step_avg:83.24ms -step:4/20000 train_loss:7.2316 train_time:334ms step_avg:83.51ms -step:5/20000 train_loss:7.1692 train_time:420ms step_avg:84.02ms -step:6/20000 train_loss:7.0905 train_time:504ms step_avg:84.04ms -step:7/20000 train_loss:6.9854 train_time:589ms step_avg:84.09ms -step:8/20000 train_loss:6.7960 train_time:673ms step_avg:84.16ms -step:9/20000 train_loss:6.4285 train_time:759ms step_avg:84.31ms -step:10/20000 train_loss:6.0222 train_time:843ms step_avg:84.32ms -step:500/20000 train_loss:2.3854 train_time:42993ms step_avg:85.99ms -step:1000/20000 train_loss:2.2586 train_time:86137ms step_avg:86.14ms -step:1500/20000 train_loss:2.2018 train_time:129307ms step_avg:86.20ms -step:2000/20000 train_loss:2.0412 train_time:172573ms step_avg:86.29ms -step:2500/20000 train_loss:2.1523 train_time:215865ms step_avg:86.35ms -step:3000/20000 train_loss:2.1411 train_time:259118ms step_avg:86.37ms -step:3500/20000 train_loss:2.1530 train_time:302408ms step_avg:86.40ms -step:4000/20000 train_loss:1.9448 train_time:345735ms step_avg:86.43ms -step:4000/20000 val_loss:2.0348 val_bpb:1.2051 train_time:345792ms step_avg:86.45ms -step:4500/20000 train_loss:2.0954 train_time:389043ms step_avg:86.45ms -step:5000/20000 train_loss:2.0762 train_time:432310ms step_avg:86.46ms -step:5500/20000 train_loss:1.9973 train_time:475604ms step_avg:86.47ms -step:6000/20000 train_loss:1.9187 train_time:518888ms step_avg:86.48ms -swa:start step:6150 -late_qat:enabled step:6333 scale:0.1498 -step:6500/20000 train_loss:2.0628 train_time:562798ms step_avg:86.58ms -step:6922/20000 val_loss:1.9162 val_bpb:1.1349 train_time:600058ms step_avg:86.69ms -stopping_early: wallclock_cap train_time:600058ms step:6922/20000 -peak memory allocated: 22847 MiB reserved: 22894 MiB -ema:applying EMA weights -DIAGNOSTIC post_ema val_loss:1.9146 val_bpb:1.1340 eval_time:2062ms -Serialized model: 106289590 bytes -Code size: 101850 bytes -gptq:building non-banked model for Hessian collection... -gptq:generating autoregressive calibration data (64 seqs x 2048 tokens, temp=0.8)... -gptq:generated 64 sequences in 198.3s -gptq:collecting hessians from autoregressive data... -gptq:collected hessians for 68 layers (AR self-gen) -selective_prune: 4212332 ±1 candidates, unpruned=15.24MB target=15.9MB -selective_prune: already fits, no pruning needed -Serialized model int6+lzma: 15883000 bytes -Total submission size int6+lzma: 15984850 bytes -final_int6_roundtrip val_loss:1.9216 val_bpb:1.1381 eval_time:7093ms -final_int6_roundtrip_exact val_loss:1.92161667 val_bpb:1.13808963 -final_int6_sliding_window val_loss:1.8816 val_bpb:1.1144 stride:64 eval_time:77178ms -final_int6_sliding_window_exact val_loss:1.88156874 val_bpb:1.11437394 -final_int8_zlib_roundtrip_exact val_loss:1.88156874 val_bpb:1.11437394 diff --git a/records/track_10min_16mb/2026-03-25_ValCalib_GPTQ_XSA_BigramHash3072/train_seed999.log b/records/track_10min_16mb/2026-03-25_ValCalib_GPTQ_XSA_BigramHash3072/train_seed999.log deleted file mode 100644 index f1d62a214c..0000000000 --- a/records/track_10min_16mb/2026-03-25_ValCalib_GPTQ_XSA_BigramHash3072/train_seed999.log +++ /dev/null @@ -1,85 +0,0 @@ -W0326 21:07:17.732000 67802 torch/distributed/run.py:803] -W0326 21:07:17.732000 67802 torch/distributed/run.py:803] ***************************************** -W0326 21:07:17.732000 67802 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. -W0326 21:07:17.732000 67802 torch/distributed/run.py:803] ***************************************** -logs/c39e968f-fc0a-4996-9304-7ef4c2b72dc4.txt -val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model -train_loader:dataset:fineweb10B_sp1024 train_shards:80 -val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 -model_params:27067484 -mtp_num_heads:0 mtp_loss_weight:0.2 mtp_params:0 -XSA:last_11 active_layers:[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] -world_size:8 grad_accum_steps:1 -sdp_backends:cudnn=False flash=True mem_efficient=False math=False -attention_mode:gqa num_heads:8 num_kv_heads:4 -tie_embeddings:True embed_lr:0.035 head_lr:0.0 matrix_lr:0.025 scalar_lr:0.025 -train_batch_tokens:786432 train_seq_len:2048 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 -seed:999 -warmup_step:1/20 -warmup_step:2/20 -warmup_step:3/20 -warmup_step:4/20 -warmup_step:5/20 -warmup_step:6/20 -warmup_step:7/20 -warmup_step:8/20 -warmup_step:9/20 -warmup_step:10/20 -warmup_step:11/20 -warmup_step:12/20 -warmup_step:13/20 -warmup_step:14/20 -warmup_step:15/20 -warmup_step:16/20 -warmup_step:17/20 -warmup_step:18/20 -warmup_step:19/20 -warmup_step:20/20 -step:0/20000 val_loss:6.9301 val_bpb:4.1044 train_time:0ms step_avg:0.01ms -step:1/20000 train_loss:6.9316 train_time:134ms step_avg:133.66ms -step:2/20000 train_loss:8.6443 train_time:165ms step_avg:82.49ms -step:3/20000 train_loss:7.5750 train_time:249ms step_avg:82.90ms -step:4/20000 train_loss:7.3107 train_time:333ms step_avg:83.31ms -step:5/20000 train_loss:7.1701 train_time:418ms step_avg:83.58ms -step:6/20000 train_loss:7.0637 train_time:502ms step_avg:83.69ms -step:7/20000 train_loss:7.0150 train_time:587ms step_avg:83.79ms -step:8/20000 train_loss:6.8799 train_time:672ms step_avg:83.96ms -step:9/20000 train_loss:6.4639 train_time:756ms step_avg:84.01ms -step:10/20000 train_loss:6.0463 train_time:841ms step_avg:84.06ms -step:500/20000 train_loss:2.3979 train_time:42999ms step_avg:86.00ms -step:1000/20000 train_loss:2.2588 train_time:86110ms step_avg:86.11ms -step:1500/20000 train_loss:2.2040 train_time:129306ms step_avg:86.20ms -step:2000/20000 train_loss:2.0465 train_time:172584ms step_avg:86.29ms -step:2500/20000 train_loss:2.1497 train_time:215933ms step_avg:86.37ms -step:3000/20000 train_loss:2.1412 train_time:259292ms step_avg:86.43ms -step:3500/20000 train_loss:2.1508 train_time:302674ms step_avg:86.48ms -step:4000/20000 train_loss:1.9437 train_time:346007ms step_avg:86.50ms -step:4000/20000 val_loss:2.0350 val_bpb:1.2053 train_time:346063ms step_avg:86.52ms -step:4500/20000 train_loss:2.0976 train_time:389355ms step_avg:86.52ms -step:5000/20000 train_loss:2.0791 train_time:432705ms step_avg:86.54ms -step:5500/20000 train_loss:1.9952 train_time:476052ms step_avg:86.55ms -step:6000/20000 train_loss:1.9200 train_time:519377ms step_avg:86.56ms -swa:start step:6150 -late_qat:enabled step:6327 scale:0.1498 -step:6500/20000 train_loss:2.0611 train_time:563277ms step_avg:86.66ms -step:6917/20000 val_loss:1.9169 val_bpb:1.1353 train_time:600137ms step_avg:86.76ms -stopping_early: wallclock_cap train_time:600137ms step:6917/20000 -peak memory allocated: 22847 MiB reserved: 22894 MiB -ema:applying EMA weights -DIAGNOSTIC post_ema val_loss:1.9153 val_bpb:1.1343 eval_time:2056ms -Serialized model: 106289590 bytes -Code size: 101850 bytes -gptq:building non-banked model for Hessian collection... -gptq:generating autoregressive calibration data (64 seqs x 2048 tokens, temp=0.8)... -gptq:generated 64 sequences in 196.5s -gptq:collecting hessians from autoregressive data... -gptq:collected hessians for 68 layers (AR self-gen) -selective_prune: 4198459 ±1 candidates, unpruned=15.14MB target=15.9MB -selective_prune: already fits, no pruning needed -Serialized model int6+lzma: 15774460 bytes -Total submission size int6+lzma: 15876310 bytes -final_int6_roundtrip val_loss:1.9220 val_bpb:1.1383 eval_time:6796ms -final_int6_roundtrip_exact val_loss:1.92204521 val_bpb:1.13834344 -final_int6_sliding_window val_loss:1.8822 val_bpb:1.1148 stride:64 eval_time:77150ms -final_int6_sliding_window_exact val_loss:1.88220393 val_bpb:1.11475014 -final_int8_zlib_roundtrip_exact val_loss:1.88220393 val_bpb:1.11475014 diff --git a/records/track_non_record_16mb/2026-03-18_Quasi10Bfrom50B_SP1024_9x512_KV4_4h_pgut3/README.md b/records/track_non_record_16mb/2026-03-18_Quasi10Bfrom50B_SP1024_9x512_KV4_4h_pgut3/README.md deleted file mode 100644 index c36c90b816..0000000000 --- a/records/track_non_record_16mb/2026-03-18_Quasi10Bfrom50B_SP1024_9x512_KV4_4h_pgut3/README.md +++ /dev/null @@ -1,56 +0,0 @@ -This record captures an unlimited-compute non-record submission built from the current root `train_gpt.py`. - -This run is not intended to satisfy the 10-minute cutoff for the main leaderboard. It uses the same 9x512 SP-1024 tied-embedding baseline layout, but extends training to a 4-hour wallclock cap on `pgut3` while evaluating against the full 50k-document validation split every 20k steps. - -Configuration: -- Track: `non-record`, unlimited compute, still under the `16,000,000` byte artifact cap -- Layout: `VOCAB_SIZE=1024 NUM_LAYERS=9 MODEL_DIM=512 NUM_HEADS=8 NUM_KV_HEADS=4 MLP_MULT=2` -- Tied output/input embeddings: `TIE_EMBEDDINGS=1` -- Tied embedding LR: `TIED_EMBED_LR=0.05` -- Batching: `TRAIN_BATCH_TOKENS=524288 TRAIN_SEQ_LEN=1024` -- Validation cadence: `VAL_LOSS_EVERY=20000` on the full `fineweb_val_*` split - -Command (track-relevant params): -```bash -OMP_NUM_THREADS=1 \ -TORCH_NCCL_ASYNC_ERROR_HANDLING=1 \ -RUN_ID=train_gpt_pgut3_quasi10b_sp1024_4h_20260318_075102 \ -DATA_PATH=/tmp/fineweb_quasi10Bfrom50B_50keval_sp1024_v0/datasets/fineweb10B_sp1024 \ -TOKENIZER_PATH=/tmp/fineweb_quasi10Bfrom50B_50keval_sp1024_v0/tokenizers/fineweb_1024_bpe.model \ -VOCAB_SIZE=1024 \ -NUM_LAYERS=9 \ -MODEL_DIM=512 \ -NUM_HEADS=8 \ -NUM_KV_HEADS=4 \ -MLP_MULT=2 \ -TIE_EMBEDDINGS=1 \ -TIED_EMBED_LR=0.05 \ -ITERATIONS=500000 \ -WARMUP_STEPS=20 \ -MAX_WALLCLOCK_SECONDS=14400 \ -TRAIN_BATCH_TOKENS=524288 \ -TRAIN_SEQ_LEN=1024 \ -TRAIN_LOG_EVERY=200 \ -VAL_LOSS_EVERY=20000 \ -torchrun --standalone --nproc_per_node=8 /root/code/parameter-golf/train_gpt.py -``` - -Key metrics (from `train.log`): -- Timed training stopped at `329430/500000` steps due to the 4-hour wallclock cap. -- Best pre-quant eval at stop: `val_loss:1.9837`, `val_bpb:1.1749` -- Post-quant roundtrip eval: `val_loss:2.0386`, `val_bpb:1.2074` -- Exact printed metric: `final_int8_zlib_roundtrip_exact val_bpb:1.20737944` -- Train time: `14400039ms` (`step_avg:43.71ms`) -- Peak memory: `10184 MiB allocated`, `10588 MiB reserved` -- Serialized model int8+zlib: `15762519 bytes` -- Code size: `47642 bytes` -- Total submission size int8+zlib: `15810161 bytes` - -Training volume: -- Global batch: `524288` tokens/step -- Total train tokens seen: `172716195840` - -Included files: -- `train_gpt.py` (code snapshot used for the run) -- `train.log` (exact remote training log) -- `submission.json` (leaderboard metadata) diff --git a/records/track_non_record_16mb/2026-03-18_Quasi10Bfrom50B_SP1024_9x512_KV4_4h_pgut3/submission.json b/records/track_non_record_16mb/2026-03-18_Quasi10Bfrom50B_SP1024_9x512_KV4_4h_pgut3/submission.json deleted file mode 100644 index 4192982daa..0000000000 --- a/records/track_non_record_16mb/2026-03-18_Quasi10Bfrom50B_SP1024_9x512_KV4_4h_pgut3/submission.json +++ /dev/null @@ -1,17 +0,0 @@ -{ - "author": "Will DePue", - "github_id": "williamd", - "name": "4-Hour Quasi-10B SP1024", - "blurb": "Unlimited compute track: SP-1024 9x512 KV4 run on pgut3 for 4 hours against the quasi10Bfrom50B 50k-eval export; pre-quant reached 1.1749 BPB at wallclock stop and final int8+zlib roundtrip scored 1.2074 under the 16,000,000-byte cap.", - "date": "2026-03-18T11:53:00Z", - "track": "non-record-unlimited-compute-16mb", - "val_loss": 2.03860961, - "val_bpb": 1.20737944, - "pre_quant_val_loss": 1.9837, - "pre_quant_val_bpb": 1.1749, - "step_stop": 329430, - "wallclock_seconds": 14400.039, - "bytes_total": 15810161, - "bytes_model_int8_zlib": 15762519, - "bytes_code": 47642 -} diff --git a/records/track_non_record_16mb/2026-03-18_Quasi10Bfrom50B_SP1024_9x512_KV4_4h_pgut3/train.log b/records/track_non_record_16mb/2026-03-18_Quasi10Bfrom50B_SP1024_9x512_KV4_4h_pgut3/train.log deleted file mode 100644 index fbc482360b..0000000000 --- a/records/track_non_record_16mb/2026-03-18_Quasi10Bfrom50B_SP1024_9x512_KV4_4h_pgut3/train.log +++ /dev/null @@ -1,2901 +0,0 @@ -""" -The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. - -Hard stop: `train_gpt.py` and `train_gpt_mlx.py` must never be longer than 1500 lines. -""" - -from __future__ import annotations - -import copy -import glob -import io -import math -import os -import random -import subprocess -import sys -import time -import uuid -import zlib -from pathlib import Path - -import numpy as np -import sentencepiece as spm -import torch -import torch.distributed as dist -import torch.nn.functional as F -from torch import Tensor, nn -from torch.nn.parallel import DistributedDataParallel as DDP - -# ----------------------------- -# HYPERPARAMETERS -# ----------------------------- -# Default Simple Baseline run: -# - 9 transformer blocks at width 512 -# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion -# - vocab size 1024, sequence length 1024, tied embeddings -# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap - -class Hyperparameters: - # Data paths are shard globs produced by the existing preprocessing pipeline. - data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") - train_files = os.path.join(data_path, "fineweb_train_*.bin") - val_files = os.path.join(data_path, "fineweb_val_*.bin") - tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") - run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) - seed = int(os.environ.get("SEED", 1337)) - - # Validation cadence and batch size. Validation always uses the full fineweb_val split. - val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) - val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) - train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) - - # Training length. - iterations = int(os.environ.get("ITERATIONS", 20000)) - warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) - warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) - train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) - train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) - max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) - qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) - - # Model shape. - vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) - num_layers = int(os.environ.get("NUM_LAYERS", 9)) - num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) - model_dim = int(os.environ.get("MODEL_DIM", 512)) - num_heads = int(os.environ.get("NUM_HEADS", 8)) - mlp_mult = int(os.environ.get("MLP_MULT", 2)) - tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) - rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) - logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) - - # Optimizer hyperparameters. - embed_lr = float(os.environ.get("EMBED_LR", 0.6)) - head_lr = float(os.environ.get("HEAD_LR", 0.008)) - tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) - tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) - matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) - scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) - muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) - muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) - muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) - muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) - beta1 = float(os.environ.get("BETA1", 0.9)) - beta2 = float(os.environ.get("BETA2", 0.95)) - adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) - grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) - -# ----------------------------- -# MUON OPTIMIZER -# ----------------------------- -# -# As borrowed from modded-nanogpt -# Background on Muon: https://kellerjordan.github.io/posts/muon/ - -def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: - # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. - # Muon uses this to normalize matrix-shaped gradients before applying them. - a, b, c = (3.4445, -4.7750, 2.0315) - X = G.bfloat16() - X /= X.norm() + eps - transposed = G.size(0) > G.size(1) - if transposed: - X = X.T - for _ in range(steps): - A = X @ X.T - B = b * A + c * A @ A - X = a * X + B @ X - return X.T if transposed else X - - -class Muon(torch.optim.Optimizer): - def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): - super().__init__( - params, - dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), - ) - - @torch.no_grad() - def step(self, closure=None): - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - distributed = dist.is_available() and dist.is_initialized() - world_size = dist.get_world_size() if distributed else 1 - rank = dist.get_rank() if distributed else 0 - - for group in self.param_groups: - params = group["params"] - if not params: - continue - lr = group["lr"] - momentum = group["momentum"] - backend_steps = group["backend_steps"] - nesterov = group["nesterov"] - - total_params = sum(int(p.numel()) for p in params) - updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) - - curr = 0 - for i, p in enumerate(params): - if i % world_size == rank and p.grad is not None: - g = p.grad - state = self.state[p] - if "momentum_buffer" not in state: - state["momentum_buffer"] = torch.zeros_like(g) - buf = state["momentum_buffer"] - buf.mul_(momentum).add_(g) - if nesterov: - g = g.add(buf, alpha=momentum) - g = zeropower_via_newtonschulz5(g, steps=backend_steps) - # Scale correction from Muon reference implementations. - g *= max(1, g.size(0) / g.size(1)) ** 0.5 - updates_flat[curr : curr + p.numel()] = g.reshape(-1) - curr += p.numel() - - if distributed: - dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) - - curr = 0 - for p in params: - g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) - p.add_(g, alpha=-lr) - curr += p.numel() - - return loss - - -# ----------------------------- -# TOKENIZER-AGNOSTIC EVALUATION SETUP -# ----------------------------- -# -# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. -# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. -# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. -# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. - -def build_sentencepiece_luts( - sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device -) -> tuple[Tensor, Tensor, Tensor]: - sp_vocab_size = int(sp.vocab_size()) - table_size = max(sp_vocab_size, vocab_size) - base_bytes_np = np.zeros((table_size,), dtype=np.int16) - has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) - is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) - for token_id in range(sp_vocab_size): - if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): - continue - is_boundary_token_np[token_id] = False - if sp.is_byte(token_id): - base_bytes_np[token_id] = 1 - continue - piece = sp.id_to_piece(token_id) - if piece.startswith("▁"): - has_leading_space_np[token_id] = True - piece = piece[1:] - base_bytes_np[token_id] = len(piece.encode("utf-8")) - return ( - torch.tensor(base_bytes_np, dtype=torch.int16, device=device), - torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), - torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), - ) - - -def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: - files = [Path(p) for p in sorted(glob.glob(pattern))] - if not files: - raise FileNotFoundError(f"No files found for pattern: {pattern}") - # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. - tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() - usable = ((tokens.numel() - 1) // seq_len) * seq_len - if usable <= 0: - raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") - return tokens[: usable + 1] - - -def eval_val( - args: Hyperparameters, - model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - grad_accum_steps: int, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, -) -> tuple[float, float]: - # Validation computes two metrics: - # - val_loss: token cross-entropy (natural log) - # - val_bpb: tokenizer-agnostic compression metric used by the challenge - local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) - if local_batch_tokens < args.train_seq_len: - raise ValueError( - "VAL_BATCH_SIZE must provide at least one sequence per rank; " - f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " - f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" - ) - local_batch_seqs = local_batch_tokens // args.train_seq_len - total_seqs = (val_tokens.numel() - 1) // args.train_seq_len - seq_start = (total_seqs * rank) // world_size - seq_end = (total_seqs * (rank + 1)) // world_size - val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) - val_token_count = torch.zeros((), device=device, dtype=torch.float64) - val_byte_count = torch.zeros((), device=device, dtype=torch.float64) - - model.eval() - with torch.inference_mode(): - for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): - batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) - raw_start = batch_seq_start * args.train_seq_len - raw_end = batch_seq_end * args.train_seq_len + 1 - local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) - x = local[:-1].reshape(-1, args.train_seq_len) - y = local[1:].reshape(-1, args.train_seq_len) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - batch_loss = model(x, y).detach() - batch_token_count = float(y.numel()) - val_loss_sum += batch_loss.to(torch.float64) * batch_token_count - val_token_count += batch_token_count - prev_ids = x.reshape(-1) - tgt_ids = y.reshape(-1) - token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) - token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) - val_byte_count += token_bytes.to(torch.float64).sum() - - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) - - val_loss = val_loss_sum / val_token_count - bits_per_token = val_loss.item() / math.log(2.0) - tokens_per_byte = val_token_count.item() / val_byte_count.item() - model.train() - return float(val_loss.item()), float(bits_per_token * tokens_per_byte) - -# ----------------------------- -# POST-TRAINING QUANTIZATION -# ----------------------------- -# -# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. -# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. -# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. - -CONTROL_TENSOR_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "CONTROL_TENSOR_NAME_PATTERNS", - "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", - ",".join(CONTROL_TENSOR_NAME_PATTERNS), - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 -INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 -INT8_PER_ROW_SCALE_DTYPE = torch.float16 -INT8_CLIP_PERCENTILE = 99.99984 -INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 - -def tensor_nbytes(t: Tensor) -> int: - return int(t.numel()) * int(t.element_size()) - -def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: - if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): - return t.float().contiguous() - if t.dtype in {torch.float32, torch.bfloat16}: - passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") - return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() - return t - -def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: - t32 = t.float() - if t32.ndim == 2: - # Matrices get one scale per row, which usually tracks output-channel - # ranges much better than a single tensor-wide scale. - clip_abs = ( - torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) - if t32.numel() - else torch.empty((t32.shape[0],), dtype=torch.float32) - ) - clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) - scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) - q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() - return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() - - # Vectors / scalars use a simpler per-tensor scale. - clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 - scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) - q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() - return q, scale - -def quantize_state_dict_int8(state_dict: dict[str, Tensor]): - # Single supported clean-script export format: - # - per-row int8 for 2D float tensors - # - per-tensor int8 for other float tensors - # - exact passthrough for non-floats - # - passthrough for small float tensors, stored as fp16 to save bytes - quantized: dict[str, Tensor] = {} - scales: dict[str, Tensor] = {} - dtypes: dict[str, str] = {} - passthrough: dict[str, Tensor] = {} - passthrough_orig_dtypes: dict[str, str] = {} - qmeta: dict[str, dict[str, object]] = {} - stats = dict.fromkeys( - ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), - 0, - ) - - for name, tensor in state_dict.items(): - t = tensor.detach().to("cpu").contiguous() - stats["param_count"] += int(t.numel()) - stats["num_tensors"] += 1 - stats["baseline_tensor_bytes"] += tensor_nbytes(t) - - if not t.is_floating_point(): - stats["num_nonfloat_tensors"] += 1 - passthrough[name] = t - stats["int8_payload_bytes"] += tensor_nbytes(t) - continue - - # Small float tensors are cheap enough to keep directly. We still downcast - # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. - if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: - kept = keep_float_tensor(name, t, passthrough_orig_dtypes) - passthrough[name] = kept - stats["int8_payload_bytes"] += tensor_nbytes(kept) - continue - - stats["num_float_tensors"] += 1 - q, s = quantize_float_tensor(t) - if s.ndim > 0: - qmeta[name] = {"scheme": "per_row", "axis": 0} - quantized[name] = q - scales[name] = s - dtypes[name] = str(t.dtype).removeprefix("torch.") - stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) - - obj: dict[str, object] = { - "__quant_format__": "int8_clean_per_row_v1", - "quantized": quantized, - "scales": scales, - "dtypes": dtypes, - "passthrough": passthrough, - } - if qmeta: - obj["qmeta"] = qmeta - if passthrough_orig_dtypes: - obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes - return obj, stats - -def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: - out: dict[str, Tensor] = {} - qmeta = obj.get("qmeta", {}) - passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) - for name, q in obj["quantized"].items(): - dtype = getattr(torch, obj["dtypes"][name]) - s = obj["scales"][name] - if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: - s = s.to(dtype=torch.float32) - # Broadcast the saved row scale back across trailing dimensions. - out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() - else: - scale = float(s.item()) - out[name] = (q.float() * scale).to(dtype=dtype).contiguous() - for name, t in obj["passthrough"].items(): - # Restore small tensors, undoing the temporary fp16 storage cast if needed. - out_t = t.detach().to("cpu").contiguous() - orig_dtype = passthrough_orig_dtypes.get(name) - if isinstance(orig_dtype, str): - out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() - out[name] = out_t - return out - - -# ----------------------------- -# DATA LOADING -# ----------------------------- - -def load_data_shard(file: Path) -> Tensor: - header_bytes = 256 * np.dtype(" None: - self.file_idx = (self.file_idx + 1) % len(self.files) - self.tokens = load_data_shard(self.files[self.file_idx]) - self.pos = 0 - - def take(self, n: int) -> Tensor: - chunks: list[Tensor] = [] - remaining = n - while remaining > 0: - avail = self.tokens.numel() - self.pos - if avail <= 0: - self._advance_file() - continue - k = min(remaining, avail) - chunks.append(self.tokens[self.pos : self.pos + k]) - self.pos += k - remaining -= k - return chunks[0] if len(chunks) == 1 else torch.cat(chunks) - - -class DistributedTokenLoader: - # Each call consumes a contiguous chunk from the shared token stream, then slices out - # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. - def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): - self.rank = rank - self.world_size = world_size - self.device = device - self.stream = TokenStream(pattern) - - def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: - local_tokens = global_tokens // (self.world_size * grad_accum_steps) - per_rank_span = local_tokens + 1 - chunk = self.stream.take(per_rank_span * self.world_size) - start = self.rank * per_rank_span - local = chunk[start : start + per_rank_span].to(dtype=torch.int64) - x = local[:-1].reshape(-1, seq_len) - y = local[1:].reshape(-1, seq_len) - return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) - -# ----------------------------- -# TRANSFORMER MODULES -# ----------------------------- - -class RMSNorm(nn.Module): - def __init__(self, eps: float | None = None): - super().__init__() - self.eps = eps - - def forward(self, x: Tensor) -> Tensor: - return F.rms_norm(x, (x.size(-1),), eps=self.eps) - - -class CastedLinear(nn.Linear): - # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. - def forward(self, x: Tensor) -> Tensor: - bias = self.bias.to(x.dtype) if self.bias is not None else None - return F.linear(x, self.weight.to(x.dtype), bias) - - -def restore_low_dim_params_to_fp32(module: nn.Module) -> None: - # Keep small/control parameters in fp32 even when the model body runs in bf16. - with torch.no_grad(): - for name, param in module.named_parameters(): - if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: - param.data = param.data.float() - - -class Rotary(nn.Module): - # Caches cos/sin tables per sequence length on the current device. - def __init__(self, dim: int, base: float = 10000.0): - super().__init__() - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self._seq_len_cached = 0 - self._cos_cached: Tensor | None = None - self._sin_cached: Tensor | None = None - - def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: - if ( - self._cos_cached is None - or self._sin_cached is None - or self._seq_len_cached != seq_len - or self._cos_cached.device != device - ): - t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) - freqs = torch.outer(t, self.inv_freq.to(device)) - self._cos_cached = freqs.cos()[None, None, :, :] - self._sin_cached = freqs.sin()[None, None, :, :] - self._seq_len_cached = seq_len - return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) - - -def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: - half = x.size(-1) // 2 - x1, x2 = x[..., :half], x[..., half:] - return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) - - -class CausalSelfAttention(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - rope_base: float, - qk_gain_init: float, - ): - super().__init__() - if dim % num_heads != 0: - raise ValueError("model_dim must be divisible by num_heads") - if num_heads % num_kv_heads != 0: - raise ValueError("num_heads must be divisible by num_kv_heads") - self.num_heads = num_heads - self.num_kv_heads = num_kv_heads - self.head_dim = dim // num_heads - if self.head_dim % 2 != 0: - raise ValueError("head_dim must be even for RoPE") - kv_dim = self.num_kv_heads * self.head_dim - self.c_q = CastedLinear(dim, dim, bias=False) - self.c_k = CastedLinear(dim, kv_dim, bias=False) - self.c_v = CastedLinear(dim, kv_dim, bias=False) - self.proj = CastedLinear(dim, dim, bias=False) - self.proj._zero_init = True - self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) - self.rotary = Rotary(self.head_dim, base=rope_base) - - def forward(self, x: Tensor) -> Tensor: - bsz, seqlen, dim = x.shape - q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) - k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) - v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) - q = F.rms_norm(q, (q.size(-1),)) - k = F.rms_norm(k, (k.size(-1),)) - cos, sin = self.rotary(seqlen, x.device, q.dtype) - q = apply_rotary_emb(q, cos, sin) - k = apply_rotary_emb(k, cos, sin) - q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] - y = F.scaled_dot_product_attention( - q, - k, - v, - attn_mask=None, - is_causal=True, - enable_gqa=(self.num_kv_heads != self.num_heads), - ) - y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) - return self.proj(y) - - -class MLP(nn.Module): - # relu^2 MLP from the original modded-nanogpt setup - def __init__(self, dim: int, mlp_mult: int): - super().__init__() - hidden = mlp_mult * dim - self.fc = CastedLinear(dim, hidden, bias=False) - self.proj = CastedLinear(hidden, dim, bias=False) - self.proj._zero_init = True - - def forward(self, x: Tensor) -> Tensor: - x = torch.relu(self.fc(x)) - return self.proj(x.square()) - - -class Block(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: int, - rope_base: float, - qk_gain_init: float, - ): - super().__init__() - self.attn_norm = RMSNorm() - self.mlp_norm = RMSNorm() - self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) - self.mlp = MLP(dim, mlp_mult) - self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) - - def forward(self, x: Tensor, x0: Tensor) -> Tensor: - mix = self.resid_mix.to(dtype=x.dtype) - x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 - attn_out = self.attn(self.attn_norm(x)) - x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out - x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) - return x - - -class GPT(nn.Module): - def __init__( - self, - vocab_size: int, - num_layers: int, - model_dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: int, - tie_embeddings: bool, - tied_embed_init_std: float, - logit_softcap: float, - rope_base: float, - qk_gain_init: float, - ): - super().__init__() - if logit_softcap <= 0.0: - raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") - self.tie_embeddings = tie_embeddings - self.tied_embed_init_std = tied_embed_init_std - self.logit_softcap = logit_softcap - self.tok_emb = nn.Embedding(vocab_size, model_dim) - self.num_encoder_layers = num_layers // 2 - self.num_decoder_layers = num_layers - self.num_encoder_layers - self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) - self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) - self.blocks = nn.ModuleList( - [ - Block( - model_dim, - num_heads, - num_kv_heads, - mlp_mult, - rope_base, - qk_gain_init, - ) - for i in range(num_layers) - ] - ) - self.final_norm = RMSNorm() - self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) - if self.lm_head is not None: - self.lm_head._zero_init = True - self._init_weights() - - def _init_weights(self) -> None: - if self.tie_embeddings: - nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) - for module in self.modules(): - if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): - nn.init.zeros_(module.weight) - - def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: - x = self.tok_emb(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - x0 = x - skips: list[Tensor] = [] - - # First half stores skips; second half reuses them in reverse order. - for i in range(self.num_encoder_layers): - x = self.blocks[i](x, x0) - skips.append(x) - for i in range(self.num_decoder_layers): - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - x = self.blocks[self.num_encoder_layers + i](x, x0) - - x = self.final_norm(x).reshape(-1, x.size(-1)) - targets = target_ids.reshape(-1) - if self.tie_embeddings: - logits_proj = F.linear(x, self.tok_emb.weight) - else: - if self.lm_head is None: - raise RuntimeError("lm_head is required when tie_embeddings=False") - logits_proj = self.lm_head(x) - logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - return F.cross_entropy(logits.float(), targets, reduction="mean") - - -# ----------------------------- -# TRAINING -# ----------------------------- - -def main() -> None: - global zeropower_via_newtonschulz5 - - code = Path(__file__).read_text(encoding="utf-8") - args = Hyperparameters() - zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) - - # ----------------------------- - # DISTRIBUTED + CUDA SETUP - # ----------------------------- - - distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ - rank = int(os.environ.get("RANK", "0")) - world_size = int(os.environ.get("WORLD_SIZE", "1")) - local_rank = int(os.environ.get("LOCAL_RANK", "0")) - if world_size <= 0: - raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") - if 8 % world_size != 0: - raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") - grad_accum_steps = 8 // world_size - grad_scale = 1.0 / grad_accum_steps - if not torch.cuda.is_available(): - raise RuntimeError("CUDA is required") - device = torch.device("cuda", local_rank) - torch.cuda.set_device(device) - if distributed: - dist.init_process_group(backend="nccl", device_id=device) - dist.barrier() - master_process = rank == 0 - - # Fast math knobs - torch.backends.cuda.matmul.allow_tf32 = True - torch.backends.cudnn.allow_tf32 = True - from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp - - enable_cudnn_sdp(False) - enable_flash_sdp(True) - enable_mem_efficient_sdp(False) - enable_math_sdp(False) - - logfile = None - if master_process: - os.makedirs("logs", exist_ok=True) - logfile = f"logs/{args.run_id}.txt" - print(logfile) - - def log0(msg: str, console: bool = True) -> None: - if not master_process: - return - if console: - print(msg) - if logfile is not None: - with open(logfile, "a", encoding="utf-8") as f: - print(msg, file=f) - - log0(code, console=False) - log0("=" * 100, console=False) - log0(f"Running Python {sys.version}", console=False) - log0(f"Running PyTorch {torch.__version__}", console=False) - log0( - subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, - console=False, - ) - log0("=" * 100, console=False) - - # ----------------------------- - # TOKENIZER + VALIDATION METRIC SETUP - # ----------------------------- - - random.seed(args.seed) - np.random.seed(args.seed) - torch.manual_seed(args.seed) - torch.cuda.manual_seed_all(args.seed) - - if not args.tokenizer_path.endswith(".model"): - raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") - sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) - if int(sp.vocab_size()) != args.vocab_size: - raise ValueError( - f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" - ) - dataset_dir = Path(args.data_path).resolve() - actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) - val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) - base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( - sp, args.vocab_size, device - ) - log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") - log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") - log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") - - # ----------------------------- - # MODEL + OPTIMIZER SETUP - # ----------------------------- - - base_model = GPT( - vocab_size=args.vocab_size, - num_layers=args.num_layers, - model_dim=args.model_dim, - num_heads=args.num_heads, - num_kv_heads=args.num_kv_heads, - mlp_mult=args.mlp_mult, - tie_embeddings=args.tie_embeddings, - tied_embed_init_std=args.tied_embed_init_std, - logit_softcap=args.logit_softcap, - rope_base=args.rope_base, - qk_gain_init=args.qk_gain_init, - ).to(device).bfloat16() - for module in base_model.modules(): - if isinstance(module, CastedLinear): - module.float() - restore_low_dim_params_to_fp32(base_model) - compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) - model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model - - # Optimizer split: - # - token embedding (Adam) uses EMBED_LR - # - untied lm_head (Adam) uses HEAD_LR - # - matrix params in transformer blocks use MATRIX_LR via Muon - # - vectors/scalars use SCALAR_LR via Adam - block_named_params = list(base_model.blocks.named_parameters()) - matrix_params = [ - p - for name, p in block_named_params - if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - scalar_params = [ - p - for name, p in block_named_params - if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - if base_model.skip_weights.numel() > 0: - scalar_params.append(base_model.skip_weights) - token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr - optimizer_tok = torch.optim.Adam( - [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizer_muon = Muon( - matrix_params, - lr=args.matrix_lr, - momentum=args.muon_momentum, - backend_steps=args.muon_backend_steps, - ) - for group in optimizer_muon.param_groups: - group["base_lr"] = args.matrix_lr - optimizer_scalar = torch.optim.Adam( - [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] - if base_model.lm_head is not None: - optimizer_head = torch.optim.Adam( - [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizers.insert(1, optimizer_head) - - n_params = sum(p.numel() for p in base_model.parameters()) - log0(f"model_params:{n_params}") - log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") - log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") - log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") - log0( - f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " - f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " - f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" - ) - log0( - f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " - f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " - f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" - ) - log0(f"seed:{args.seed}") - - # ----------------------------- - # DATA LOADER & MODEL WARMUP - # ----------------------------- - - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - def zero_grad_all() -> None: - for opt in optimizers: - opt.zero_grad(set_to_none=True) - - max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None - - def lr_mul(step: int, elapsed_ms: float) -> float: - if args.warmdown_iters <= 0: - return 1.0 - if max_wallclock_ms is None: - warmdown_start = max(args.iterations - args.warmdown_iters, 0) - return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 - step_ms = elapsed_ms / max(step, 1) - warmdown_ms = args.warmdown_iters * step_ms - remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) - return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 - - # Warmup primes the compiled forward/backward/optimizer paths, then we restore the - # initial weights/optimizer state so measured training starts from the true init. - if args.warmup_steps > 0: - initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} - initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] - model.train() - for warmup_step in range(args.warmup_steps): - zero_grad_all() - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - warmup_loss = model(x, y) - (warmup_loss * grad_scale).backward() - for opt in optimizers: - opt.step() - zero_grad_all() - if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: - log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") - base_model.load_state_dict(initial_model_state, strict=True) - for opt, state in zip(optimizers, initial_optimizer_states, strict=True): - opt.load_state_dict(state) - zero_grad_all() - if distributed: - model.require_backward_grad_sync = True - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - # ----------------------------- - # MAIN TRAINING LOOP - # ----------------------------- - - training_time_ms = 0.0 - stop_after_step: int | None = None - torch.cuda.synchronize() - t0 = time.perf_counter() - - step = 0 - while True: - last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) - - should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) - if should_validate: - torch.cuda.synchronize() - training_time_ms += 1000.0 * (time.perf_counter() - t0) - val_loss, val_bpb = eval_val( - args, - model, - rank, - world_size, - device, - grad_accum_steps, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, - ) - log0( - f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " - f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" - ) - torch.cuda.synchronize() - t0 = time.perf_counter() - - if last_step: - if stop_after_step is not None and step < args.iterations: - log0( - f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " - f"step:{step}/{args.iterations}" - ) - break - - elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - scale = lr_mul(step, elapsed_ms) - zero_grad_all() - train_loss = torch.zeros((), device=device) - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - loss = model(x, y) - train_loss += loss.detach() - (loss * grad_scale).backward() - train_loss /= grad_accum_steps - - frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 - muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum - for group in optimizer_muon.param_groups: - group["momentum"] = muon_momentum - - for opt in optimizers: - for group in opt.param_groups: - group["lr"] = group["base_lr"] * scale - - if args.grad_clip_norm > 0: - torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) - for opt in optimizers: - opt.step() - zero_grad_all() - - step += 1 - approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - should_log_train = ( - args.train_log_every > 0 - and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) - ) - if should_log_train: - log0( - f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " - f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" - ) - - # Needed to sync whether we've reached the wallclock cap. - reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms - if distributed and max_wallclock_ms is not None: - reached_cap_tensor = torch.tensor(int(reached_cap), device=device) - dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) - reached_cap = bool(reached_cap_tensor.item()) - if stop_after_step is None and reached_cap: - stop_after_step = step - - log0( - f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " - f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" - ) - - # ----------------------------- - # SERIALIZATION + ROUNDTRIP VALIDATION - # ----------------------------- - # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce - # the compressed int8+zlib artifact and validate the round-tripped weights. - - if master_process: - torch.save(base_model.state_dict(), "final_model.pt") - model_bytes = os.path.getsize("final_model.pt") - code_bytes = len(code.encode("utf-8")) - log0(f"Serialized model: {model_bytes} bytes") - log0(f"Code size: {code_bytes} bytes") - log0(f"Total submission size: {model_bytes + code_bytes} bytes") - - quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) - quant_buf = io.BytesIO() - torch.save(quant_obj, quant_buf) - quant_raw = quant_buf.getvalue() - quant_blob = zlib.compress(quant_raw, level=9) - quant_raw_bytes = len(quant_raw) - if master_process: - with open("final_model.int8.ptz", "wb") as f: - f.write(quant_blob) - quant_file_bytes = os.path.getsize("final_model.int8.ptz") - code_bytes = len(code.encode("utf-8")) - ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) - log0( - f"Serialized model int8+zlib: {quant_file_bytes} bytes " - f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" - ) - log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") - - if distributed: - dist.barrier() - with open("final_model.int8.ptz", "rb") as f: - quant_blob_disk = f.read() - quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") - base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) - torch.cuda.synchronize() - t_qeval = time.perf_counter() - q_val_loss, q_val_bpb = eval_val( - args, - model, - rank, - world_size, - device, - grad_accum_steps, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, - ) - torch.cuda.synchronize() - log0( - f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " - f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" - ) - log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") - - if distributed: - dist.destroy_process_group() - - -if __name__ == "__main__": - main() - -==================================================================================================== -Running Python 3.12.9 (main, Jan 22 2026, 18:37:37) [GCC 13.3.0] -Running PyTorch 2.10.0+cu128 -Wed Mar 18 07:51:20 2026 -+-----------------------------------------------------------------------------------------+ -| NVIDIA-SMI 565.57.01 Driver Version: 565.57.01 CUDA Version: 13.0 | -|-----------------------------------------+------------------------+----------------------+ -| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | -| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | -| | | MIG M. | -|=========================================+========================+======================| -| 0 NVIDIA H100 80GB HBM3 On | 00000001:00:00.0 Off | 0 | -| N/A 31C P0 119W / 700W | 1774MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 1 NVIDIA H100 80GB HBM3 On | 00000002:00:00.0 Off | 0 | -| N/A 31C P0 120W / 700W | 1822MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 2 NVIDIA H100 80GB HBM3 On | 00000003:00:00.0 Off | 0 | -| N/A 30C P0 117W / 700W | 1822MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 3 NVIDIA H100 80GB HBM3 On | 00000008:00:00.0 Off | 0 | -| N/A 30C P0 119W / 700W | 1822MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 4 NVIDIA H100 80GB HBM3 On | 00000009:00:00.0 Off | 0 | -| N/A 29C P0 120W / 700W | 1822MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 5 NVIDIA H100 80GB HBM3 On | 0000000A:00:00.0 Off | 0 | -| N/A 31C P0 119W / 700W | 1822MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 6 NVIDIA H100 80GB HBM3 On | 0000000B:00:00.0 Off | 0 | -| N/A 30C P0 116W / 700W | 1822MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 7 NVIDIA H100 80GB HBM3 On | 0000000C:00:00.0 Off | 0 | -| N/A 30C P0 118W / 700W | 1582MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ - -+-----------------------------------------------------------------------------------------+ -| Processes: | -| GPU GI CI PID Type Process name GPU Memory | -| ID ID Usage | -|=========================================================================================| -| 0 N/A N/A 987032 C ...t/.pyenv/versions/3.12.9/bin/python 1764MiB | -| 1 N/A N/A 987033 C ...t/.pyenv/versions/3.12.9/bin/python 1812MiB | -| 2 N/A N/A 987034 C ...t/.pyenv/versions/3.12.9/bin/python 1812MiB | -| 3 N/A N/A 987035 C ...t/.pyenv/versions/3.12.9/bin/python 1812MiB | -| 4 N/A N/A 987036 C ...t/.pyenv/versions/3.12.9/bin/python 1812MiB | -| 5 N/A N/A 987037 C ...t/.pyenv/versions/3.12.9/bin/python 1812MiB | -| 6 N/A N/A 987038 C ...t/.pyenv/versions/3.12.9/bin/python 1812MiB | -| 7 N/A N/A 987039 C ...t/.pyenv/versions/3.12.9/bin/python 1572MiB | -+-----------------------------------------------------------------------------------------+ - -==================================================================================================== -val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=/tmp/fineweb_quasi10Bfrom50B_50keval_sp1024_v0/tokenizers/fineweb_1024_bpe.model -train_loader:dataset:fineweb10B_sp1024 train_shards:195 -val_loader:shards pattern=/tmp/fineweb_quasi10Bfrom50B_50keval_sp1024_v0/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 -model_params:17059912 -world_size:8 grad_accum_steps:1 -sdp_backends:cudnn=False flash=True mem_efficient=False math=False -attention_mode:gqa num_heads:8 num_kv_heads:4 -tie_embeddings:True embed_lr:0.05 head_lr:0.0 matrix_lr:0.04 scalar_lr:0.04 -train_batch_tokens:524288 train_seq_len:1024 iterations:500000 warmup_steps:20 max_wallclock_seconds:14400.000 -seed:1337 -warmup_step:1/20 -warmup_step:2/20 -warmup_step:3/20 -warmup_step:4/20 -warmup_step:5/20 -warmup_step:6/20 -warmup_step:7/20 -warmup_step:8/20 -warmup_step:9/20 -warmup_step:10/20 -warmup_step:11/20 -warmup_step:12/20 -warmup_step:13/20 -warmup_step:14/20 -warmup_step:15/20 -warmup_step:16/20 -warmup_step:17/20 -warmup_step:18/20 -warmup_step:19/20 -warmup_step:20/20 -step:0/500000 val_loss:6.9357 val_bpb:4.1077 train_time:0ms step_avg:0.02ms -step:1/500000 train_loss:6.9370 train_time:33ms step_avg:32.84ms -step:2/500000 train_loss:16.8366 train_time:77ms step_avg:38.37ms -step:3/500000 train_loss:8.7611 train_time:120ms step_avg:39.91ms -step:4/500000 train_loss:6.6384 train_time:162ms step_avg:40.39ms -step:5/500000 train_loss:6.6118 train_time:206ms step_avg:41.14ms -step:6/500000 train_loss:7.4217 train_time:249ms step_avg:41.44ms -step:7/500000 train_loss:6.3501 train_time:292ms step_avg:41.72ms -step:8/500000 train_loss:6.1581 train_time:336ms step_avg:41.97ms -step:9/500000 train_loss:6.0678 train_time:378ms step_avg:42.02ms -step:10/500000 train_loss:5.9746 train_time:422ms step_avg:42.15ms -step:200/500000 train_loss:2.8504 train_time:8738ms step_avg:43.69ms -step:400/500000 train_loss:2.3638 train_time:17484ms step_avg:43.71ms -step:600/500000 train_loss:2.5453 train_time:26216ms step_avg:43.69ms -step:800/500000 train_loss:2.2913 train_time:34960ms step_avg:43.70ms -step:1000/500000 train_loss:2.3707 train_time:43708ms step_avg:43.71ms -step:1200/500000 train_loss:2.3889 train_time:52441ms step_avg:43.70ms -step:1400/500000 train_loss:2.4331 train_time:61180ms step_avg:43.70ms -step:1600/500000 train_loss:2.0960 train_time:69912ms step_avg:43.70ms -step:1800/500000 train_loss:2.1998 train_time:78649ms step_avg:43.69ms -step:2000/500000 train_loss:2.2506 train_time:87392ms step_avg:43.70ms -step:2200/500000 train_loss:2.0768 train_time:96125ms step_avg:43.69ms -step:2400/500000 train_loss:2.2012 train_time:104860ms step_avg:43.69ms -step:2600/500000 train_loss:2.4155 train_time:113597ms step_avg:43.69ms -step:2800/500000 train_loss:2.2353 train_time:122349ms step_avg:43.70ms -step:3000/500000 train_loss:2.2275 train_time:131087ms step_avg:43.70ms -step:3200/500000 train_loss:2.1879 train_time:139825ms step_avg:43.70ms -step:3400/500000 train_loss:2.1611 train_time:148561ms step_avg:43.69ms -step:3600/500000 train_loss:2.1179 train_time:157297ms step_avg:43.69ms -step:3800/500000 train_loss:2.2211 train_time:166018ms step_avg:43.69ms -step:4000/500000 train_loss:2.1627 train_time:174758ms step_avg:43.69ms -step:4200/500000 train_loss:2.1757 train_time:183623ms step_avg:43.72ms -step:4400/500000 train_loss:2.1143 train_time:192357ms step_avg:43.72ms -step:4600/500000 train_loss:1.9730 train_time:201085ms step_avg:43.71ms -step:4800/500000 train_loss:2.2618 train_time:209818ms step_avg:43.71ms -step:5000/500000 train_loss:2.0275 train_time:218548ms step_avg:43.71ms -step:5200/500000 train_loss:2.1723 train_time:227276ms step_avg:43.71ms -step:5400/500000 train_loss:2.1842 train_time:235999ms step_avg:43.70ms -step:5600/500000 train_loss:2.1841 train_time:244724ms step_avg:43.70ms -step:5800/500000 train_loss:2.1435 train_time:253451ms step_avg:43.70ms -step:6000/500000 train_loss:2.2229 train_time:262182ms step_avg:43.70ms -step:6200/500000 train_loss:2.0860 train_time:270912ms step_avg:43.70ms -step:6400/500000 train_loss:2.1598 train_time:279652ms step_avg:43.70ms -step:6600/500000 train_loss:2.1228 train_time:288379ms step_avg:43.69ms -step:6800/500000 train_loss:2.1890 train_time:297103ms step_avg:43.69ms -step:7000/500000 train_loss:2.2279 train_time:305832ms step_avg:43.69ms -step:7200/500000 train_loss:2.1999 train_time:314562ms step_avg:43.69ms -step:7400/500000 train_loss:2.1173 train_time:323296ms step_avg:43.69ms -step:7600/500000 train_loss:1.9995 train_time:332032ms step_avg:43.69ms -step:7800/500000 train_loss:2.1500 train_time:340763ms step_avg:43.69ms -step:8000/500000 train_loss:2.1150 train_time:349492ms step_avg:43.69ms -step:8200/500000 train_loss:2.1884 train_time:358222ms step_avg:43.69ms -step:8400/500000 train_loss:2.1346 train_time:367069ms step_avg:43.70ms -step:8600/500000 train_loss:2.1364 train_time:375806ms step_avg:43.70ms -step:8800/500000 train_loss:2.0998 train_time:384530ms step_avg:43.70ms -step:9000/500000 train_loss:2.0268 train_time:393259ms step_avg:43.70ms -step:9200/500000 train_loss:2.0837 train_time:401988ms step_avg:43.69ms -step:9400/500000 train_loss:2.1327 train_time:410718ms step_avg:43.69ms -step:9600/500000 train_loss:2.1452 train_time:419449ms step_avg:43.69ms -step:9800/500000 train_loss:2.0746 train_time:428188ms step_avg:43.69ms -step:10000/500000 train_loss:2.1114 train_time:436922ms step_avg:43.69ms -step:10200/500000 train_loss:2.0700 train_time:445656ms step_avg:43.69ms -step:10400/500000 train_loss:2.0994 train_time:454390ms step_avg:43.69ms -step:10600/500000 train_loss:1.9771 train_time:463127ms step_avg:43.69ms -step:10800/500000 train_loss:2.1884 train_time:471870ms step_avg:43.69ms -step:11000/500000 train_loss:2.1136 train_time:480607ms step_avg:43.69ms -step:11200/500000 train_loss:2.0714 train_time:489337ms step_avg:43.69ms -step:11400/500000 train_loss:2.0570 train_time:498065ms step_avg:43.69ms -step:11600/500000 train_loss:2.0643 train_time:506798ms step_avg:43.69ms -step:11800/500000 train_loss:2.0966 train_time:515533ms step_avg:43.69ms -step:12000/500000 train_loss:2.0734 train_time:524265ms step_avg:43.69ms -step:12200/500000 train_loss:2.2214 train_time:532995ms step_avg:43.69ms -step:12400/500000 train_loss:1.8632 train_time:541847ms step_avg:43.70ms -step:12600/500000 train_loss:2.0944 train_time:550591ms step_avg:43.70ms -step:12800/500000 train_loss:2.1093 train_time:559322ms step_avg:43.70ms -step:13000/500000 train_loss:2.1906 train_time:568053ms step_avg:43.70ms -step:13200/500000 train_loss:2.2034 train_time:576786ms step_avg:43.70ms -step:13400/500000 train_loss:2.0818 train_time:585518ms step_avg:43.70ms -step:13600/500000 train_loss:1.9545 train_time:594253ms step_avg:43.70ms -step:13800/500000 train_loss:2.0350 train_time:602978ms step_avg:43.69ms -step:14000/500000 train_loss:2.0960 train_time:611714ms step_avg:43.69ms -step:14200/500000 train_loss:2.1828 train_time:620450ms step_avg:43.69ms -step:14400/500000 train_loss:2.0819 train_time:629189ms step_avg:43.69ms -step:14600/500000 train_loss:2.1389 train_time:637923ms step_avg:43.69ms -step:14800/500000 train_loss:1.9175 train_time:646657ms step_avg:43.69ms -step:15000/500000 train_loss:2.0332 train_time:655403ms step_avg:43.69ms -step:15200/500000 train_loss:2.1406 train_time:664139ms step_avg:43.69ms -step:15400/500000 train_loss:2.0916 train_time:672882ms step_avg:43.69ms -step:15600/500000 train_loss:2.1547 train_time:681620ms step_avg:43.69ms -step:15800/500000 train_loss:1.8204 train_time:690352ms step_avg:43.69ms -step:16000/500000 train_loss:2.1014 train_time:699084ms step_avg:43.69ms -step:16200/500000 train_loss:2.1821 train_time:707820ms step_avg:43.69ms -step:16400/500000 train_loss:2.1894 train_time:716552ms step_avg:43.69ms -step:16600/500000 train_loss:2.0465 train_time:725412ms step_avg:43.70ms -step:16800/500000 train_loss:1.9392 train_time:734145ms step_avg:43.70ms -step:17000/500000 train_loss:1.9603 train_time:742889ms step_avg:43.70ms -step:17200/500000 train_loss:2.0424 train_time:751641ms step_avg:43.70ms -step:17400/500000 train_loss:2.0173 train_time:760388ms step_avg:43.70ms -step:17600/500000 train_loss:2.0272 train_time:769119ms step_avg:43.70ms -step:17800/500000 train_loss:2.0577 train_time:777863ms step_avg:43.70ms -step:18000/500000 train_loss:2.1233 train_time:786590ms step_avg:43.70ms -step:18200/500000 train_loss:2.2190 train_time:795310ms step_avg:43.70ms -step:18400/500000 train_loss:2.0528 train_time:804037ms step_avg:43.70ms -step:18600/500000 train_loss:2.0409 train_time:812766ms step_avg:43.70ms -step:18800/500000 train_loss:2.1861 train_time:821488ms step_avg:43.70ms -step:19000/500000 train_loss:2.0094 train_time:830217ms step_avg:43.70ms -step:19200/500000 train_loss:2.1793 train_time:838942ms step_avg:43.69ms -step:19400/500000 train_loss:2.1423 train_time:847676ms step_avg:43.69ms -step:19600/500000 train_loss:1.9673 train_time:856408ms step_avg:43.69ms -step:19800/500000 train_loss:2.2609 train_time:865135ms step_avg:43.69ms -step:20000/500000 train_loss:1.9800 train_time:873911ms step_avg:43.70ms -step:20000/500000 val_loss:2.0779 val_bpb:1.2307 train_time:873923ms step_avg:43.70ms -step:20200/500000 train_loss:2.0109 train_time:882650ms step_avg:43.70ms -step:20400/500000 train_loss:2.0714 train_time:891387ms step_avg:43.70ms -step:20600/500000 train_loss:2.0559 train_time:900265ms step_avg:43.70ms -step:20800/500000 train_loss:2.0586 train_time:909008ms step_avg:43.70ms -step:21000/500000 train_loss:2.0829 train_time:917731ms step_avg:43.70ms -step:21200/500000 train_loss:2.0840 train_time:926484ms step_avg:43.70ms -step:21400/500000 train_loss:2.0728 train_time:935227ms step_avg:43.70ms -step:21600/500000 train_loss:1.9781 train_time:943962ms step_avg:43.70ms -step:21800/500000 train_loss:2.1248 train_time:952707ms step_avg:43.70ms -step:22000/500000 train_loss:2.0181 train_time:961450ms step_avg:43.70ms -step:22200/500000 train_loss:1.9460 train_time:970185ms step_avg:43.70ms -step:22400/500000 train_loss:2.1037 train_time:978935ms step_avg:43.70ms -step:22600/500000 train_loss:2.0755 train_time:987675ms step_avg:43.70ms -step:22800/500000 train_loss:2.0061 train_time:996415ms step_avg:43.70ms -step:23000/500000 train_loss:2.1035 train_time:1005151ms step_avg:43.70ms -step:23200/500000 train_loss:2.0248 train_time:1013891ms step_avg:43.70ms -step:23400/500000 train_loss:2.0755 train_time:1022624ms step_avg:43.70ms -step:23600/500000 train_loss:2.0633 train_time:1031363ms step_avg:43.70ms -step:23800/500000 train_loss:2.0454 train_time:1040107ms step_avg:43.70ms -step:24000/500000 train_loss:2.0450 train_time:1048848ms step_avg:43.70ms -step:24200/500000 train_loss:2.1111 train_time:1057577ms step_avg:43.70ms -step:24400/500000 train_loss:2.2038 train_time:1066306ms step_avg:43.70ms -step:24600/500000 train_loss:2.1040 train_time:1075040ms step_avg:43.70ms -step:24800/500000 train_loss:2.0724 train_time:1083908ms step_avg:43.71ms -step:25000/500000 train_loss:2.0719 train_time:1092641ms step_avg:43.71ms -step:25200/500000 train_loss:2.0243 train_time:1101373ms step_avg:43.71ms -step:25400/500000 train_loss:1.9389 train_time:1110103ms step_avg:43.70ms -step:25600/500000 train_loss:1.9308 train_time:1118832ms step_avg:43.70ms -step:25800/500000 train_loss:2.0562 train_time:1127565ms step_avg:43.70ms -step:26000/500000 train_loss:2.1770 train_time:1136291ms step_avg:43.70ms -step:26200/500000 train_loss:2.0391 train_time:1145021ms step_avg:43.70ms -step:26400/500000 train_loss:1.9735 train_time:1153756ms step_avg:43.70ms -step:26600/500000 train_loss:2.0604 train_time:1162477ms step_avg:43.70ms -step:26800/500000 train_loss:1.9815 train_time:1171216ms step_avg:43.70ms -step:27000/500000 train_loss:2.0145 train_time:1179957ms step_avg:43.70ms -step:27200/500000 train_loss:2.0304 train_time:1188694ms step_avg:43.70ms -step:27400/500000 train_loss:2.0649 train_time:1197422ms step_avg:43.70ms -step:27600/500000 train_loss:2.0879 train_time:1206157ms step_avg:43.70ms -step:27800/500000 train_loss:2.2113 train_time:1214902ms step_avg:43.70ms -step:28000/500000 train_loss:2.1648 train_time:1223631ms step_avg:43.70ms -step:28200/500000 train_loss:2.0987 train_time:1232368ms step_avg:43.70ms -step:28400/500000 train_loss:2.1582 train_time:1241115ms step_avg:43.70ms -step:28600/500000 train_loss:2.0779 train_time:1249862ms step_avg:43.70ms -step:28800/500000 train_loss:2.0425 train_time:1258604ms step_avg:43.70ms -step:29000/500000 train_loss:2.1468 train_time:1267458ms step_avg:43.71ms -step:29200/500000 train_loss:2.1703 train_time:1276186ms step_avg:43.70ms -step:29400/500000 train_loss:2.0395 train_time:1284915ms step_avg:43.70ms -step:29600/500000 train_loss:2.0624 train_time:1293655ms step_avg:43.70ms -step:29800/500000 train_loss:2.0755 train_time:1302387ms step_avg:43.70ms -step:30000/500000 train_loss:2.0016 train_time:1311121ms step_avg:43.70ms -step:30200/500000 train_loss:2.0520 train_time:1319857ms step_avg:43.70ms -step:30400/500000 train_loss:1.9604 train_time:1328595ms step_avg:43.70ms -step:30600/500000 train_loss:2.0195 train_time:1337334ms step_avg:43.70ms -step:30800/500000 train_loss:2.0386 train_time:1346076ms step_avg:43.70ms -step:31000/500000 train_loss:2.1222 train_time:1354809ms step_avg:43.70ms -step:31200/500000 train_loss:2.1388 train_time:1363547ms step_avg:43.70ms -step:31400/500000 train_loss:2.1097 train_time:1372279ms step_avg:43.70ms -step:31600/500000 train_loss:2.0476 train_time:1381026ms step_avg:43.70ms -step:31800/500000 train_loss:2.0700 train_time:1389780ms step_avg:43.70ms -step:32000/500000 train_loss:2.1175 train_time:1398523ms step_avg:43.70ms -step:32200/500000 train_loss:2.0473 train_time:1407258ms step_avg:43.70ms -step:32400/500000 train_loss:1.9902 train_time:1416002ms step_avg:43.70ms -step:32600/500000 train_loss:2.0894 train_time:1424736ms step_avg:43.70ms -step:32800/500000 train_loss:2.1487 train_time:1433470ms step_avg:43.70ms -step:33000/500000 train_loss:2.0049 train_time:1442332ms step_avg:43.71ms -step:33200/500000 train_loss:2.0618 train_time:1451061ms step_avg:43.71ms -step:33400/500000 train_loss:1.9522 train_time:1459784ms step_avg:43.71ms -step:33600/500000 train_loss:2.1620 train_time:1468519ms step_avg:43.71ms -step:33800/500000 train_loss:2.0664 train_time:1477245ms step_avg:43.71ms -step:34000/500000 train_loss:2.3050 train_time:1485969ms step_avg:43.70ms -step:34200/500000 train_loss:2.1240 train_time:1494699ms step_avg:43.70ms -step:34400/500000 train_loss:2.0971 train_time:1503428ms step_avg:43.70ms -step:34600/500000 train_loss:2.1775 train_time:1512165ms step_avg:43.70ms -step:34800/500000 train_loss:2.0770 train_time:1520903ms step_avg:43.70ms -step:35000/500000 train_loss:2.0156 train_time:1529639ms step_avg:43.70ms -step:35200/500000 train_loss:2.0252 train_time:1538380ms step_avg:43.70ms -step:35400/500000 train_loss:2.0857 train_time:1547117ms step_avg:43.70ms -step:35600/500000 train_loss:2.2012 train_time:1555857ms step_avg:43.70ms -step:35800/500000 train_loss:2.0054 train_time:1564595ms step_avg:43.70ms -step:36000/500000 train_loss:2.1126 train_time:1573347ms step_avg:43.70ms -step:36200/500000 train_loss:1.9196 train_time:1582070ms step_avg:43.70ms -step:36400/500000 train_loss:2.0816 train_time:1590801ms step_avg:43.70ms -step:36600/500000 train_loss:2.0577 train_time:1599527ms step_avg:43.70ms -step:36800/500000 train_loss:1.9859 train_time:1608254ms step_avg:43.70ms -step:37000/500000 train_loss:1.9921 train_time:1616984ms step_avg:43.70ms -step:37200/500000 train_loss:2.0478 train_time:1625821ms step_avg:43.70ms -step:37400/500000 train_loss:2.0508 train_time:1634571ms step_avg:43.71ms -step:37600/500000 train_loss:2.0218 train_time:1643314ms step_avg:43.71ms -step:37800/500000 train_loss:2.1115 train_time:1652053ms step_avg:43.71ms -step:38000/500000 train_loss:1.9950 train_time:1660788ms step_avg:43.70ms -step:38200/500000 train_loss:2.1599 train_time:1669527ms step_avg:43.70ms -step:38400/500000 train_loss:2.0854 train_time:1678274ms step_avg:43.71ms -step:38600/500000 train_loss:2.0777 train_time:1687012ms step_avg:43.70ms -step:38800/500000 train_loss:2.0381 train_time:1695754ms step_avg:43.71ms -step:39000/500000 train_loss:1.9603 train_time:1704494ms step_avg:43.70ms -step:39200/500000 train_loss:2.0654 train_time:1713230ms step_avg:43.70ms -step:39400/500000 train_loss:2.0960 train_time:1721966ms step_avg:43.70ms -step:39600/500000 train_loss:2.1160 train_time:1730705ms step_avg:43.70ms -step:39800/500000 train_loss:2.0706 train_time:1739449ms step_avg:43.70ms -step:40000/500000 train_loss:1.9931 train_time:1748179ms step_avg:43.70ms -step:40000/500000 val_loss:2.0513 val_bpb:1.2149 train_time:1748195ms step_avg:43.70ms -step:40200/500000 train_loss:2.0616 train_time:1757025ms step_avg:43.71ms -step:40400/500000 train_loss:2.1068 train_time:1765759ms step_avg:43.71ms -step:40600/500000 train_loss:2.1224 train_time:1774490ms step_avg:43.71ms -step:40800/500000 train_loss:1.9901 train_time:1783221ms step_avg:43.71ms -step:41000/500000 train_loss:2.0482 train_time:1791956ms step_avg:43.71ms -step:41200/500000 train_loss:2.0054 train_time:1800685ms step_avg:43.71ms -step:41400/500000 train_loss:2.1389 train_time:1809420ms step_avg:43.71ms -step:41600/500000 train_loss:1.9084 train_time:1818157ms step_avg:43.71ms -step:41800/500000 train_loss:1.9776 train_time:1826898ms step_avg:43.71ms -step:42000/500000 train_loss:2.1240 train_time:1835629ms step_avg:43.71ms -step:42200/500000 train_loss:2.0346 train_time:1844369ms step_avg:43.71ms -step:42400/500000 train_loss:1.8743 train_time:1853109ms step_avg:43.71ms -step:42600/500000 train_loss:2.0811 train_time:1861841ms step_avg:43.71ms -step:42800/500000 train_loss:1.9275 train_time:1870573ms step_avg:43.70ms -step:43000/500000 train_loss:2.0186 train_time:1879304ms step_avg:43.70ms -step:43200/500000 train_loss:1.9877 train_time:1888040ms step_avg:43.70ms -step:43400/500000 train_loss:2.0362 train_time:1896771ms step_avg:43.70ms -step:43600/500000 train_loss:2.0231 train_time:1905503ms step_avg:43.70ms -step:43800/500000 train_loss:2.0801 train_time:1914237ms step_avg:43.70ms -step:44000/500000 train_loss:2.0663 train_time:1922975ms step_avg:43.70ms -step:44200/500000 train_loss:2.0227 train_time:1931827ms step_avg:43.71ms -step:44400/500000 train_loss:1.9966 train_time:1940564ms step_avg:43.71ms -step:44600/500000 train_loss:2.0269 train_time:1949298ms step_avg:43.71ms -step:44800/500000 train_loss:2.4033 train_time:1958041ms step_avg:43.71ms -step:45000/500000 train_loss:2.1128 train_time:1966777ms step_avg:43.71ms -step:45200/500000 train_loss:2.1673 train_time:1975520ms step_avg:43.71ms -step:45400/500000 train_loss:2.1298 train_time:1984260ms step_avg:43.71ms -step:45600/500000 train_loss:1.8693 train_time:1992998ms step_avg:43.71ms -step:45800/500000 train_loss:2.0838 train_time:2001733ms step_avg:43.71ms -step:46000/500000 train_loss:2.0543 train_time:2010471ms step_avg:43.71ms -step:46200/500000 train_loss:2.1408 train_time:2019211ms step_avg:43.71ms -step:46400/500000 train_loss:2.0154 train_time:2027949ms step_avg:43.71ms -step:46600/500000 train_loss:2.0614 train_time:2036684ms step_avg:43.71ms -step:46800/500000 train_loss:2.1562 train_time:2045414ms step_avg:43.71ms -step:47000/500000 train_loss:2.0508 train_time:2054144ms step_avg:43.71ms -step:47200/500000 train_loss:2.0838 train_time:2062882ms step_avg:43.71ms -step:47400/500000 train_loss:2.0732 train_time:2071643ms step_avg:43.71ms -step:47600/500000 train_loss:2.0627 train_time:2080393ms step_avg:43.71ms -step:47800/500000 train_loss:2.0581 train_time:2089138ms step_avg:43.71ms -step:48000/500000 train_loss:2.1821 train_time:2097874ms step_avg:43.71ms -step:48200/500000 train_loss:2.0664 train_time:2106621ms step_avg:43.71ms -step:48400/500000 train_loss:2.0724 train_time:2115483ms step_avg:43.71ms -step:48600/500000 train_loss:1.9536 train_time:2124224ms step_avg:43.71ms -step:48800/500000 train_loss:2.0122 train_time:2132956ms step_avg:43.71ms -step:49000/500000 train_loss:2.1620 train_time:2141694ms step_avg:43.71ms -step:49200/500000 train_loss:1.9946 train_time:2150431ms step_avg:43.71ms -step:49400/500000 train_loss:2.0624 train_time:2159169ms step_avg:43.71ms -step:49600/500000 train_loss:2.0001 train_time:2167906ms step_avg:43.71ms -step:49800/500000 train_loss:2.0727 train_time:2176641ms step_avg:43.71ms -step:50000/500000 train_loss:2.0876 train_time:2185379ms step_avg:43.71ms -step:50200/500000 train_loss:2.0509 train_time:2194118ms step_avg:43.71ms -step:50400/500000 train_loss:2.1277 train_time:2202852ms step_avg:43.71ms -step:50600/500000 train_loss:2.0534 train_time:2211582ms step_avg:43.71ms -step:50800/500000 train_loss:1.9725 train_time:2220324ms step_avg:43.71ms -step:51000/500000 train_loss:1.9960 train_time:2229067ms step_avg:43.71ms -step:51200/500000 train_loss:2.0345 train_time:2237807ms step_avg:43.71ms -step:51400/500000 train_loss:2.0990 train_time:2246550ms step_avg:43.71ms -step:51600/500000 train_loss:2.0457 train_time:2255285ms step_avg:43.71ms -step:51800/500000 train_loss:2.0500 train_time:2264020ms step_avg:43.71ms -step:52000/500000 train_loss:1.9877 train_time:2272760ms step_avg:43.71ms -step:52200/500000 train_loss:1.9642 train_time:2281497ms step_avg:43.71ms -step:52400/500000 train_loss:2.0589 train_time:2290236ms step_avg:43.71ms -step:52600/500000 train_loss:2.0642 train_time:2299090ms step_avg:43.71ms -step:52800/500000 train_loss:2.0002 train_time:2307829ms step_avg:43.71ms -step:53000/500000 train_loss:2.1251 train_time:2316559ms step_avg:43.71ms -step:53200/500000 train_loss:2.2760 train_time:2325293ms step_avg:43.71ms -step:53400/500000 train_loss:1.9892 train_time:2334032ms step_avg:43.71ms -step:53600/500000 train_loss:2.0399 train_time:2342765ms step_avg:43.71ms -step:53800/500000 train_loss:2.0171 train_time:2351497ms step_avg:43.71ms -step:54000/500000 train_loss:1.9835 train_time:2360222ms step_avg:43.71ms -step:54200/500000 train_loss:1.9996 train_time:2368976ms step_avg:43.71ms -step:54400/500000 train_loss:2.1281 train_time:2377711ms step_avg:43.71ms -step:54600/500000 train_loss:2.0236 train_time:2386444ms step_avg:43.71ms -step:54800/500000 train_loss:1.9770 train_time:2395176ms step_avg:43.71ms -step:55000/500000 train_loss:2.1209 train_time:2403912ms step_avg:43.71ms -step:55200/500000 train_loss:2.0261 train_time:2412637ms step_avg:43.71ms -step:55400/500000 train_loss:2.0456 train_time:2421371ms step_avg:43.71ms -step:55600/500000 train_loss:2.0821 train_time:2430109ms step_avg:43.71ms -step:55800/500000 train_loss:1.9503 train_time:2438834ms step_avg:43.71ms -step:56000/500000 train_loss:2.0074 train_time:2447565ms step_avg:43.71ms -step:56200/500000 train_loss:2.0663 train_time:2456301ms step_avg:43.71ms -step:56400/500000 train_loss:2.0383 train_time:2465029ms step_avg:43.71ms -step:56600/500000 train_loss:2.0852 train_time:2473880ms step_avg:43.71ms -step:56800/500000 train_loss:1.9910 train_time:2482615ms step_avg:43.71ms -step:57000/500000 train_loss:2.0997 train_time:2491342ms step_avg:43.71ms -step:57200/500000 train_loss:2.0564 train_time:2500070ms step_avg:43.71ms -step:57400/500000 train_loss:1.9974 train_time:2508803ms step_avg:43.71ms -step:57600/500000 train_loss:2.0217 train_time:2517529ms step_avg:43.71ms -step:57800/500000 train_loss:1.9696 train_time:2526265ms step_avg:43.71ms -step:58000/500000 train_loss:2.0311 train_time:2534998ms step_avg:43.71ms -step:58200/500000 train_loss:2.0664 train_time:2543727ms step_avg:43.71ms -step:58400/500000 train_loss:2.0663 train_time:2552468ms step_avg:43.71ms -step:58600/500000 train_loss:1.8755 train_time:2561203ms step_avg:43.71ms -step:58800/500000 train_loss:2.0705 train_time:2569943ms step_avg:43.71ms -step:59000/500000 train_loss:2.1983 train_time:2578691ms step_avg:43.71ms -step:59200/500000 train_loss:2.0254 train_time:2587421ms step_avg:43.71ms -step:59400/500000 train_loss:1.9081 train_time:2596157ms step_avg:43.71ms -step:59600/500000 train_loss:1.8892 train_time:2604900ms step_avg:43.71ms -step:59800/500000 train_loss:2.0723 train_time:2613631ms step_avg:43.71ms -step:60000/500000 train_loss:1.9396 train_time:2622368ms step_avg:43.71ms -step:60000/500000 val_loss:2.0377 val_bpb:1.2069 train_time:2622385ms step_avg:43.71ms -step:60200/500000 train_loss:2.0829 train_time:2631107ms step_avg:43.71ms -step:60400/500000 train_loss:2.0247 train_time:2639847ms step_avg:43.71ms -step:60600/500000 train_loss:2.0155 train_time:2648581ms step_avg:43.71ms -step:60800/500000 train_loss:1.9590 train_time:2657437ms step_avg:43.71ms -step:61000/500000 train_loss:2.0764 train_time:2666173ms step_avg:43.71ms -step:61200/500000 train_loss:2.0025 train_time:2674908ms step_avg:43.71ms -step:61400/500000 train_loss:2.0852 train_time:2683660ms step_avg:43.71ms -step:61600/500000 train_loss:2.1964 train_time:2692401ms step_avg:43.71ms -step:61800/500000 train_loss:2.1862 train_time:2701152ms step_avg:43.71ms -step:62000/500000 train_loss:2.0954 train_time:2709889ms step_avg:43.71ms -step:62200/500000 train_loss:2.0075 train_time:2718626ms step_avg:43.71ms -step:62400/500000 train_loss:1.9852 train_time:2727373ms step_avg:43.71ms -step:62600/500000 train_loss:2.0035 train_time:2736123ms step_avg:43.71ms -step:62800/500000 train_loss:2.1441 train_time:2744854ms step_avg:43.71ms -step:63000/500000 train_loss:1.8316 train_time:2753594ms step_avg:43.71ms -step:63200/500000 train_loss:2.0102 train_time:2762341ms step_avg:43.71ms -step:63400/500000 train_loss:1.8111 train_time:2771087ms step_avg:43.71ms -step:63600/500000 train_loss:2.0566 train_time:2779840ms step_avg:43.71ms -step:63800/500000 train_loss:2.0424 train_time:2788589ms step_avg:43.71ms -step:64000/500000 train_loss:1.9626 train_time:2797329ms step_avg:43.71ms -step:64200/500000 train_loss:2.0900 train_time:2806076ms step_avg:43.71ms -step:64400/500000 train_loss:2.0560 train_time:2814809ms step_avg:43.71ms -step:64600/500000 train_loss:1.9803 train_time:2823559ms step_avg:43.71ms -step:64800/500000 train_loss:2.0888 train_time:2832415ms step_avg:43.71ms -step:65000/500000 train_loss:1.9832 train_time:2841158ms step_avg:43.71ms -step:65200/500000 train_loss:2.0433 train_time:2849898ms step_avg:43.71ms -step:65400/500000 train_loss:1.8724 train_time:2858640ms step_avg:43.71ms -step:65600/500000 train_loss:1.9251 train_time:2867381ms step_avg:43.71ms -step:65800/500000 train_loss:2.0526 train_time:2876123ms step_avg:43.71ms -step:66000/500000 train_loss:1.9819 train_time:2884855ms step_avg:43.71ms -step:66200/500000 train_loss:2.1081 train_time:2893587ms step_avg:43.71ms -step:66400/500000 train_loss:2.0156 train_time:2902325ms step_avg:43.71ms -step:66600/500000 train_loss:2.0231 train_time:2911066ms step_avg:43.71ms -step:66800/500000 train_loss:1.9735 train_time:2919799ms step_avg:43.71ms -step:67000/500000 train_loss:1.9554 train_time:2928542ms step_avg:43.71ms -step:67200/500000 train_loss:2.0098 train_time:2937282ms step_avg:43.71ms -step:67400/500000 train_loss:2.1122 train_time:2946019ms step_avg:43.71ms -step:67600/500000 train_loss:2.0330 train_time:2954754ms step_avg:43.71ms -step:67800/500000 train_loss:1.9299 train_time:2963497ms step_avg:43.71ms -step:68000/500000 train_loss:2.0881 train_time:2972236ms step_avg:43.71ms -step:68200/500000 train_loss:1.9947 train_time:2980982ms step_avg:43.71ms -step:68400/500000 train_loss:2.1097 train_time:2989720ms step_avg:43.71ms -step:68600/500000 train_loss:2.1249 train_time:2998464ms step_avg:43.71ms -step:68800/500000 train_loss:1.9755 train_time:3007204ms step_avg:43.71ms -step:69000/500000 train_loss:1.9303 train_time:3016067ms step_avg:43.71ms -step:69200/500000 train_loss:2.0941 train_time:3024806ms step_avg:43.71ms -step:69400/500000 train_loss:2.0425 train_time:3033540ms step_avg:43.71ms -step:69600/500000 train_loss:1.9208 train_time:3042275ms step_avg:43.71ms -step:69800/500000 train_loss:2.0736 train_time:3051002ms step_avg:43.71ms -step:70000/500000 train_loss:2.0506 train_time:3059740ms step_avg:43.71ms -step:70200/500000 train_loss:1.9177 train_time:3068480ms step_avg:43.71ms -step:70400/500000 train_loss:2.0552 train_time:3077209ms step_avg:43.71ms -step:70600/500000 train_loss:2.1593 train_time:3085950ms step_avg:43.71ms -step:70800/500000 train_loss:2.3957 train_time:3094684ms step_avg:43.71ms -step:71000/500000 train_loss:2.1222 train_time:3103419ms step_avg:43.71ms -step:71200/500000 train_loss:2.1032 train_time:3112149ms step_avg:43.71ms -step:71400/500000 train_loss:1.9591 train_time:3120894ms step_avg:43.71ms -step:71600/500000 train_loss:1.9592 train_time:3129636ms step_avg:43.71ms -step:71800/500000 train_loss:2.0074 train_time:3138374ms step_avg:43.71ms -step:72000/500000 train_loss:2.2693 train_time:3147117ms step_avg:43.71ms -step:72200/500000 train_loss:2.1045 train_time:3155857ms step_avg:43.71ms -step:72400/500000 train_loss:1.8356 train_time:3164598ms step_avg:43.71ms -step:72600/500000 train_loss:2.0390 train_time:3173344ms step_avg:43.71ms -step:72800/500000 train_loss:1.9867 train_time:3182090ms step_avg:43.71ms -step:73000/500000 train_loss:2.0488 train_time:3190957ms step_avg:43.71ms -step:73200/500000 train_loss:2.0355 train_time:3199692ms step_avg:43.71ms -step:73400/500000 train_loss:1.9839 train_time:3208438ms step_avg:43.71ms -step:73600/500000 train_loss:2.0317 train_time:3217184ms step_avg:43.71ms -step:73800/500000 train_loss:1.9925 train_time:3225921ms step_avg:43.71ms -step:74000/500000 train_loss:1.9522 train_time:3234658ms step_avg:43.71ms -step:74200/500000 train_loss:2.1049 train_time:3243364ms step_avg:43.71ms -step:74400/500000 train_loss:1.9785 train_time:3252099ms step_avg:43.71ms -step:74600/500000 train_loss:2.0592 train_time:3260837ms step_avg:43.71ms -step:74800/500000 train_loss:2.0489 train_time:3269587ms step_avg:43.71ms -step:75000/500000 train_loss:2.0631 train_time:3278321ms step_avg:43.71ms -step:75200/500000 train_loss:2.0235 train_time:3287054ms step_avg:43.71ms -step:75400/500000 train_loss:1.8544 train_time:3295787ms step_avg:43.71ms -step:75600/500000 train_loss:1.8574 train_time:3304514ms step_avg:43.71ms -step:75800/500000 train_loss:2.0504 train_time:3313237ms step_avg:43.71ms -step:76000/500000 train_loss:1.9967 train_time:3322083ms step_avg:43.71ms -step:76200/500000 train_loss:2.0440 train_time:3330827ms step_avg:43.71ms -step:76400/500000 train_loss:1.9953 train_time:3339574ms step_avg:43.71ms -step:76600/500000 train_loss:2.0072 train_time:3348313ms step_avg:43.71ms -step:76800/500000 train_loss:2.0488 train_time:3357061ms step_avg:43.71ms -step:77000/500000 train_loss:2.0404 train_time:3365811ms step_avg:43.71ms -step:77200/500000 train_loss:2.0290 train_time:3374553ms step_avg:43.71ms -step:77400/500000 train_loss:1.8698 train_time:3383295ms step_avg:43.71ms -step:77600/500000 train_loss:1.9732 train_time:3392031ms step_avg:43.71ms -step:77800/500000 train_loss:2.0140 train_time:3400774ms step_avg:43.71ms -step:78000/500000 train_loss:1.9584 train_time:3409520ms step_avg:43.71ms -step:78200/500000 train_loss:2.2657 train_time:3418280ms step_avg:43.71ms -step:78400/500000 train_loss:2.0041 train_time:3427037ms step_avg:43.71ms -step:78600/500000 train_loss:2.0953 train_time:3435793ms step_avg:43.71ms -step:78800/500000 train_loss:2.0426 train_time:3444542ms step_avg:43.71ms -step:79000/500000 train_loss:2.0681 train_time:3453286ms step_avg:43.71ms -step:79200/500000 train_loss:2.0154 train_time:3462037ms step_avg:43.71ms -step:79400/500000 train_loss:1.9769 train_time:3470781ms step_avg:43.71ms -step:79600/500000 train_loss:2.0938 train_time:3479530ms step_avg:43.71ms -step:79800/500000 train_loss:2.0043 train_time:3488267ms step_avg:43.71ms -step:80000/500000 train_loss:2.0123 train_time:3497014ms step_avg:43.71ms -step:80000/500000 val_loss:2.0306 val_bpb:1.2026 train_time:3497030ms step_avg:43.71ms -step:80200/500000 train_loss:1.9839 train_time:3505891ms step_avg:43.71ms -step:80400/500000 train_loss:2.0201 train_time:3514643ms step_avg:43.71ms -step:80600/500000 train_loss:1.9697 train_time:3523393ms step_avg:43.71ms -step:80800/500000 train_loss:1.8849 train_time:3532145ms step_avg:43.71ms -step:81000/500000 train_loss:2.0080 train_time:3540891ms step_avg:43.71ms -step:81200/500000 train_loss:2.0728 train_time:3549647ms step_avg:43.71ms -step:81400/500000 train_loss:2.1014 train_time:3558381ms step_avg:43.71ms -step:81600/500000 train_loss:2.0568 train_time:3567112ms step_avg:43.71ms -step:81800/500000 train_loss:2.0455 train_time:3575852ms step_avg:43.71ms -step:82000/500000 train_loss:2.0547 train_time:3584589ms step_avg:43.71ms -step:82200/500000 train_loss:2.0608 train_time:3593331ms step_avg:43.71ms -step:82400/500000 train_loss:2.3063 train_time:3602075ms step_avg:43.71ms -step:82600/500000 train_loss:2.0339 train_time:3610813ms step_avg:43.71ms -step:82800/500000 train_loss:2.1101 train_time:3619548ms step_avg:43.71ms -step:83000/500000 train_loss:1.9921 train_time:3628292ms step_avg:43.71ms -step:83200/500000 train_loss:1.9948 train_time:3637032ms step_avg:43.71ms -step:83400/500000 train_loss:2.0244 train_time:3645763ms step_avg:43.71ms -step:83600/500000 train_loss:1.9101 train_time:3654495ms step_avg:43.71ms -step:83800/500000 train_loss:2.0107 train_time:3663231ms step_avg:43.71ms -step:84000/500000 train_loss:1.9378 train_time:3671965ms step_avg:43.71ms -step:84200/500000 train_loss:2.0769 train_time:3680704ms step_avg:43.71ms -step:84400/500000 train_loss:1.9761 train_time:3689562ms step_avg:43.72ms -step:84600/500000 train_loss:2.0646 train_time:3698295ms step_avg:43.72ms -step:84800/500000 train_loss:2.0266 train_time:3707034ms step_avg:43.72ms -step:85000/500000 train_loss:2.0077 train_time:3715778ms step_avg:43.72ms -step:85200/500000 train_loss:2.1978 train_time:3724509ms step_avg:43.71ms -step:85400/500000 train_loss:2.1181 train_time:3733240ms step_avg:43.71ms -step:85600/500000 train_loss:1.9731 train_time:3741988ms step_avg:43.71ms -step:85800/500000 train_loss:1.9917 train_time:3750739ms step_avg:43.71ms -step:86000/500000 train_loss:1.9015 train_time:3759481ms step_avg:43.71ms -step:86200/500000 train_loss:1.9947 train_time:3768217ms step_avg:43.71ms -step:86400/500000 train_loss:2.0555 train_time:3776958ms step_avg:43.71ms -step:86600/500000 train_loss:1.9431 train_time:3785705ms step_avg:43.71ms -step:86800/500000 train_loss:1.9248 train_time:3794444ms step_avg:43.71ms -step:87000/500000 train_loss:2.2830 train_time:3803191ms step_avg:43.71ms -step:87200/500000 train_loss:1.9020 train_time:3811936ms step_avg:43.71ms -step:87400/500000 train_loss:2.0276 train_time:3820677ms step_avg:43.71ms -step:87600/500000 train_loss:1.9286 train_time:3829421ms step_avg:43.71ms -step:87800/500000 train_loss:2.0150 train_time:3838163ms step_avg:43.71ms -step:88000/500000 train_loss:2.1064 train_time:3846914ms step_avg:43.71ms -step:88200/500000 train_loss:1.8532 train_time:3855651ms step_avg:43.71ms -step:88400/500000 train_loss:2.1184 train_time:3864514ms step_avg:43.72ms -step:88600/500000 train_loss:1.9995 train_time:3873261ms step_avg:43.72ms -step:88800/500000 train_loss:2.0633 train_time:3882020ms step_avg:43.72ms -step:89000/500000 train_loss:2.0382 train_time:3890758ms step_avg:43.72ms -step:89200/500000 train_loss:1.9755 train_time:3899507ms step_avg:43.72ms -step:89400/500000 train_loss:1.9484 train_time:3908254ms step_avg:43.72ms -step:89600/500000 train_loss:2.0261 train_time:3916999ms step_avg:43.72ms -step:89800/500000 train_loss:1.9487 train_time:3925742ms step_avg:43.72ms -step:90000/500000 train_loss:2.0480 train_time:3934482ms step_avg:43.72ms -step:90200/500000 train_loss:2.1049 train_time:3943216ms step_avg:43.72ms -step:90400/500000 train_loss:2.0033 train_time:3951955ms step_avg:43.72ms -step:90600/500000 train_loss:2.1195 train_time:3960696ms step_avg:43.72ms -step:90800/500000 train_loss:2.0872 train_time:3969435ms step_avg:43.72ms -step:91000/500000 train_loss:2.0414 train_time:3978176ms step_avg:43.72ms -step:91200/500000 train_loss:1.9898 train_time:3986934ms step_avg:43.72ms -step:91400/500000 train_loss:2.1097 train_time:3995676ms step_avg:43.72ms -step:91600/500000 train_loss:2.0608 train_time:4004427ms step_avg:43.72ms -step:91800/500000 train_loss:2.0089 train_time:4013164ms step_avg:43.72ms -step:92000/500000 train_loss:2.0163 train_time:4021901ms step_avg:43.72ms -step:92200/500000 train_loss:2.0968 train_time:4030647ms step_avg:43.72ms -step:92400/500000 train_loss:1.9743 train_time:4039391ms step_avg:43.72ms -step:92600/500000 train_loss:2.0348 train_time:4048245ms step_avg:43.72ms -step:92800/500000 train_loss:2.0199 train_time:4056989ms step_avg:43.72ms -step:93000/500000 train_loss:2.0642 train_time:4065735ms step_avg:43.72ms -step:93200/500000 train_loss:2.1773 train_time:4074480ms step_avg:43.72ms -step:93400/500000 train_loss:2.0545 train_time:4083215ms step_avg:43.72ms -step:93600/500000 train_loss:1.9643 train_time:4091955ms step_avg:43.72ms -step:93800/500000 train_loss:2.0638 train_time:4100694ms step_avg:43.72ms -step:94000/500000 train_loss:2.0236 train_time:4109446ms step_avg:43.72ms -step:94200/500000 train_loss:2.0259 train_time:4118191ms step_avg:43.72ms -step:94400/500000 train_loss:1.9855 train_time:4126935ms step_avg:43.72ms -step:94600/500000 train_loss:2.0248 train_time:4135681ms step_avg:43.72ms -step:94800/500000 train_loss:2.0139 train_time:4144419ms step_avg:43.72ms -step:95000/500000 train_loss:2.1720 train_time:4153155ms step_avg:43.72ms -step:95200/500000 train_loss:2.0017 train_time:4161900ms step_avg:43.72ms -step:95400/500000 train_loss:2.0994 train_time:4170642ms step_avg:43.72ms -step:95600/500000 train_loss:2.0451 train_time:4179383ms step_avg:43.72ms -step:95800/500000 train_loss:1.9698 train_time:4188120ms step_avg:43.72ms -step:96000/500000 train_loss:2.0817 train_time:4196853ms step_avg:43.72ms -step:96200/500000 train_loss:2.1018 train_time:4205589ms step_avg:43.72ms -step:96400/500000 train_loss:2.1385 train_time:4214327ms step_avg:43.72ms -step:96600/500000 train_loss:2.0407 train_time:4223177ms step_avg:43.72ms -step:96800/500000 train_loss:1.9939 train_time:4231910ms step_avg:43.72ms -step:97000/500000 train_loss:2.0416 train_time:4240640ms step_avg:43.72ms -step:97200/500000 train_loss:2.0274 train_time:4249373ms step_avg:43.72ms -step:97400/500000 train_loss:1.8758 train_time:4258108ms step_avg:43.72ms -step:97600/500000 train_loss:1.9845 train_time:4266837ms step_avg:43.72ms -step:97800/500000 train_loss:2.1325 train_time:4275574ms step_avg:43.72ms -step:98000/500000 train_loss:1.9135 train_time:4284317ms step_avg:43.72ms -step:98200/500000 train_loss:2.1159 train_time:4293049ms step_avg:43.72ms -step:98400/500000 train_loss:1.9857 train_time:4301781ms step_avg:43.72ms -step:98600/500000 train_loss:1.9655 train_time:4310513ms step_avg:43.72ms -step:98800/500000 train_loss:1.9529 train_time:4319254ms step_avg:43.72ms -step:99000/500000 train_loss:2.0475 train_time:4327994ms step_avg:43.72ms -step:99200/500000 train_loss:2.0054 train_time:4336730ms step_avg:43.72ms -step:99400/500000 train_loss:1.9931 train_time:4345463ms step_avg:43.72ms -step:99600/500000 train_loss:1.9279 train_time:4354197ms step_avg:43.72ms -step:99800/500000 train_loss:1.9186 train_time:4362935ms step_avg:43.72ms -step:100000/500000 train_loss:1.7548 train_time:4371686ms step_avg:43.72ms -step:100000/500000 val_loss:2.0253 val_bpb:1.1995 train_time:4371696ms step_avg:43.72ms -step:100200/500000 train_loss:1.9653 train_time:4380428ms step_avg:43.72ms -step:100400/500000 train_loss:2.1490 train_time:4389172ms step_avg:43.72ms -step:100600/500000 train_loss:2.1338 train_time:4397923ms step_avg:43.72ms -step:100800/500000 train_loss:1.9365 train_time:4406793ms step_avg:43.72ms -step:101000/500000 train_loss:2.1717 train_time:4415535ms step_avg:43.72ms -step:101200/500000 train_loss:2.0397 train_time:4424281ms step_avg:43.72ms -step:101400/500000 train_loss:2.0684 train_time:4433025ms step_avg:43.72ms -step:101600/500000 train_loss:2.0624 train_time:4441766ms step_avg:43.72ms -step:101800/500000 train_loss:1.9881 train_time:4450520ms step_avg:43.72ms -step:102000/500000 train_loss:2.0370 train_time:4459258ms step_avg:43.72ms -step:102200/500000 train_loss:2.0265 train_time:4468004ms step_avg:43.72ms -step:102400/500000 train_loss:2.0059 train_time:4476745ms step_avg:43.72ms -step:102600/500000 train_loss:2.1766 train_time:4485487ms step_avg:43.72ms -step:102800/500000 train_loss:1.9531 train_time:4494235ms step_avg:43.72ms -step:103000/500000 train_loss:2.0770 train_time:4502984ms step_avg:43.72ms -step:103200/500000 train_loss:1.9543 train_time:4511735ms step_avg:43.72ms -step:103400/500000 train_loss:2.0411 train_time:4520487ms step_avg:43.72ms -step:103600/500000 train_loss:2.0500 train_time:4529227ms step_avg:43.72ms -step:103800/500000 train_loss:1.9965 train_time:4537967ms step_avg:43.72ms -step:104000/500000 train_loss:2.0624 train_time:4546714ms step_avg:43.72ms -step:104200/500000 train_loss:1.9853 train_time:4555461ms step_avg:43.72ms -step:104400/500000 train_loss:1.9747 train_time:4564204ms step_avg:43.72ms -step:104600/500000 train_loss:2.0848 train_time:4572941ms step_avg:43.72ms -step:104800/500000 train_loss:2.0619 train_time:4581686ms step_avg:43.72ms -step:105000/500000 train_loss:2.0278 train_time:4590551ms step_avg:43.72ms -step:105200/500000 train_loss:1.9722 train_time:4599291ms step_avg:43.72ms -step:105400/500000 train_loss:2.1234 train_time:4608037ms step_avg:43.72ms -step:105600/500000 train_loss:1.9276 train_time:4616777ms step_avg:43.72ms -step:105800/500000 train_loss:2.1207 train_time:4625512ms step_avg:43.72ms -step:106000/500000 train_loss:1.9724 train_time:4634245ms step_avg:43.72ms -step:106200/500000 train_loss:1.8480 train_time:4642984ms step_avg:43.72ms -step:106400/500000 train_loss:1.9962 train_time:4651724ms step_avg:43.72ms -step:106600/500000 train_loss:2.0682 train_time:4660463ms step_avg:43.72ms -step:106800/500000 train_loss:1.9596 train_time:4669202ms step_avg:43.72ms -step:107000/500000 train_loss:1.9887 train_time:4677945ms step_avg:43.72ms -step:107200/500000 train_loss:2.0878 train_time:4686687ms step_avg:43.72ms -step:107400/500000 train_loss:1.9820 train_time:4695424ms step_avg:43.72ms -step:107600/500000 train_loss:1.8612 train_time:4704155ms step_avg:43.72ms -step:107800/500000 train_loss:2.1396 train_time:4712892ms step_avg:43.72ms -step:108000/500000 train_loss:2.0850 train_time:4721621ms step_avg:43.72ms -step:108200/500000 train_loss:2.0261 train_time:4730365ms step_avg:43.72ms -step:108400/500000 train_loss:2.0620 train_time:4739111ms step_avg:43.72ms -step:108600/500000 train_loss:1.9826 train_time:4747853ms step_avg:43.72ms -step:108800/500000 train_loss:2.1798 train_time:4756598ms step_avg:43.72ms -step:109000/500000 train_loss:2.0688 train_time:4765478ms step_avg:43.72ms -step:109200/500000 train_loss:2.1944 train_time:4774228ms step_avg:43.72ms -step:109400/500000 train_loss:2.0273 train_time:4782964ms step_avg:43.72ms -step:109600/500000 train_loss:1.9818 train_time:4791698ms step_avg:43.72ms -step:109800/500000 train_loss:1.8197 train_time:4800433ms step_avg:43.72ms -step:110000/500000 train_loss:1.9632 train_time:4809167ms step_avg:43.72ms -step:110200/500000 train_loss:2.0929 train_time:4817911ms step_avg:43.72ms -step:110400/500000 train_loss:1.9924 train_time:4826648ms step_avg:43.72ms -step:110600/500000 train_loss:1.9691 train_time:4835381ms step_avg:43.72ms -step:110800/500000 train_loss:2.0186 train_time:4844110ms step_avg:43.72ms -step:111000/500000 train_loss:2.0252 train_time:4852847ms step_avg:43.72ms -step:111200/500000 train_loss:2.2204 train_time:4861590ms step_avg:43.72ms -step:111400/500000 train_loss:2.0328 train_time:4870299ms step_avg:43.72ms -step:111600/500000 train_loss:2.0633 train_time:4879034ms step_avg:43.72ms -step:111800/500000 train_loss:2.0296 train_time:4887776ms step_avg:43.72ms -step:112000/500000 train_loss:2.0067 train_time:4896638ms step_avg:43.72ms -step:112200/500000 train_loss:2.0360 train_time:4905378ms step_avg:43.72ms -step:112400/500000 train_loss:1.8986 train_time:4914115ms step_avg:43.72ms -step:112600/500000 train_loss:1.8944 train_time:4922857ms step_avg:43.72ms -step:112800/500000 train_loss:1.9784 train_time:4931604ms step_avg:43.72ms -step:113000/500000 train_loss:2.0355 train_time:4940342ms step_avg:43.72ms -step:113200/500000 train_loss:1.8432 train_time:4949080ms step_avg:43.72ms -step:113400/500000 train_loss:2.2008 train_time:4957827ms step_avg:43.72ms -step:113600/500000 train_loss:1.9759 train_time:4966563ms step_avg:43.72ms -step:113800/500000 train_loss:1.8877 train_time:4975306ms step_avg:43.72ms -step:114000/500000 train_loss:2.0328 train_time:4984057ms step_avg:43.72ms -step:114200/500000 train_loss:1.9748 train_time:4992790ms step_avg:43.72ms -step:114400/500000 train_loss:1.9246 train_time:5001527ms step_avg:43.72ms -step:114600/500000 train_loss:2.1953 train_time:5010259ms step_avg:43.72ms -step:114800/500000 train_loss:1.9761 train_time:5018992ms step_avg:43.72ms -step:115000/500000 train_loss:1.9908 train_time:5027726ms step_avg:43.72ms -step:115200/500000 train_loss:2.0108 train_time:5036461ms step_avg:43.72ms -step:115400/500000 train_loss:2.0292 train_time:5045199ms step_avg:43.72ms -step:115600/500000 train_loss:2.0888 train_time:5053924ms step_avg:43.72ms -step:115800/500000 train_loss:2.2096 train_time:5062667ms step_avg:43.72ms -step:116000/500000 train_loss:1.9721 train_time:5071397ms step_avg:43.72ms -step:116200/500000 train_loss:2.0199 train_time:5080251ms step_avg:43.72ms -step:116400/500000 train_loss:2.1796 train_time:5088980ms step_avg:43.72ms -step:116600/500000 train_loss:1.9663 train_time:5097712ms step_avg:43.72ms -step:116800/500000 train_loss:2.0318 train_time:5106443ms step_avg:43.72ms -step:117000/500000 train_loss:1.9711 train_time:5115180ms step_avg:43.72ms -step:117200/500000 train_loss:1.7761 train_time:5123922ms step_avg:43.72ms -step:117400/500000 train_loss:2.0077 train_time:5132656ms step_avg:43.72ms -step:117600/500000 train_loss:1.9814 train_time:5141384ms step_avg:43.72ms -step:117800/500000 train_loss:2.1267 train_time:5150118ms step_avg:43.72ms -step:118000/500000 train_loss:1.8956 train_time:5158844ms step_avg:43.72ms -step:118200/500000 train_loss:1.8866 train_time:5167573ms step_avg:43.72ms -step:118400/500000 train_loss:2.0781 train_time:5176309ms step_avg:43.72ms -step:118600/500000 train_loss:2.3821 train_time:5185055ms step_avg:43.72ms -step:118800/500000 train_loss:2.0305 train_time:5193784ms step_avg:43.72ms -step:119000/500000 train_loss:1.9897 train_time:5202517ms step_avg:43.72ms -step:119200/500000 train_loss:2.0434 train_time:5211253ms step_avg:43.72ms -step:119400/500000 train_loss:2.0924 train_time:5220001ms step_avg:43.72ms -step:119600/500000 train_loss:2.0761 train_time:5228736ms step_avg:43.72ms -step:119800/500000 train_loss:2.2817 train_time:5237464ms step_avg:43.72ms -step:120000/500000 train_loss:2.1493 train_time:5246197ms step_avg:43.72ms -step:120000/500000 val_loss:2.0210 val_bpb:1.1969 train_time:5246213ms step_avg:43.72ms -step:120200/500000 train_loss:1.9144 train_time:5255050ms step_avg:43.72ms -step:120400/500000 train_loss:2.0053 train_time:5263788ms step_avg:43.72ms -step:120600/500000 train_loss:1.9489 train_time:5272526ms step_avg:43.72ms -step:120800/500000 train_loss:1.9367 train_time:5281265ms step_avg:43.72ms -step:121000/500000 train_loss:1.9876 train_time:5289998ms step_avg:43.72ms -step:121200/500000 train_loss:2.0388 train_time:5298737ms step_avg:43.72ms -step:121400/500000 train_loss:1.9117 train_time:5307477ms step_avg:43.72ms -step:121600/500000 train_loss:2.1033 train_time:5316209ms step_avg:43.72ms -step:121800/500000 train_loss:1.9801 train_time:5324944ms step_avg:43.72ms -step:122000/500000 train_loss:1.9211 train_time:5333677ms step_avg:43.72ms -step:122200/500000 train_loss:1.9922 train_time:5342421ms step_avg:43.72ms -step:122400/500000 train_loss:2.0145 train_time:5351162ms step_avg:43.72ms -step:122600/500000 train_loss:1.9822 train_time:5359901ms step_avg:43.72ms -step:122800/500000 train_loss:1.9503 train_time:5368632ms step_avg:43.72ms -step:123000/500000 train_loss:2.0314 train_time:5377369ms step_avg:43.72ms -step:123200/500000 train_loss:2.0501 train_time:5386094ms step_avg:43.72ms -step:123400/500000 train_loss:2.0650 train_time:5394818ms step_avg:43.72ms -step:123600/500000 train_loss:2.0346 train_time:5403547ms step_avg:43.72ms -step:123800/500000 train_loss:1.9980 train_time:5412273ms step_avg:43.72ms -step:124000/500000 train_loss:1.8607 train_time:5421006ms step_avg:43.72ms -step:124200/500000 train_loss:2.0196 train_time:5429737ms step_avg:43.72ms -step:124400/500000 train_loss:1.9967 train_time:5438575ms step_avg:43.72ms -step:124600/500000 train_loss:2.0685 train_time:5447313ms step_avg:43.72ms -step:124800/500000 train_loss:2.0213 train_time:5456041ms step_avg:43.72ms -step:125000/500000 train_loss:2.0980 train_time:5464778ms step_avg:43.72ms -step:125200/500000 train_loss:1.9944 train_time:5473524ms step_avg:43.72ms -step:125400/500000 train_loss:2.0300 train_time:5482273ms step_avg:43.72ms -step:125600/500000 train_loss:2.0471 train_time:5491013ms step_avg:43.72ms -step:125800/500000 train_loss:2.0354 train_time:5499761ms step_avg:43.72ms -step:126000/500000 train_loss:2.0454 train_time:5508491ms step_avg:43.72ms -step:126200/500000 train_loss:1.9727 train_time:5517233ms step_avg:43.72ms -step:126400/500000 train_loss:2.1303 train_time:5525963ms step_avg:43.72ms -step:126600/500000 train_loss:2.1323 train_time:5534702ms step_avg:43.72ms -step:126800/500000 train_loss:2.0353 train_time:5543447ms step_avg:43.72ms -step:127000/500000 train_loss:2.0481 train_time:5552191ms step_avg:43.72ms -step:127200/500000 train_loss:1.9966 train_time:5560928ms step_avg:43.72ms -step:127400/500000 train_loss:2.0024 train_time:5569671ms step_avg:43.72ms -step:127600/500000 train_loss:2.1352 train_time:5578409ms step_avg:43.72ms -step:127800/500000 train_loss:1.9181 train_time:5587143ms step_avg:43.72ms -step:128000/500000 train_loss:2.0442 train_time:5595877ms step_avg:43.72ms -step:128200/500000 train_loss:1.9795 train_time:5604623ms step_avg:43.72ms -step:128400/500000 train_loss:2.0676 train_time:5613487ms step_avg:43.72ms -step:128600/500000 train_loss:1.9502 train_time:5622229ms step_avg:43.72ms -step:128800/500000 train_loss:1.9356 train_time:5630963ms step_avg:43.72ms -step:129000/500000 train_loss:1.8350 train_time:5639707ms step_avg:43.72ms -step:129200/500000 train_loss:1.9958 train_time:5648439ms step_avg:43.72ms -step:129400/500000 train_loss:1.7850 train_time:5657189ms step_avg:43.72ms -step:129600/500000 train_loss:2.1434 train_time:5665932ms step_avg:43.72ms -step:129800/500000 train_loss:1.9649 train_time:5674682ms step_avg:43.72ms -step:130000/500000 train_loss:2.0598 train_time:5683416ms step_avg:43.72ms -step:130200/500000 train_loss:2.0008 train_time:5692156ms step_avg:43.72ms -step:130400/500000 train_loss:2.1846 train_time:5700896ms step_avg:43.72ms -step:130600/500000 train_loss:2.1258 train_time:5709644ms step_avg:43.72ms -step:130800/500000 train_loss:2.0358 train_time:5718386ms step_avg:43.72ms -step:131000/500000 train_loss:1.9617 train_time:5727113ms step_avg:43.72ms -step:131200/500000 train_loss:1.9527 train_time:5735854ms step_avg:43.72ms -step:131400/500000 train_loss:1.9986 train_time:5744605ms step_avg:43.72ms -step:131600/500000 train_loss:1.8493 train_time:5753346ms step_avg:43.72ms -step:131800/500000 train_loss:2.0235 train_time:5762090ms step_avg:43.72ms -step:132000/500000 train_loss:1.9611 train_time:5770826ms step_avg:43.72ms -step:132200/500000 train_loss:2.0716 train_time:5779577ms step_avg:43.72ms -step:132400/500000 train_loss:2.0732 train_time:5788321ms step_avg:43.72ms -step:132600/500000 train_loss:1.9754 train_time:5797192ms step_avg:43.72ms -step:132800/500000 train_loss:2.0154 train_time:5805936ms step_avg:43.72ms -step:133000/500000 train_loss:1.9978 train_time:5814679ms step_avg:43.72ms -step:133200/500000 train_loss:2.1977 train_time:5823425ms step_avg:43.72ms -step:133400/500000 train_loss:2.0363 train_time:5832175ms step_avg:43.72ms -step:133600/500000 train_loss:1.8937 train_time:5840923ms step_avg:43.72ms -step:133800/500000 train_loss:1.9849 train_time:5849669ms step_avg:43.72ms -step:134000/500000 train_loss:2.1667 train_time:5858407ms step_avg:43.72ms -step:134200/500000 train_loss:2.1546 train_time:5867169ms step_avg:43.72ms -step:134400/500000 train_loss:2.1969 train_time:5875915ms step_avg:43.72ms -step:134600/500000 train_loss:1.9611 train_time:5884654ms step_avg:43.72ms -step:134800/500000 train_loss:1.9517 train_time:5893390ms step_avg:43.72ms -step:135000/500000 train_loss:2.0010 train_time:5902136ms step_avg:43.72ms -step:135200/500000 train_loss:1.9772 train_time:5910876ms step_avg:43.72ms -step:135400/500000 train_loss:2.1341 train_time:5919617ms step_avg:43.72ms -step:135600/500000 train_loss:2.1205 train_time:5928363ms step_avg:43.72ms -step:135800/500000 train_loss:2.0861 train_time:5937104ms step_avg:43.72ms -step:136000/500000 train_loss:1.9799 train_time:5945849ms step_avg:43.72ms -step:136200/500000 train_loss:1.9827 train_time:5954598ms step_avg:43.72ms -step:136400/500000 train_loss:2.0416 train_time:5963337ms step_avg:43.72ms -step:136600/500000 train_loss:1.9804 train_time:5972081ms step_avg:43.72ms -step:136800/500000 train_loss:2.0562 train_time:5980950ms step_avg:43.72ms -step:137000/500000 train_loss:1.9513 train_time:5989697ms step_avg:43.72ms -step:137200/500000 train_loss:2.0339 train_time:5998445ms step_avg:43.72ms -step:137400/500000 train_loss:2.0287 train_time:6007174ms step_avg:43.72ms -step:137600/500000 train_loss:1.9576 train_time:6015910ms step_avg:43.72ms -step:137800/500000 train_loss:2.0302 train_time:6024660ms step_avg:43.72ms -step:138000/500000 train_loss:2.0371 train_time:6033400ms step_avg:43.72ms -step:138200/500000 train_loss:1.8320 train_time:6042150ms step_avg:43.72ms -step:138400/500000 train_loss:2.0225 train_time:6050887ms step_avg:43.72ms -step:138600/500000 train_loss:1.8439 train_time:6059629ms step_avg:43.72ms -step:138800/500000 train_loss:2.1438 train_time:6068362ms step_avg:43.72ms -step:139000/500000 train_loss:1.9947 train_time:6077102ms step_avg:43.72ms -step:139200/500000 train_loss:2.1038 train_time:6085847ms step_avg:43.72ms -step:139400/500000 train_loss:2.0055 train_time:6094592ms step_avg:43.72ms -step:139600/500000 train_loss:1.9556 train_time:6103326ms step_avg:43.72ms -step:139800/500000 train_loss:1.9420 train_time:6112064ms step_avg:43.72ms -step:140000/500000 train_loss:1.9340 train_time:6120801ms step_avg:43.72ms -step:140000/500000 val_loss:2.0168 val_bpb:1.1945 train_time:6120818ms step_avg:43.72ms -step:140200/500000 train_loss:2.0384 train_time:6129542ms step_avg:43.72ms -step:140400/500000 train_loss:1.9519 train_time:6138277ms step_avg:43.72ms -step:140600/500000 train_loss:1.9898 train_time:6147017ms step_avg:43.72ms -step:140800/500000 train_loss:2.2550 train_time:6155876ms step_avg:43.72ms -step:141000/500000 train_loss:1.9867 train_time:6164617ms step_avg:43.72ms -step:141200/500000 train_loss:2.0205 train_time:6173359ms step_avg:43.72ms -step:141400/500000 train_loss:1.8807 train_time:6182099ms step_avg:43.72ms -step:141600/500000 train_loss:1.9898 train_time:6190842ms step_avg:43.72ms -step:141800/500000 train_loss:1.8427 train_time:6199583ms step_avg:43.72ms -step:142000/500000 train_loss:2.1160 train_time:6208323ms step_avg:43.72ms -step:142200/500000 train_loss:1.8754 train_time:6217067ms step_avg:43.72ms -step:142400/500000 train_loss:2.0096 train_time:6225813ms step_avg:43.72ms -step:142600/500000 train_loss:1.9781 train_time:6234549ms step_avg:43.72ms -step:142800/500000 train_loss:2.0586 train_time:6243289ms step_avg:43.72ms -step:143000/500000 train_loss:2.1796 train_time:6252012ms step_avg:43.72ms -step:143200/500000 train_loss:2.0443 train_time:6260744ms step_avg:43.72ms -step:143400/500000 train_loss:2.0865 train_time:6269501ms step_avg:43.72ms -step:143600/500000 train_loss:2.0975 train_time:6278235ms step_avg:43.72ms -step:143800/500000 train_loss:2.0190 train_time:6286972ms step_avg:43.72ms -step:144000/500000 train_loss:2.0697 train_time:6295708ms step_avg:43.72ms -step:144200/500000 train_loss:1.9661 train_time:6304436ms step_avg:43.72ms -step:144400/500000 train_loss:2.0613 train_time:6313163ms step_avg:43.72ms -step:144600/500000 train_loss:2.2161 train_time:6321898ms step_avg:43.72ms -step:144800/500000 train_loss:2.0155 train_time:6330626ms step_avg:43.72ms -step:145000/500000 train_loss:2.0206 train_time:6339481ms step_avg:43.72ms -step:145200/500000 train_loss:2.0652 train_time:6348218ms step_avg:43.72ms -step:145400/500000 train_loss:2.0265 train_time:6356943ms step_avg:43.72ms -step:145600/500000 train_loss:2.0159 train_time:6365669ms step_avg:43.72ms -step:145800/500000 train_loss:1.8650 train_time:6374392ms step_avg:43.72ms -step:146000/500000 train_loss:2.0364 train_time:6383124ms step_avg:43.72ms -step:146200/500000 train_loss:2.0379 train_time:6391852ms step_avg:43.72ms -step:146400/500000 train_loss:1.9428 train_time:6400597ms step_avg:43.72ms -step:146600/500000 train_loss:1.9223 train_time:6409330ms step_avg:43.72ms -step:146800/500000 train_loss:2.0064 train_time:6418075ms step_avg:43.72ms -step:147000/500000 train_loss:2.0725 train_time:6426812ms step_avg:43.72ms -step:147200/500000 train_loss:1.9570 train_time:6435542ms step_avg:43.72ms -step:147400/500000 train_loss:2.0533 train_time:6444273ms step_avg:43.72ms -step:147600/500000 train_loss:2.0055 train_time:6453012ms step_avg:43.72ms -step:147800/500000 train_loss:2.0093 train_time:6461744ms step_avg:43.72ms -step:148000/500000 train_loss:1.9966 train_time:6470481ms step_avg:43.72ms -step:148200/500000 train_loss:1.9797 train_time:6479223ms step_avg:43.72ms -step:148400/500000 train_loss:2.0464 train_time:6487966ms step_avg:43.72ms -step:148600/500000 train_loss:1.9383 train_time:6496787ms step_avg:43.72ms -step:148800/500000 train_loss:2.1146 train_time:6505531ms step_avg:43.72ms -step:149000/500000 train_loss:2.0930 train_time:6514273ms step_avg:43.72ms -step:149200/500000 train_loss:2.0785 train_time:6523008ms step_avg:43.72ms -step:149400/500000 train_loss:2.0231 train_time:6531742ms step_avg:43.72ms -step:149600/500000 train_loss:2.2172 train_time:6540477ms step_avg:43.72ms -step:149800/500000 train_loss:2.0192 train_time:6549214ms step_avg:43.72ms -step:150000/500000 train_loss:2.0543 train_time:6557949ms step_avg:43.72ms -step:150200/500000 train_loss:1.9361 train_time:6566691ms step_avg:43.72ms -step:150400/500000 train_loss:2.0412 train_time:6575420ms step_avg:43.72ms -step:150600/500000 train_loss:1.9852 train_time:6584160ms step_avg:43.72ms -step:150800/500000 train_loss:2.0049 train_time:6592900ms step_avg:43.72ms -step:151000/500000 train_loss:2.0040 train_time:6601638ms step_avg:43.72ms -step:151200/500000 train_loss:2.1101 train_time:6610372ms step_avg:43.72ms -step:151400/500000 train_loss:2.0910 train_time:6619105ms step_avg:43.72ms -step:151600/500000 train_loss:2.0578 train_time:6627844ms step_avg:43.72ms -step:151800/500000 train_loss:2.0207 train_time:6636581ms step_avg:43.72ms -step:152000/500000 train_loss:2.0865 train_time:6645440ms step_avg:43.72ms -step:152200/500000 train_loss:2.0833 train_time:6654166ms step_avg:43.72ms -step:152400/500000 train_loss:2.1092 train_time:6662895ms step_avg:43.72ms -step:152600/500000 train_loss:1.9655 train_time:6671624ms step_avg:43.72ms -step:152800/500000 train_loss:1.9758 train_time:6680366ms step_avg:43.72ms -step:153000/500000 train_loss:2.2251 train_time:6689098ms step_avg:43.72ms -step:153200/500000 train_loss:2.0474 train_time:6697837ms step_avg:43.72ms -step:153400/500000 train_loss:1.9763 train_time:6706572ms step_avg:43.72ms -step:153600/500000 train_loss:1.8641 train_time:6715312ms step_avg:43.72ms -step:153800/500000 train_loss:2.0611 train_time:6724045ms step_avg:43.72ms -step:154000/500000 train_loss:2.0526 train_time:6732775ms step_avg:43.72ms -step:154200/500000 train_loss:2.1270 train_time:6741515ms step_avg:43.72ms -step:154400/500000 train_loss:1.9840 train_time:6750256ms step_avg:43.72ms -step:154600/500000 train_loss:2.0056 train_time:6758997ms step_avg:43.72ms -step:154800/500000 train_loss:2.0772 train_time:6767752ms step_avg:43.72ms -step:155000/500000 train_loss:2.1605 train_time:6776484ms step_avg:43.72ms -step:155200/500000 train_loss:1.9979 train_time:6785227ms step_avg:43.72ms -step:155400/500000 train_loss:2.0597 train_time:6793955ms step_avg:43.72ms -step:155600/500000 train_loss:1.9873 train_time:6802697ms step_avg:43.72ms -step:155800/500000 train_loss:1.9067 train_time:6811437ms step_avg:43.72ms -step:156000/500000 train_loss:2.0082 train_time:6820174ms step_avg:43.72ms -step:156200/500000 train_loss:1.8543 train_time:6829025ms step_avg:43.72ms -step:156400/500000 train_loss:2.0477 train_time:6837757ms step_avg:43.72ms -step:156600/500000 train_loss:2.0584 train_time:6846488ms step_avg:43.72ms -step:156800/500000 train_loss:2.0802 train_time:6855230ms step_avg:43.72ms -step:157000/500000 train_loss:2.0046 train_time:6863969ms step_avg:43.72ms -step:157200/500000 train_loss:2.0375 train_time:6872709ms step_avg:43.72ms -step:157400/500000 train_loss:2.1092 train_time:6881444ms step_avg:43.72ms -step:157600/500000 train_loss:1.9823 train_time:6890182ms step_avg:43.72ms -step:157800/500000 train_loss:2.0278 train_time:6898919ms step_avg:43.72ms -step:158000/500000 train_loss:2.0222 train_time:6907648ms step_avg:43.72ms -step:158200/500000 train_loss:1.9798 train_time:6916376ms step_avg:43.72ms -step:158400/500000 train_loss:1.9678 train_time:6925111ms step_avg:43.72ms -step:158600/500000 train_loss:2.0966 train_time:6933843ms step_avg:43.72ms -step:158800/500000 train_loss:1.9167 train_time:6942577ms step_avg:43.72ms -step:159000/500000 train_loss:1.9557 train_time:6951310ms step_avg:43.72ms -step:159200/500000 train_loss:1.9954 train_time:6960046ms step_avg:43.72ms -step:159400/500000 train_loss:1.9841 train_time:6968776ms step_avg:43.72ms -step:159600/500000 train_loss:2.0950 train_time:6977509ms step_avg:43.72ms -step:159800/500000 train_loss:2.0244 train_time:6986242ms step_avg:43.72ms -step:160000/500000 train_loss:1.9207 train_time:6994982ms step_avg:43.72ms -step:160000/500000 val_loss:2.0151 val_bpb:1.1935 train_time:6994998ms step_avg:43.72ms -step:160200/500000 train_loss:2.0018 train_time:7003716ms step_avg:43.72ms -step:160400/500000 train_loss:2.0475 train_time:7012579ms step_avg:43.72ms -step:160600/500000 train_loss:1.9469 train_time:7021324ms step_avg:43.72ms -step:160800/500000 train_loss:1.9277 train_time:7030059ms step_avg:43.72ms -step:161000/500000 train_loss:1.9321 train_time:7038802ms step_avg:43.72ms -step:161200/500000 train_loss:1.9661 train_time:7047535ms step_avg:43.72ms -step:161400/500000 train_loss:2.0205 train_time:7056270ms step_avg:43.72ms -step:161600/500000 train_loss:1.9691 train_time:7065010ms step_avg:43.72ms -step:161800/500000 train_loss:1.9580 train_time:7073749ms step_avg:43.72ms -step:162000/500000 train_loss:2.1239 train_time:7082482ms step_avg:43.72ms -step:162200/500000 train_loss:2.1069 train_time:7091216ms step_avg:43.72ms -step:162400/500000 train_loss:1.9907 train_time:7099955ms step_avg:43.72ms -step:162600/500000 train_loss:2.0223 train_time:7108693ms step_avg:43.72ms -step:162800/500000 train_loss:1.9429 train_time:7117432ms step_avg:43.72ms -step:163000/500000 train_loss:2.0619 train_time:7126171ms step_avg:43.72ms -step:163200/500000 train_loss:1.9096 train_time:7134910ms step_avg:43.72ms -step:163400/500000 train_loss:1.8932 train_time:7143647ms step_avg:43.72ms -step:163600/500000 train_loss:2.0096 train_time:7152383ms step_avg:43.72ms -step:163800/500000 train_loss:1.9822 train_time:7161125ms step_avg:43.72ms -step:164000/500000 train_loss:1.8641 train_time:7169864ms step_avg:43.72ms -step:164200/500000 train_loss:2.0342 train_time:7178599ms step_avg:43.72ms -step:164400/500000 train_loss:1.9387 train_time:7187462ms step_avg:43.72ms -step:164600/500000 train_loss:2.0189 train_time:7196201ms step_avg:43.72ms -step:164800/500000 train_loss:1.9743 train_time:7204932ms step_avg:43.72ms -step:165000/500000 train_loss:2.0910 train_time:7213663ms step_avg:43.72ms -step:165200/500000 train_loss:2.0131 train_time:7222396ms step_avg:43.72ms -step:165400/500000 train_loss:1.8053 train_time:7231130ms step_avg:43.72ms -step:165600/500000 train_loss:2.1171 train_time:7239864ms step_avg:43.72ms -step:165800/500000 train_loss:2.0570 train_time:7248591ms step_avg:43.72ms -step:166000/500000 train_loss:2.1153 train_time:7257330ms step_avg:43.72ms -step:166200/500000 train_loss:2.0100 train_time:7266063ms step_avg:43.72ms -step:166400/500000 train_loss:1.9842 train_time:7274794ms step_avg:43.72ms -step:166600/500000 train_loss:2.0044 train_time:7283530ms step_avg:43.72ms -step:166800/500000 train_loss:1.9264 train_time:7292273ms step_avg:43.72ms -step:167000/500000 train_loss:2.0284 train_time:7301000ms step_avg:43.72ms -step:167200/500000 train_loss:2.0714 train_time:7309735ms step_avg:43.72ms -step:167400/500000 train_loss:1.8683 train_time:7318470ms step_avg:43.72ms -step:167600/500000 train_loss:2.0481 train_time:7327205ms step_avg:43.72ms -step:167800/500000 train_loss:2.0047 train_time:7335945ms step_avg:43.72ms -step:168000/500000 train_loss:2.0068 train_time:7344682ms step_avg:43.72ms -step:168200/500000 train_loss:1.9203 train_time:7353418ms step_avg:43.72ms -step:168400/500000 train_loss:1.9711 train_time:7362156ms step_avg:43.72ms -step:168600/500000 train_loss:2.0499 train_time:7371013ms step_avg:43.72ms -step:168800/500000 train_loss:2.0193 train_time:7379753ms step_avg:43.72ms -step:169000/500000 train_loss:1.9500 train_time:7388495ms step_avg:43.72ms -step:169200/500000 train_loss:1.9424 train_time:7397235ms step_avg:43.72ms -step:169400/500000 train_loss:2.0937 train_time:7405974ms step_avg:43.72ms -step:169600/500000 train_loss:2.0015 train_time:7414717ms step_avg:43.72ms -step:169800/500000 train_loss:1.9852 train_time:7423454ms step_avg:43.72ms -step:170000/500000 train_loss:1.8737 train_time:7432193ms step_avg:43.72ms -step:170200/500000 train_loss:2.0264 train_time:7440931ms step_avg:43.72ms -step:170400/500000 train_loss:2.1094 train_time:7449668ms step_avg:43.72ms -step:170600/500000 train_loss:1.9665 train_time:7458399ms step_avg:43.72ms -step:170800/500000 train_loss:1.9582 train_time:7467133ms step_avg:43.72ms -step:171000/500000 train_loss:1.9925 train_time:7475896ms step_avg:43.72ms -step:171200/500000 train_loss:2.1872 train_time:7484633ms step_avg:43.72ms -step:171400/500000 train_loss:2.0731 train_time:7493380ms step_avg:43.72ms -step:171600/500000 train_loss:1.9710 train_time:7502128ms step_avg:43.72ms -step:171800/500000 train_loss:1.9976 train_time:7510868ms step_avg:43.72ms -step:172000/500000 train_loss:1.9109 train_time:7519611ms step_avg:43.72ms -step:172200/500000 train_loss:1.9541 train_time:7528366ms step_avg:43.72ms -step:172400/500000 train_loss:2.0850 train_time:7537121ms step_avg:43.72ms -step:172600/500000 train_loss:1.8652 train_time:7545995ms step_avg:43.72ms -step:172800/500000 train_loss:1.9614 train_time:7554747ms step_avg:43.72ms -step:173000/500000 train_loss:2.1891 train_time:7563490ms step_avg:43.72ms -step:173200/500000 train_loss:1.9574 train_time:7572235ms step_avg:43.72ms -step:173400/500000 train_loss:2.0368 train_time:7580959ms step_avg:43.72ms -step:173600/500000 train_loss:2.0617 train_time:7589684ms step_avg:43.72ms -step:173800/500000 train_loss:2.1104 train_time:7598407ms step_avg:43.72ms -step:174000/500000 train_loss:2.0506 train_time:7607136ms step_avg:43.72ms -step:174200/500000 train_loss:1.9299 train_time:7615871ms step_avg:43.72ms -step:174400/500000 train_loss:2.1104 train_time:7624607ms step_avg:43.72ms -step:174600/500000 train_loss:2.1805 train_time:7633333ms step_avg:43.72ms -step:174800/500000 train_loss:2.0631 train_time:7642073ms step_avg:43.72ms -step:175000/500000 train_loss:2.1022 train_time:7650819ms step_avg:43.72ms -step:175200/500000 train_loss:1.8115 train_time:7659557ms step_avg:43.72ms -step:175400/500000 train_loss:2.0692 train_time:7668291ms step_avg:43.72ms -step:175600/500000 train_loss:2.1445 train_time:7677029ms step_avg:43.72ms -step:175800/500000 train_loss:1.9551 train_time:7685758ms step_avg:43.72ms -step:176000/500000 train_loss:1.9635 train_time:7694493ms step_avg:43.72ms -step:176200/500000 train_loss:1.9418 train_time:7703234ms step_avg:43.72ms -step:176400/500000 train_loss:2.1112 train_time:7711975ms step_avg:43.72ms -step:176600/500000 train_loss:2.1513 train_time:7720710ms step_avg:43.72ms -step:176800/500000 train_loss:2.0037 train_time:7729571ms step_avg:43.72ms -step:177000/500000 train_loss:1.9129 train_time:7738307ms step_avg:43.72ms -step:177200/500000 train_loss:2.1441 train_time:7747046ms step_avg:43.72ms -step:177400/500000 train_loss:1.9549 train_time:7755780ms step_avg:43.72ms -step:177600/500000 train_loss:1.9830 train_time:7764513ms step_avg:43.72ms -step:177800/500000 train_loss:1.9886 train_time:7773254ms step_avg:43.72ms -step:178000/500000 train_loss:2.1259 train_time:7781991ms step_avg:43.72ms -step:178200/500000 train_loss:1.9354 train_time:7790727ms step_avg:43.72ms -step:178400/500000 train_loss:2.1140 train_time:7799471ms step_avg:43.72ms -step:178600/500000 train_loss:1.8590 train_time:7808207ms step_avg:43.72ms -step:178800/500000 train_loss:2.0454 train_time:7816945ms step_avg:43.72ms -step:179000/500000 train_loss:1.9407 train_time:7825676ms step_avg:43.72ms -step:179200/500000 train_loss:1.9950 train_time:7834409ms step_avg:43.72ms -step:179400/500000 train_loss:1.9914 train_time:7843141ms step_avg:43.72ms -step:179600/500000 train_loss:2.0901 train_time:7851877ms step_avg:43.72ms -step:179800/500000 train_loss:1.9138 train_time:7860608ms step_avg:43.72ms -step:180000/500000 train_loss:2.0217 train_time:7869340ms step_avg:43.72ms -step:180000/500000 val_loss:2.0117 val_bpb:1.1914 train_time:7869356ms step_avg:43.72ms -step:180200/500000 train_loss:2.0164 train_time:7878077ms step_avg:43.72ms -step:180400/500000 train_loss:2.0996 train_time:7886806ms step_avg:43.72ms -step:180600/500000 train_loss:2.1237 train_time:7895545ms step_avg:43.72ms -step:180800/500000 train_loss:2.0568 train_time:7904280ms step_avg:43.72ms -step:181000/500000 train_loss:2.0839 train_time:7913142ms step_avg:43.72ms -step:181200/500000 train_loss:1.9792 train_time:7921880ms step_avg:43.72ms -step:181400/500000 train_loss:1.8169 train_time:7930628ms step_avg:43.72ms -step:181600/500000 train_loss:2.0093 train_time:7939374ms step_avg:43.72ms -step:181800/500000 train_loss:1.8723 train_time:7948111ms step_avg:43.72ms -step:182000/500000 train_loss:2.0589 train_time:7956844ms step_avg:43.72ms -step:182200/500000 train_loss:2.2871 train_time:7965596ms step_avg:43.72ms -step:182400/500000 train_loss:1.9727 train_time:7974326ms step_avg:43.72ms -step:182600/500000 train_loss:2.1237 train_time:7983068ms step_avg:43.72ms -step:182800/500000 train_loss:1.9695 train_time:7991805ms step_avg:43.72ms -step:183000/500000 train_loss:1.9939 train_time:8000541ms step_avg:43.72ms -step:183200/500000 train_loss:2.1586 train_time:8009273ms step_avg:43.72ms -step:183400/500000 train_loss:1.9795 train_time:8018006ms step_avg:43.72ms -step:183600/500000 train_loss:2.0676 train_time:8026744ms step_avg:43.72ms -step:183800/500000 train_loss:1.9571 train_time:8035482ms step_avg:43.72ms -step:184000/500000 train_loss:2.0105 train_time:8044212ms step_avg:43.72ms -step:184200/500000 train_loss:2.0779 train_time:8052949ms step_avg:43.72ms -step:184400/500000 train_loss:2.0069 train_time:8061687ms step_avg:43.72ms -step:184600/500000 train_loss:2.0897 train_time:8070420ms step_avg:43.72ms -step:184800/500000 train_loss:1.9466 train_time:8079148ms step_avg:43.72ms -step:185000/500000 train_loss:2.0431 train_time:8088000ms step_avg:43.72ms -step:185200/500000 train_loss:2.0700 train_time:8096735ms step_avg:43.72ms -step:185400/500000 train_loss:1.8829 train_time:8105470ms step_avg:43.72ms -step:185600/500000 train_loss:2.0099 train_time:8114181ms step_avg:43.72ms -step:185800/500000 train_loss:1.9734 train_time:8122913ms step_avg:43.72ms -step:186000/500000 train_loss:2.0107 train_time:8131657ms step_avg:43.72ms -step:186200/500000 train_loss:1.8863 train_time:8140408ms step_avg:43.72ms -step:186400/500000 train_loss:1.8992 train_time:8149145ms step_avg:43.72ms -step:186600/500000 train_loss:2.0194 train_time:8157881ms step_avg:43.72ms -step:186800/500000 train_loss:2.0629 train_time:8166619ms step_avg:43.72ms -step:187000/500000 train_loss:2.0338 train_time:8175361ms step_avg:43.72ms -step:187200/500000 train_loss:2.0185 train_time:8184101ms step_avg:43.72ms -step:187400/500000 train_loss:2.1287 train_time:8192842ms step_avg:43.72ms -step:187600/500000 train_loss:1.9641 train_time:8201582ms step_avg:43.72ms -step:187800/500000 train_loss:1.8271 train_time:8210327ms step_avg:43.72ms -step:188000/500000 train_loss:1.8761 train_time:8219195ms step_avg:43.72ms -step:188200/500000 train_loss:2.0407 train_time:8227939ms step_avg:43.72ms -step:188400/500000 train_loss:1.9929 train_time:8236687ms step_avg:43.72ms -step:188600/500000 train_loss:2.1855 train_time:8245431ms step_avg:43.72ms -step:188800/500000 train_loss:1.9462 train_time:8254187ms step_avg:43.72ms -step:189000/500000 train_loss:2.0223 train_time:8262930ms step_avg:43.72ms -step:189200/500000 train_loss:2.0223 train_time:8271678ms step_avg:43.72ms -step:189400/500000 train_loss:1.9818 train_time:8280415ms step_avg:43.72ms -step:189600/500000 train_loss:1.9641 train_time:8289171ms step_avg:43.72ms -step:189800/500000 train_loss:2.0795 train_time:8297919ms step_avg:43.72ms -step:190000/500000 train_loss:1.9859 train_time:8306658ms step_avg:43.72ms -step:190200/500000 train_loss:2.0013 train_time:8315407ms step_avg:43.72ms -step:190400/500000 train_loss:1.8687 train_time:8324156ms step_avg:43.72ms -step:190600/500000 train_loss:2.0704 train_time:8332901ms step_avg:43.72ms -step:190800/500000 train_loss:2.0072 train_time:8341640ms step_avg:43.72ms -step:191000/500000 train_loss:1.9459 train_time:8350385ms step_avg:43.72ms -step:191200/500000 train_loss:2.0498 train_time:8359126ms step_avg:43.72ms -step:191400/500000 train_loss:2.0077 train_time:8367862ms step_avg:43.72ms -step:191600/500000 train_loss:2.0151 train_time:8376604ms step_avg:43.72ms -step:191800/500000 train_loss:2.0355 train_time:8385354ms step_avg:43.72ms -step:192000/500000 train_loss:1.9724 train_time:8394096ms step_avg:43.72ms -step:192200/500000 train_loss:2.0984 train_time:8402951ms step_avg:43.72ms -step:192400/500000 train_loss:1.8789 train_time:8411689ms step_avg:43.72ms -step:192600/500000 train_loss:1.8561 train_time:8420432ms step_avg:43.72ms -step:192800/500000 train_loss:2.0320 train_time:8429163ms step_avg:43.72ms -step:193000/500000 train_loss:2.1582 train_time:8437899ms step_avg:43.72ms -step:193200/500000 train_loss:1.8950 train_time:8446638ms step_avg:43.72ms -step:193400/500000 train_loss:1.9427 train_time:8455369ms step_avg:43.72ms -step:193600/500000 train_loss:2.0281 train_time:8464117ms step_avg:43.72ms -step:193800/500000 train_loss:2.1803 train_time:8472865ms step_avg:43.72ms -step:194000/500000 train_loss:2.0341 train_time:8481600ms step_avg:43.72ms -step:194200/500000 train_loss:2.0818 train_time:8490324ms step_avg:43.72ms -step:194400/500000 train_loss:2.0853 train_time:8499050ms step_avg:43.72ms -step:194600/500000 train_loss:2.1322 train_time:8507769ms step_avg:43.72ms -step:194800/500000 train_loss:1.9417 train_time:8516499ms step_avg:43.72ms -step:195000/500000 train_loss:1.9583 train_time:8525228ms step_avg:43.72ms -step:195200/500000 train_loss:2.1288 train_time:8533957ms step_avg:43.72ms -step:195400/500000 train_loss:1.8735 train_time:8542684ms step_avg:43.72ms -step:195600/500000 train_loss:1.9380 train_time:8551417ms step_avg:43.72ms -step:195800/500000 train_loss:2.0153 train_time:8560153ms step_avg:43.72ms -step:196000/500000 train_loss:2.0793 train_time:8568904ms step_avg:43.72ms -step:196200/500000 train_loss:1.9494 train_time:8577766ms step_avg:43.72ms -step:196400/500000 train_loss:2.0027 train_time:8586500ms step_avg:43.72ms -step:196600/500000 train_loss:1.8933 train_time:8595250ms step_avg:43.72ms -step:196800/500000 train_loss:2.1100 train_time:8603995ms step_avg:43.72ms -step:197000/500000 train_loss:1.9367 train_time:8612745ms step_avg:43.72ms -step:197200/500000 train_loss:2.0373 train_time:8621490ms step_avg:43.72ms -step:197400/500000 train_loss:1.9918 train_time:8630234ms step_avg:43.72ms -step:197600/500000 train_loss:1.9524 train_time:8638976ms step_avg:43.72ms -step:197800/500000 train_loss:2.0744 train_time:8647712ms step_avg:43.72ms -step:198000/500000 train_loss:2.0879 train_time:8656445ms step_avg:43.72ms -step:198200/500000 train_loss:1.8853 train_time:8665173ms step_avg:43.72ms -step:198400/500000 train_loss:1.9999 train_time:8673907ms step_avg:43.72ms -step:198600/500000 train_loss:2.0703 train_time:8682644ms step_avg:43.72ms -step:198800/500000 train_loss:1.9407 train_time:8691376ms step_avg:43.72ms -step:199000/500000 train_loss:1.9825 train_time:8700102ms step_avg:43.72ms -step:199200/500000 train_loss:2.2138 train_time:8708839ms step_avg:43.72ms -step:199400/500000 train_loss:2.0556 train_time:8717569ms step_avg:43.72ms -step:199600/500000 train_loss:1.8311 train_time:8726291ms step_avg:43.72ms -step:199800/500000 train_loss:1.8928 train_time:8735019ms step_avg:43.72ms -step:200000/500000 train_loss:1.9818 train_time:8743744ms step_avg:43.72ms -step:200000/500000 val_loss:2.0105 val_bpb:1.1907 train_time:8743760ms step_avg:43.72ms -step:200200/500000 train_loss:1.9868 train_time:8752474ms step_avg:43.72ms -step:200400/500000 train_loss:2.1167 train_time:8761320ms step_avg:43.72ms -step:200600/500000 train_loss:1.9449 train_time:8770043ms step_avg:43.72ms -step:200800/500000 train_loss:2.0322 train_time:8778769ms step_avg:43.72ms -step:201000/500000 train_loss:2.0473 train_time:8787500ms step_avg:43.72ms -step:201200/500000 train_loss:1.9667 train_time:8796227ms step_avg:43.72ms -step:201400/500000 train_loss:2.1328 train_time:8804954ms step_avg:43.72ms -step:201600/500000 train_loss:2.0573 train_time:8813695ms step_avg:43.72ms -step:201800/500000 train_loss:1.9931 train_time:8822429ms step_avg:43.72ms -step:202000/500000 train_loss:2.1286 train_time:8831155ms step_avg:43.72ms -step:202200/500000 train_loss:2.0619 train_time:8839888ms step_avg:43.72ms -step:202400/500000 train_loss:2.1035 train_time:8848616ms step_avg:43.72ms -step:202600/500000 train_loss:2.0375 train_time:8857346ms step_avg:43.72ms -step:202800/500000 train_loss:1.9938 train_time:8866075ms step_avg:43.72ms -step:203000/500000 train_loss:1.8688 train_time:8874808ms step_avg:43.72ms -step:203200/500000 train_loss:1.9930 train_time:8883537ms step_avg:43.72ms -step:203400/500000 train_loss:1.9899 train_time:8892264ms step_avg:43.72ms -step:203600/500000 train_loss:1.7814 train_time:8901000ms step_avg:43.72ms -step:203800/500000 train_loss:1.9964 train_time:8909742ms step_avg:43.72ms -step:204000/500000 train_loss:2.0461 train_time:8918478ms step_avg:43.72ms -step:204200/500000 train_loss:1.9670 train_time:8927209ms step_avg:43.72ms -step:204400/500000 train_loss:2.0828 train_time:8936045ms step_avg:43.72ms -step:204600/500000 train_loss:2.0625 train_time:8944774ms step_avg:43.72ms -step:204800/500000 train_loss:1.9721 train_time:8953505ms step_avg:43.72ms -step:205000/500000 train_loss:1.8662 train_time:8962232ms step_avg:43.72ms -step:205200/500000 train_loss:1.9807 train_time:8970959ms step_avg:43.72ms -step:205400/500000 train_loss:2.0544 train_time:8979687ms step_avg:43.72ms -step:205600/500000 train_loss:1.9625 train_time:8988423ms step_avg:43.72ms -step:205800/500000 train_loss:2.0482 train_time:8997154ms step_avg:43.72ms -step:206000/500000 train_loss:1.9984 train_time:9005880ms step_avg:43.72ms -step:206200/500000 train_loss:1.9716 train_time:9014613ms step_avg:43.72ms -step:206400/500000 train_loss:2.0728 train_time:9023343ms step_avg:43.72ms -step:206600/500000 train_loss:1.9063 train_time:9032073ms step_avg:43.72ms -step:206800/500000 train_loss:1.9478 train_time:9040806ms step_avg:43.72ms -step:207000/500000 train_loss:2.0179 train_time:9049527ms step_avg:43.72ms -step:207200/500000 train_loss:2.0109 train_time:9058254ms step_avg:43.72ms -step:207400/500000 train_loss:2.0661 train_time:9066986ms step_avg:43.72ms -step:207600/500000 train_loss:1.8557 train_time:9075723ms step_avg:43.72ms -step:207800/500000 train_loss:1.8529 train_time:9084456ms step_avg:43.72ms -step:208000/500000 train_loss:2.0855 train_time:9093188ms step_avg:43.72ms -step:208200/500000 train_loss:2.0335 train_time:9101921ms step_avg:43.72ms -step:208400/500000 train_loss:1.8988 train_time:9110652ms step_avg:43.72ms -step:208600/500000 train_loss:1.9468 train_time:9119502ms step_avg:43.72ms -step:208800/500000 train_loss:1.9999 train_time:9128232ms step_avg:43.72ms -step:209000/500000 train_loss:2.0396 train_time:9136961ms step_avg:43.72ms -step:209200/500000 train_loss:2.0371 train_time:9145693ms step_avg:43.72ms -step:209400/500000 train_loss:2.1630 train_time:9154422ms step_avg:43.72ms -step:209600/500000 train_loss:1.9490 train_time:9163147ms step_avg:43.72ms -step:209800/500000 train_loss:1.9805 train_time:9171877ms step_avg:43.72ms -step:210000/500000 train_loss:2.0601 train_time:9180605ms step_avg:43.72ms -step:210200/500000 train_loss:2.0570 train_time:9189333ms step_avg:43.72ms -step:210400/500000 train_loss:2.0318 train_time:9198063ms step_avg:43.72ms -step:210600/500000 train_loss:2.1689 train_time:9206795ms step_avg:43.72ms -step:210800/500000 train_loss:1.9921 train_time:9215524ms step_avg:43.72ms -step:211000/500000 train_loss:2.0106 train_time:9224247ms step_avg:43.72ms -step:211200/500000 train_loss:2.0720 train_time:9232976ms step_avg:43.72ms -step:211400/500000 train_loss:2.1134 train_time:9241710ms step_avg:43.72ms -step:211600/500000 train_loss:2.1559 train_time:9250441ms step_avg:43.72ms -step:211800/500000 train_loss:2.0516 train_time:9259174ms step_avg:43.72ms -step:212000/500000 train_loss:2.0935 train_time:9267899ms step_avg:43.72ms -step:212200/500000 train_loss:2.0999 train_time:9276632ms step_avg:43.72ms -step:212400/500000 train_loss:2.0021 train_time:9285357ms step_avg:43.72ms -step:212600/500000 train_loss:2.0492 train_time:9294090ms step_avg:43.72ms -step:212800/500000 train_loss:2.0105 train_time:9302943ms step_avg:43.72ms -step:213000/500000 train_loss:1.9174 train_time:9311671ms step_avg:43.72ms -step:213200/500000 train_loss:2.1274 train_time:9320401ms step_avg:43.72ms -step:213400/500000 train_loss:1.9693 train_time:9329132ms step_avg:43.72ms -step:213600/500000 train_loss:1.8956 train_time:9337862ms step_avg:43.72ms -step:213800/500000 train_loss:2.0412 train_time:9346591ms step_avg:43.72ms -step:214000/500000 train_loss:2.0581 train_time:9355323ms step_avg:43.72ms -step:214200/500000 train_loss:2.0161 train_time:9364052ms step_avg:43.72ms -step:214400/500000 train_loss:2.0053 train_time:9372785ms step_avg:43.72ms -step:214600/500000 train_loss:2.0402 train_time:9381520ms step_avg:43.72ms -step:214800/500000 train_loss:1.9926 train_time:9390259ms step_avg:43.72ms -step:215000/500000 train_loss:2.0061 train_time:9398986ms step_avg:43.72ms -step:215200/500000 train_loss:1.9774 train_time:9407719ms step_avg:43.72ms -step:215400/500000 train_loss:2.0146 train_time:9416449ms step_avg:43.72ms -step:215600/500000 train_loss:2.0522 train_time:9425178ms step_avg:43.72ms -step:215800/500000 train_loss:1.9747 train_time:9433914ms step_avg:43.72ms -step:216000/500000 train_loss:2.0128 train_time:9442642ms step_avg:43.72ms -step:216200/500000 train_loss:2.0113 train_time:9451366ms step_avg:43.72ms -step:216400/500000 train_loss:1.8583 train_time:9460100ms step_avg:43.72ms -step:216600/500000 train_loss:2.0709 train_time:9468838ms step_avg:43.72ms -step:216800/500000 train_loss:2.0007 train_time:9477686ms step_avg:43.72ms -step:217000/500000 train_loss:2.0354 train_time:9486424ms step_avg:43.72ms -step:217200/500000 train_loss:2.0861 train_time:9495155ms step_avg:43.72ms -step:217400/500000 train_loss:1.9562 train_time:9503889ms step_avg:43.72ms -step:217600/500000 train_loss:2.0061 train_time:9512625ms step_avg:43.72ms -step:217800/500000 train_loss:1.9055 train_time:9521360ms step_avg:43.72ms -step:218000/500000 train_loss:2.0533 train_time:9530097ms step_avg:43.72ms -step:218200/500000 train_loss:2.0140 train_time:9538834ms step_avg:43.72ms -step:218400/500000 train_loss:2.0301 train_time:9547568ms step_avg:43.72ms -step:218600/500000 train_loss:2.1542 train_time:9556315ms step_avg:43.72ms -step:218800/500000 train_loss:2.0486 train_time:9565045ms step_avg:43.72ms -step:219000/500000 train_loss:1.9843 train_time:9573785ms step_avg:43.72ms -step:219200/500000 train_loss:1.9709 train_time:9582517ms step_avg:43.72ms -step:219400/500000 train_loss:2.0735 train_time:9591258ms step_avg:43.72ms -step:219600/500000 train_loss:1.9672 train_time:9600003ms step_avg:43.72ms -step:219800/500000 train_loss:1.9823 train_time:9608744ms step_avg:43.72ms -step:220000/500000 train_loss:1.9377 train_time:9617487ms step_avg:43.72ms -step:220000/500000 val_loss:2.0076 val_bpb:1.1890 train_time:9617503ms step_avg:43.72ms -step:220200/500000 train_loss:2.0217 train_time:9626225ms step_avg:43.72ms -step:220400/500000 train_loss:2.0827 train_time:9634959ms step_avg:43.72ms -step:220600/500000 train_loss:2.0294 train_time:9643691ms step_avg:43.72ms -step:220800/500000 train_loss:2.1254 train_time:9652438ms step_avg:43.72ms -step:221000/500000 train_loss:2.1577 train_time:9661294ms step_avg:43.72ms -step:221200/500000 train_loss:2.0463 train_time:9670034ms step_avg:43.72ms -step:221400/500000 train_loss:2.1585 train_time:9678765ms step_avg:43.72ms -step:221600/500000 train_loss:1.9449 train_time:9687499ms step_avg:43.72ms -step:221800/500000 train_loss:2.1088 train_time:9696229ms step_avg:43.72ms -step:222000/500000 train_loss:1.9759 train_time:9704959ms step_avg:43.72ms -step:222200/500000 train_loss:1.9192 train_time:9713682ms step_avg:43.72ms -step:222400/500000 train_loss:1.9609 train_time:9722414ms step_avg:43.72ms -step:222600/500000 train_loss:1.9626 train_time:9731139ms step_avg:43.72ms -step:222800/500000 train_loss:1.9320 train_time:9739836ms step_avg:43.72ms -step:223000/500000 train_loss:2.0679 train_time:9748562ms step_avg:43.72ms -step:223200/500000 train_loss:1.9997 train_time:9757297ms step_avg:43.72ms -step:223400/500000 train_loss:1.8773 train_time:9766028ms step_avg:43.72ms -step:223600/500000 train_loss:2.0217 train_time:9774760ms step_avg:43.72ms -step:223800/500000 train_loss:2.1673 train_time:9783485ms step_avg:43.72ms -step:224000/500000 train_loss:1.9545 train_time:9792341ms step_avg:43.72ms -step:224200/500000 train_loss:1.9287 train_time:9801075ms step_avg:43.72ms -step:224400/500000 train_loss:1.9226 train_time:9809802ms step_avg:43.72ms -step:224600/500000 train_loss:2.0093 train_time:9818539ms step_avg:43.72ms -step:224800/500000 train_loss:2.0917 train_time:9827270ms step_avg:43.72ms -step:225000/500000 train_loss:1.8178 train_time:9836011ms step_avg:43.72ms -step:225200/500000 train_loss:1.9130 train_time:9844760ms step_avg:43.72ms -step:225400/500000 train_loss:2.0751 train_time:9853506ms step_avg:43.72ms -step:225600/500000 train_loss:1.9437 train_time:9862235ms step_avg:43.72ms -step:225800/500000 train_loss:2.0723 train_time:9870970ms step_avg:43.72ms -step:226000/500000 train_loss:1.9647 train_time:9879707ms step_avg:43.72ms -step:226200/500000 train_loss:2.0678 train_time:9888446ms step_avg:43.72ms -step:226400/500000 train_loss:2.0256 train_time:9897181ms step_avg:43.72ms -step:226600/500000 train_loss:2.0307 train_time:9905921ms step_avg:43.72ms -step:226800/500000 train_loss:2.1227 train_time:9914660ms step_avg:43.72ms -step:227000/500000 train_loss:2.0353 train_time:9923391ms step_avg:43.72ms -step:227200/500000 train_loss:1.9588 train_time:9932120ms step_avg:43.72ms -step:227400/500000 train_loss:2.0887 train_time:9940856ms step_avg:43.72ms -step:227600/500000 train_loss:1.9178 train_time:9949585ms step_avg:43.72ms -step:227800/500000 train_loss:2.0666 train_time:9958313ms step_avg:43.72ms -step:228000/500000 train_loss:1.9225 train_time:9967159ms step_avg:43.72ms -step:228200/500000 train_loss:1.9397 train_time:9975883ms step_avg:43.72ms -step:228400/500000 train_loss:1.9910 train_time:9984618ms step_avg:43.72ms -step:228600/500000 train_loss:1.9247 train_time:9993362ms step_avg:43.72ms -step:228800/500000 train_loss:1.8871 train_time:10002097ms step_avg:43.72ms -step:229000/500000 train_loss:2.0247 train_time:10010830ms step_avg:43.72ms -step:229200/500000 train_loss:1.9405 train_time:10019566ms step_avg:43.72ms -step:229400/500000 train_loss:2.0672 train_time:10028294ms step_avg:43.72ms -step:229600/500000 train_loss:2.0012 train_time:10037022ms step_avg:43.72ms -step:229800/500000 train_loss:2.0390 train_time:10045754ms step_avg:43.72ms -step:230000/500000 train_loss:1.9441 train_time:10054490ms step_avg:43.72ms -step:230200/500000 train_loss:1.9243 train_time:10063234ms step_avg:43.72ms -step:230400/500000 train_loss:2.0568 train_time:10071966ms step_avg:43.72ms -step:230600/500000 train_loss:1.7721 train_time:10080697ms step_avg:43.72ms -step:230800/500000 train_loss:1.9805 train_time:10089428ms step_avg:43.72ms -step:231000/500000 train_loss:1.8422 train_time:10098161ms step_avg:43.71ms -step:231200/500000 train_loss:2.0484 train_time:10106900ms step_avg:43.71ms -step:231400/500000 train_loss:2.0169 train_time:10115633ms step_avg:43.71ms -step:231600/500000 train_loss:2.0759 train_time:10124365ms step_avg:43.71ms -step:231800/500000 train_loss:1.9531 train_time:10133092ms step_avg:43.71ms -step:232000/500000 train_loss:2.0485 train_time:10141817ms step_avg:43.71ms -step:232200/500000 train_loss:1.9279 train_time:10150679ms step_avg:43.72ms -step:232400/500000 train_loss:2.1539 train_time:10159418ms step_avg:43.72ms -step:232600/500000 train_loss:2.0334 train_time:10168164ms step_avg:43.72ms -step:232800/500000 train_loss:1.7609 train_time:10176898ms step_avg:43.72ms -step:233000/500000 train_loss:2.0500 train_time:10185626ms step_avg:43.72ms -step:233200/500000 train_loss:1.9637 train_time:10194360ms step_avg:43.72ms -step:233400/500000 train_loss:2.0705 train_time:10203093ms step_avg:43.72ms -step:233600/500000 train_loss:2.0482 train_time:10211826ms step_avg:43.72ms -step:233800/500000 train_loss:2.1037 train_time:10220562ms step_avg:43.71ms -step:234000/500000 train_loss:1.9615 train_time:10229297ms step_avg:43.71ms -step:234200/500000 train_loss:1.9095 train_time:10238033ms step_avg:43.71ms -step:234400/500000 train_loss:1.9257 train_time:10246759ms step_avg:43.71ms -step:234600/500000 train_loss:2.2710 train_time:10255487ms step_avg:43.71ms -step:234800/500000 train_loss:1.9472 train_time:10264212ms step_avg:43.71ms -step:235000/500000 train_loss:2.0240 train_time:10272947ms step_avg:43.71ms -step:235200/500000 train_loss:2.0640 train_time:10281676ms step_avg:43.71ms -step:235400/500000 train_loss:1.8869 train_time:10290411ms step_avg:43.71ms -step:235600/500000 train_loss:1.9445 train_time:10299144ms step_avg:43.71ms -step:235800/500000 train_loss:2.0251 train_time:10307873ms step_avg:43.71ms -step:236000/500000 train_loss:1.9722 train_time:10316597ms step_avg:43.71ms -step:236200/500000 train_loss:2.0008 train_time:10325324ms step_avg:43.71ms -step:236400/500000 train_loss:2.1860 train_time:10334172ms step_avg:43.71ms -step:236600/500000 train_loss:1.9035 train_time:10342897ms step_avg:43.71ms -step:236800/500000 train_loss:1.9173 train_time:10351628ms step_avg:43.71ms -step:237000/500000 train_loss:2.0712 train_time:10360356ms step_avg:43.71ms -step:237200/500000 train_loss:1.9455 train_time:10369078ms step_avg:43.71ms -step:237400/500000 train_loss:2.0799 train_time:10377812ms step_avg:43.71ms -step:237600/500000 train_loss:1.9462 train_time:10386543ms step_avg:43.71ms -step:237800/500000 train_loss:2.0383 train_time:10395286ms step_avg:43.71ms -step:238000/500000 train_loss:1.9322 train_time:10404027ms step_avg:43.71ms -step:238200/500000 train_loss:1.9640 train_time:10412765ms step_avg:43.71ms -step:238400/500000 train_loss:2.0001 train_time:10421496ms step_avg:43.71ms -step:238600/500000 train_loss:2.1391 train_time:10430231ms step_avg:43.71ms -step:238800/500000 train_loss:1.9923 train_time:10438966ms step_avg:43.71ms -step:239000/500000 train_loss:2.0646 train_time:10447711ms step_avg:43.71ms -step:239200/500000 train_loss:1.8668 train_time:10456446ms step_avg:43.71ms -step:239400/500000 train_loss:2.4861 train_time:10465183ms step_avg:43.71ms -step:239600/500000 train_loss:1.8213 train_time:10473921ms step_avg:43.71ms -step:239800/500000 train_loss:2.0816 train_time:10482658ms step_avg:43.71ms -step:240000/500000 train_loss:1.8781 train_time:10491396ms step_avg:43.71ms -step:240000/500000 val_loss:2.0063 val_bpb:1.1883 train_time:10491412ms step_avg:43.71ms -step:240200/500000 train_loss:1.9510 train_time:10500136ms step_avg:43.71ms -step:240400/500000 train_loss:1.9606 train_time:10508991ms step_avg:43.71ms -step:240600/500000 train_loss:1.9815 train_time:10517721ms step_avg:43.71ms -step:240800/500000 train_loss:1.9837 train_time:10526454ms step_avg:43.71ms -step:241000/500000 train_loss:2.0349 train_time:10535193ms step_avg:43.71ms -step:241200/500000 train_loss:1.8801 train_time:10543928ms step_avg:43.71ms -step:241400/500000 train_loss:1.9354 train_time:10552661ms step_avg:43.71ms -step:241600/500000 train_loss:1.8986 train_time:10561395ms step_avg:43.71ms -step:241800/500000 train_loss:2.1174 train_time:10570128ms step_avg:43.71ms -step:242000/500000 train_loss:1.9755 train_time:10578858ms step_avg:43.71ms -step:242200/500000 train_loss:1.9471 train_time:10587594ms step_avg:43.71ms -step:242400/500000 train_loss:2.0002 train_time:10596330ms step_avg:43.71ms -step:242600/500000 train_loss:1.9670 train_time:10605064ms step_avg:43.71ms -step:242800/500000 train_loss:1.8907 train_time:10613816ms step_avg:43.71ms -step:243000/500000 train_loss:2.0368 train_time:10622560ms step_avg:43.71ms -step:243200/500000 train_loss:2.0311 train_time:10631301ms step_avg:43.71ms -step:243400/500000 train_loss:1.8762 train_time:10640036ms step_avg:43.71ms -step:243600/500000 train_loss:1.9528 train_time:10648777ms step_avg:43.71ms -step:243800/500000 train_loss:2.0887 train_time:10657512ms step_avg:43.71ms -step:244000/500000 train_loss:2.0231 train_time:10666246ms step_avg:43.71ms -step:244200/500000 train_loss:2.0592 train_time:10674982ms step_avg:43.71ms -step:244400/500000 train_loss:2.0455 train_time:10683715ms step_avg:43.71ms -step:244600/500000 train_loss:1.9829 train_time:10692567ms step_avg:43.71ms -step:244800/500000 train_loss:1.8873 train_time:10701300ms step_avg:43.71ms -step:245000/500000 train_loss:1.9246 train_time:10710035ms step_avg:43.71ms -step:245200/500000 train_loss:1.9532 train_time:10718767ms step_avg:43.71ms -step:245400/500000 train_loss:2.0863 train_time:10727502ms step_avg:43.71ms -step:245600/500000 train_loss:2.0128 train_time:10736234ms step_avg:43.71ms -step:245800/500000 train_loss:2.0420 train_time:10744968ms step_avg:43.71ms -step:246000/500000 train_loss:1.9984 train_time:10753707ms step_avg:43.71ms -step:246200/500000 train_loss:1.9309 train_time:10762446ms step_avg:43.71ms -step:246400/500000 train_loss:2.0014 train_time:10771184ms step_avg:43.71ms -step:246600/500000 train_loss:1.9898 train_time:10779923ms step_avg:43.71ms -step:246800/500000 train_loss:1.9892 train_time:10788658ms step_avg:43.71ms -step:247000/500000 train_loss:2.0507 train_time:10797397ms step_avg:43.71ms -step:247200/500000 train_loss:1.8842 train_time:10806137ms step_avg:43.71ms -step:247400/500000 train_loss:1.9936 train_time:10814874ms step_avg:43.71ms -step:247600/500000 train_loss:2.0653 train_time:10823607ms step_avg:43.71ms -step:247800/500000 train_loss:2.0729 train_time:10832353ms step_avg:43.71ms -step:248000/500000 train_loss:1.9083 train_time:10841094ms step_avg:43.71ms -step:248200/500000 train_loss:1.9371 train_time:10849830ms step_avg:43.71ms -step:248400/500000 train_loss:1.9112 train_time:10858566ms step_avg:43.71ms -step:248600/500000 train_loss:2.0996 train_time:10867428ms step_avg:43.71ms -step:248800/500000 train_loss:1.9620 train_time:10876162ms step_avg:43.71ms -step:249000/500000 train_loss:2.0224 train_time:10884895ms step_avg:43.71ms -step:249200/500000 train_loss:1.9578 train_time:10893631ms step_avg:43.71ms -step:249400/500000 train_loss:1.8937 train_time:10902369ms step_avg:43.71ms -step:249600/500000 train_loss:1.8381 train_time:10911107ms step_avg:43.71ms -step:249800/500000 train_loss:2.0133 train_time:10919845ms step_avg:43.71ms -step:250000/500000 train_loss:2.0581 train_time:10928578ms step_avg:43.71ms -step:250200/500000 train_loss:1.9867 train_time:10937312ms step_avg:43.71ms -step:250400/500000 train_loss:1.9796 train_time:10946038ms step_avg:43.71ms -step:250600/500000 train_loss:2.0323 train_time:10954770ms step_avg:43.71ms -step:250800/500000 train_loss:1.8523 train_time:10963508ms step_avg:43.71ms -step:251000/500000 train_loss:1.9193 train_time:10972245ms step_avg:43.71ms -step:251200/500000 train_loss:1.9896 train_time:10980982ms step_avg:43.71ms -step:251400/500000 train_loss:1.9091 train_time:10989715ms step_avg:43.71ms -step:251600/500000 train_loss:1.9960 train_time:10998454ms step_avg:43.71ms -step:251800/500000 train_loss:2.0287 train_time:11007188ms step_avg:43.71ms -step:252000/500000 train_loss:1.9612 train_time:11015923ms step_avg:43.71ms -step:252200/500000 train_loss:2.0508 train_time:11024659ms step_avg:43.71ms -step:252400/500000 train_loss:1.9624 train_time:11033389ms step_avg:43.71ms -step:252600/500000 train_loss:1.9766 train_time:11042125ms step_avg:43.71ms -step:252800/500000 train_loss:2.0657 train_time:11050997ms step_avg:43.71ms -step:253000/500000 train_loss:1.9535 train_time:11059732ms step_avg:43.71ms -step:253200/500000 train_loss:2.0059 train_time:11068474ms step_avg:43.71ms -step:253400/500000 train_loss:1.9705 train_time:11077211ms step_avg:43.71ms -step:253600/500000 train_loss:2.0051 train_time:11085946ms step_avg:43.71ms -step:253800/500000 train_loss:1.8818 train_time:11094681ms step_avg:43.71ms -step:254000/500000 train_loss:2.0705 train_time:11103415ms step_avg:43.71ms -step:254200/500000 train_loss:1.9941 train_time:11112156ms step_avg:43.71ms -step:254400/500000 train_loss:1.9379 train_time:11120890ms step_avg:43.71ms -step:254600/500000 train_loss:1.9741 train_time:11129625ms step_avg:43.71ms -step:254800/500000 train_loss:1.8901 train_time:11138356ms step_avg:43.71ms -step:255000/500000 train_loss:1.9534 train_time:11147089ms step_avg:43.71ms -step:255200/500000 train_loss:2.0204 train_time:11155825ms step_avg:43.71ms -step:255400/500000 train_loss:1.8525 train_time:11164563ms step_avg:43.71ms -step:255600/500000 train_loss:2.0028 train_time:11173305ms step_avg:43.71ms -step:255800/500000 train_loss:1.9450 train_time:11182037ms step_avg:43.71ms -step:256000/500000 train_loss:2.0375 train_time:11190773ms step_avg:43.71ms -step:256200/500000 train_loss:1.9353 train_time:11199508ms step_avg:43.71ms -step:256400/500000 train_loss:2.1361 train_time:11208247ms step_avg:43.71ms -step:256600/500000 train_loss:2.0129 train_time:11216977ms step_avg:43.71ms -step:256800/500000 train_loss:2.0725 train_time:11225838ms step_avg:43.71ms -step:257000/500000 train_loss:2.1332 train_time:11234569ms step_avg:43.71ms -step:257200/500000 train_loss:1.7644 train_time:11243314ms step_avg:43.71ms -step:257400/500000 train_loss:2.0439 train_time:11252048ms step_avg:43.71ms -step:257600/500000 train_loss:1.9964 train_time:11260788ms step_avg:43.71ms -step:257800/500000 train_loss:2.0681 train_time:11269525ms step_avg:43.71ms -step:258000/500000 train_loss:1.9432 train_time:11278255ms step_avg:43.71ms -step:258200/500000 train_loss:1.9413 train_time:11286980ms step_avg:43.71ms -step:258400/500000 train_loss:2.1118 train_time:11295708ms step_avg:43.71ms -step:258600/500000 train_loss:1.9258 train_time:11304449ms step_avg:43.71ms -step:258800/500000 train_loss:2.0245 train_time:11313180ms step_avg:43.71ms -step:259000/500000 train_loss:2.0141 train_time:11321910ms step_avg:43.71ms -step:259200/500000 train_loss:1.9900 train_time:11330641ms step_avg:43.71ms -step:259400/500000 train_loss:2.4748 train_time:11339371ms step_avg:43.71ms -step:259600/500000 train_loss:2.0246 train_time:11348100ms step_avg:43.71ms -step:259800/500000 train_loss:1.9782 train_time:11356835ms step_avg:43.71ms -step:260000/500000 train_loss:1.9953 train_time:11365662ms step_avg:43.71ms -step:260000/500000 val_loss:2.0041 val_bpb:1.1870 train_time:11365677ms step_avg:43.71ms -step:260200/500000 train_loss:1.9975 train_time:11374405ms step_avg:43.71ms -step:260400/500000 train_loss:2.0585 train_time:11383133ms step_avg:43.71ms -step:260600/500000 train_loss:2.0042 train_time:11391866ms step_avg:43.71ms -step:260800/500000 train_loss:1.9907 train_time:11400605ms step_avg:43.71ms -step:261000/500000 train_loss:1.8036 train_time:11409345ms step_avg:43.71ms -step:261200/500000 train_loss:1.9688 train_time:11418070ms step_avg:43.71ms -step:261400/500000 train_loss:2.0001 train_time:11426801ms step_avg:43.71ms -step:261600/500000 train_loss:1.7854 train_time:11435534ms step_avg:43.71ms -step:261800/500000 train_loss:1.9908 train_time:11444271ms step_avg:43.71ms -step:262000/500000 train_loss:2.0764 train_time:11453013ms step_avg:43.71ms -step:262200/500000 train_loss:1.8727 train_time:11461753ms step_avg:43.71ms -step:262400/500000 train_loss:2.6147 train_time:11470498ms step_avg:43.71ms -step:262600/500000 train_loss:2.0029 train_time:11479235ms step_avg:43.71ms -step:262800/500000 train_loss:2.2071 train_time:11487964ms step_avg:43.71ms -step:263000/500000 train_loss:2.0746 train_time:11496701ms step_avg:43.71ms -step:263200/500000 train_loss:2.1073 train_time:11505436ms step_avg:43.71ms -step:263400/500000 train_loss:1.9842 train_time:11514169ms step_avg:43.71ms -step:263600/500000 train_loss:2.0114 train_time:11522905ms step_avg:43.71ms -step:263800/500000 train_loss:2.0245 train_time:11531645ms step_avg:43.71ms -step:264000/500000 train_loss:2.0311 train_time:11540511ms step_avg:43.71ms -step:264200/500000 train_loss:1.9754 train_time:11549243ms step_avg:43.71ms -step:264400/500000 train_loss:2.0189 train_time:11557976ms step_avg:43.71ms -step:264600/500000 train_loss:2.0092 train_time:11566712ms step_avg:43.71ms -step:264800/500000 train_loss:1.9114 train_time:11575445ms step_avg:43.71ms -step:265000/500000 train_loss:1.8924 train_time:11584189ms step_avg:43.71ms -step:265200/500000 train_loss:2.0347 train_time:11592928ms step_avg:43.71ms -step:265400/500000 train_loss:2.0220 train_time:11601670ms step_avg:43.71ms -step:265600/500000 train_loss:2.0149 train_time:11610396ms step_avg:43.71ms -step:265800/500000 train_loss:2.0498 train_time:11619137ms step_avg:43.71ms -step:266000/500000 train_loss:2.0556 train_time:11627871ms step_avg:43.71ms -step:266200/500000 train_loss:1.9562 train_time:11636604ms step_avg:43.71ms -step:266400/500000 train_loss:2.0086 train_time:11645338ms step_avg:43.71ms -step:266600/500000 train_loss:2.0249 train_time:11654089ms step_avg:43.71ms -step:266800/500000 train_loss:1.9725 train_time:11662827ms step_avg:43.71ms -step:267000/500000 train_loss:2.0500 train_time:11671566ms step_avg:43.71ms -step:267200/500000 train_loss:2.0668 train_time:11680299ms step_avg:43.71ms -step:267400/500000 train_loss:1.9282 train_time:11689039ms step_avg:43.71ms -step:267600/500000 train_loss:1.9192 train_time:11697772ms step_avg:43.71ms -step:267800/500000 train_loss:1.9734 train_time:11706515ms step_avg:43.71ms -step:268000/500000 train_loss:2.0209 train_time:11715254ms step_avg:43.71ms -step:268200/500000 train_loss:1.9353 train_time:11724109ms step_avg:43.71ms -step:268400/500000 train_loss:2.0996 train_time:11732847ms step_avg:43.71ms -step:268600/500000 train_loss:1.9886 train_time:11741584ms step_avg:43.71ms -step:268800/500000 train_loss:2.0374 train_time:11750317ms step_avg:43.71ms -step:269000/500000 train_loss:2.0219 train_time:11759057ms step_avg:43.71ms -step:269200/500000 train_loss:1.8802 train_time:11767793ms step_avg:43.71ms -step:269400/500000 train_loss:2.0656 train_time:11776526ms step_avg:43.71ms -step:269600/500000 train_loss:1.9443 train_time:11785264ms step_avg:43.71ms -step:269800/500000 train_loss:1.9677 train_time:11793999ms step_avg:43.71ms -step:270000/500000 train_loss:2.0042 train_time:11802735ms step_avg:43.71ms -step:270200/500000 train_loss:1.9749 train_time:11811467ms step_avg:43.71ms -step:270400/500000 train_loss:2.0969 train_time:11820205ms step_avg:43.71ms -step:270600/500000 train_loss:2.1183 train_time:11828939ms step_avg:43.71ms -step:270800/500000 train_loss:1.8985 train_time:11837662ms step_avg:43.71ms -step:271000/500000 train_loss:2.0729 train_time:11846398ms step_avg:43.71ms -step:271200/500000 train_loss:2.0153 train_time:11855135ms step_avg:43.71ms -step:271400/500000 train_loss:2.0576 train_time:11863864ms step_avg:43.71ms -step:271600/500000 train_loss:1.9227 train_time:11872593ms step_avg:43.71ms -step:271800/500000 train_loss:1.9194 train_time:11881327ms step_avg:43.71ms -step:272000/500000 train_loss:2.0508 train_time:11890061ms step_avg:43.71ms -step:272200/500000 train_loss:1.9958 train_time:11898917ms step_avg:43.71ms -step:272400/500000 train_loss:2.0412 train_time:11907652ms step_avg:43.71ms -step:272600/500000 train_loss:2.1157 train_time:11916387ms step_avg:43.71ms -step:272800/500000 train_loss:1.9742 train_time:11925122ms step_avg:43.71ms -step:273000/500000 train_loss:1.9292 train_time:11933857ms step_avg:43.71ms -step:273200/500000 train_loss:2.2138 train_time:11942593ms step_avg:43.71ms -step:273400/500000 train_loss:2.0001 train_time:11951335ms step_avg:43.71ms -step:273600/500000 train_loss:2.0742 train_time:11960070ms step_avg:43.71ms -step:273800/500000 train_loss:2.0358 train_time:11968805ms step_avg:43.71ms -step:274000/500000 train_loss:1.9466 train_time:11977538ms step_avg:43.71ms -step:274200/500000 train_loss:2.1008 train_time:11986276ms step_avg:43.71ms -step:274400/500000 train_loss:1.9161 train_time:11995003ms step_avg:43.71ms -step:274600/500000 train_loss:1.9558 train_time:12003738ms step_avg:43.71ms -step:274800/500000 train_loss:2.2432 train_time:12012473ms step_avg:43.71ms -step:275000/500000 train_loss:2.0272 train_time:12021207ms step_avg:43.71ms -step:275200/500000 train_loss:2.0214 train_time:12029946ms step_avg:43.71ms -step:275400/500000 train_loss:1.9085 train_time:12038674ms step_avg:43.71ms -step:275600/500000 train_loss:2.0615 train_time:12047412ms step_avg:43.71ms -step:275800/500000 train_loss:2.1550 train_time:12056142ms step_avg:43.71ms -step:276000/500000 train_loss:2.1168 train_time:12064879ms step_avg:43.71ms -step:276200/500000 train_loss:1.9004 train_time:12073619ms step_avg:43.71ms -step:276400/500000 train_loss:1.8931 train_time:12082474ms step_avg:43.71ms -step:276600/500000 train_loss:2.0315 train_time:12091206ms step_avg:43.71ms -step:276800/500000 train_loss:2.0176 train_time:12099933ms step_avg:43.71ms -step:277000/500000 train_loss:2.0328 train_time:12108668ms step_avg:43.71ms -step:277200/500000 train_loss:1.9919 train_time:12117401ms step_avg:43.71ms -step:277400/500000 train_loss:2.0349 train_time:12126135ms step_avg:43.71ms -step:277600/500000 train_loss:1.9732 train_time:12134862ms step_avg:43.71ms -step:277800/500000 train_loss:2.0866 train_time:12143592ms step_avg:43.71ms -step:278000/500000 train_loss:1.9639 train_time:12152324ms step_avg:43.71ms -step:278200/500000 train_loss:1.8576 train_time:12161051ms step_avg:43.71ms -step:278400/500000 train_loss:1.9091 train_time:12169780ms step_avg:43.71ms -step:278600/500000 train_loss:1.9311 train_time:12178513ms step_avg:43.71ms -step:278800/500000 train_loss:2.0320 train_time:12187246ms step_avg:43.71ms -step:279000/500000 train_loss:2.2007 train_time:12195978ms step_avg:43.71ms -step:279200/500000 train_loss:1.9399 train_time:12204704ms step_avg:43.71ms -step:279400/500000 train_loss:2.1437 train_time:12213438ms step_avg:43.71ms -step:279600/500000 train_loss:1.8597 train_time:12222170ms step_avg:43.71ms -step:279800/500000 train_loss:1.8687 train_time:12230902ms step_avg:43.71ms -step:280000/500000 train_loss:1.9473 train_time:12239631ms step_avg:43.71ms -step:280000/500000 val_loss:2.0059 val_bpb:1.1880 train_time:12239647ms step_avg:43.71ms -step:280200/500000 train_loss:1.8587 train_time:12248371ms step_avg:43.71ms -step:280400/500000 train_loss:1.9275 train_time:12257226ms step_avg:43.71ms -step:280600/500000 train_loss:2.0865 train_time:12265963ms step_avg:43.71ms -step:280800/500000 train_loss:2.0359 train_time:12274697ms step_avg:43.71ms -step:281000/500000 train_loss:2.0256 train_time:12283428ms step_avg:43.71ms -step:281200/500000 train_loss:2.0721 train_time:12292168ms step_avg:43.71ms -step:281400/500000 train_loss:1.9527 train_time:12300901ms step_avg:43.71ms -step:281600/500000 train_loss:2.0435 train_time:12309642ms step_avg:43.71ms -step:281800/500000 train_loss:1.9409 train_time:12318374ms step_avg:43.71ms -step:282000/500000 train_loss:1.8839 train_time:12327110ms step_avg:43.71ms -step:282200/500000 train_loss:1.9873 train_time:12335849ms step_avg:43.71ms -step:282400/500000 train_loss:2.0421 train_time:12344583ms step_avg:43.71ms -step:282600/500000 train_loss:2.1513 train_time:12353319ms step_avg:43.71ms -step:282800/500000 train_loss:1.9599 train_time:12362052ms step_avg:43.71ms -step:283000/500000 train_loss:2.1058 train_time:12370792ms step_avg:43.71ms -step:283200/500000 train_loss:2.0484 train_time:12379527ms step_avg:43.71ms -step:283400/500000 train_loss:2.0516 train_time:12388269ms step_avg:43.71ms -step:283600/500000 train_loss:1.9050 train_time:12397005ms step_avg:43.71ms -step:283800/500000 train_loss:1.9629 train_time:12405746ms step_avg:43.71ms -step:284000/500000 train_loss:1.9291 train_time:12414495ms step_avg:43.71ms -step:284200/500000 train_loss:2.0472 train_time:12423237ms step_avg:43.71ms -step:284400/500000 train_loss:2.0072 train_time:12431976ms step_avg:43.71ms -step:284600/500000 train_loss:1.9916 train_time:12440841ms step_avg:43.71ms -step:284800/500000 train_loss:1.9631 train_time:12449579ms step_avg:43.71ms -step:285000/500000 train_loss:2.0340 train_time:12458321ms step_avg:43.71ms -step:285200/500000 train_loss:1.9090 train_time:12467053ms step_avg:43.71ms -step:285400/500000 train_loss:2.1183 train_time:12475792ms step_avg:43.71ms -step:285600/500000 train_loss:2.0143 train_time:12484524ms step_avg:43.71ms -step:285800/500000 train_loss:2.0114 train_time:12493265ms step_avg:43.71ms -step:286000/500000 train_loss:2.1389 train_time:12502004ms step_avg:43.71ms -step:286200/500000 train_loss:1.9397 train_time:12510739ms step_avg:43.71ms -step:286400/500000 train_loss:1.9368 train_time:12519480ms step_avg:43.71ms -step:286600/500000 train_loss:2.0340 train_time:12528218ms step_avg:43.71ms -step:286800/500000 train_loss:2.0173 train_time:12536952ms step_avg:43.71ms -step:287000/500000 train_loss:2.0685 train_time:12545692ms step_avg:43.71ms -step:287200/500000 train_loss:2.0053 train_time:12554435ms step_avg:43.71ms -step:287400/500000 train_loss:1.9727 train_time:12563172ms step_avg:43.71ms -step:287600/500000 train_loss:1.9363 train_time:12571910ms step_avg:43.71ms -step:287800/500000 train_loss:2.4742 train_time:12580644ms step_avg:43.71ms -step:288000/500000 train_loss:2.0207 train_time:12589381ms step_avg:43.71ms -step:288200/500000 train_loss:1.9881 train_time:12598121ms step_avg:43.71ms -step:288400/500000 train_loss:2.0589 train_time:12606861ms step_avg:43.71ms -step:288600/500000 train_loss:1.9856 train_time:12615613ms step_avg:43.71ms -step:288800/500000 train_loss:1.9577 train_time:12624470ms step_avg:43.71ms -step:289000/500000 train_loss:2.0183 train_time:12633200ms step_avg:43.71ms -step:289200/500000 train_loss:1.8817 train_time:12641936ms step_avg:43.71ms -step:289400/500000 train_loss:2.1841 train_time:12650673ms step_avg:43.71ms -step:289600/500000 train_loss:2.0147 train_time:12659407ms step_avg:43.71ms -step:289800/500000 train_loss:2.0655 train_time:12668147ms step_avg:43.71ms -step:290000/500000 train_loss:2.1097 train_time:12676876ms step_avg:43.71ms -step:290200/500000 train_loss:1.9592 train_time:12685620ms step_avg:43.71ms -step:290400/500000 train_loss:2.0262 train_time:12694358ms step_avg:43.71ms -step:290600/500000 train_loss:1.9529 train_time:12703095ms step_avg:43.71ms -step:290800/500000 train_loss:2.0480 train_time:12711840ms step_avg:43.71ms -step:291000/500000 train_loss:1.8834 train_time:12720580ms step_avg:43.71ms -step:291200/500000 train_loss:2.3031 train_time:12729321ms step_avg:43.71ms -step:291400/500000 train_loss:1.9826 train_time:12738057ms step_avg:43.71ms -step:291600/500000 train_loss:2.1083 train_time:12746796ms step_avg:43.71ms -step:291800/500000 train_loss:2.0582 train_time:12755531ms step_avg:43.71ms -step:292000/500000 train_loss:1.9632 train_time:12764274ms step_avg:43.71ms -step:292200/500000 train_loss:2.3024 train_time:12773015ms step_avg:43.71ms -step:292400/500000 train_loss:2.0066 train_time:12781752ms step_avg:43.71ms -step:292600/500000 train_loss:1.9131 train_time:12790492ms step_avg:43.71ms -step:292800/500000 train_loss:1.7997 train_time:12799355ms step_avg:43.71ms -step:293000/500000 train_loss:1.9488 train_time:12808094ms step_avg:43.71ms -step:293200/500000 train_loss:1.9176 train_time:12816830ms step_avg:43.71ms -step:293400/500000 train_loss:1.9941 train_time:12825563ms step_avg:43.71ms -step:293600/500000 train_loss:1.9393 train_time:12834298ms step_avg:43.71ms -step:293800/500000 train_loss:2.0138 train_time:12843033ms step_avg:43.71ms -step:294000/500000 train_loss:1.9956 train_time:12851779ms step_avg:43.71ms -step:294200/500000 train_loss:2.2086 train_time:12860523ms step_avg:43.71ms -step:294400/500000 train_loss:2.0443 train_time:12869263ms step_avg:43.71ms -step:294600/500000 train_loss:2.0434 train_time:12878001ms step_avg:43.71ms -step:294800/500000 train_loss:1.9583 train_time:12886742ms step_avg:43.71ms -step:295000/500000 train_loss:2.0998 train_time:12895484ms step_avg:43.71ms -step:295200/500000 train_loss:2.0722 train_time:12904220ms step_avg:43.71ms -step:295400/500000 train_loss:1.9461 train_time:12912961ms step_avg:43.71ms -step:295600/500000 train_loss:2.1570 train_time:12921702ms step_avg:43.71ms -step:295800/500000 train_loss:2.0004 train_time:12930439ms step_avg:43.71ms -step:296000/500000 train_loss:2.0067 train_time:12939179ms step_avg:43.71ms -step:296200/500000 train_loss:2.0775 train_time:12947917ms step_avg:43.71ms -step:296400/500000 train_loss:1.9923 train_time:12956652ms step_avg:43.71ms -step:296600/500000 train_loss:1.9149 train_time:12965382ms step_avg:43.71ms -step:296800/500000 train_loss:2.0540 train_time:12974118ms step_avg:43.71ms -step:297000/500000 train_loss:2.0463 train_time:12982942ms step_avg:43.71ms -step:297200/500000 train_loss:2.0057 train_time:12991673ms step_avg:43.71ms -step:297400/500000 train_loss:2.0463 train_time:13000411ms step_avg:43.71ms -step:297600/500000 train_loss:1.9426 train_time:13009145ms step_avg:43.71ms -step:297800/500000 train_loss:2.1075 train_time:13017877ms step_avg:43.71ms -step:298000/500000 train_loss:2.0239 train_time:13026611ms step_avg:43.71ms -step:298200/500000 train_loss:2.0610 train_time:13035344ms step_avg:43.71ms -step:298400/500000 train_loss:1.9703 train_time:13044080ms step_avg:43.71ms -step:298600/500000 train_loss:2.0391 train_time:13052807ms step_avg:43.71ms -step:298800/500000 train_loss:2.0636 train_time:13061543ms step_avg:43.71ms -step:299000/500000 train_loss:2.0894 train_time:13070275ms step_avg:43.71ms -step:299200/500000 train_loss:2.0629 train_time:13079003ms step_avg:43.71ms -step:299400/500000 train_loss:1.9541 train_time:13087740ms step_avg:43.71ms -step:299600/500000 train_loss:1.9493 train_time:13096471ms step_avg:43.71ms -step:299800/500000 train_loss:1.9702 train_time:13105202ms step_avg:43.71ms -step:300000/500000 train_loss:2.1034 train_time:13114057ms step_avg:43.71ms -step:300000/500000 val_loss:2.0019 val_bpb:1.1856 train_time:13114073ms step_avg:43.71ms -step:300200/500000 train_loss:2.0228 train_time:13122789ms step_avg:43.71ms -step:300400/500000 train_loss:1.9918 train_time:13131517ms step_avg:43.71ms -step:300600/500000 train_loss:1.9048 train_time:13140246ms step_avg:43.71ms -step:300800/500000 train_loss:1.9942 train_time:13148989ms step_avg:43.71ms -step:301000/500000 train_loss:2.1526 train_time:13157717ms step_avg:43.71ms -step:301200/500000 train_loss:1.9766 train_time:13166455ms step_avg:43.71ms -step:301400/500000 train_loss:2.0515 train_time:13175188ms step_avg:43.71ms -step:301600/500000 train_loss:1.8824 train_time:13183920ms step_avg:43.71ms -step:301800/500000 train_loss:1.8421 train_time:13192653ms step_avg:43.71ms -step:302000/500000 train_loss:2.0031 train_time:13201380ms step_avg:43.71ms -step:302200/500000 train_loss:1.9957 train_time:13210119ms step_avg:43.71ms -step:302400/500000 train_loss:2.1431 train_time:13218856ms step_avg:43.71ms -step:302600/500000 train_loss:2.0113 train_time:13227593ms step_avg:43.71ms -step:302800/500000 train_loss:2.0139 train_time:13236331ms step_avg:43.71ms -step:303000/500000 train_loss:1.9975 train_time:13245066ms step_avg:43.71ms -step:303200/500000 train_loss:2.1686 train_time:13253809ms step_avg:43.71ms -step:303400/500000 train_loss:1.9958 train_time:13262553ms step_avg:43.71ms -step:303600/500000 train_loss:2.0634 train_time:13271294ms step_avg:43.71ms -step:303800/500000 train_loss:1.8341 train_time:13280026ms step_avg:43.71ms -step:304000/500000 train_loss:2.0196 train_time:13288891ms step_avg:43.71ms -step:304200/500000 train_loss:2.0038 train_time:13297621ms step_avg:43.71ms -step:304400/500000 train_loss:2.1533 train_time:13306354ms step_avg:43.71ms -step:304600/500000 train_loss:1.9385 train_time:13315087ms step_avg:43.71ms -step:304800/500000 train_loss:1.9258 train_time:13323823ms step_avg:43.71ms -step:305000/500000 train_loss:2.0313 train_time:13332561ms step_avg:43.71ms -step:305200/500000 train_loss:2.0963 train_time:13341296ms step_avg:43.71ms -step:305400/500000 train_loss:1.9291 train_time:13350029ms step_avg:43.71ms -step:305600/500000 train_loss:2.1894 train_time:13358764ms step_avg:43.71ms -step:305800/500000 train_loss:1.9851 train_time:13367498ms step_avg:43.71ms -step:306000/500000 train_loss:2.0476 train_time:13376229ms step_avg:43.71ms -step:306200/500000 train_loss:1.9089 train_time:13384965ms step_avg:43.71ms -step:306400/500000 train_loss:1.9578 train_time:13393701ms step_avg:43.71ms -step:306600/500000 train_loss:1.9322 train_time:13402431ms step_avg:43.71ms -step:306800/500000 train_loss:1.9681 train_time:13411166ms step_avg:43.71ms -step:307000/500000 train_loss:1.9248 train_time:13419899ms step_avg:43.71ms -step:307200/500000 train_loss:2.0327 train_time:13428633ms step_avg:43.71ms -step:307400/500000 train_loss:2.1124 train_time:13437369ms step_avg:43.71ms -step:307600/500000 train_loss:1.8934 train_time:13446105ms step_avg:43.71ms -step:307800/500000 train_loss:1.9565 train_time:13454841ms step_avg:43.71ms -step:308000/500000 train_loss:1.9663 train_time:13463575ms step_avg:43.71ms -step:308200/500000 train_loss:2.0848 train_time:13472431ms step_avg:43.71ms -step:308400/500000 train_loss:1.9285 train_time:13481167ms step_avg:43.71ms -step:308600/500000 train_loss:2.0670 train_time:13489905ms step_avg:43.71ms -step:308800/500000 train_loss:1.8746 train_time:13498638ms step_avg:43.71ms -step:309000/500000 train_loss:1.9112 train_time:13507371ms step_avg:43.71ms -step:309200/500000 train_loss:1.9494 train_time:13516103ms step_avg:43.71ms -step:309400/500000 train_loss:1.9585 train_time:13524834ms step_avg:43.71ms -step:309600/500000 train_loss:2.0978 train_time:13533563ms step_avg:43.71ms -step:309800/500000 train_loss:1.7424 train_time:13542293ms step_avg:43.71ms -step:310000/500000 train_loss:2.0507 train_time:13551024ms step_avg:43.71ms -step:310200/500000 train_loss:1.9832 train_time:13559759ms step_avg:43.71ms -step:310400/500000 train_loss:2.0606 train_time:13568497ms step_avg:43.71ms -step:310600/500000 train_loss:1.9982 train_time:13577225ms step_avg:43.71ms -step:310800/500000 train_loss:1.9846 train_time:13585959ms step_avg:43.71ms -step:311000/500000 train_loss:2.1402 train_time:13594684ms step_avg:43.71ms -step:311200/500000 train_loss:1.9878 train_time:13603418ms step_avg:43.71ms -step:311400/500000 train_loss:2.0720 train_time:13612147ms step_avg:43.71ms -step:311600/500000 train_loss:1.9734 train_time:13620878ms step_avg:43.71ms -step:311800/500000 train_loss:2.0040 train_time:13629609ms step_avg:43.71ms -step:312000/500000 train_loss:1.9738 train_time:13638336ms step_avg:43.71ms -step:312200/500000 train_loss:2.0241 train_time:13647067ms step_avg:43.71ms -step:312400/500000 train_loss:1.9978 train_time:13655930ms step_avg:43.71ms -step:312600/500000 train_loss:2.0128 train_time:13664662ms step_avg:43.71ms -step:312800/500000 train_loss:2.0693 train_time:13673392ms step_avg:43.71ms -step:313000/500000 train_loss:2.0319 train_time:13682118ms step_avg:43.71ms -step:313200/500000 train_loss:1.9284 train_time:13690852ms step_avg:43.71ms -step:313400/500000 train_loss:1.8951 train_time:13699586ms step_avg:43.71ms -step:313600/500000 train_loss:2.0394 train_time:13708322ms step_avg:43.71ms -step:313800/500000 train_loss:2.0955 train_time:13717054ms step_avg:43.71ms -step:314000/500000 train_loss:1.9987 train_time:13725790ms step_avg:43.71ms -step:314200/500000 train_loss:2.0925 train_time:13734528ms step_avg:43.71ms -step:314400/500000 train_loss:1.9578 train_time:13743269ms step_avg:43.71ms -step:314600/500000 train_loss:2.0680 train_time:13752011ms step_avg:43.71ms -step:314800/500000 train_loss:2.0073 train_time:13760745ms step_avg:43.71ms -step:315000/500000 train_loss:2.0866 train_time:13769479ms step_avg:43.71ms -step:315200/500000 train_loss:1.9914 train_time:13778210ms step_avg:43.71ms -step:315400/500000 train_loss:1.9911 train_time:13786945ms step_avg:43.71ms -step:315600/500000 train_loss:2.1094 train_time:13795674ms step_avg:43.71ms -step:315800/500000 train_loss:1.9628 train_time:13804407ms step_avg:43.71ms -step:316000/500000 train_loss:1.9922 train_time:13813140ms step_avg:43.71ms -step:316200/500000 train_loss:1.9326 train_time:13821872ms step_avg:43.71ms -step:316400/500000 train_loss:1.9757 train_time:13830721ms step_avg:43.71ms -step:316600/500000 train_loss:1.9799 train_time:13839462ms step_avg:43.71ms -step:316800/500000 train_loss:2.0344 train_time:13848182ms step_avg:43.71ms -step:317000/500000 train_loss:1.9756 train_time:13856912ms step_avg:43.71ms -step:317200/500000 train_loss:1.8813 train_time:13865639ms step_avg:43.71ms -step:317400/500000 train_loss:2.0688 train_time:13874365ms step_avg:43.71ms -step:317600/500000 train_loss:2.1282 train_time:13883090ms step_avg:43.71ms -step:317800/500000 train_loss:2.0416 train_time:13891818ms step_avg:43.71ms -step:318000/500000 train_loss:1.9459 train_time:13900547ms step_avg:43.71ms -step:318200/500000 train_loss:2.1129 train_time:13909275ms step_avg:43.71ms -step:318400/500000 train_loss:1.9522 train_time:13918006ms step_avg:43.71ms -step:318600/500000 train_loss:2.0662 train_time:13926738ms step_avg:43.71ms -step:318800/500000 train_loss:2.0459 train_time:13935471ms step_avg:43.71ms -step:319000/500000 train_loss:2.0817 train_time:13944202ms step_avg:43.71ms -step:319200/500000 train_loss:2.0718 train_time:13952933ms step_avg:43.71ms -step:319400/500000 train_loss:2.0248 train_time:13961668ms step_avg:43.71ms -step:319600/500000 train_loss:2.0153 train_time:13970395ms step_avg:43.71ms -step:319800/500000 train_loss:1.9507 train_time:13979128ms step_avg:43.71ms -step:320000/500000 train_loss:1.9200 train_time:13987858ms step_avg:43.71ms -step:320000/500000 val_loss:2.0010 val_bpb:1.1851 train_time:13987873ms step_avg:43.71ms -step:320200/500000 train_loss:2.0915 train_time:13996588ms step_avg:43.71ms -step:320400/500000 train_loss:1.9669 train_time:14005321ms step_avg:43.71ms -step:320600/500000 train_loss:2.0496 train_time:14014180ms step_avg:43.71ms -step:320800/500000 train_loss:2.1033 train_time:14022909ms step_avg:43.71ms -step:321000/500000 train_loss:1.8803 train_time:14031642ms step_avg:43.71ms -step:321200/500000 train_loss:1.9655 train_time:14040378ms step_avg:43.71ms -step:321400/500000 train_loss:2.0748 train_time:14049121ms step_avg:43.71ms -step:321600/500000 train_loss:2.0919 train_time:14057851ms step_avg:43.71ms -step:321800/500000 train_loss:2.0200 train_time:14066585ms step_avg:43.71ms -step:322000/500000 train_loss:2.0471 train_time:14075312ms step_avg:43.71ms -step:322200/500000 train_loss:1.9971 train_time:14084041ms step_avg:43.71ms -step:322400/500000 train_loss:1.8674 train_time:14092769ms step_avg:43.71ms -step:322600/500000 train_loss:1.9619 train_time:14101496ms step_avg:43.71ms -step:322800/500000 train_loss:2.0699 train_time:14110228ms step_avg:43.71ms -step:323000/500000 train_loss:2.4914 train_time:14118957ms step_avg:43.71ms -step:323200/500000 train_loss:1.9907 train_time:14127689ms step_avg:43.71ms -step:323400/500000 train_loss:2.0156 train_time:14136416ms step_avg:43.71ms -step:323600/500000 train_loss:2.0626 train_time:14145143ms step_avg:43.71ms -step:323800/500000 train_loss:1.9674 train_time:14153871ms step_avg:43.71ms -step:324000/500000 train_loss:2.0524 train_time:14162599ms step_avg:43.71ms -step:324200/500000 train_loss:2.0582 train_time:14171336ms step_avg:43.71ms -step:324400/500000 train_loss:1.8915 train_time:14180068ms step_avg:43.71ms -step:324600/500000 train_loss:2.0369 train_time:14188927ms step_avg:43.71ms -step:324800/500000 train_loss:2.0923 train_time:14197654ms step_avg:43.71ms -step:325000/500000 train_loss:1.9624 train_time:14206393ms step_avg:43.71ms -step:325200/500000 train_loss:2.0010 train_time:14215134ms step_avg:43.71ms -step:325400/500000 train_loss:2.0460 train_time:14223861ms step_avg:43.71ms -step:325600/500000 train_loss:2.0737 train_time:14232589ms step_avg:43.71ms -step:325800/500000 train_loss:1.8715 train_time:14241323ms step_avg:43.71ms -step:326000/500000 train_loss:1.9732 train_time:14250055ms step_avg:43.71ms -step:326200/500000 train_loss:1.8768 train_time:14258799ms step_avg:43.71ms -step:326400/500000 train_loss:1.9460 train_time:14267540ms step_avg:43.71ms -step:326600/500000 train_loss:1.9567 train_time:14276274ms step_avg:43.71ms -step:326800/500000 train_loss:1.9932 train_time:14285014ms step_avg:43.71ms -step:327000/500000 train_loss:2.0671 train_time:14293752ms step_avg:43.71ms -step:327200/500000 train_loss:1.9528 train_time:14302485ms step_avg:43.71ms -step:327400/500000 train_loss:2.0451 train_time:14311217ms step_avg:43.71ms -step:327600/500000 train_loss:2.0360 train_time:14319952ms step_avg:43.71ms -step:327800/500000 train_loss:2.0273 train_time:14328692ms step_avg:43.71ms -step:328000/500000 train_loss:2.1318 train_time:14337428ms step_avg:43.71ms -step:328200/500000 train_loss:2.0219 train_time:14346167ms step_avg:43.71ms -step:328400/500000 train_loss:1.8702 train_time:14354904ms step_avg:43.71ms -step:328600/500000 train_loss:2.2112 train_time:14363648ms step_avg:43.71ms -step:328800/500000 train_loss:2.0007 train_time:14372507ms step_avg:43.71ms -step:329000/500000 train_loss:1.9333 train_time:14381249ms step_avg:43.71ms -step:329200/500000 train_loss:1.8833 train_time:14389989ms step_avg:43.71ms -step:329400/500000 train_loss:1.9639 train_time:14398731ms step_avg:43.71ms -step:329430/500000 val_loss:1.9837 val_bpb:1.1749 train_time:14400039ms step_avg:43.71ms -stopping_early: wallclock_cap train_time:14400039ms step:329430/500000 -peak memory allocated: 10184 MiB reserved: 10588 MiB -Serialized model: 67224983 bytes -Code size: 47642 bytes -Total submission size: 67272625 bytes -Serialized model int8+zlib: 15762519 bytes (payload:17178912 raw_torch:17224025 payload_ratio:3.91x) -Total submission size int8+zlib: 15810161 bytes -final_int8_zlib_roundtrip val_loss:2.0386 val_bpb:1.2074 eval_time:1356ms -final_int8_zlib_roundtrip_exact val_loss:2.03860961 val_bpb:1.20737944 diff --git a/records/track_non_record_16mb/2026-03-18_Quasi10Bfrom50B_SP1024_9x512_KV4_4h_pgut3/train_gpt.py b/records/track_non_record_16mb/2026-03-18_Quasi10Bfrom50B_SP1024_9x512_KV4_4h_pgut3/train_gpt.py deleted file mode 100644 index 0deb0565f5..0000000000 --- a/records/track_non_record_16mb/2026-03-18_Quasi10Bfrom50B_SP1024_9x512_KV4_4h_pgut3/train_gpt.py +++ /dev/null @@ -1,1126 +0,0 @@ -""" -The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. - -Hard stop: `train_gpt.py` and `train_gpt_mlx.py` must never be longer than 1500 lines. -""" - -from __future__ import annotations - -import copy -import glob -import io -import math -import os -import random -import subprocess -import sys -import time -import uuid -import zlib -from pathlib import Path - -import numpy as np -import sentencepiece as spm -import torch -import torch.distributed as dist -import torch.nn.functional as F -from torch import Tensor, nn -from torch.nn.parallel import DistributedDataParallel as DDP - -# ----------------------------- -# HYPERPARAMETERS -# ----------------------------- -# Default Simple Baseline run: -# - 9 transformer blocks at width 512 -# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion -# - vocab size 1024, sequence length 1024, tied embeddings -# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap - -class Hyperparameters: - # Data paths are shard globs produced by the existing preprocessing pipeline. - data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") - train_files = os.path.join(data_path, "fineweb_train_*.bin") - val_files = os.path.join(data_path, "fineweb_val_*.bin") - tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") - run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) - seed = int(os.environ.get("SEED", 1337)) - - # Validation cadence and batch size. Validation always uses the full fineweb_val split. - val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) - val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) - train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) - - # Training length. - iterations = int(os.environ.get("ITERATIONS", 20000)) - warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) - warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) - train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) - train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) - max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) - qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) - - # Model shape. - vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) - num_layers = int(os.environ.get("NUM_LAYERS", 9)) - num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) - model_dim = int(os.environ.get("MODEL_DIM", 512)) - num_heads = int(os.environ.get("NUM_HEADS", 8)) - mlp_mult = int(os.environ.get("MLP_MULT", 2)) - tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) - rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) - logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) - - # Optimizer hyperparameters. - embed_lr = float(os.environ.get("EMBED_LR", 0.6)) - head_lr = float(os.environ.get("HEAD_LR", 0.008)) - tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) - tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) - matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) - scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) - muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) - muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) - muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) - muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) - beta1 = float(os.environ.get("BETA1", 0.9)) - beta2 = float(os.environ.get("BETA2", 0.95)) - adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) - grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) - -# ----------------------------- -# MUON OPTIMIZER -# ----------------------------- -# -# As borrowed from modded-nanogpt -# Background on Muon: https://kellerjordan.github.io/posts/muon/ - -def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: - # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. - # Muon uses this to normalize matrix-shaped gradients before applying them. - a, b, c = (3.4445, -4.7750, 2.0315) - X = G.bfloat16() - X /= X.norm() + eps - transposed = G.size(0) > G.size(1) - if transposed: - X = X.T - for _ in range(steps): - A = X @ X.T - B = b * A + c * A @ A - X = a * X + B @ X - return X.T if transposed else X - - -class Muon(torch.optim.Optimizer): - def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): - super().__init__( - params, - dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), - ) - - @torch.no_grad() - def step(self, closure=None): - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - distributed = dist.is_available() and dist.is_initialized() - world_size = dist.get_world_size() if distributed else 1 - rank = dist.get_rank() if distributed else 0 - - for group in self.param_groups: - params = group["params"] - if not params: - continue - lr = group["lr"] - momentum = group["momentum"] - backend_steps = group["backend_steps"] - nesterov = group["nesterov"] - - total_params = sum(int(p.numel()) for p in params) - updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) - - curr = 0 - for i, p in enumerate(params): - if i % world_size == rank and p.grad is not None: - g = p.grad - state = self.state[p] - if "momentum_buffer" not in state: - state["momentum_buffer"] = torch.zeros_like(g) - buf = state["momentum_buffer"] - buf.mul_(momentum).add_(g) - if nesterov: - g = g.add(buf, alpha=momentum) - g = zeropower_via_newtonschulz5(g, steps=backend_steps) - # Scale correction from Muon reference implementations. - g *= max(1, g.size(0) / g.size(1)) ** 0.5 - updates_flat[curr : curr + p.numel()] = g.reshape(-1) - curr += p.numel() - - if distributed: - dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) - - curr = 0 - for p in params: - g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) - p.add_(g, alpha=-lr) - curr += p.numel() - - return loss - - -# ----------------------------- -# TOKENIZER-AGNOSTIC EVALUATION SETUP -# ----------------------------- -# -# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. -# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. -# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. -# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. - -def build_sentencepiece_luts( - sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device -) -> tuple[Tensor, Tensor, Tensor]: - sp_vocab_size = int(sp.vocab_size()) - table_size = max(sp_vocab_size, vocab_size) - base_bytes_np = np.zeros((table_size,), dtype=np.int16) - has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) - is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) - for token_id in range(sp_vocab_size): - if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): - continue - is_boundary_token_np[token_id] = False - if sp.is_byte(token_id): - base_bytes_np[token_id] = 1 - continue - piece = sp.id_to_piece(token_id) - if piece.startswith("▁"): - has_leading_space_np[token_id] = True - piece = piece[1:] - base_bytes_np[token_id] = len(piece.encode("utf-8")) - return ( - torch.tensor(base_bytes_np, dtype=torch.int16, device=device), - torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), - torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), - ) - - -def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: - files = [Path(p) for p in sorted(glob.glob(pattern))] - if not files: - raise FileNotFoundError(f"No files found for pattern: {pattern}") - # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. - tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() - usable = ((tokens.numel() - 1) // seq_len) * seq_len - if usable <= 0: - raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") - return tokens[: usable + 1] - - -def eval_val( - args: Hyperparameters, - model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - grad_accum_steps: int, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, -) -> tuple[float, float]: - # Validation computes two metrics: - # - val_loss: token cross-entropy (natural log) - # - val_bpb: tokenizer-agnostic compression metric used by the challenge - local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) - if local_batch_tokens < args.train_seq_len: - raise ValueError( - "VAL_BATCH_SIZE must provide at least one sequence per rank; " - f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " - f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" - ) - local_batch_seqs = local_batch_tokens // args.train_seq_len - total_seqs = (val_tokens.numel() - 1) // args.train_seq_len - seq_start = (total_seqs * rank) // world_size - seq_end = (total_seqs * (rank + 1)) // world_size - val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) - val_token_count = torch.zeros((), device=device, dtype=torch.float64) - val_byte_count = torch.zeros((), device=device, dtype=torch.float64) - - model.eval() - with torch.inference_mode(): - for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): - batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) - raw_start = batch_seq_start * args.train_seq_len - raw_end = batch_seq_end * args.train_seq_len + 1 - local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) - x = local[:-1].reshape(-1, args.train_seq_len) - y = local[1:].reshape(-1, args.train_seq_len) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - batch_loss = model(x, y).detach() - batch_token_count = float(y.numel()) - val_loss_sum += batch_loss.to(torch.float64) * batch_token_count - val_token_count += batch_token_count - prev_ids = x.reshape(-1) - tgt_ids = y.reshape(-1) - token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) - token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) - val_byte_count += token_bytes.to(torch.float64).sum() - - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) - - val_loss = val_loss_sum / val_token_count - bits_per_token = val_loss.item() / math.log(2.0) - tokens_per_byte = val_token_count.item() / val_byte_count.item() - model.train() - return float(val_loss.item()), float(bits_per_token * tokens_per_byte) - -# ----------------------------- -# POST-TRAINING QUANTIZATION -# ----------------------------- -# -# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. -# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. -# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. - -CONTROL_TENSOR_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "CONTROL_TENSOR_NAME_PATTERNS", - "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", - ",".join(CONTROL_TENSOR_NAME_PATTERNS), - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 -INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 -INT8_PER_ROW_SCALE_DTYPE = torch.float16 -INT8_CLIP_PERCENTILE = 99.99984 -INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 - -def tensor_nbytes(t: Tensor) -> int: - return int(t.numel()) * int(t.element_size()) - -def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: - if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): - return t.float().contiguous() - if t.dtype in {torch.float32, torch.bfloat16}: - passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") - return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() - return t - -def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: - t32 = t.float() - if t32.ndim == 2: - # Matrices get one scale per row, which usually tracks output-channel - # ranges much better than a single tensor-wide scale. - clip_abs = ( - torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) - if t32.numel() - else torch.empty((t32.shape[0],), dtype=torch.float32) - ) - clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) - scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) - q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() - return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() - - # Vectors / scalars use a simpler per-tensor scale. - clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 - scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) - q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() - return q, scale - -def quantize_state_dict_int8(state_dict: dict[str, Tensor]): - # Single supported clean-script export format: - # - per-row int8 for 2D float tensors - # - per-tensor int8 for other float tensors - # - exact passthrough for non-floats - # - passthrough for small float tensors, stored as fp16 to save bytes - quantized: dict[str, Tensor] = {} - scales: dict[str, Tensor] = {} - dtypes: dict[str, str] = {} - passthrough: dict[str, Tensor] = {} - passthrough_orig_dtypes: dict[str, str] = {} - qmeta: dict[str, dict[str, object]] = {} - stats = dict.fromkeys( - ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), - 0, - ) - - for name, tensor in state_dict.items(): - t = tensor.detach().to("cpu").contiguous() - stats["param_count"] += int(t.numel()) - stats["num_tensors"] += 1 - stats["baseline_tensor_bytes"] += tensor_nbytes(t) - - if not t.is_floating_point(): - stats["num_nonfloat_tensors"] += 1 - passthrough[name] = t - stats["int8_payload_bytes"] += tensor_nbytes(t) - continue - - # Small float tensors are cheap enough to keep directly. We still downcast - # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. - if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: - kept = keep_float_tensor(name, t, passthrough_orig_dtypes) - passthrough[name] = kept - stats["int8_payload_bytes"] += tensor_nbytes(kept) - continue - - stats["num_float_tensors"] += 1 - q, s = quantize_float_tensor(t) - if s.ndim > 0: - qmeta[name] = {"scheme": "per_row", "axis": 0} - quantized[name] = q - scales[name] = s - dtypes[name] = str(t.dtype).removeprefix("torch.") - stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) - - obj: dict[str, object] = { - "__quant_format__": "int8_clean_per_row_v1", - "quantized": quantized, - "scales": scales, - "dtypes": dtypes, - "passthrough": passthrough, - } - if qmeta: - obj["qmeta"] = qmeta - if passthrough_orig_dtypes: - obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes - return obj, stats - -def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: - out: dict[str, Tensor] = {} - qmeta = obj.get("qmeta", {}) - passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) - for name, q in obj["quantized"].items(): - dtype = getattr(torch, obj["dtypes"][name]) - s = obj["scales"][name] - if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: - s = s.to(dtype=torch.float32) - # Broadcast the saved row scale back across trailing dimensions. - out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() - else: - scale = float(s.item()) - out[name] = (q.float() * scale).to(dtype=dtype).contiguous() - for name, t in obj["passthrough"].items(): - # Restore small tensors, undoing the temporary fp16 storage cast if needed. - out_t = t.detach().to("cpu").contiguous() - orig_dtype = passthrough_orig_dtypes.get(name) - if isinstance(orig_dtype, str): - out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() - out[name] = out_t - return out - - -# ----------------------------- -# DATA LOADING -# ----------------------------- - -def load_data_shard(file: Path) -> Tensor: - header_bytes = 256 * np.dtype(" None: - self.file_idx = (self.file_idx + 1) % len(self.files) - self.tokens = load_data_shard(self.files[self.file_idx]) - self.pos = 0 - - def take(self, n: int) -> Tensor: - chunks: list[Tensor] = [] - remaining = n - while remaining > 0: - avail = self.tokens.numel() - self.pos - if avail <= 0: - self._advance_file() - continue - k = min(remaining, avail) - chunks.append(self.tokens[self.pos : self.pos + k]) - self.pos += k - remaining -= k - return chunks[0] if len(chunks) == 1 else torch.cat(chunks) - - -class DistributedTokenLoader: - # Each call consumes a contiguous chunk from the shared token stream, then slices out - # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. - def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): - self.rank = rank - self.world_size = world_size - self.device = device - self.stream = TokenStream(pattern) - - def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: - local_tokens = global_tokens // (self.world_size * grad_accum_steps) - per_rank_span = local_tokens + 1 - chunk = self.stream.take(per_rank_span * self.world_size) - start = self.rank * per_rank_span - local = chunk[start : start + per_rank_span].to(dtype=torch.int64) - x = local[:-1].reshape(-1, seq_len) - y = local[1:].reshape(-1, seq_len) - return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) - -# ----------------------------- -# TRANSFORMER MODULES -# ----------------------------- - -class RMSNorm(nn.Module): - def __init__(self, eps: float | None = None): - super().__init__() - self.eps = eps - - def forward(self, x: Tensor) -> Tensor: - return F.rms_norm(x, (x.size(-1),), eps=self.eps) - - -class CastedLinear(nn.Linear): - # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. - def forward(self, x: Tensor) -> Tensor: - bias = self.bias.to(x.dtype) if self.bias is not None else None - return F.linear(x, self.weight.to(x.dtype), bias) - - -def restore_low_dim_params_to_fp32(module: nn.Module) -> None: - # Keep small/control parameters in fp32 even when the model body runs in bf16. - with torch.no_grad(): - for name, param in module.named_parameters(): - if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: - param.data = param.data.float() - - -class Rotary(nn.Module): - # Caches cos/sin tables per sequence length on the current device. - def __init__(self, dim: int, base: float = 10000.0): - super().__init__() - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self._seq_len_cached = 0 - self._cos_cached: Tensor | None = None - self._sin_cached: Tensor | None = None - - def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: - if ( - self._cos_cached is None - or self._sin_cached is None - or self._seq_len_cached != seq_len - or self._cos_cached.device != device - ): - t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) - freqs = torch.outer(t, self.inv_freq.to(device)) - self._cos_cached = freqs.cos()[None, None, :, :] - self._sin_cached = freqs.sin()[None, None, :, :] - self._seq_len_cached = seq_len - return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) - - -def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: - half = x.size(-1) // 2 - x1, x2 = x[..., :half], x[..., half:] - return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) - - -class CausalSelfAttention(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - rope_base: float, - qk_gain_init: float, - ): - super().__init__() - if dim % num_heads != 0: - raise ValueError("model_dim must be divisible by num_heads") - if num_heads % num_kv_heads != 0: - raise ValueError("num_heads must be divisible by num_kv_heads") - self.num_heads = num_heads - self.num_kv_heads = num_kv_heads - self.head_dim = dim // num_heads - if self.head_dim % 2 != 0: - raise ValueError("head_dim must be even for RoPE") - kv_dim = self.num_kv_heads * self.head_dim - self.c_q = CastedLinear(dim, dim, bias=False) - self.c_k = CastedLinear(dim, kv_dim, bias=False) - self.c_v = CastedLinear(dim, kv_dim, bias=False) - self.proj = CastedLinear(dim, dim, bias=False) - self.proj._zero_init = True - self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) - self.rotary = Rotary(self.head_dim, base=rope_base) - - def forward(self, x: Tensor) -> Tensor: - bsz, seqlen, dim = x.shape - q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) - k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) - v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) - q = F.rms_norm(q, (q.size(-1),)) - k = F.rms_norm(k, (k.size(-1),)) - cos, sin = self.rotary(seqlen, x.device, q.dtype) - q = apply_rotary_emb(q, cos, sin) - k = apply_rotary_emb(k, cos, sin) - q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] - y = F.scaled_dot_product_attention( - q, - k, - v, - attn_mask=None, - is_causal=True, - enable_gqa=(self.num_kv_heads != self.num_heads), - ) - y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) - return self.proj(y) - - -class MLP(nn.Module): - # relu^2 MLP from the original modded-nanogpt setup - def __init__(self, dim: int, mlp_mult: int): - super().__init__() - hidden = mlp_mult * dim - self.fc = CastedLinear(dim, hidden, bias=False) - self.proj = CastedLinear(hidden, dim, bias=False) - self.proj._zero_init = True - - def forward(self, x: Tensor) -> Tensor: - x = torch.relu(self.fc(x)) - return self.proj(x.square()) - - -class Block(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: int, - rope_base: float, - qk_gain_init: float, - ): - super().__init__() - self.attn_norm = RMSNorm() - self.mlp_norm = RMSNorm() - self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) - self.mlp = MLP(dim, mlp_mult) - self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) - - def forward(self, x: Tensor, x0: Tensor) -> Tensor: - mix = self.resid_mix.to(dtype=x.dtype) - x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 - attn_out = self.attn(self.attn_norm(x)) - x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out - x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) - return x - - -class GPT(nn.Module): - def __init__( - self, - vocab_size: int, - num_layers: int, - model_dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: int, - tie_embeddings: bool, - tied_embed_init_std: float, - logit_softcap: float, - rope_base: float, - qk_gain_init: float, - ): - super().__init__() - if logit_softcap <= 0.0: - raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") - self.tie_embeddings = tie_embeddings - self.tied_embed_init_std = tied_embed_init_std - self.logit_softcap = logit_softcap - self.tok_emb = nn.Embedding(vocab_size, model_dim) - self.num_encoder_layers = num_layers // 2 - self.num_decoder_layers = num_layers - self.num_encoder_layers - self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) - self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) - self.blocks = nn.ModuleList( - [ - Block( - model_dim, - num_heads, - num_kv_heads, - mlp_mult, - rope_base, - qk_gain_init, - ) - for i in range(num_layers) - ] - ) - self.final_norm = RMSNorm() - self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) - if self.lm_head is not None: - self.lm_head._zero_init = True - self._init_weights() - - def _init_weights(self) -> None: - if self.tie_embeddings: - nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) - for module in self.modules(): - if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): - nn.init.zeros_(module.weight) - - def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: - x = self.tok_emb(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - x0 = x - skips: list[Tensor] = [] - - # First half stores skips; second half reuses them in reverse order. - for i in range(self.num_encoder_layers): - x = self.blocks[i](x, x0) - skips.append(x) - for i in range(self.num_decoder_layers): - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - x = self.blocks[self.num_encoder_layers + i](x, x0) - - x = self.final_norm(x).reshape(-1, x.size(-1)) - targets = target_ids.reshape(-1) - if self.tie_embeddings: - logits_proj = F.linear(x, self.tok_emb.weight) - else: - if self.lm_head is None: - raise RuntimeError("lm_head is required when tie_embeddings=False") - logits_proj = self.lm_head(x) - logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - return F.cross_entropy(logits.float(), targets, reduction="mean") - - -# ----------------------------- -# TRAINING -# ----------------------------- - -def main() -> None: - global zeropower_via_newtonschulz5 - - code = Path(__file__).read_text(encoding="utf-8") - args = Hyperparameters() - zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) - - # ----------------------------- - # DISTRIBUTED + CUDA SETUP - # ----------------------------- - - distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ - rank = int(os.environ.get("RANK", "0")) - world_size = int(os.environ.get("WORLD_SIZE", "1")) - local_rank = int(os.environ.get("LOCAL_RANK", "0")) - if world_size <= 0: - raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") - if 8 % world_size != 0: - raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") - grad_accum_steps = 8 // world_size - grad_scale = 1.0 / grad_accum_steps - if not torch.cuda.is_available(): - raise RuntimeError("CUDA is required") - device = torch.device("cuda", local_rank) - torch.cuda.set_device(device) - if distributed: - dist.init_process_group(backend="nccl", device_id=device) - dist.barrier() - master_process = rank == 0 - - # Fast math knobs - torch.backends.cuda.matmul.allow_tf32 = True - torch.backends.cudnn.allow_tf32 = True - from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp - - enable_cudnn_sdp(False) - enable_flash_sdp(True) - enable_mem_efficient_sdp(False) - enable_math_sdp(False) - - logfile = None - if master_process: - os.makedirs("logs", exist_ok=True) - logfile = f"logs/{args.run_id}.txt" - print(logfile) - - def log0(msg: str, console: bool = True) -> None: - if not master_process: - return - if console: - print(msg) - if logfile is not None: - with open(logfile, "a", encoding="utf-8") as f: - print(msg, file=f) - - log0(code, console=False) - log0("=" * 100, console=False) - log0(f"Running Python {sys.version}", console=False) - log0(f"Running PyTorch {torch.__version__}", console=False) - log0( - subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, - console=False, - ) - log0("=" * 100, console=False) - - # ----------------------------- - # TOKENIZER + VALIDATION METRIC SETUP - # ----------------------------- - - random.seed(args.seed) - np.random.seed(args.seed) - torch.manual_seed(args.seed) - torch.cuda.manual_seed_all(args.seed) - - if not args.tokenizer_path.endswith(".model"): - raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") - sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) - if int(sp.vocab_size()) != args.vocab_size: - raise ValueError( - f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" - ) - dataset_dir = Path(args.data_path).resolve() - actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) - val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) - base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( - sp, args.vocab_size, device - ) - log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") - log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") - log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") - - # ----------------------------- - # MODEL + OPTIMIZER SETUP - # ----------------------------- - - base_model = GPT( - vocab_size=args.vocab_size, - num_layers=args.num_layers, - model_dim=args.model_dim, - num_heads=args.num_heads, - num_kv_heads=args.num_kv_heads, - mlp_mult=args.mlp_mult, - tie_embeddings=args.tie_embeddings, - tied_embed_init_std=args.tied_embed_init_std, - logit_softcap=args.logit_softcap, - rope_base=args.rope_base, - qk_gain_init=args.qk_gain_init, - ).to(device).bfloat16() - for module in base_model.modules(): - if isinstance(module, CastedLinear): - module.float() - restore_low_dim_params_to_fp32(base_model) - compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) - model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model - - # Optimizer split: - # - token embedding (Adam) uses EMBED_LR - # - untied lm_head (Adam) uses HEAD_LR - # - matrix params in transformer blocks use MATRIX_LR via Muon - # - vectors/scalars use SCALAR_LR via Adam - block_named_params = list(base_model.blocks.named_parameters()) - matrix_params = [ - p - for name, p in block_named_params - if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - scalar_params = [ - p - for name, p in block_named_params - if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - if base_model.skip_weights.numel() > 0: - scalar_params.append(base_model.skip_weights) - token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr - optimizer_tok = torch.optim.Adam( - [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizer_muon = Muon( - matrix_params, - lr=args.matrix_lr, - momentum=args.muon_momentum, - backend_steps=args.muon_backend_steps, - ) - for group in optimizer_muon.param_groups: - group["base_lr"] = args.matrix_lr - optimizer_scalar = torch.optim.Adam( - [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] - if base_model.lm_head is not None: - optimizer_head = torch.optim.Adam( - [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizers.insert(1, optimizer_head) - - n_params = sum(p.numel() for p in base_model.parameters()) - log0(f"model_params:{n_params}") - log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") - log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") - log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") - log0( - f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " - f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " - f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" - ) - log0( - f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " - f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " - f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" - ) - log0(f"seed:{args.seed}") - - # ----------------------------- - # DATA LOADER & MODEL WARMUP - # ----------------------------- - - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - def zero_grad_all() -> None: - for opt in optimizers: - opt.zero_grad(set_to_none=True) - - max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None - - def lr_mul(step: int, elapsed_ms: float) -> float: - if args.warmdown_iters <= 0: - return 1.0 - if max_wallclock_ms is None: - warmdown_start = max(args.iterations - args.warmdown_iters, 0) - return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 - step_ms = elapsed_ms / max(step, 1) - warmdown_ms = args.warmdown_iters * step_ms - remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) - return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 - - # Warmup primes the compiled forward/backward/optimizer paths, then we restore the - # initial weights/optimizer state so measured training starts from the true init. - if args.warmup_steps > 0: - initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} - initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] - model.train() - for warmup_step in range(args.warmup_steps): - zero_grad_all() - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - warmup_loss = model(x, y) - (warmup_loss * grad_scale).backward() - for opt in optimizers: - opt.step() - zero_grad_all() - if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: - log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") - base_model.load_state_dict(initial_model_state, strict=True) - for opt, state in zip(optimizers, initial_optimizer_states, strict=True): - opt.load_state_dict(state) - zero_grad_all() - if distributed: - model.require_backward_grad_sync = True - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - # ----------------------------- - # MAIN TRAINING LOOP - # ----------------------------- - - training_time_ms = 0.0 - stop_after_step: int | None = None - torch.cuda.synchronize() - t0 = time.perf_counter() - - step = 0 - while True: - last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) - - should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) - if should_validate: - torch.cuda.synchronize() - training_time_ms += 1000.0 * (time.perf_counter() - t0) - val_loss, val_bpb = eval_val( - args, - model, - rank, - world_size, - device, - grad_accum_steps, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, - ) - log0( - f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " - f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" - ) - torch.cuda.synchronize() - t0 = time.perf_counter() - - if last_step: - if stop_after_step is not None and step < args.iterations: - log0( - f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " - f"step:{step}/{args.iterations}" - ) - break - - elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - scale = lr_mul(step, elapsed_ms) - zero_grad_all() - train_loss = torch.zeros((), device=device) - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - loss = model(x, y) - train_loss += loss.detach() - (loss * grad_scale).backward() - train_loss /= grad_accum_steps - - frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 - muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum - for group in optimizer_muon.param_groups: - group["momentum"] = muon_momentum - - for opt in optimizers: - for group in opt.param_groups: - group["lr"] = group["base_lr"] * scale - - if args.grad_clip_norm > 0: - torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) - for opt in optimizers: - opt.step() - zero_grad_all() - - step += 1 - approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - should_log_train = ( - args.train_log_every > 0 - and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) - ) - if should_log_train: - log0( - f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " - f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" - ) - - # Needed to sync whether we've reached the wallclock cap. - reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms - if distributed and max_wallclock_ms is not None: - reached_cap_tensor = torch.tensor(int(reached_cap), device=device) - dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) - reached_cap = bool(reached_cap_tensor.item()) - if stop_after_step is None and reached_cap: - stop_after_step = step - - log0( - f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " - f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" - ) - - # ----------------------------- - # SERIALIZATION + ROUNDTRIP VALIDATION - # ----------------------------- - # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce - # the compressed int8+zlib artifact and validate the round-tripped weights. - - if master_process: - torch.save(base_model.state_dict(), "final_model.pt") - model_bytes = os.path.getsize("final_model.pt") - code_bytes = len(code.encode("utf-8")) - log0(f"Serialized model: {model_bytes} bytes") - log0(f"Code size: {code_bytes} bytes") - log0(f"Total submission size: {model_bytes + code_bytes} bytes") - - quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) - quant_buf = io.BytesIO() - torch.save(quant_obj, quant_buf) - quant_raw = quant_buf.getvalue() - quant_blob = zlib.compress(quant_raw, level=9) - quant_raw_bytes = len(quant_raw) - if master_process: - with open("final_model.int8.ptz", "wb") as f: - f.write(quant_blob) - quant_file_bytes = os.path.getsize("final_model.int8.ptz") - code_bytes = len(code.encode("utf-8")) - ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) - log0( - f"Serialized model int8+zlib: {quant_file_bytes} bytes " - f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" - ) - log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") - - if distributed: - dist.barrier() - with open("final_model.int8.ptz", "rb") as f: - quant_blob_disk = f.read() - quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") - base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) - torch.cuda.synchronize() - t_qeval = time.perf_counter() - q_val_loss, q_val_bpb = eval_val( - args, - model, - rank, - world_size, - device, - grad_accum_steps, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, - ) - torch.cuda.synchronize() - log0( - f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " - f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" - ) - log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") - - if distributed: - dist.destroy_process_group() - - -if __name__ == "__main__": - main() diff --git a/records/track_non_record_16mb/2026-03-19_SwiGLU_WarmdownFix_QuarterBatch_1x5090/README.md b/records/track_non_record_16mb/2026-03-19_SwiGLU_WarmdownFix_QuarterBatch_1x5090/README.md deleted file mode 100644 index d839d5e4f9..0000000000 --- a/records/track_non_record_16mb/2026-03-19_SwiGLU_WarmdownFix_QuarterBatch_1x5090/README.md +++ /dev/null @@ -1,67 +0,0 @@ -# Non-Record Submission: SwiGLU + Warmdown Fix + Quarter Batch (1×RTX 5090) - -This is a non-record submission documenting a systematic 10-experiment exploration on a **single RTX 5090**, iterating from the stock baseline toward better val_bpb under the 16MB artifact constraint. - -The final post-quant score (**1.3281 val_bpb**) does not beat the 8×H100 baseline (1.2244) due to hardware throughput limitations (~3,773 steps vs ~13,780 on 8×H100), but the individual improvements — particularly the **warmdown schedule bug fix** — are hardware-agnostic and should transfer directly to multi-GPU runs. - -## Summary of Changes (cumulative, all kept) - -1. **SwiGLU activation** replacing ReLU² — better gating mechanism, widely adopted in modern LLMs -2. **Warmdown schedule bug fix** — stock config decays LR from step 1; fixed via time-fraction approach -3. **Reduced MLP hidden (640)** — trades params for artifact budget headroom -4. **Quarter batch size (131K tokens)** — 4× more optimizer steps in the same wall-clock time -5. **Gradient accumulation (2 steps)** — doubles effective batch without increasing per-step memory - -## Key Discovery: Warmdown Schedule Bug - -The stock `train_gpt.py` sets `warmdown_iters=1200`, but with a 600s wallclock cap the implied warmdown window exceeds total training time. This means the learning rate decays from step 1 — the model never trains at full LR. - -**Fix:** Replace iteration-based warmdown with a time-fraction approach (`warmdown_frac=0.2`), so warmdown occupies the last 20% of wall-clock time. This alone gave **-0.006 bpb** improvement. - -## Full Experiment Log - -| Exp | Description | val_bpb | Delta | Artifact (MB) | Status | -|-----|-------------|---------|-------|---------------|--------| -| 001 | Baseline (stock config) | 1.3633 | — | 12.3 | keep | -| 002 | SwiGLU MLP | 1.3592 | -0.0041 | 15.1 | keep | -| 003 | Warmdown fix (time-fraction 20%) | 1.3536 | -0.0056 | 17.9 | discard (>16MB) | -| 004 | SwiGLU(768) + warmdown fix | 1.3496 | -0.0096 | 15.4 | keep | -| 005 | Half batch (262K tokens) | 1.3336 | -0.0160 | 16.6 | discard (>16MB) | -| 006 | Half batch + MLP hidden 704 | 1.3359 | -0.0137 | 15.8 | keep | -| 007 | Quarter batch (131K) + MLP hidden 640 | 1.3305 | -0.0054 | 15.3 | keep | -| 008 | + Gradient accumulation ×2 | **1.3281** | -0.0024 | 15.3 | **best** | -| 009 | + Weight decay 0.01 | 1.3284 | +0.0002 | 15.3 | discard | -| 010 | Layer recurrence ×2 | 1.3791 | +0.0510 | 15.1 | discard | - -**Total improvement over baseline: -0.0352 bpb** (1.3633 → 1.3281) - -## Negative Results Worth Noting - -- **Weight decay** (exp009): No benefit at this scale/duration. The regularization effect is negligible for short training runs. -- **Layer recurrence** (exp010): Doubling depth by reusing weights halves the number of training steps in fixed wall-clock time, which more than offsets any capacity gain. Worst result since baseline (+0.051 bpb). - -## Configuration (Best Run — exp008) - -``` -VOCAB_SIZE=1024 NUM_LAYERS=9 MODEL_DIM=512 NUM_HEADS=8 NUM_KV_HEADS=4 -MLP_HIDDEN=640 TRAIN_BATCH_TOKENS=131072 TRAIN_SEQ_LEN=1024 -WARMDOWN_FRAC=0.2 GRAD_ACCUM_STEPS=2 -``` - -Key metrics: -- `val_bpb` (post-quant): **1.32814313** -- Artifact size: **15,327,112 bytes** (~670KB headroom) -- Model params: 16,470,088 -- Steps completed: 3,773 -- Peak memory: 10,225 MiB -- GPU: 1×RTX 5090, 600s wallclock - -## Hardware Note - -All experiments ran on a single RTX 5090 with a 10-minute wallclock cap. The throughput gap vs 8×H100 (~3.6× fewer steps) explains the score gap vs the baseline leaderboard entry. The architectural and schedule improvements documented here are hardware-agnostic and intended to be validated on 8×H100 as a next step. - -## Included Files - -- `train_gpt.py` — code snapshot of the best configuration so far (008) -- `results.tsv` — full experiment results table -- `submission.json` — leaderboard metadata diff --git a/records/track_non_record_16mb/2026-03-19_SwiGLU_WarmdownFix_QuarterBatch_1x5090/results.tsv b/records/track_non_record_16mb/2026-03-19_SwiGLU_WarmdownFix_QuarterBatch_1x5090/results.tsv deleted file mode 100644 index e24185c862..0000000000 --- a/records/track_non_record_16mb/2026-03-19_SwiGLU_WarmdownFix_QuarterBatch_1x5090/results.tsv +++ /dev/null @@ -1,11 +0,0 @@ -val_bpb artifact_bytes memory_mb notes -1.36330099 12265862 10255 baseline unmodified -1.35919834 15102915 11552 SwiGLU MLP replacing ReLU² -1.35355753 17944610 11552 warmdown fix (artifact over 16MB) -1.34958600 15369302 10702 SwiGLU(768) + warmdown fix -1.33357453 16553858 5523 half batch 262K (artifact over 16MB) -1.33591363 15811633 5354 SwiGLU(704) + half batch 262K + warmdown fix -1.33052627 15326871 2789 SwiGLU(640) + quarter batch 131K + warmdown fix -1.32814313 15327112 10225 grad_accum=2, batch 131K, SwiGLU(640), warmdown -1.32836253 15325763 10221 weight decay 0.01 (no improvement) -1.37914359 15066998 19599 layer recurrence x2 (worse, fewer steps) \ No newline at end of file diff --git a/records/track_non_record_16mb/2026-03-19_SwiGLU_WarmdownFix_QuarterBatch_1x5090/submission.json b/records/track_non_record_16mb/2026-03-19_SwiGLU_WarmdownFix_QuarterBatch_1x5090/submission.json deleted file mode 100644 index bae4218881..0000000000 --- a/records/track_non_record_16mb/2026-03-19_SwiGLU_WarmdownFix_QuarterBatch_1x5090/submission.json +++ /dev/null @@ -1,18 +0,0 @@ -{ - "author": "", - "github_id": "", - "name": "SwiGLU + Warmdown Fix + Quarter Batch (1x5090)", - "blurb": "Non-record 1xRTX5090 submission: systematic 10-experiment exploration discovering warmdown schedule bug, SwiGLU activation, quarter batch sizing, and gradient accumulation. Total improvement -0.035 bpb over stock baseline. Post-quant val_bpb 1.3281 under the 16MB artifact cap.", - "date": "2026-03-19T00:00:00Z", - "track": "non-record-16mb", - "val_loss": null, - "val_bpb": 1.32814313, - "pre_quant_val_loss": null, - "pre_quant_val_bpb": null, - "step_stop": 3773, - "wallclock_seconds": 600, - "bytes_total": 15327112, - "bytes_model_int8_zlib": null, - "bytes_code": null, - "gpu": "1xRTX5090" - } \ No newline at end of file diff --git a/records/track_non_record_16mb/2026-03-19_SwiGLU_WarmdownFix_QuarterBatch_1x5090/train_gpt.py b/records/track_non_record_16mb/2026-03-19_SwiGLU_WarmdownFix_QuarterBatch_1x5090/train_gpt.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/records/track_non_record_16mb/2026-03-21_DepthRecurrence_MixedPrecisionQuant/README.md b/records/track_non_record_16mb/2026-03-21_DepthRecurrence_MixedPrecisionQuant/README.md deleted file mode 100644 index 267417c554..0000000000 --- a/records/track_non_record_16mb/2026-03-21_DepthRecurrence_MixedPrecisionQuant/README.md +++ /dev/null @@ -1,460 +0,0 @@ -# Depth Recurrence in Parameter-Constrained Transformers: What Works, What Doesn't, and Why - -**PR #363 | Non-Record Submission (Research Contribution)** -**Author:** Evangeline Kamin ([@evangelinehelsinki](https://github.com/evangelinehelsinki), itsmeaura/Aura on Discord) -**Base:** PR #325 by Aum08Desai (1.1462 bpb) -**Duration:** 4 days, ~35 runs across 8xH100 SXM bare metal, 2xH100, RTX 3070, and A4500 pods -**Final best (looped):** 1.1787 bpb sliding window | **Flat comparison:** 1.1648 bpb | **Gap:** +0.025 bpb - ---- - -## The Short Version - -I spent four days trying to make depth-recurrent transformers competitive in Parameter Golf. They aren't. A flat 11-layer model beats a looped 3x3 model by 0.025 bpb on identical hardware with identical tricks. Three independent researchers (me, Frosty40, and Ciprian-Florin Ifrim) arrived at the same conclusion from different starting points. - -But the failure is informative, and two findings survived: **Noisy QAT** (a training technique that collapses quantization error amplification through recurrence from 0.37 bpb to 0.002 bpb) and **the 3x3 > 2x5 loop configuration** (more unique blocks with fewer repeats beats fewer blocks with more repeats, on every metric). - -This document covers 250+ hours of experiments, 12 negative results with specific numbers, and an honest post-mortem on why the "save parameters through weight sharing, spend them on more capacity" thesis doesn't work under competition constraints. If you're considering depth recurrence for Parameter Golf, read this first. It will save you days. - ---- - -## Table of Contents - -1. [How I Got Here](#how-i-got-here) -2. [The Architecture](#the-architecture) -3. [What Worked](#what-worked) -4. [The Controlled Comparison](#the-controlled-comparison) -5. [Why Recurrence Fails at This Scale](#why-recurrence-fails-at-this-scale) -6. [The Full Experiment Log](#the-full-experiment-log) -7. [Negative Results (All 12)](#negative-results-all-12) -8. [What Might Work With More Compute](#what-might-work-with-more-compute) -9. [Acknowledgments](#acknowledgments) -10. [Reproducing These Results](#reproducing-these-results) - ---- - -## How I Got Here - -On Day 0, I deployed 15 research agents to mine papers from labs in 12 countries (Chinese, Japanese, Korean, Israeli, Indian, and others) looking for approaches nobody else in the competition was trying. Depth recurrence kept coming up: Samsung's TRM, Alibaba's Huginn, Relaxed Recursive Transformers, Mixture-of-Recursions. The appeal was obvious for a size-constrained competition. If you share weights across loop iterations, you get more effective depth per byte of artifact. My first looped model on a 3070 hit 1.5630 bpb with only 6.1M params and a 4.1MB artifact. 64% fewer parameters than the baseline. I remember seeing that artifact size and thinking "this is going to crush everyone." - -It didn't. - -The gap between "this architecture is parameter-efficient" and "this architecture is competitive in a 10-minute training race" turned out to be enormous. But figuring out exactly *why* it's enormous, and documenting every attempt to close it, is (I think) more useful to the community than another 0.001 improvement on the standard 11L stack. - -### Background on me - -I'm a high school student in Phoenix. I work as a waitress. I have no formal ML background. My compute budget for this competition was about $30 out of pocket plus $170 in Hyperbolic referral credits (thank you to whoever started the referral chain in the Discord, and sorry to Hyperbolic's VCs). My development hardware ranged from an RTX 3070 to bare metal 8xH100 SXM5 nodes rented by the hour. I mention this not for sympathy points but for context: every experiment had a real dollar cost, which shaped which experiments I ran and how carefully I designed them. - -### The research pipeline - -To compensate for limited compute, I built an aggressive research pipeline: -- **15 parallel research agents** scanning recent papers, filtering for parameter-efficient training techniques relevant to the 16MB/10min constraint -- **A 26-model code review gauntlet** where I ran my training script through GPT-5, Gemini 3.1 Pro, DeepSeek V3.2, O3 Deep Research, Kimi K2.5, Claude Opus, and 20 others. This caught a critical `global _QAT_ACTIVE` bug (QAT may have never been running), env var name mismatches, torch.compile recompilation stalls, and redundant zero_grad calls. -- **Systematic PR mining**: I fetched and analyzed all 600+ competition PRs, spawning subagents to deep-dive the top submissions. This is how I tracked the converging "meta stack" and identified which techniques were worth testing on my architecture. - ---- - -## The Architecture - -### The Thesis - -Depth recurrence (reusing the same transformer blocks multiple times in a forward pass) has a long lineage: Universal Transformer (Dehghani et al., 2019), Huginn (Alibaba, 2025), Samsung TRM, and several Parameter Golf submissions including PR #325 by Aum08Desai. Share weights across loop iterations, get more effective depth per byte of artifact. In a competition with a 16MB cap, this should be a cheat code. - -### Middle-Cycle Layout - -PR #325 introduced a "Middle-Cycle" architecture that splits layers into three sections: - -``` -[Stem blocks] → [Core blocks × R repeats] → [Tail blocks] -``` - -- **Stem blocks**: Unique layers processing raw embeddings. Not shared. -- **Core blocks**: Shared layers that execute R times. This is where the parameter savings come from. -- **Tail blocks**: Unique layers producing final representations. Not shared. -- **U-Net skip connections**: Stem outputs added (with learnable weights) to tail block inputs. - -I tested two configurations extensively: - -| Config | Stem | Core | Repeats | Tail | Effective Depth | Unique Blocks | -|--------|------|------|---------|------|-----------------|---------------| -| **3x3** | 3 | 3 | 3 | 3 | 12 | 9 | -| **2x5** | 2 | 2 | 5 | 2 | 16 | 6 | - -The 2x5 was my starting point (forked from PR #325). The 3x3 came from studying Frosty40's Frugendorff architecture (PR #499), which used 6 blocks × 2 repeats. More on why 3x3 won later. - -Both configs used 640d model dimension, 8 attention heads with 4 KV heads (GQA), 3x MLP expansion, tied embeddings with vocab 1024, and SmearGate + BigramHash + RoPE from the PR #325 base. - -### Where this sits in the competition - -The meta as of ~640 PRs is flat 11-12 layer architectures at 512d. For reference: - -| PR | Score (bpb) | Approach | -|----|-------------|----------| -| #573 | 1.0523 | Multi-pass streaming legal TTT (overall leader) | -| #609 | 1.1154 | Flat 11L, XSA-all + Full GPTQ, no TTT | -| #593 | 1.1171 | Flat 11L, Parallel Muon + Full GPTQ, no TTT | -| #325 | 1.1462 | Looped 2x5, Middle-Cycle (my starting point) | -| **#363 (this PR)** | **1.1787** | **Looped 3x3, Noisy QAT + EMA + MTP** | - -My best looped result is 0.063 bpb behind the best no-TTT flat submission. That gap is the cost of recurrence under these constraints. - ---- - -## What Worked - -### 1. Noisy QAT (Original Contribution) - -This is the finding I'm most proud of and the reason this PR exists. - -**The discovery**: On Day 1, my first 8xH100 run produced a catastrophic result. Pre-quantization bpb was 2.07 (decent for the architecture). Post-quantization bpb was 3.22. A **1.14 bpb gap**. The model was learning fine but quantization was destroying it. - -Standard STE (Straight-Through Estimator) quantization-aware training simulates quantization during the forward pass. This works for flat architectures where each weight matrix is used once. But for looped architectures, quantization error compounds: the same weights get quantized once at export, but errors propagate through N repeat iterations. I measured the amplification factor at roughly **900x through 3 recurrence cycles**. Int6 starts with about 4x more error than int8, and that compounds through the loop into something catastrophic. - -**The fix**: Instead of STE fake-quantization, inject differentiable uniform noise calibrated to match the magnitude of int8 per-row quantization error: - -```python -# In CastedLinear.forward(), for loop core blocks only: -with torch.no_grad(): - amax = self.weight.float().abs().amax(dim=1, keepdim=True).clamp_min(1e-12) - step_size = amax / 127.0 -noise = (torch.rand_like(w) - 0.5) * step_size.to(w.dtype) -w = w + noise -``` - -Key properties: -- **Differentiable**: Unlike STE, gradients flow through the noise. The model learns weight configurations robust to quantization-scale perturbations. -- **Loop-aware**: Applied only to core (shared) blocks, not stem/tail. -- **Calibrated**: Noise magnitude matches int8 per-row quantization step size. Not arbitrary regularization; matched to the actual export format. - -**Result**: Quantization gap collapsed from **0.37 bpb to 0.002 bpb**. That's a 185x reduction. The technique is simple, costs nothing at inference, and should transfer to any depth-recurrent architecture. - -(An aside: on the Middle-Cycle architecture with int5 export, Noisy QAT calibrated for int8 actually hurts slightly because the noise magnitude is wrong for int5 step sizes. Matching the noise to the actual export precision is critical. See negative result #10.) - -### 2. SWA Inverts the Quantization Gap on Middle-Cycle - -This was the weirdest result. Stochastic Weight Averaging (SWA), which periodically averages model checkpoints during training, produces smoother weight distributions. On the Middle-Cycle architecture, post-quantization bpb was sometimes **better** than pre-quantization bpb. - -My hypothesis: SWA pushes weights toward flatter minima where the weight distribution is more uniform across rows. Per-row quantization handles uniform distributions well. The smoothing effect of SWA accidentally compensates for quantization noise rather than fighting it. - -This might be useful to anyone combining SWA with aggressive quantization schemes. - -### 3. 3x3 > 2x5 Loop Configuration - -This is the most practically useful finding for anyone working on looped transformers. - -I switched from 2x5 to 3x3 after studying Frosty40's Frugendorff (PR #499), which used 6 unique blocks looped only 2x. The intuition: more unique blocks with fewer repeats provides more representational diversity per parameter. - -**Controlled comparison (single GPU, identical hyperparameters):** - -| Config | Effective Depth | bpb | Artifact Size | ms/step | -|--------|----------------|-----|---------------|---------| -| **3x3** (3 core × 3 repeats) | 12 | **1.3462** | **11.9 MB** | **236** | -| 2x5 (2 core × 5 repeats) | 16 | 1.3519 | 13.2 MB | 260 | - -3x3 wins on every axis: **-0.006 bpb, -1.3 MB smaller, -24 ms/step faster**. Two shared blocks repeated 5 times gives the model only 2 distinct computational "programs" to compose. Three shared blocks repeated 3 times gives 3 distinct programs, 50% more diversity, at the cost of only one additional block's worth of parameters. - -### 4. The Training Data Shard Lesson - -This one cost me hours of debugging and I'm including it as a public service announcement. - -Midway through Day 3, I was getting 1.28 bpb on an 8xH100 VM where I'd previously gotten 1.18 on Hyperbolic bare metal. Same code, same config. I ran A/B tests, made LeakyReLU configurable, checked for code regressions. Nothing explained it. - -The root cause: **I had only downloaded 1 training shard instead of 80.** The model was memorizing that single shard and generalizing poorly to the validation set. With 80 shards: 1.1914. With 1 shard: ~1.30. A 0.1 bpb difference from training data diversity alone. - -Always use all 80 shards. Always. - ---- - -## The Controlled Comparison - -This is the definitive experiment. Same hardware (8xH100 SXM bare metal), same quantization (all-int5), same attention config (full MHA, 8 KV heads), same BigramHash (4096), same warmdown (2000), same seed, same eval pipeline (sliding window stride 64, T=0.90). - -| | Flat 11L 512d | Looped 3x3 640d | Delta | -|---|---|---|---| -| **bpb (sliding window)** | **1.1648** | 1.1894 | **+0.025** (looped worse) | -| Artifact size | 15.3 MB | 14.5 MB | -0.8 MB (looped smaller) | -| Training steps | 5375 | 4175 | -1200 steps (looped fewer) | -| ms/step | 112 | 144 | +32 ms (looped slower) | - -The looped model trains for 1200 fewer steps and each step is 32ms slower. In a 600-second time budget, this is devastating. - -Frosty40 shared his own conclusion in the Discord on the same day: *"yeah i did a ton of a/b testing and its not improving anything, it was other modifications. so now im stripping those and running a/b. the recursion in this form is a bust."* He added: *"i kept adding shit to the 'recursive layer' exciting it was getting faster, and those modifications worked anyway, layer was just wasting cycles."* - -Ciprian-Florin Ifrim, who ran 250+ experiments for his ternary submission and documented everything in a PDF I wish I'd had on Day 1, found the same. His eval depth recurrence sweep showed a total range of 0.0009 bpb across 5 different repeat counts. Pure noise. - -Three independent researchers. Three different architectures. Three different optimization approaches. Same conclusion. - ---- - -## Why Recurrence Fails at This Scale - -There are two distinct penalties. I call them the **two taxes of recurrence**. - -### Tax 1: Quantization Compounding - -Shared weights are stored once and quantized once. But during inference, quantization error propagates through every repeat iteration. For 3x3, each core block's error is seen 3 times. For 2x5, 5 times. And the errors compound nonlinearly because each iteration's output feeds into the next iteration's input. - -Noisy QAT partially addresses this (see above), but only for int8 targets. At int5 precision, the interaction between QAT noise and already-aggressive quantization becomes counterproductive. - -boreas in the Discord summarized this perfectly: *"so you can't scale the recurrence to take advantage of the smaller size because of the compounding quant tax?"* - -Exactly. - -### Tax 2: Step Time Overhead - -Each loop iteration adds wall-clock time. On 8xH100: - -- Flat 11L: 600s / 0.112s = **~5375 steps** -- Looped 3x3: 600s / 0.144s = **~4175 steps** - -That's 22% fewer training steps. In a regime where every step matters, this is a brutal penalty. - -### Why the Size Advantage Cannot Compensate - -The looped model is 0.8 MB smaller (14.5 vs 15.3 MB). Could that headroom fund higher precision to close the 0.025 bpb gap? - -No. Moving from int5 to int8 on 0.8 MB of parameters improves roughly 0.005 bpb (based on competition-wide quant deltas). That's an order of magnitude short of the 0.025 gap. The parameter savings from weight sharing are real but insufficient to offset both taxes combined. - ---- - -## The Full Experiment Log - -### Day 0: Research + 3070 Prototyping - -- Deployed 15 research agents across Chinese, Japanese, Korean, Israeli, Indian labs -- Identified depth recurrence as the unexplored lane -- Built first looped model on 3070: 1.5630 bpb, 6.1M params, 4.1MB artifact -- Ran scaling sweep on 3070: tested wide (3x3 at 768d), deep (5x3 at 512d), balanced (4x4 at 640d) -- All larger configs throughput-limited on 3070; couldn't get enough steps to converge -- Investigated custom compression (entropy analysis showed 2.94 bits/value for int6 vs 5.0-5.5 from zstd) -- Tested bit-packing, delta encoding (delta encoding was a dud), Huffman coding concepts - -### Day 1: A4500 Testing, First 8xH100, The Quantization Discovery - -- Rented 2x A4500 pods ($0.19/hr spot) for scaling sweeps -- Tested LoRA adapters on recurrence: NoLoRA won at low step counts -- BigramHash stacked well with recurrence -- SmearGate hurt recurrence (gating mechanism incompatible with shared weights) -- MTP broke badly (auxiliary gradients corrupted shared recurrent weights) -- **First 8xH100 run: catastrophic 1.14 bpb quantization gap** (pre-quant 2.07, post-quant 3.22) -- Discovered the ~900x error amplification through recurrence cycles -- **Developed Noisy QAT**: gap collapsed from 0.37 to 0.002 bpb -- Submitted PR #363 as non-record research contribution - -### Day 2: Forking PR #325, Code Review Gauntlet, Sweeps - -- Forked Node's PR #325 (looped 2x5 Middle-Cycle architecture) -- Applied batch fixes: Muon 0.99, warmdown adjustment, Partial RoPE 16/64, LN Scale, XSA last 4, Late QAT -- Discovered SWA gap inversion (post-quant sometimes better than pre-quant on Middle-Cycle) -- **26-model code review gauntlet** found the `global _QAT_ACTIVE` bug and 5 other issues -- Ran parallel hyperparameter sweeps on two 2xH100 rigs while at work -- Confirmed: EMA(0.997) ≈ SWA, warmdown 1500 > 3000 > 1000, MTP 4 heads / weight 0.3, Muon WD 0.02 -- GPTQ-lite: -0.0027 bpb (free, post-training) -- Value Residual: catastrophically incompatible with loops (+0.14 worse) -- TTT with AdamW: catastrophically overfit at lr=0.0005 (1.5636 bpb) - -### Day 3: 3x3 Beats 2x5, The Shard Lesson, Architecture Switch - -- Tested 3x3 vs 2x5 after studying Frosty40's Frugendorff: **3x3 won on every dimension** -- Lost hours debugging 1.28 bpb on 8xH100 VM; root cause was 1 training shard instead of 80 -- With 80 shards: 1.1914. With 1 shard: ~1.30. -- Best 8xH100 looped result: **1.1787 bpb** sliding window (3x3 + EMA + MTP + int5 + GPTQ-lite + T=0.90) -- Tried FIIZiK_'s techniques: stride 16 eval (-0.015 bpb, huge), T=0.90 (optimal for relu² via grid search) -- Factored embeddings at 192d: catastrophic (+0.053 regression). At 256d: still bad (+0.063) -- FIIZiK_ told me his optimal was 256 on 768d, but it doesn't transfer to our int5 setup - -### Day 4: Flat Comparison, Accepting the Data - -- Frosty40 DMs me: recursion is a bust, he's stripping it out after days of DGX Spark A/B testing -- FIIZiK_ asks if I'm on the recurrent transformer; I tell him yes, factored dims didn't work, 1.1787 -- He says: *"Well 1.18 to 1.17 is nice"* and *"I mean that's not the point of this challenge imo"* -- **Ran the controlled flat vs looped comparison**: flat 1.1600 (int6, over budget), flat 1.1648 (all-int5, fits), looped 1.1894 (same tuned config) -- Flat wins by 0.025. The loop adds ~32ms/step overhead = 1200 fewer training steps. -- Tried adding the loop back to the tuned flat config just to be sure: confirmed +0.025 penalty -- Compared against Frosty40's PR #499: his MLP 4x and 6×2 loop gave 1.1478, better than our 3×3 with 3x MLP, but his own A/B testing showed the gains came from MLP width, not the loop - -### 8xH100 Results Summary - -| Config | Sliding bpb | Steps | ms/step | Artifact | Fits? | -|--------|------------|-------|---------|----------|-------| -| Flat 11L tuned (fullMHA+bg4096+wd2000, all-int5) | **1.1648** | 5375 | 112 | 15.3MB | YES | -| Flat 11L baseline (GQA, bg2048, wd1500, all-int5) | 1.1671 | 5550 | 108 | 15.0MB | YES | -| Flat 11L (int6, over budget) | 1.1600 | 5550 | 108 | 17.2MB | NO | -| Looped 3x3 best (EMA+MTP+int5+GPTQ-lite) | 1.1787 | 4200 | 143 | 15.6MB | YES | -| Looped 3x3 tuned (same config as flat winner) | 1.1894 | 4175 | 144 | 14.5MB | YES | -| Looped 2x5 (original PR #325 fork, 3-seed mean) | 1.1834 | 4200 | 143 | 15.6MB | YES | - -### Hyperparameter Sweeps (2xH100) - -All sweeps on 2xH100 with 1 data shard. Directionally reliable but absolute numbers are higher than 8xH100. - -**EMA x Warmdown** (20 combinations, most corrupted by torch.compile recompilation): -- Best surviving: EMA 0.996, Warmdown 2000 = 1.2910 bpb - -**MTP (Multi-Token Prediction)**: - -| MTP Heads | Loss Weight | bpb | -|-----------|-------------|-----| -| **4** | **0.3** | **1.2974** | -| 6 | 0.3 | 1.3010 | -| 2 | 0.3 | 1.3045 | - -**Muon Weight Decay** (lower is better for looped, opposite to flat convention): - -| WD | bpb | Delta | -|----|-----|-------| -| **0.02** | **1.2955** | baseline | -| 0.04 | 1.2983 | +0.003 | -| 0.06 | 1.3060 | +0.011 | - -Hypothesis: weight decay on shared parameters has an outsized effect because those weights are used in every loop iteration. Aggressive decay compounds through the loop just like quantization error. - ---- - -## Negative Results (All 12) - -Every failed experiment, with specific numbers. This section may be the most useful part of this writeup. - -### 1. XSA on All Layers (Looped) - -XSA applied to all blocks including loop core on every repeat: **+0.001 worse** (1.1953 vs 1.1940). On a looped architecture, "all layers" means the shared core blocks get XSA on every repeat. Too aggressive. The standard 11L stack benefits because its "all 11 layers" means 11 *unique* computations. Our "all layers" means 3 unique computations, each repeated 3 times. Very different. - -### 2. Cyclic Muon Momentum (0.85-0.95, period 50) - -Reported as -0.0045 bpb on flat architectures (PR #623). Combined with XSA and QuadgramHash: **+0.058 worse** (catastrophic). The momentum drops below the warmup target (0.85), destabilizing looped convergence. Looped architectures amplify optimizer instability because perturbations compound through repeat iterations. - -### 3. QuadgramHash (1024 buckets, dim 32) - -Tested alongside cyclic momentum and XSA. Could not isolate. When the combined test came back +0.058 worse, there wasn't compute budget to test each independently. Inconclusive. - -### 4. Factored Embeddings (EMBED_DIM 192 and 256) - -FIIZiK_ used EMBED_DIM=254 on his 768d ternary model and called it "very small loss." But his architecture is fundamentally different (ternary weights, 8192 vocab). On our int5 setup with vocab 1024: - -| EMBED_DIM | Ratio | bpb | Delta | Artifact | -|-----------|-------|-----|-------|----------| -| 640 (none) | 100% | 1.1787 | baseline | 15.6MB | -| 256 | 40% | 1.2416 | **+0.063** | 14.8MB | -| 192 | 30% | 1.2316 | **+0.053** | 16.4MB (OVER) | - -Both terrible. With a 1024-token vocabulary, the embedding table is already small (1024 × 512 = 0.5M params). Compressing it further saves negligible parameters while destroying representation quality. Factored embeddings only make sense with large vocabularies (FIIZiK_ uses 8192). - -### 5. Value Residual (ResFormer) - -Reported as -0.015 bpb on flat architectures (PRs #486/#490). On looped: **+0.14 worse** (1.4378 bpb). Catastrophic. Even with initialization fix (lambda init at -4.0, so sigmoid(-4.0) ≈ 0.018 = almost no mixing initially). - -In a looped architecture, the "first layer V" is from the stem, but the loop core sees it on every iteration. The V residual creates an increasingly stale reference as depth increases, and the shared weights cannot learn different mixing ratios for different repeat iterations. Value Residual assumes each layer has a unique position in the network; shared layers violate that assumption. - -### 6. Progressive Loop Unrolling (2 → 5 repeats) - -Start training with 2 loop repeats, linearly increase to 5. Broke DDP. Dynamic control flow is incompatible with torch.compile + DistributedDataParallel. Single-GPU test: **2172 ms/step** (9x slower than baseline 236 ms/step). The compile graph breaks on every repeat-count change, triggering full recompilation. - -### 7. Sawtooth LR Schedule - -Caused torch.compile recompilation **every step** because the LR change triggers a guard check. Step time went from 248 ms to **987 ms** (4x slowdown). Only 607 steps completed. Results were garbage. - -Same root cause as #6: anything that changes a value torch.compile traces through causes recompilation. LR schedules must be implemented outside the compiled region. - -### 8. Test-Time Training (Full-Weight) - -829 steps of AdamW on validation data: **1.56 bpb** vs 1.38 baseline. Massive overfitting. GPTQ-quantized weights sit in narrow curvature-aligned minima that AdamW's adaptive learning rates destroy. TTT and aggressive quantization are fundamentally at odds unless using SGD or carefully constrained LoRA. - -(Per-document LoRA TTT was implemented but DDP crashes prevented proper multi-GPU testing. Still on the to-do list.) - -### 9. LeakyReLU(0.5)² - -Reported as -0.003 on flat architectures. Showed **-0.003 improvement on 2xH100** (1-shard) but **negligible on 8xH100** (80-shard). The benefit may be data-regime-dependent: with 1 shard the model sees less diversity, and leaky activation's gradient flow through negative values helps; with 80 shards the model learns to route around dead ReLU regions naturally. - -**Always validate single-GPU findings on the target hardware.** - -### 10. Late QAT + int5 - -Enable QAT in the final 10% of steps, combined with int5 export: **+0.006 worse**. QAT calibrated for int8 noise is the wrong magnitude for int5 export. The model gets trained to be robust to int8-scale perturbations but actually faces int5-scale perturbations at export. Matching QAT noise to export precision is critical. - -### 11. BigramHash(10240) - -Reported as -0.070 bpb on flat 11L (PR #450). On looped: **no improvement** (1.2980 vs 1.2963 on 2xH100). Hypothesis: the looped architecture already gets some n-gram-like pattern recognition from seeing data multiple times through the loop. The additional bigram capacity is redundant with what the loop provides. - -### 12. 704d Model Dimension - -Increase from 640d to 704d for more capacity per block: **worse** on 2xH100. Fewer steps at higher ms/step. The wider model doesn't train enough in 10 minutes to compensate for increased per-step cost. - ---- - -## What Might Work With More Compute - -Honest speculation, clearly labeled. - -### Longer Training Budgets - -The fundamental issue is that looped models trade step count for effective depth. In 10 minutes, this trade is unfavorable. At 30+ minutes (or unlimited track), the step-count penalty shrinks while the parameter-efficiency advantage grows. PR #612 achieves 1.1079 bpb on the unlimited (100-min) track with a GEPA architecture. Looped architectures may be competitive at longer time horizons where the "Tax 2" (step time overhead) becomes less dominant. - -### Adaptive Depth at Inference - -If the model could choose how many loop iterations per token, easy tokens could exit early and hard tokens could iterate longer. This is the Universal Transformer's original proposal. The challenge: making this compatible with torch.compile and batched inference, both of which demand static computation graphs. - -### Noisy QAT Matched to Export Precision - -Our Noisy QAT was calibrated for int8 (step_size = amax / 127.0) but we exported at int5. A version calibrated for int5 noise (step_size = amax / 15.0) might close the gap. We ran out of compute to test this. - -### Better Loop Designs - -The 3x3 > 2x5 finding suggests the optimal configuration isn't obvious. Asymmetric loops (more stem than tail), heterogeneous repeat counts (repeat block 1 more than block 2), or attention on first and last repeat only with MLP-only middle repeats are all unexplored. - ---- - -## Acknowledgments - -- **Aum08Desai** (PR #325): The Middle-Cycle architecture and original 1.1462 bpb looped submission. -- **Frosty40** (PR #499, "The Frugendorff"): For sharing his negative results on recursion openly, both in DMs and in the public Discord. His honest assessment ("the recursion in this form is a bust... I kept adding [] to the 'recursive layer' exciting it was getting faster, and those modifications worked anyway, layer was just wasting cycles") saved me and others significant compute. -- **[Ciprian-Florin Ifrim](https://github.com/CiprianFlorin)** (PRs #640/#641): The most thorough experiment documentation in the competition (250+ experiments). His suggestions on eval stride 16, temperature scaling (T=0.90 for relu² — note this is activation-dependent, found via grid search, not a universal default; SwiGLU architectures use T=1.0 since the tail is sharper), factored embeddings, and z-loss directly shaped my experiments. His 250-experiment PDF is a masterclass in systematic ML research. -- **boreas**: For summarizing the core tension better than I could ("so you can't scale the recurrence to take advantage of the smaller size because of the compounding quant tax?"). Exactly. -- **Node / capitlism** (PR #325): For open-sourcing the looped transformer that started this whole investigation and telling people to "feel free to optimize." -- **The flat no-TTT SOTA authors** (PRs #609, #593, #606): The reference points that define what the standard stack can achieve, and indirectly, the ceiling that recurrence has to beat to be worth using. -- **OpenAI / Will DePue**: For sponsoring compute credits, actively answering questions in Discord, and creating a competition that explicitly rewards honest research alongside leaderboard performance. Will's comment that "people aren't being nearly ambitious enough" is what pushed me to continue working on the looped architecture in the first place. -- **Hyperbolic**: For the referral credits that made this possible. Sorry to your VCs. -- **The entire Parameter Golf community** (~640 PRs of shared knowledge): This competition's culture of open experimentation made this work possible. Seeing fbe_dev share his results in real-time, watching the referral credit meta-game unfold, and getting direct coaching from top competitors is not something I expected from an ML competition. - ---- - -## Reproducing These Results - -Training script: `pr325_train_gpt.py` - -Key environment variables for the controlled comparison: - -```bash -# Flat 11L 512d (best submittable: 1.1648 bpb) -NUM_LAYERS=11 MODEL_DIM=512 LOOP_CORE_LAYERS=0 LOOP_REPEATS=1 \ -MLP_INT5=1 ATTN_INT5=1 NUM_HEADS=8 NUM_KV_HEADS=8 \ -BIGRAM_VOCAB_SIZE=4096 WARMDOWN_ITERS=2000 \ -EVAL_TEMPERATURE=0.90 EVAL_STRIDE=64 SEED=42 - -# Looped 3x3 640d (1.1894 bpb on same config) -NUM_LAYERS=9 MODEL_DIM=640 LOOP_CORE_LAYERS=3 LOOP_REPEATS=3 \ -MLP_INT5=1 ATTN_INT5=1 NUM_HEADS=8 NUM_KV_HEADS=8 \ -BIGRAM_VOCAB_SIZE=4096 WARMDOWN_ITERS=2000 \ -EVAL_TEMPERATURE=0.90 EVAL_STRIDE=64 SEED=42 -``` - -Both use `MAX_WALLCLOCK_SECONDS=600` on 8xH100 SXM with 80 training shards. - ---- - -## Final Thoughts - -I set out to prove that depth recurrence could be competitive in Parameter Golf. I failed. But I think the failure is worth more than another 0.001 improvement on the standard stack. - -The two taxes, quantization compounding and step-time overhead, are structural. They are not hyperparameter problems or implementation bugs. They are consequences of the competition's constraints: a fixed time budget that penalizes slower steps, and an artifact size limit that forces aggressive quantization where shared weights compound errors. - -Noisy QAT is, to my knowledge, a novel contribution. The idea that loop-core weights should be trained with noise calibrated to quantization error is simple, effective for int8 targets, and should transfer to any depth-recurrent architecture. The 0.37 → 0.002 bpb gap collapse is the strongest single result in this work. - -The 3x3 > 2x5 finding is immediately actionable: prefer more unique blocks with fewer repeats. - -Everything else is a negative result. I believe documenting these honestly is more valuable than cherry-picking the one configuration where looped models look competitive. When boreas asked "what sort of things did you try?" in the Discord, and Frosty40 warned "DO NOT FRUGENDORFF it just wastes cycles," I realized that the most useful thing I could do was write all of this down so the next person doesn't have to spend 4 days and $200 learning the same lessons. - -If someone finds a way to make recurrence work under these constraints, these failures will save them time. If the gap turns out to be fundamental at this scale, this document explains why. - ---- - -*Best looped: 1.1787 bpb (3x3, 8xH100, sliding window) | Best flat: 1.1648 bpb (11L, same hardware) | Controlled gap: +0.025 bpb (looped worse)* diff --git a/records/track_non_record_16mb/2026-03-21_DepthRecurrence_MixedPrecisionQuant/requirements.txt b/records/track_non_record_16mb/2026-03-21_DepthRecurrence_MixedPrecisionQuant/requirements.txt deleted file mode 100644 index 472714a640..0000000000 --- a/records/track_non_record_16mb/2026-03-21_DepthRecurrence_MixedPrecisionQuant/requirements.txt +++ /dev/null @@ -1,11 +0,0 @@ -numpy -tqdm -torch -huggingface-hub -kernels -setuptools -typing-extensions==4.15.0 -datasets -tiktoken -sentencepiece -zstandard diff --git a/records/track_non_record_16mb/2026-03-21_DepthRecurrence_MixedPrecisionQuant/submission.json b/records/track_non_record_16mb/2026-03-21_DepthRecurrence_MixedPrecisionQuant/submission.json deleted file mode 100644 index 4c494fa52c..0000000000 --- a/records/track_non_record_16mb/2026-03-21_DepthRecurrence_MixedPrecisionQuant/submission.json +++ /dev/null @@ -1,10 +0,0 @@ -{ - "track": "non_record_16mb", - "date": "2026-03-21", - "name": "Depth Recurrence + Mixed-Precision Quantization", - "author": "Evangeline Kamin", - "github_id": "evangelinehelsinki", - "val_bpb": 2.3876, - "val_loss": 4.0314, - "bytes_total": 1461542 -} diff --git a/records/track_non_record_16mb/2026-03-21_DepthRecurrence_MixedPrecisionQuant/train_gpt.py b/records/track_non_record_16mb/2026-03-21_DepthRecurrence_MixedPrecisionQuant/train_gpt.py deleted file mode 100644 index bd73d1d471..0000000000 --- a/records/track_non_record_16mb/2026-03-21_DepthRecurrence_MixedPrecisionQuant/train_gpt.py +++ /dev/null @@ -1,1274 +0,0 @@ -""" -Parameter Golf - Competition Submission - -Single self-contained training script: Recurrent GPT with BigramHash + XSA. - -Architecture: - - NUM_UNIQUE_BLOCKS unique transformer blocks cycled to EFFECTIVE_DEPTH virtual layers - - Per-depth LoRA adapters (DEPTH_LORA_RANK, 0=off) - - Encoder/decoder split with U-Net skip connections - - BigramHash: hash consecutive token pairs into embedding table - - XSA (Exclusive Self Attention) on last N layers - - Late STE QAT (togglable via QAT_FRACTION) - -Run: torchrun --standalone --nproc_per_node=1 train_submission.py -""" - -from __future__ import annotations - -import copy -import glob -import io -import math -import os -import random -import subprocess -import sys -import time -import uuid -import zlib -from pathlib import Path - -try: - import zstandard as zstd - _HAS_ZSTD = True -except ImportError: - _HAS_ZSTD = False - -import numpy as np -import sentencepiece as spm -import torch -import torch.distributed as dist -import torch.nn.functional as F -from torch import Tensor, nn -from torch.nn.parallel import DistributedDataParallel as DDP - -# --------------------------------------------------------------------------- -# Global QAT flag -- toggled mid-training by wallclock fraction -# --------------------------------------------------------------------------- -_QAT_ACTIVE = False - -# --------------------------------------------------------------------------- -# Hyperparameters -# --------------------------------------------------------------------------- - -CONTROL_TENSOR_NAME_PATTERNS = tuple( - p for p in os.environ.get( - "CONTROL_TENSOR_NAME_PATTERNS", - "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", - ).split(",") if p -) -INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( - p for p in os.environ.get( - "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", - ",".join(CONTROL_TENSOR_NAME_PATTERNS), - ).split(",") if p -) -INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 -INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 -INT8_PER_ROW_SCALE_DTYPE = torch.float16 -INT8_CLIP_PERCENTILE = 99.99984 -INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 - - -class Hyperparameters: - data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") - train_files = os.path.join(data_path, "fineweb_train_*.bin") - val_files = os.path.join(data_path, "fineweb_val_*.bin") - tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") - run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) - seed = int(os.environ.get("SEED", 1337)) - - val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) - val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) - train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) - - iterations = int(os.environ.get("ITERATIONS", 20000)) - warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) - warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) - train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) - train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) - max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) - qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) - - vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) - num_layers = int(os.environ.get("NUM_LAYERS", 9)) - num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) - model_dim = int(os.environ.get("MODEL_DIM", 512)) - num_heads = int(os.environ.get("NUM_HEADS", 8)) - mlp_mult = int(os.environ.get("MLP_MULT", 2)) - tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) - rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) - logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) - - embed_lr = float(os.environ.get("EMBED_LR", 0.6)) - head_lr = float(os.environ.get("HEAD_LR", 0.008)) - tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) - tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) - matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) - scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) - muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) - muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) - muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) - muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) - beta1 = float(os.environ.get("BETA1", 0.9)) - beta2 = float(os.environ.get("BETA2", 0.95)) - adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) - grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) - - # Recurrence - num_unique_blocks = int(os.environ.get("NUM_UNIQUE_BLOCKS", 3)) - effective_depth = int(os.environ.get("EFFECTIVE_DEPTH", 9)) - depth_lora_rank = int(os.environ.get("DEPTH_LORA_RANK", 4)) - - # BigramHash - bigram_buckets = int(os.environ.get("BIGRAM_BUCKETS", 2048)) - bigram_dim = int(os.environ.get("BIGRAM_DIM", 128)) - - # XSA - xsa_last_n = int(os.environ.get("XSA_LAST_N", 3)) - - # QAT - qat_fraction = float(os.environ.get("QAT_FRACTION", 0.0)) - - # Muon weight decay - muon_weight_decay = float(os.environ.get("MUON_WEIGHT_DECAY", 0.0)) - - -# --------------------------------------------------------------------------- -# Muon Optimizer -# --------------------------------------------------------------------------- - -def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: - a, b, c = (3.4445, -4.7750, 2.0315) - X = G.bfloat16() - X /= X.norm() + eps - transposed = G.size(0) > G.size(1) - if transposed: - X = X.T - for _ in range(steps): - A = X @ X.T - B = b * A + c * A @ A - X = a * X + B @ X - return X.T if transposed else X - - -class Muon(torch.optim.Optimizer): - def __init__(self, params, lr: float, momentum: float, backend_steps: int, - nesterov: bool = True, weight_decay: float = 0.0): - super().__init__( - params, - dict(lr=lr, momentum=momentum, backend_steps=backend_steps, - nesterov=nesterov, weight_decay=weight_decay), - ) - - @torch.no_grad() - def step(self, closure=None): - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - distributed = dist.is_available() and dist.is_initialized() - world_size = dist.get_world_size() if distributed else 1 - rank = dist.get_rank() if distributed else 0 - - for group in self.param_groups: - params = group["params"] - if not params: - continue - lr = group["lr"] - momentum = group["momentum"] - backend_steps = group["backend_steps"] - nesterov = group["nesterov"] - wd = group["weight_decay"] - - total_params = sum(int(p.numel()) for p in params) - updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) - - curr = 0 - for i, p in enumerate(params): - if i % world_size == rank and p.grad is not None: - g = p.grad - if wd > 0: - g = g.add(p.data, alpha=wd) - state = self.state[p] - if "momentum_buffer" not in state: - state["momentum_buffer"] = torch.zeros_like(g) - buf = state["momentum_buffer"] - buf.mul_(momentum).add_(g) - if nesterov: - g = g.add(buf, alpha=momentum) - g = zeropower_via_newtonschulz5(g, steps=backend_steps) - g *= max(1, g.size(0) / g.size(1)) ** 0.5 - updates_flat[curr : curr + p.numel()] = g.reshape(-1) - curr += p.numel() - - if distributed: - dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) - - curr = 0 - for p in params: - g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) - p.add_(g, alpha=-lr) - curr += p.numel() - - return loss - - -# --------------------------------------------------------------------------- -# Tokenizer / Evaluation Helpers -# --------------------------------------------------------------------------- - -def build_sentencepiece_luts( - sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device -) -> tuple[Tensor, Tensor, Tensor]: - sp_vocab_size = int(sp.vocab_size()) - table_size = max(sp_vocab_size, vocab_size) - base_bytes_np = np.zeros((table_size,), dtype=np.int16) - has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) - is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) - for token_id in range(sp_vocab_size): - if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): - continue - is_boundary_token_np[token_id] = False - if sp.is_byte(token_id): - base_bytes_np[token_id] = 1 - continue - piece = sp.id_to_piece(token_id) - if piece.startswith("\u2581"): - has_leading_space_np[token_id] = True - piece = piece[1:] - base_bytes_np[token_id] = len(piece.encode("utf-8")) - return ( - torch.tensor(base_bytes_np, dtype=torch.int16, device=device), - torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), - torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), - ) - - -def load_data_shard(file: Path) -> Tensor: - header_bytes = 256 * np.dtype(" Tensor: - files = [Path(p) for p in sorted(glob.glob(pattern))] - if not files: - raise FileNotFoundError(f"No files found for pattern: {pattern}") - tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() - usable = ((tokens.numel() - 1) // seq_len) * seq_len - if usable <= 0: - raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") - return tokens[: usable + 1] - - -def eval_val( - args: Hyperparameters, model: nn.Module, rank: int, world_size: int, - device: torch.device, grad_accum_steps: int, val_tokens: Tensor, - base_bytes_lut: Tensor, has_leading_space_lut: Tensor, is_boundary_token_lut: Tensor, -) -> tuple[float, float]: - local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) - if local_batch_tokens < args.train_seq_len: - raise ValueError( - f"VAL_BATCH_SIZE too small: {args.val_batch_size} for world={world_size} " - f"accum={grad_accum_steps} seq={args.train_seq_len}" - ) - local_batch_seqs = local_batch_tokens // args.train_seq_len - total_seqs = (val_tokens.numel() - 1) // args.train_seq_len - seq_start = (total_seqs * rank) // world_size - seq_end = (total_seqs * (rank + 1)) // world_size - val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) - val_token_count = torch.zeros((), device=device, dtype=torch.float64) - val_byte_count = torch.zeros((), device=device, dtype=torch.float64) - - model.eval() - with torch.inference_mode(): - for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): - batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) - raw_start = batch_seq_start * args.train_seq_len - raw_end = batch_seq_end * args.train_seq_len + 1 - local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) - x = local[:-1].reshape(-1, args.train_seq_len) - y = local[1:].reshape(-1, args.train_seq_len) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - batch_loss = model(x, y).detach() - batch_token_count = float(y.numel()) - val_loss_sum += batch_loss.to(torch.float64) * batch_token_count - val_token_count += batch_token_count - prev_ids = x.reshape(-1) - tgt_ids = y.reshape(-1) - token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) - token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) - val_byte_count += token_bytes.to(torch.float64).sum() - - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) - - val_loss = val_loss_sum / val_token_count - bits_per_token = val_loss.item() / math.log(2.0) - tokens_per_byte = val_token_count.item() / val_byte_count.item() - model.train() - return float(val_loss.item()), float(bits_per_token * tokens_per_byte) - - -# --------------------------------------------------------------------------- -# Quantization (int8 + zlib) -# --------------------------------------------------------------------------- - -def tensor_nbytes(t: Tensor) -> int: - return int(t.numel()) * int(t.element_size()) - - -def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: - if any(p in name for p in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): - return t.float().contiguous() - if t.dtype in {torch.float32, torch.bfloat16}: - passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") - return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() - return t - - -def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: - t32 = t.float() - if t32.ndim == 2: - clip_abs = ( - torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) - if t32.numel() - else torch.empty((t32.shape[0],), dtype=torch.float32) - ) - clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) - scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) - q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() - return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() - clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 - scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) - q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() - return q, scale - - -def quantize_state_dict_int8(state_dict: dict[str, Tensor]): - quantized: dict[str, Tensor] = {} - scales: dict[str, Tensor] = {} - dtypes: dict[str, str] = {} - passthrough: dict[str, Tensor] = {} - passthrough_orig_dtypes: dict[str, str] = {} - qmeta: dict[str, dict[str, object]] = {} - stats = dict.fromkeys( - ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", - "baseline_tensor_bytes", "int8_payload_bytes"), 0, - ) - for name, tensor in state_dict.items(): - t = tensor.detach().to("cpu").contiguous() - stats["param_count"] += int(t.numel()) - stats["num_tensors"] += 1 - stats["baseline_tensor_bytes"] += tensor_nbytes(t) - if not t.is_floating_point(): - stats["num_nonfloat_tensors"] += 1 - passthrough[name] = t - stats["int8_payload_bytes"] += tensor_nbytes(t) - continue - if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: - kept = keep_float_tensor(name, t, passthrough_orig_dtypes) - passthrough[name] = kept - stats["int8_payload_bytes"] += tensor_nbytes(kept) - continue - stats["num_float_tensors"] += 1 - q, s = quantize_float_tensor(t) - if s.ndim > 0: - qmeta[name] = {"scheme": "per_row", "axis": 0} - quantized[name] = q - scales[name] = s - dtypes[name] = str(t.dtype).removeprefix("torch.") - stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) - - obj: dict[str, object] = { - "__quant_format__": "int8_clean_per_row_v1", - "quantized": quantized, "scales": scales, "dtypes": dtypes, - "passthrough": passthrough, - } - if qmeta: - obj["qmeta"] = qmeta - if passthrough_orig_dtypes: - obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes - return obj, stats - - -def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: - out: dict[str, Tensor] = {} - qmeta = obj.get("qmeta", {}) - passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) - for name, q in obj["quantized"].items(): - dtype = getattr(torch, obj["dtypes"][name]) - s = obj["scales"][name] - if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: - s = s.to(dtype=torch.float32) - out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() - else: - out[name] = (q.float() * float(s.item())).to(dtype=dtype).contiguous() - for name, t in obj["passthrough"].items(): - out_t = t.detach().to("cpu").contiguous() - orig_dtype = passthrough_orig_dtypes.get(name) - if isinstance(orig_dtype, str): - out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() - out[name] = out_t - return out - - -# --------------------------------------------------------------------------- -# Int6 Quantization + zstd Compression -# --------------------------------------------------------------------------- - -INT6_MAX_VAL = 31 -INT6_CLIP_Q = 99.99984 / 100.0 - -def quantize_float_tensor_int6(t: Tensor) -> tuple[Tensor, Tensor]: - """Quantize a float tensor to 6-bit signed integers [-31, 31] with per-row scaling.""" - t32 = t.float() - if t32.ndim == 2: - clip_abs = ( - torch.quantile(t32.abs(), INT6_CLIP_Q, dim=1) - if t32.numel() - else torch.empty((t32.shape[0],), dtype=torch.float32) - ) - clipped = torch.clamp(t32, -clip_abs[:, None], clip_abs[:, None]) - scale = (clip_abs / INT6_MAX_VAL).clamp_min(1.0 / INT6_MAX_VAL) - q = torch.clamp(torch.round(clipped / scale[:, None]), -INT6_MAX_VAL, INT6_MAX_VAL).to(torch.int8).contiguous() - return q, scale.to(dtype=torch.float16).contiguous() - clip_abs = float(torch.quantile(t32.abs().flatten(), INT6_CLIP_Q).item()) if t32.numel() else 0.0 - scale = torch.tensor(clip_abs / INT6_MAX_VAL if clip_abs > 0 else 1.0, dtype=torch.float32) - q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -INT6_MAX_VAL, INT6_MAX_VAL).to(torch.int8).contiguous() - return q, scale - - -def quantize_state_dict_int6(state_dict: dict[str, Tensor]): - """Mixed-precision quantization: int8 for shared block weights (recurrence-sensitive), - int6 for single-use tensors (embeddings, bigram, etc.), fp16 for small/control tensors. - - Recurrent models amplify quantization error ~900x per cycle. Using int8 for shared - blocks reduces this amplified error by 4x vs int6, at minimal artifact cost. - """ - # Shared block weights get int8 (reused in recurrence, error amplifies) - # Everything else gets int6 (used once, can tolerate more noise) - BLOCK_PATTERNS = ("blocks.",) - - quantized: dict[str, Tensor] = {} - scales: dict[str, Tensor] = {} - dtypes: dict[str, str] = {} - passthrough: dict[str, Tensor] = {} - passthrough_orig_dtypes: dict[str, str] = {} - qmeta: dict[str, dict[str, object]] = {} - stats = dict.fromkeys( - ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", - "baseline_tensor_bytes", "payload_bytes"), 0, - ) - for name, tensor in state_dict.items(): - t = tensor.detach().to("cpu").contiguous() - stats["param_count"] += int(t.numel()) - stats["num_tensors"] += 1 - stats["baseline_tensor_bytes"] += tensor_nbytes(t) - if not t.is_floating_point(): - stats["num_nonfloat_tensors"] += 1 - passthrough[name] = t - stats["payload_bytes"] += tensor_nbytes(t) - continue - if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: - kept = keep_float_tensor(name, t, passthrough_orig_dtypes) - passthrough[name] = kept - stats["payload_bytes"] += tensor_nbytes(kept) - continue - stats["num_float_tensors"] += 1 - # Use int8 for all tensors in recurrent models — int6 error amplifies - # through weight-sharing cycles. We have artifact headroom to spare. - q, s = quantize_float_tensor(t) # int8 for everything - if s.ndim > 0: - qmeta[name] = {"scheme": "per_row", "axis": 0} - quantized[name] = q - scales[name] = s - dtypes[name] = str(t.dtype).removeprefix("torch.") - stats["payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) - obj: dict[str, object] = { - "__quant_format__": "mixed_int8_int6_v1", - "quantized": quantized, "scales": scales, "dtypes": dtypes, - "passthrough": passthrough, - } - if qmeta: - obj["qmeta"] = qmeta - if passthrough_orig_dtypes: - obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes - return obj, stats - - -def dequantize_state_dict_int6(obj: dict[str, object]) -> dict[str, Tensor]: - """Dequantize int6 state dict back to float tensors.""" - # Same logic as int8 — scale multiplication works identically - return dequantize_state_dict_int8(obj) - - -def compress_bytes(data: bytes) -> bytes: - """Compress using zstd-22 if available, otherwise zlib-9.""" - if _HAS_ZSTD: - cctx = zstd.ZstdCompressor(level=22) - return b"ZSTD" + cctx.compress(data) - return b"ZLIB" + zlib.compress(data, level=9) - - -def decompress_bytes(data: bytes) -> bytes: - """Decompress, auto-detecting zstd vs zlib from header.""" - if data[:4] == b"ZSTD": - if not _HAS_ZSTD: - raise RuntimeError("zstandard package required to decompress ZSTD data") - dctx = zstd.ZstdDecompressor() - return dctx.decompress(data[4:]) - if data[:4] == b"ZLIB": - return zlib.decompress(data[4:]) - # Legacy: no header, assume zlib - return zlib.decompress(data) - - -# --------------------------------------------------------------------------- -# Data Loading -# --------------------------------------------------------------------------- - -class TokenStream: - def __init__(self, pattern: str): - self.files = [Path(p) for p in sorted(glob.glob(pattern))] - if not self.files: - raise FileNotFoundError(f"No files found for pattern: {pattern}") - self.file_idx = 0 - self.tokens = load_data_shard(self.files[0]) - self.pos = 0 - - def _advance_file(self) -> None: - self.file_idx = (self.file_idx + 1) % len(self.files) - self.tokens = load_data_shard(self.files[self.file_idx]) - self.pos = 0 - - def take(self, n: int) -> Tensor: - chunks: list[Tensor] = [] - remaining = n - while remaining > 0: - avail = self.tokens.numel() - self.pos - if avail <= 0: - self._advance_file() - continue - k = min(remaining, avail) - chunks.append(self.tokens[self.pos : self.pos + k]) - self.pos += k - remaining -= k - return chunks[0] if len(chunks) == 1 else torch.cat(chunks) - - -class DistributedTokenLoader: - def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): - self.rank = rank - self.world_size = world_size - self.device = device - self.stream = TokenStream(pattern) - - def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: - local_tokens = global_tokens // (self.world_size * grad_accum_steps) - per_rank_span = local_tokens + 1 - chunk = self.stream.take(per_rank_span * self.world_size) - start = self.rank * per_rank_span - local = chunk[start : start + per_rank_span].to(dtype=torch.int64) - x = local[:-1].reshape(-1, seq_len) - y = local[1:].reshape(-1, seq_len) - return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) - - -# --------------------------------------------------------------------------- -# STE Fake-Quantize for QAT -# --------------------------------------------------------------------------- - -def _fake_quantize_int6_ste(w: Tensor) -> Tensor: - """Straight-through estimator fake int6 quantize for 2D weight matrices.""" - INT6_MAX = 31.0 - with torch.no_grad(): - amax = w.abs().amax(dim=1, keepdim=True).clamp_min(1e-12) - scale = amax / INT6_MAX - q = torch.clamp(torch.round(w / scale), -INT6_MAX, INT6_MAX) - # STE: forward uses quantized, backward passes through - return w + (q * scale - w).detach() - - -# --------------------------------------------------------------------------- -# Transformer Modules -# --------------------------------------------------------------------------- - -class RMSNorm(nn.Module): - def __init__(self, eps: float | None = None): - super().__init__() - self.eps = eps - - def forward(self, x: Tensor) -> Tensor: - return F.rms_norm(x, (x.size(-1),), eps=self.eps) - - -class CastedLinear(nn.Linear): - def forward(self, x: Tensor) -> Tensor: - w = self.weight - if self.training and _QAT_ACTIVE and w.ndim == 2: - w = _fake_quantize_int6_ste(w) - bias = self.bias.to(x.dtype) if self.bias is not None else None - return F.linear(x, w.to(x.dtype), bias) - - -def restore_low_dim_params_to_fp32(module: nn.Module) -> None: - with torch.no_grad(): - for name, param in module.named_parameters(): - if (param.ndim < 2 or any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: - param.data = param.data.float() - - -class Rotary(nn.Module): - def __init__(self, dim: int, base: float = 10000.0): - super().__init__() - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self._seq_len_cached = 0 - self._cos_cached: Tensor | None = None - self._sin_cached: Tensor | None = None - - def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: - if ( - self._cos_cached is None - or self._sin_cached is None - or self._seq_len_cached != seq_len - or self._cos_cached.device != device - ): - t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) - freqs = torch.outer(t, self.inv_freq.to(device)) - self._cos_cached = freqs.cos()[None, None, :, :] - self._sin_cached = freqs.sin()[None, None, :, :] - self._seq_len_cached = seq_len - return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) - - -def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: - half = x.size(-1) // 2 - x1, x2 = x[..., :half], x[..., half:] - return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) - - -class CausalSelfAttention(nn.Module): - def __init__(self, dim: int, num_heads: int, num_kv_heads: int, - rope_base: float, qk_gain_init: float): - super().__init__() - if dim % num_heads != 0: - raise ValueError("model_dim must be divisible by num_heads") - if num_heads % num_kv_heads != 0: - raise ValueError("num_heads must be divisible by num_kv_heads") - self.num_heads = num_heads - self.num_kv_heads = num_kv_heads - self.head_dim = dim // num_heads - if self.head_dim % 2 != 0: - raise ValueError("head_dim must be even for RoPE") - kv_dim = self.num_kv_heads * self.head_dim - self.c_q = CastedLinear(dim, dim, bias=False) - self.c_k = CastedLinear(dim, kv_dim, bias=False) - self.c_v = CastedLinear(dim, kv_dim, bias=False) - self.proj = CastedLinear(dim, dim, bias=False) - self.proj._zero_init = True - self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) - self.rotary = Rotary(self.head_dim, base=rope_base) - self.use_xsa = False - - def forward(self, x: Tensor, q_delta=None, v_delta=None) -> Tensor: - bsz, seqlen, dim = x.shape - q = self.c_q(x) + (q_delta if q_delta is not None else 0) - k = self.c_k(x) - v = self.c_v(x) + (v_delta if v_delta is not None else 0) - q = q.reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) - k = k.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) - v = v.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) - q = F.rms_norm(q, (q.size(-1),)) - k = F.rms_norm(k, (k.size(-1),)) - cos, sin = self.rotary(seqlen, x.device, q.dtype) - q = apply_rotary_emb(q, cos, sin) - k = apply_rotary_emb(k, cos, sin) - q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] - y = F.scaled_dot_product_attention( - q, k, v, attn_mask=None, is_causal=True, - enable_gqa=(self.num_kv_heads != self.num_heads), - ) - # y shape: (bsz, num_heads, seqlen, head_dim) - # XSA: project out the value-aligned component - if self.use_xsa: - y = y.transpose(1, 2) # (bsz, seqlen, num_heads, head_dim) - v_for_xsa = v.transpose(1, 2) # (bsz, seqlen, num_kv_heads, head_dim) - y = self._xsa_efficient(y, v_for_xsa) - y = y.reshape(bsz, seqlen, dim) - else: - y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) - return self.proj(y) - - def _xsa_efficient(self, y: Tensor, v: Tensor) -> Tensor: - """Exclusive Self Attention: remove value-aligned component from attention output.""" - B, T, H, D = y.shape - Hkv = v.size(-2) - group = H // Hkv - y_g = y.reshape(B, T, Hkv, group, D) - vn = F.normalize(v, dim=-1).unsqueeze(-2) # (B, T, Hkv, 1, D) - proj = (y_g * vn).sum(dim=-1, keepdim=True) * vn - return (y_g - proj).reshape(B, T, H, D) - - -class MLP(nn.Module): - def __init__(self, dim: int, mlp_mult: int): - super().__init__() - hidden = mlp_mult * dim - self.fc = CastedLinear(dim, hidden, bias=False) - self.proj = CastedLinear(hidden, dim, bias=False) - self.proj._zero_init = True - - def forward(self, x: Tensor) -> Tensor: - x = torch.relu(self.fc(x)) - return self.proj(x.square()) - - -class Block(nn.Module): - def __init__(self, dim: int, num_heads: int, num_kv_heads: int, - mlp_mult: int, rope_base: float, qk_gain_init: float): - super().__init__() - self.attn_norm = RMSNorm() - self.mlp_norm = RMSNorm() - self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) - self.mlp = MLP(dim, mlp_mult) - self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) - - def forward(self, x: Tensor, x0: Tensor, q_delta_fn=None, v_delta_fn=None) -> Tensor: - mix = self.resid_mix.to(dtype=x.dtype) - x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 - n = self.attn_norm(x) - qd = q_delta_fn(n) if q_delta_fn is not None else None - vd = v_delta_fn(n) if v_delta_fn is not None else None - attn_out = self.attn(n, qd, vd) - x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out - x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) - return x - - -# --------------------------------------------------------------------------- -# Per-Depth LoRA Adapter -# --------------------------------------------------------------------------- - -class DepthLoRA(nn.Module): - def __init__(self, dim: int, q_dim: int, v_dim: int, rank: int): - super().__init__() - if rank <= 0: - self.enabled = False - return - self.enabled = True - self.q_down = nn.Linear(dim, rank, bias=False) - self.q_up = nn.Linear(rank, q_dim, bias=False) - self.v_down = nn.Linear(dim, rank, bias=False) - self.v_up = nn.Linear(rank, v_dim, bias=False) - nn.init.zeros_(self.q_up.weight) - nn.init.zeros_(self.v_up.weight) - - def q_delta(self, x: Tensor) -> Tensor: - if not self.enabled: - return torch.zeros_like(x) - return self.q_up(self.q_down(x)) - - def v_delta(self, x: Tensor) -> Tensor: - if not self.enabled: - return torch.zeros_like(x) - return self.v_up(self.v_down(x)) - - -# --------------------------------------------------------------------------- -# Recurrent GPT with BigramHash + XSA -# --------------------------------------------------------------------------- - -class RecurrentGPT(nn.Module): - def __init__( - self, - vocab_size: int, - num_unique_blocks: int, - effective_depth: int, - depth_lora_rank: int, - model_dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: int, - tie_embeddings: bool, - tied_embed_init_std: float, - logit_softcap: float, - rope_base: float, - qk_gain_init: float, - bigram_buckets: int = 2048, - bigram_dim: int = 128, - xsa_last_n: int = 3, - ): - super().__init__() - if effective_depth < num_unique_blocks: - raise ValueError( - f"effective_depth ({effective_depth}) must be >= num_unique_blocks ({num_unique_blocks})" - ) - if logit_softcap <= 0.0: - raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") - - self.tie_embeddings = tie_embeddings - self.tied_embed_init_std = tied_embed_init_std - self.logit_softcap = logit_softcap - self.num_unique_blocks = num_unique_blocks - self.effective_depth = effective_depth - self.model_dim = model_dim - self.tok_emb = nn.Embedding(vocab_size, model_dim) - - # Encoder/decoder split - self.num_encoder_layers = effective_depth // 2 - self.num_decoder_layers = effective_depth - self.num_encoder_layers - self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) - self.skip_weights = nn.Parameter( - torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32) - ) - - # Shared transformer blocks - self.blocks = nn.ModuleList([ - Block(model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init) - for _ in range(num_unique_blocks) - ]) - - # Per-depth LoRA adapters - kv_dim = num_kv_heads * (model_dim // num_heads) - self.depth_adapters = nn.ModuleList([ - DepthLoRA(model_dim, q_dim=model_dim, v_dim=kv_dim, rank=depth_lora_rank) - for _ in range(effective_depth) - ]) - - self.final_norm = RMSNorm() - self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) - if self.lm_head is not None: - self.lm_head._zero_init = True - - # Cycle pattern - self.block_schedule = [i % num_unique_blocks for i in range(effective_depth)] - - # BigramHash - self.bigram_emb = nn.Embedding(bigram_buckets, bigram_dim) - self.bigram_proj = CastedLinear(bigram_dim, model_dim, bias=False) - self.bigram_buckets_val = bigram_buckets - nn.init.normal_(self.bigram_emb.weight, mean=0.0, std=0.02) - nn.init.zeros_(self.bigram_proj.weight) - - # XSA: enable on last N virtual layers - self.xsa_last_n = xsa_last_n - if xsa_last_n > 0: - for i in range(max(0, effective_depth - xsa_last_n), effective_depth): - block_idx = self.block_schedule[i] - self.blocks[block_idx].attn.use_xsa = True - - self._init_weights() - - def _init_weights(self) -> None: - if self.tie_embeddings: - nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) - for module in self.modules(): - if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): - nn.init.zeros_(module.weight) - - def forward(self, input_ids: Tensor, target_ids: Tensor, lora=None) -> Tensor: - x = self.tok_emb(input_ids) - - # BigramHash: add bigram embeddings - prev_tokens = torch.cat([input_ids[:, :1], input_ids[:, :-1]], dim=1) - bigram_ids = (prev_tokens.long() * 1000003 + input_ids.long()) % self.bigram_buckets_val - bigram_out = self.bigram_proj(self.bigram_emb(bigram_ids)) - x = x + bigram_out - - x = F.rms_norm(x, (x.size(-1),)) - x0 = x - skips: list[Tensor] = [] - - # Encoder - for i in range(self.num_encoder_layers): - block_idx = self.block_schedule[i] - adapter = self.depth_adapters[i] - qd_fn = adapter.q_delta if adapter.enabled else None - vd_fn = adapter.v_delta if adapter.enabled else None - x = self.blocks[block_idx](x, x0, qd_fn, vd_fn) - skips.append(x) - - # Decoder - for i in range(self.num_decoder_layers): - vi = self.num_encoder_layers + i - block_idx = self.block_schedule[vi] - adapter = self.depth_adapters[vi] - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - qd_fn = adapter.q_delta if adapter.enabled else None - vd_fn = adapter.v_delta if adapter.enabled else None - x = self.blocks[block_idx](x, x0, qd_fn, vd_fn) - - x = self.final_norm(x) - if self.tie_embeddings: - logits = F.linear(x, self.tok_emb.weight) - else: - logits = self.lm_head(x) - logits = self.logit_softcap * torch.tanh(logits / self.logit_softcap) - return F.cross_entropy( - logits.float().reshape(-1, logits.size(-1)), - target_ids.reshape(-1), - reduction="mean", - ) - - -# --------------------------------------------------------------------------- -# Training -# --------------------------------------------------------------------------- - -def main() -> None: - global zeropower_via_newtonschulz5, _QAT_ACTIVE - - code = Path(__file__).read_text(encoding="utf-8") - args = Hyperparameters() - zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) - - # Distributed + CUDA setup - distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ - rank = int(os.environ.get("RANK", "0")) - world_size = int(os.environ.get("WORLD_SIZE", "1")) - local_rank = int(os.environ.get("LOCAL_RANK", "0")) - if 8 % world_size != 0: - raise ValueError(f"WORLD_SIZE={world_size} must divide 8") - grad_accum_steps = 8 // world_size - grad_scale = 1.0 / grad_accum_steps - if not torch.cuda.is_available(): - raise RuntimeError("CUDA is required") - device = torch.device("cuda", local_rank) - torch.cuda.set_device(device) - if distributed: - dist.init_process_group(backend="nccl", device_id=device) - dist.barrier() - master_process = rank == 0 - - torch.backends.cuda.matmul.allow_tf32 = True - torch.backends.cudnn.allow_tf32 = True - from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp - enable_cudnn_sdp(False) - enable_flash_sdp(True) - enable_mem_efficient_sdp(False) - enable_math_sdp(False) - - logfile = None - if master_process: - os.makedirs("logs", exist_ok=True) - logfile = f"logs/{args.run_id}.txt" - print(logfile) - - def log0(msg: str, console: bool = True) -> None: - if not master_process: - return - if console: - print(msg) - if logfile is not None: - with open(logfile, "a", encoding="utf-8") as f: - print(msg, file=f) - - log0(code, console=False) - log0("=" * 100, console=False) - log0(f"=== SUBMISSION: Recurrent GPT + BigramHash + XSA ===") - log0(f"num_unique_blocks:{args.num_unique_blocks} effective_depth:{args.effective_depth} " - f"depth_lora_rank:{args.depth_lora_rank}") - log0(f"bigram_buckets:{args.bigram_buckets} bigram_dim:{args.bigram_dim}") - log0(f"xsa_last_n:{args.xsa_last_n} qat_fraction:{args.qat_fraction}") - log0(f"muon_weight_decay:{args.muon_weight_decay}") - - # Seeding - random.seed(args.seed) - np.random.seed(args.seed) - torch.manual_seed(args.seed) - torch.cuda.manual_seed_all(args.seed) - - # Tokenizer + validation - sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) - val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) - base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( - sp, args.vocab_size, device - ) - - # Model - base_model = RecurrentGPT( - vocab_size=args.vocab_size, - num_unique_blocks=args.num_unique_blocks, - effective_depth=args.effective_depth, - depth_lora_rank=args.depth_lora_rank, - model_dim=args.model_dim, - num_heads=args.num_heads, - num_kv_heads=args.num_kv_heads, - mlp_mult=args.mlp_mult, - tie_embeddings=args.tie_embeddings, - tied_embed_init_std=args.tied_embed_init_std, - logit_softcap=args.logit_softcap, - rope_base=args.rope_base, - qk_gain_init=args.qk_gain_init, - bigram_buckets=args.bigram_buckets, - bigram_dim=args.bigram_dim, - xsa_last_n=args.xsa_last_n, - ).to(device).bfloat16() - for module in base_model.modules(): - if isinstance(module, CastedLinear): - module.float() - if isinstance(module, Rotary): - module.inv_freq.data = module.inv_freq.data.float() - restore_low_dim_params_to_fp32(base_model) - - compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) - model = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model - - # Optimizer setup - block_named_params = list(base_model.blocks.named_parameters()) - matrix_params = [ - p for name, p in block_named_params - if p.ndim == 2 and not any(pat in name for pat in CONTROL_TENSOR_NAME_PATTERNS) - ] - scalar_params = [ - p for name, p in block_named_params - if p.ndim < 2 or any(pat in name for pat in CONTROL_TENSOR_NAME_PATTERNS) - ] - if base_model.skip_weights.numel() > 0: - scalar_params.append(base_model.skip_weights) - - # Depth adapter params - for name, p in base_model.depth_adapters.named_parameters(): - if p.ndim == 2: - matrix_params.append(p) - else: - scalar_params.append(p) - - # BigramHash: embedding to token optimizer, projection to Muon - bigram_embed_params = [base_model.bigram_emb.weight] - matrix_params.append(base_model.bigram_proj.weight) - - token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr - tok_param_list = [base_model.tok_emb.weight] + bigram_embed_params - optimizer_tok = torch.optim.Adam( - [{"params": tok_param_list, "lr": token_lr, "base_lr": token_lr}], - betas=(args.beta1, args.beta2), eps=args.adam_eps, fused=True, - ) - optimizer_muon = Muon( - matrix_params, lr=args.matrix_lr, momentum=args.muon_momentum, - backend_steps=args.muon_backend_steps, weight_decay=args.muon_weight_decay, - ) - for group in optimizer_muon.param_groups: - group["base_lr"] = args.matrix_lr - optimizer_scalar = torch.optim.Adam( - [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], - betas=(args.beta1, args.beta2), eps=args.adam_eps, fused=True, - ) - optimizers = [optimizer_tok, optimizer_muon, optimizer_scalar] - if base_model.lm_head is not None: - optimizer_head = torch.optim.Adam( - [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], - betas=(args.beta1, args.beta2), eps=args.adam_eps, fused=True, - ) - optimizers.insert(1, optimizer_head) - - n_params = sum(p.numel() for p in base_model.parameters()) - n_unique_block_params = sum(p.numel() for p in base_model.blocks.parameters()) - n_adapter_params = sum(p.numel() for p in base_model.depth_adapters.parameters()) - log0(f"total_params:{n_params} unique_block_params:{n_unique_block_params} " - f"adapter_params:{n_adapter_params}") - log0(f"block_schedule:{base_model.block_schedule}") - log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") - log0(f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len}") - - # Data loader - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - def zero_grad_all(): - for opt in optimizers: - opt.zero_grad(set_to_none=True) - - max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None - - def lr_mul(step, elapsed_ms): - if args.warmdown_iters <= 0: - return 1.0 - if max_wallclock_ms is None: - warmdown_start = max(args.iterations - args.warmdown_iters, 0) - if warmdown_start <= step < args.iterations: - return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) - return 1.0 - step_ms = elapsed_ms / max(step, 1) - warmdown_ms = args.warmdown_iters * step_ms - remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) - if remaining_ms <= warmdown_ms: - return remaining_ms / max(warmdown_ms, 1e-9) - return 1.0 - - # Warmup - if args.warmup_steps > 0: - initial_model_state = {n: t.detach().cpu().clone() for n, t in base_model.state_dict().items()} - initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] - model.train() - for warmup_step in range(args.warmup_steps): - zero_grad_all() - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - warmup_loss = model(x, y) - (warmup_loss * grad_scale).backward() - for opt in optimizers: - opt.step() - zero_grad_all() - if warmup_step + 1 == args.warmup_steps or (warmup_step + 1) % 10 == 0: - log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") - base_model.load_state_dict(initial_model_state, strict=True) - for opt, state in zip(optimizers, initial_optimizer_states, strict=True): - opt.load_state_dict(state) - zero_grad_all() - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - # Training loop - training_time_ms = 0.0 - stop_after_step = None - qat_activated = False - torch.cuda.synchronize() - t0 = time.perf_counter() - - step = 0 - while True: - last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) - - should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) - if should_validate: - torch.cuda.synchronize() - training_time_ms += 1000.0 * (time.perf_counter() - t0) - val_loss, val_bpb = eval_val( - args, model, rank, world_size, device, grad_accum_steps, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - ) - log0(f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " - f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms") - torch.cuda.synchronize() - t0 = time.perf_counter() - - if last_step: - if stop_after_step is not None and step < args.iterations: - log0(f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms step:{step}/{args.iterations}") - break - - elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - - # Late QAT activation - if args.qat_fraction > 0 and not qat_activated and max_wallclock_ms is not None: - qat_start_ms = max_wallclock_ms * (1.0 - args.qat_fraction) - if elapsed_ms >= qat_start_ms: - _QAT_ACTIVE = True - qat_activated = True - log0(f"QAT activated at step:{step} elapsed:{elapsed_ms:.0f}ms") - - scale = lr_mul(step, elapsed_ms) - zero_grad_all() - train_loss = torch.zeros((), device=device) - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - loss = model(x, y) - train_loss += loss.detach() - (loss * grad_scale).backward() - train_loss /= grad_accum_steps - - frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 - muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum - for group in optimizer_muon.param_groups: - group["momentum"] = muon_momentum - - for opt in optimizers: - for group in opt.param_groups: - group["lr"] = group["base_lr"] * scale - - if args.grad_clip_norm > 0: - torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) - for opt in optimizers: - opt.step() - zero_grad_all() - - step += 1 - approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - should_log = args.train_log_every > 0 and (step <= 10 or step % args.train_log_every == 0) - if should_log: - log0(f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " - f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms") - - reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms - if distributed and max_wallclock_ms is not None: - reached_cap_tensor = torch.tensor(int(reached_cap), device=device) - dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) - reached_cap = bool(reached_cap_tensor.item()) - if stop_after_step is None and reached_cap: - stop_after_step = step - - # Deactivate QAT for serialization - _QAT_ACTIVE = False - - log0(f"peak memory: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB") - - # Serialize: int6 + zstd (fallback to int8 + zlib if zstd unavailable) - artifact_path = "final_model_submission.ptz" - if master_process: - state = base_model.state_dict() - quant_obj, quant_stats = quantize_state_dict_int6(state) - quant_buf = io.BytesIO() - torch.save(quant_obj, quant_buf) - quant_blob = compress_bytes(quant_buf.getvalue()) - with open(artifact_path, "wb") as f: - f.write(quant_blob) - code_bytes = len(code.encode("utf-8")) - compressor = "zstd-22" if _HAS_ZSTD else "zlib-9" - ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["payload_bytes"], 1) - total_artifact = len(quant_blob) + code_bytes - log0(f"artifact: {len(quant_blob)} bytes + {code_bytes} code = {total_artifact} total " - f"(compressor:{compressor} quant:int6 payload_ratio:{ratio:.2f}x)") - if total_artifact > 16_000_000: - log0(f"WARNING: artifact {total_artifact} exceeds 16,000,000 byte cap by {total_artifact - 16_000_000} bytes!") - else: - log0(f"artifact headroom: {16_000_000 - total_artifact} bytes ({(16_000_000 - total_artifact)/1e6:.3f}MB)") - - # Roundtrip validation - if distributed: - dist.barrier() - with open(artifact_path, "rb") as f: - quant_blob_disk = f.read() - quant_state = torch.load(io.BytesIO(decompress_bytes(quant_blob_disk)), map_location="cpu") - base_model.load_state_dict(dequantize_state_dict_int6(quant_state), strict=True) - torch.cuda.synchronize() - t_qeval = time.perf_counter() - q_val_loss, q_val_bpb = eval_val( - args, model, rank, world_size, device, grad_accum_steps, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - ) - torch.cuda.synchronize() - log0(f"final_int6_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " - f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms") - log0(f"final_int6_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") - - if distributed: - dist.destroy_process_group() - - -if __name__ == "__main__": - main() diff --git a/records/track_non_record_16mb/2026-03-24_106M_Binary_Asymmetric_UNet_FP8_15L_8192BPE_YaRN_NeoMuon_Smear/README.md b/records/track_non_record_16mb/2026-03-24_106M_Binary_Asymmetric_UNet_FP8_15L_8192BPE_YaRN_NeoMuon_Smear/README.md deleted file mode 100644 index e121c6c254..0000000000 --- a/records/track_non_record_16mb/2026-03-24_106M_Binary_Asymmetric_UNet_FP8_15L_8192BPE_YaRN_NeoMuon_Smear/README.md +++ /dev/null @@ -1,162 +0,0 @@ -# Notable Non-Record Submission: 1.1239 BPB — 106.2 Asymmetric Binary U-Net Transformer - -**1-bit Quantisation + 15L (7 Encoder - 8 Decoder) + NeoMuon + 4x relu² MLP + SmearGate + Factored Tied Embedding + Poly5 Softcap + YaRN 2048 + 8192 BPE + FP8 QAT + LZMA + Stride-16 Sliding Eval** - -**val_bpb: 1.1239** (sliding, seed=42) | **15.67 MB** artifact | 8×H100 SXM, 50k steps (~2.15h) - -> **This is a **non-record submission** — training exceeds the 10-minute wallclock constraint (50,000 steps / ~2.15 hours). Submitted to demonstrate the compression frontier: 106.2 parameters in 15.67MB via 1-bit quantisation. Over 120M possible with FP4 (implemented) with a worse bpb. Full experiment log: [RESULTS.md](RESULTS.md). Complete training logs: [logs/](https://github.com/CiprianFlorin-Ifrim/openai-parameter-golf-submission/tree/main/logs/cuda).** - -## Results (seed=42, 8×H100 SXM) - -| Metric | Value | -|--------|-------| -| Sliding BPB (s16) | **1.1239** | -| val_bpb | 1.1497 | -| RT bpb | 1.1516 | -| Steps | 50,000 | -| ms/step | 155.3 | -| Training time | 7,763s (~2.15h) | -| optimal_T | 0.90 | -| Artifact | 15,670,651 bytes (15.67MB) | -| Parameters | 106,154,616 | - -### Comparison to Ternary Submission - -Binary reaches better absolute quality but requires circa 13x more training time. Within the 10-minute budget, binary's best fitting run (14L, 4,820 steps) scores 1.1824 sliding — 0.025 bpb worse than ternary (my previous record PR). The zero state is worth more at convergence than the 60% parameter density advantage. - -The results document linked here and in my repo showcases all methods and sweeps applied to both Binary and Ternary Bitnets, which unfortunately are incompatible with many methods, such as Tversky Layers, EMA, Muon WD, LM Logit Head ranking and many more. - -## Architecture - -- 15 transformer layers, dim=768, 8 heads, 4 KV heads (GQA), head_dim=96 -- Binary quantisation: weights {-1, +1}, 1 bit/param, per-group (128) absmean scaling -- 4x MLP expansion (hidden=3072) with **relu²** activation, fused gate+up projection -- U-Net encoder/decoder with learned skip weights (ones-init) and per-block residual mix from input embedding -- **SmearGate:** causal cumulative mean blending with learned tanh gate, zero-init for safe residual start -- Factored tied embedding: 8192×254 bottleneck with learned projections -- Polynomial softcap (degree 5, cap=10) with Z-loss regularisation (1e-4) -- YaRN positional encoding (max_len=2048, ROPE_BASE=5000) -- Fused QKV projection -- FlashAttention-3 (Hopper native kernels) -- 106.2M parameters, 15.67MB artifact (97.3M binary + 2.5M fp8 + 70KB code) - -## Key Techniques - -### Architecture -- **Binary quantisation:** 1 bit/param packs 60% more parameters per MB than ternary (1.6 bits/param), allowing 15 layers vs 10 within similar budget -- **4x relu² MLP:* relu² strictly dominates relu; 4x width outperforms 3x even with fewer layers at matched budget -- **SmearGate:** blends each position with causal cumulative mean; adds 22ms/step overhead but provides -0.007 bpb at scale. Viable here because the run is not wallclock-constrained - -### Training -- **NeoMuon** with 3 Newton-Schulz steps optimizer -- **50,000 steps unconstrained:** binary converges slower than ternary (my other #640, at 4,000 steps (the 10-minute equivalent) binary lags by 0.025 bpb. Extended training closes the gap and surpasses ternary, showcasing with "unlimited compute" the models can be quite powerful. -- **524k batch tokens:** - -### Evaluation -- **Temperature scaling (T=0.90):** auto-calibrated grid -- **Sliding window (stride=16):** evaluation protocol - -### Compression -- **Bit-packing + LZMA (preset=9):** binary weights pack at exactly 1 bit/param before LZMA entropy coding -- **FP8 QAT (e4m3):** for non-binary parameters. Clean roundtrip, binary has no zero state, so `mean(|Q|)=1.0` always; no shrinkage correction needed -- **No EMA:** despite clean binary roundtrip math, EMA still hurts quality by 0.03 bpb in practice - -## Setup and Run - -```bash -# Environment setup (conda + Python 3.13 + PyTorch + FlashAttention-3 + Triton + dataset) -bash setup.sh - -# Activate and run -conda activate golf -SEED=42 bash run_cuda_binary.sh -``` - -
-Full run command - -```bash -RUN_ID=binary_run \ -DATA_PATH=./data/datasets/fineweb10B_sp8192 \ -TOKENIZER_PATH=./data/tokenizers/fineweb_8192_bpe.model \ -ATTN_PROJ_TYPE=standard \ -LOGIT_HEAD_TYPE=standard \ -TVERSKY_MEMBERSHIP=sigmoid \ -TVERSKY_NUM_FEATURES=0 \ -TVERSKY_FEATURE_POOLS=0 \ -VOCAB_SIZE=8192 \ -BITNET_GROUP_SIZE=128 \ -BIGRAM_HASH=0 \ -EMBED_DIM=254 \ -TRAINING_DEPTH_RECURRENCE=0 \ -EVAL_DEPTH_RECURRENCE=0 \ -NUM_LAYERS=15 \ -MODEL_DIM=768 \ -NUM_KV_HEADS=4 \ -NUM_HEADS=8 \ -DIFF_ATTN=0 \ -MLP_MULT=4 \ -MLP_GROUPS=0 \ -MATRIX_OPTIMIZER=muon \ -ADAM_LR=0.05 \ -ADAM_WD=0.05 \ -MUON_BACKEND_STEPS=3 \ -MUON_MOMENTUM=0.95 \ -MUON_MOMENTUM_WARMUP_START=0.85 \ -MUON_MOMENTUM_WARMUP_STEPS=500 \ -MUON_WD=0.0 \ -MATRIX_LR=0.04 \ -SCALAR_LR=0.02 \ -TIED_EMBED_LR=0.02 \ -WARMDOWN_FRACTION=0.2 \ -LOGIT_SOFTCAP=10 \ -QK_GAIN_INIT=2.25 \ -ROPE_TYPE=yarn \ -YARN_MAX_LEN=2048 \ -ROPE_BASE=5000 \ -BATCH_TOKENS_START=0 \ -BATCH_SCHEDULE_FRACTION=0.33 \ -TRAIN_BATCH_TOKENS=524288 \ -SEQ_LEN_START=0 \ -SEQ_SCHEDULE_FRACTION=0.0 \ -TRAIN_SEQ_LEN=1024 \ -SMEAR=1 \ -ITERATIONS=50000 \ -WARMUP_STEPS=5 \ -MAX_WALLCLOCK_SECONDS=0 \ -VAL_LOSS_EVERY=0 \ -TRAIN_LOG_EVERY=500 \ -CHURN_LOG_EVERY=1000 \ -VAL_MAX_TOKENS=0 \ -TIE_EMBEDDINGS=1 \ -UNTIE_AT_FRACTION=0.00 \ -HEAD_LR=0.02 \ -CORR_WEIGHT_LR=0.02 \ -ACTIVATION=relu2 \ -SOFTCAP_TYPE=poly \ -MTP_HEADS=0 \ -REFINER=0 \ -REFINER_KERNEL=3 \ -SLIDING_EVAL=1 \ -SLIDING_EVAL_STRIDE=16 \ -SLIDING_BATCH_SIZE=256 \ -TEMP_SCALING=1 \ -FP_STORAGE=FP8 \ -EMA=0 \ -EMA_DECAY=0.995 \ -EMA_START_FRACTION=0.5 \ -SEED=42 \ -COMPILE_MODE=default \ -OMP_NUM_THREADS=1 torchrun --standalone --nproc_per_node=8 train_gpt_cuda_binary.py -``` - -
- -## Compliance - -- [x] Artifact <=16,000,000 bytes (15,670,651) -- [x] Sliding window eval stride=16 -- [x] No test-time training on validation data -- [x] No network calls during evaluation -- [x] No external compute -- [x] Train time: **non-record submission** (7,763s/ 2.2h / 50,000 steps) diff --git a/records/track_non_record_16mb/2026-03-24_106M_Binary_Asymmetric_UNet_FP8_15L_8192BPE_YaRN_NeoMuon_Smear/RESULTS.md b/records/track_non_record_16mb/2026-03-24_106M_Binary_Asymmetric_UNet_FP8_15L_8192BPE_YaRN_NeoMuon_Smear/RESULTS.md deleted file mode 100644 index 82fcd581f0..0000000000 --- a/records/track_non_record_16mb/2026-03-24_106M_Binary_Asymmetric_UNet_FP8_15L_8192BPE_YaRN_NeoMuon_Smear/RESULTS.md +++ /dev/null @@ -1,1236 +0,0 @@ -# Parameter Golf — Complete Experiment Log - -**Author:** Ciprian-Florin Ifrim -**Date:** March 2026 - ---- - -## Challenge Overview - -Train the best language model that fits in a 16MB artifact and trains in under 10 minutes on 8×H100 SXM GPUs, evaluated by tokenizer-agnostic bits-per-byte (BPB) compression on the FineWeb validation set. - -- **Baseline:** 1.2244 bpb (9L 512d int8+zlib, 1k vocab) -- **Our best (ternary, valid):** 1.1565 bpb sliding (P2, 10L 768d relu² 4×MLP fp8, EMBED_DIM=254, seed=42, 16.00MB) -- **Our best (binary, unconstrained):** 1.1239 bpb sliding (15L 768d binary relu² 4×MLP fp8, 50k steps / ~2h compute, 15.67MB) -- **Our best (quality, over budget):** 1.1771 bpb (F59, 12L 768d swiglu 3×MLP, 21.96MB) -- **Challenge period:** March 18 – April 30, 2026 -- **Compute sponsor:** OpenAI ($1M in compute credits) - -The challenge is framed as L(N) optimisation — minimising loss given fixed parameter count N, unconstrained by data, compute, steps, or architecture. Related challenges include NanoGPT Speedrun (L(T): lowest loss given constrained time) and NanoGPT Slowrun (L(D): lowest loss given constrained dataset). - ---- - -## Run Numbering Convention - -| Prefix | Description | -|--------|-------------| -| Plain (1–100) | Dev runs on RTX 5090, 100 steps | -| R prefix (R1...) | Record runs — 600s on 8×H100, leaderboard-targeted | -| S prefix (S1...) | Scaling runs — 1500 steps or 300s on 8×H100, controlled sweeps | -| SB prefix (SB1...) | Binary scaling runs | -| F prefix (F1...) | Final runs — 600s on 8×H100, official submissions | -| P prefix (P1...) | Pushed/submission runs — final config pushed to GitHub | - -Additionally, 20 early architecture iterations were performed on MLX (Mac Studio M1 Ultra, 32GB unified memory) and 2 on MPS (MacBook Pro M1 Pro, 32GB unified memory) for rapid prototyping before GPU scaling. - -> **Note:** This document covers ~85 named runs (F, S, R series). An additional ~165 dev runs (plain numbered 1–100, repeated sweeps, smoke tests) were conducted but are not individually listed. Key findings from those runs are incorporated into the sweep tables and decision rationale. Separate synthetic-data notebooks were used to isolate the behaviour of specific techniques (Tversky similarity, linear alternatives, grouped projections) before committing H100 compute. - ---- - -## Hardware - -| System | Spec | Notes | -|--------|------|-------| -| Dev | RTX 5090 32GB, single GPU | Triton smem ceiling 101KB/SM; blocks value embeddings and some kernels | -| Mac (MLX) | Mac Studio M1 Ultra 32GB | MLX early iteration, 20 runs | -| Mac (MPS) | MacBook Pro M1 Pro 32GB | MPS early iteration, 2 runs | -| Final | 8×H100 SXM 80GB | Primary training platform | - -**Step times at 768d (12L):** relu² 2x: 89ms | relu² 3x: 99ms | relu² 4x: 91ms | swiglu 3x: 127ms | leaky relu 3x: 103ms - -**Step times at 512d:** 26L baseline: 149ms → 136ms with FA3 → 127ms with FA3 + fusions + EMBED=256 at 25L - -**FlashAttention-3** reduced step time by ~9% (~380 free training steps per 600s run). - -**Kernel fusion optimisations** (fused QKV + fused SwiGLU + dataloader + softcap) saved a further ~7-10ms/step. - -**Width vs depth discovery:** 12L 768d at 106ms/step gets ~5640 steps in 600s vs ~4720 steps for 25L 512d — 920 extra steps from the faster per-step time of wider/shallower models. Final 10L 768d 4×MLP at 91.8ms/step gets ~6530 steps. - ---- - -## Architecture: Ternary U-Net Transformer - -### Quantisation Scheme - -BitNet b1.58 ternary quantisation — weights constrained to {−1, 0, +1} with per-group absmean scaling. Approximately 1.6 bits per parameter. - -**Compression pipeline:** Base-3 packing (5 trits/byte) or bitmask packing → LZMA (preset=9). Best method auto-selected per run. Bitmask wins when zero fraction is high. - -**Quantisation shrinkage fix:** When ternary Q contains zeros, `mean(|Q|) < 1.0`, causing scale mismatch on reload. Fix: inflate by `1/mean(|Q|)` during dequantisation. Eliminates all roundtrip gaps. - -### U-Net Skip Connections - -The model uses a U-Net style encoder/decoder structure with learned skip connections. The first `num_layers // 2` blocks (encoder) store their outputs; the second half (decoder) receives these via `x = x + skip_weight[i] * skips.pop()`. This allows the decoder to simultaneously access high-level semantic representations (from deep processing) and low-level token-level features (from early processing), without requiring the decoder to reconstruct low-level information from the compressed residual stream. - -Additionally, each block receives `x0` (the original input embedding) via a learned residual mix: `x = mix[0] * x + mix[1] * x0`, giving every layer direct access to the raw token representation regardless of accumulated residual drift. - -For odd layer counts, the decoder receives the larger half (e.g. 27L → 13 encoder + 14 decoder), which is the standard U-Net convention — more processing power applied after skip injection. - -### Factored Embedding - -With `EMBED_DIM=254`, token embedding is `[8192, 254]` instead of `[8192, 768]`, with learned projections `embed_proj` (254→768) and `embed_proj_rev` (768→254) for the tied output head. - -**EMBED_DIM history:** Started at 128 (dev runs), upgraded to 256 after an optimizer coverage fix revealed that the projection layers had not been receiving gradients (−0.024 bpb improvement vs 128 once trained), then trimmed to 254 to fit artifact+code under the 16,000,000 byte budget (~0.0004 bpb cost, 0.00018/dim from 128→256 scaling data). - -### Fused Operations - -**Fused QKV:** Single `TernaryLinear(dim, dim + 2*kv_dim)`. **Fused SwiGLU/relu²:** Gate and up projections combined into single wide matrix. Combined saving: ~4-6ms/step. - -### Z-Loss Regularisation - -`1e-4 * logsumexp(logits)²` (from PaLM/Gemma) anchors logits near zero, keeping gradients sharp through the ternary STE. - ---- - -## Compression Scheme - -### Base-3 + LZMA (Primary) - -5 trits per byte (1.585 bits/trit), lossless. LZMA at preset=9 achieves ~39% reduction over int8+zlib. Ternary distribution at convergence: ~20–29% zeros, ~35–40% each ±1. The skewed distribution (more zeros) is exploited by LZMA's entropy coding. - -### Bitmask Compression (Alternative) - -Encodes "is this weight zero?" and "if nonzero, is it +1?" as separate bitmasks. Both methods are tried and the smaller is selected automatically. In practice, bitmask and base-3+LZMA produce nearly identical artifact sizes — bitmask wins marginally in some runs (e.g. S72: 15.84MB vs 15.87MB). Zero fraction would need to drop below ~5% for bitmask to provide a clear advantage; our zero fraction ranges from 17–29% at convergence, making bitmask non-competitive. - -### 3D Tensor Support - -Conv1d weights (`[dim, dim, kernel]`) are reshaped to 2D before ternary quantisation and restored to original shape on load. - -### FP8 QAT - -Non-ternary parameters (embeddings, projections) stored at fp8 (e4m3) with Quantisation-Aware Training via STE. Halves fp_params storage (~5MB → ~2.5MB). Typical roundtrip gap: 0.001–0.002 bpb. - ---- - -## Submission Runs (P prefix) — Ternary - -Configuration: F88 (10L 768d relu² 4×MLP fp8, WD=0, EMBED_DIM=254, 599s wallclock, TEMP=0.90) - -| Seed | Steps | val_bpb | RT bpb | Sliding bpb | Train Time | Eval Time | Artifact | Budget | -|------|-------|---------|--------|-------------|------------|-----------|----------|--------| -| 1337 | 6520 | 1.1825 | 1.1839 | **1.1568** | 599.1s | 428.7s | 15.92MB | 16.00/16.00MB | -| 42 | 6530 | 1.1816 | 1.1837 | **1.1565** | 599.7s | 429.3s | 15.92MB | 15.99/16.00MB | -| 7 | 6530 | 1.1823 | 1.1850 | **1.1578** | 599.6s | 429.0s | 15.92MB | 15.99/16.00MB | -| **Mean** | **6527** | **1.1821** | **1.1842** | **1.1570** | **599.5s** | **429.0s** | **15.92MB** | | -| **Std** | **5** | **0.0005** | **0.0007** | **0.0007** | **0.3s** | **0.3s** | **0.00MB** | | - -All three seeds fit within the 16,000,000 byte budget. The standard deviation of 0.0007 bpb across seeds confirms high reproducibility. All runs achieve p < 0.001 improvement over the 1.2244 bpb baseline. - -### Batch Size Sensitivity (Ternary, 599s wallclock) - -| Batch Tokens | Steps | ms/step | val_bpb | Sliding bpb | Tokens Seen | Fits Budget | -|-------------|-------|---------|---------|-------------|-------------|-------------| -| 262,144 | 10,000 | 49 | 1.2413 | — | 2.6B | No | -| **524,288** | **6,530** | **92** | **1.1850** | **1.1578** | **3.4B** | **Yes** | -| 1,048,576 | 3,480 | 172 | 1.1925 | 1.1659 | 3.5B | No | - -524k batch tokens is the optimal operating point. Halving the batch (262k) doubles the step count but degrades quality by 0.056 bpb due to noisier gradients interacting poorly with the ternary STE. Doubling it (1M) sees similar total tokens but fewer gradient updates, costing 0.008 bpb. - ---- - -## Current Best Configuration - -### Ternary: 10L 768d relu² 4×MLP fp8, WD=0, EMBED_DIM=254 - -```bash -NUM_LAYERS=10 MODEL_DIM=768 NUM_HEADS=8 -NUM_KV_HEADS=4 MLP_MULT=4 VOCAB_SIZE=8192 -ACTIVATION=relu2 LOGIT_SOFTCAP=10 SOFTCAP_TYPE=poly -QK_GAIN_INIT=2.25 ROPE_BASE=5000 ROPE_TYPE=yarn -YARN_MAX_LEN=2048 EMBED_DIM=254 TIE_EMBEDDINGS=1 -BITNET_GROUP_SIZE=128 FP_STORAGE=FP8 MUON_WD=0.0 -MATRIX_LR=0.04 SCALAR_LR=0.02 TIED_EMBED_LR=0.02 -MUON_BACKEND_STEPS=3 MUON_MOMENTUM=0.95 WARMDOWN_FRACTION=0.2 -MAX_WALLCLOCK_SECONDS=599 -SLIDING_EVAL=1 SLIDING_EVAL_STRIDE=16 TEMP_SCALING=1 -TRAIN_BATCH_TOKENS=524288 -``` - -| Metric | Value | -|--------|-------| -| val_bpb (mean) | 1.1821 | -| RT bpb (mean) | 1.1842 | -| Sliding bpb (mean) | 1.1570 | -| Artifact + code | 15,992,753–15,995,705 / 16,000,000 bytes | -| Steps | 6520–6530 | -| ms/step | 91.8 | -| zero_frac | 0.335–0.336 | -| optimal_T | 0.90 | -| Params | 73,685,840 | - ---- - -## Dev Runs (RTX 5090, 100–500 steps) - -### Phase 0 — Ternary vs Binary (500 steps, 16L 512d, 1k vocab) - -| Run | Config | val_bpb | RT bpb | Artifact | ms/step | -|-----|--------|---------|--------|----------|---------| -| 17 | Ternary baseline | 1.7110 | 1.7300 | 23.95MB | 1312 | -| 18 | Binary {−1,+1} | 1.7121 | 1.7316 | 23.93MB | 1309 | - -Ternary wins by 0.0016 bpb. The zero state provides representational benefit. - ---- - -### Phase 1 — Training Techniques (100 steps, 9L 512d, 1k vocab) - -| Run | Config | val_bpb | RT bpb | Artifact | Notes | -|-----|--------|---------|--------|----------|-------| -| 19 | Ternary 16L 512d baseline | 2.3371 | 2.3793 | 7.33MB | | -| 20 | + Untie lm_head at 2/3 | 2.3569 | 2.3983 | 8.13MB | Deferred — needs wallclock fix | -| 21 | + Value embeddings | — | — | — | Blocked: RTX 5090 Triton smem | -| 22 | + Smear module | 2.3593 | 2.3985 | 7.33MB | Deferred — gate needs many steps | -| 23 | Baseline 9L 512d | 2.4483 | 2.4768 | 4.45MB | Switched from 16L | -| 24 | + Polynomial softcap | 2.3981 | 2.4438 | 4.45MB | **−0.033 rt** | -| 25 | + Seq length schedule | 2.4633 | 2.5106 | 4.45MB | Deferred — recompile cost | -| 26 | + NorMuon | 2.4018 | 2.4104 | 4.40MB | **−0.033 rt**, 5× smaller RT gap | -| 27 | + Grad accum delay | 2.6298 | 2.6571 | 4.40MB | Deferred — needs 2000+ steps | - ---- - -### Vocabulary Sweep (100 steps, 9L 512d) - -| Run | Vocab | val_bpb | RT bpb | Artifact | Notes | -|-----|-------|---------|--------|----------|-------| -| 23 | 1024 | 2.4483 | 2.4768 | 4.45MB | Baseline | -| 28 | 4096 | 2.0930 | 2.0974 | 6.68MB | −0.32 vs 1k | -| **29** | **8192** | **1.9946** | **1.9990** | **9.64MB** | **−0.42 vs 1k — largest single win** | - -8192 vocab locked. The tokeniser merges ~1.57× more aggressively than 1k, directly reducing BPB. Val token count drops from 63.8M (sp1024) to 40.5M (sp8192) for the same 50k documents. - ---- - -### Activation Sweep (100 steps, 9L 512d, 8k vocab) - -| Run | Activation | val_bpb | RT bpb | Artifact | ms/step | -|-----|-----------|---------|--------|----------|---------| -| 29 | relu2 | 1.9946 | 1.9990 | 9.64MB | 838 | -| 30 | relu | 1.9846 | 1.9879 | 9.63MB | 830 | -| **31** | **SwiGLU** | **1.9704** | **1.9743** | **10.70MB** | **960** | -| 32 | SwiGLU + MTP(2) | 1.9627 | 1.9672 | 10.69MB | 1111 | - -SwiGLU with MTP auxiliary loss gives −0.032 bpb but +16% slower. SwiGLU alone gives −0.025 bpb. MTP deferred. - ---- - -### Embedding Factorization Sweep (100 steps, 9L 512d, 8k vocab) - -| Run | EMBED_DIM | val_bpb | RT bpb | RT gap | Artifact | -|-----|-----------|---------|--------|--------|----------| -| 33a | 0 (=512) | 1.9931 | 1.9962 | 0.003 | 9.63MB | -| **33d** | **128** | **1.9656** | **1.9656** | **0.000** | **9.12MB** | -| 33c | 256 | 2.0538 | 2.1339 | 0.080 | 6.68MB | -| 33e | 64 | 2.0936 | 2.0968 | 0.003 | 4.49MB | -| 33f | 1024 | 2.0709 | 2.1845 | 0.114 | 15.60MB | - -128 was optimal at dev scale. After an optimizer fix revealed the projection layers had not been training, 256 became optimal at full convergence — see EMBED_DIM Sweep at full convergence. - ---- - -### Tversky Neural Network Investigation - -Based on Doumbouya et al. (2025). Three-term Tversky similarity: `S = theta * f(A intersection B) - alpha * f(A - B) - beta * f(B - A)` with learned membership functions. - -**Feature count sweep (FP16 features, ternary prototypes, 100 steps, 9L 512d):** - -| Run | Features | val_bpb | RT bpb | RT gap | Artifact | -|-----|----------|---------|--------|--------|----------| -| — | No Tversky | 1.9751 | 1.9751 | 0.000 | 5.33MB | -| 38 | 16 | 1.9877 | 2.0186 | 0.031 | 5.46MB | -| 39 | 32 | 1.9843 | 2.0133 | 0.029 | 5.57MB | -| 40 | 64 | 1.9790 | 2.0097 | 0.031 | 5.79MB | -| **41** | **128** | **1.9427** | **1.9865** | **0.044** | **6.20MB** | -| 42 | 256 | 1.9737 | 2.0863 | 0.113 | 5.63MB | -| 43 | 512 | 2.0036 | 2.0965 | 0.093 | 5.90MB | -| 44 | 128 + shrinkage fix | 1.9425 | **1.9425** | **0.000** | 6.20MB | - -Tversky showed genuine quality benefit (~-0.017 bpb) at dev scale with 128 features and fp16 prototype storage. However, subsequent investigation at full convergence (12L 768d) and with corrected prototype storage showed all Tversky variants within noise of the linear baseline. Additional experiments included full ternary prototypes, shared feature pools across layers, no-features mode, logit-head application, and different membership functions (sigmoid, poly, tanh). A synthetic-data notebook confirmed that Tversky's asymmetric similarity only helps on tasks with genuine directional feature relationships (hypernym/hyponym, cause/effect); next-token prediction on FineWeb web text is not such a task. - -At the 768d architecture with relu², Tversky also incurred a 19ms/step overhead because the smaller MLP no longer masked the compute cost. - -**Conclusion:** Tversky is quality-neutral on FineWeb language modelling regardless of configuration. Not a quantisation issue, not an optimizer issue — the task simply does not benefit from asymmetric similarity. - ---- - -### Key Hyperparameter Sweeps (100 steps, 9L 512d, 8k vocab) - -**QK_GAIN_INIT sweep:** - -| Run | QK_GAIN | val_bpb | Delta | -|-----|---------|---------|-------| -| 75 | 1.0 | 2.0007 | +0.0076 | -| 73 | 1.5 | 1.9931 | baseline | -| 81 | 2.15 | 1.9913 | −0.0018 | -| **79** | **2.25** | **1.9898** | **−0.0033** | -| 77 | 2.5 | 1.9915 | −0.0016 | -| 80 | 2.75 | 1.9975 | +0.0044 | -| 78 | 3.0 | 2.0011 | +0.0080 | - -Clear inverted-U response. **QK_GAIN_INIT=2.25 locked.** - -**LOGIT_SOFTCAP sweep:** - -| Run | SOFTCAP | val_bpb | Delta | -|-----|---------|---------|-------| -| 74 | 5 | 1.9942 | −0.0013 | -| **73** | **10** | **1.9931** | **−0.0024** | -| 72 | 20 | 1.9935 | −0.0020 | -| 71 | 50 | 1.9957 | +0.0003 | - -**LOGIT_SOFTCAP=10 locked.** - -**Softcap type (poly vs tanh):** - -| Run | Type | val_bpb | Notes | -|-----|------|---------|-------| -| S23 | poly | 1.3680 | | -| S24 | tanh | 1.3693 | | -| S28/S29 | both at EMBED=1024 | 1.3460–1.3462 | Identical at convergence | - -Zero effect. Polynomial retained as default. - -**ROPE_BASE sweep:** - -| Run | ROPE_BASE | val_bpb | Notes | -|-----|-----------|---------|-------| -| **70** | **5000** | **1.9959** | Best at short training | -| 73 | 10000 | 1.9931 | Close second | -| 69 | 20000 | 2.0008 | | -| 68 | 50000 | 2.0017 | | - -**KV Heads:** - -| Run | KV_HEADS | val_bpb | Artifact | -|-----|----------|---------|----------| -| **58** | **4 (GQA)** | **1.9955** | **7.75MB** | -| 66 | 8 (MHA) | 2.0148 | 8.46MB | - -**MLP_MULT:** - -| Run | MLP_MULT | val_bpb | Artifact | -|-----|----------|---------|----------| -| **58** | **2** | **1.9955** | **7.75MB** | -| 64 | 3 | 2.0004 | 9.09MB | -| 65 | 4 | 1.9992 | 10.39MB | - -**Storage precision:** - -| Run | Storage | val_bpb | RT bpb | RT gap | Artifact | -|-----|---------|---------|--------|--------|----------| -| **90** | **fp16** | **1.9656** | **1.9656** | **0.000** | **9.06MB** | -| 91 | fp8 | 1.9662 | 1.9702 | 0.004 | 7.83MB | -| 92 | fp4 | 1.9661 | 1.9955 | 0.029 | 7.11MB | - -**TTT-LoRA sweep (100 steps, ROPE=5000):** - -| Run | Rank | LR | TTT bpb | Delta | -|-----|------|-----|---------|-------| -| **85** | **8** | **0.01** | **1.9368** | **−0.0315** | -| 86 | 8 | 0.005 | 1.9378 | −0.0312 | -| 87 | 8 | 0.02 | 1.9644 | −0.0038 | -| **88** | **4** | **0.01** | **1.9371** | **−0.0285** | -| 89 | 16 | 0.01 | OOM | — | - -TTT confirmed working at dev scale (−0.0315 bpb). Incompatible at convergence — see TTT investigation. - -**EMBED_DIM sweep at 512d (12L, 100 steps):** - -| Run | EMBED_DIM | Tversky feat | RT bpb | Artifact | bpb/MB efficiency | -|-----|-----------|-------------|--------|----------|-------------------| -| 95 | 64 | 128 | 2.1961 | 8.40MB | worst | -| 98 | 96 | 128 | 2.0356 | 8.74MB | | -| 97 | 128 | 128 | 1.9656 | 9.12MB | best | -| 99 | 192 | 128 | 2.0409 | 10.07MB | | -| 94 | 256 | 128 | 2.0703 | 10.93MB | | -| 100 | 256 | 256 | 2.0340 | 10.09MB | RT gap 0.021 | -| 96 | 512 (off) | 128 | 2.0642 | 13.50MB | | - -128 confirmed optimal at dev scale. - ---- - -### Architecture Sizing Table (Ternary, EMBED_DIM=128, standard proj) - -| Config | Layers | Artifact | Under 16MB? | RT gap | Headroom | -|--------|--------|----------|-------------|--------|----------| -| fp16 | 20 | 14.23MB | Yes | 0.0001 | 1.77MB | -| **fp16** | **22** | **15.48MB** | **Yes** | **0.0001** | **0.52MB** | -| fp16 | 24 | 16.74MB | No | — | −0.74MB | -| fp8 QAT | 24 | 14.63MB | Yes | 0.028 | 1.37MB | -| fp8 QAT | 26 | 15.77MB | Yes | 0.066 | 0.23MB | -| **fp8 QAT** | **27** | **15.42MB** | **Yes** | **0.0025** | **0.58MB** | -| fp8 QAT | 28 | 15.92MB+code | Marginal | 0.0029 | ~0MB | -| fp8 QAT | 30 | 16.92MB | No | 0.0029 | −0.92MB | - ---- - -## H100 Record Runs (R prefix) - -**Hardware:** 8×H100 SXM 80GB | **Time limit:** 600 seconds - -| Run | Config | Steps | val_bpb | RT bpb | Artifact | Notes | -|-----|--------|-------|---------|--------|---------|-------| -| R1 | 22L Tversky fp16 | 4299 | 1.2789 | 1.2792 | 15.80MB | | -| R2 | 26L standard fp16 | 3973 | 1.2649 | 1.2650 | 15.85MB | Pre-LR tuning best | -| R3 | 16L Tversky fp16 | 5949 | 1.2900 | 1.2904 | 11.95MB | Too shallow | -| R4 | 9L Tversky fp16 | 10112 | 1.3374 | 1.3394 | 7.48MB | Way too shallow | -| R5 | 30L fp8 | 2852 | 1.2689 | 1.2815 | 17.22MB | Over budget | -| R6 | 26L fp16, 2× LR | ~4003 | 1.2991 | — | ~15.85MB | LR overshot | -| **R7** | **26L fp16, LR=0.02** | **4008** | **1.2608** | **1.2610** | **15.83MB** | **Best pre-FA3** | -| R8 | 26L fp16, LR=0.01 | 4017 | 1.2853 | 1.2855 | 15.72MB | LR too low | -| R9 | 26L BigramHash | 4010 | 1.2804 | 1.2802 | 15.81MB | BigramHash negative | -| R10 | 26L untie@66% | 3706 | 1.2754 | 1.2753 | 23.15MB | Over budget | -| R11 | 26L tied, updated code | 4009 | 1.2806 | 1.2808 | 15.81MB | Code regression | - -**LR sweep (R-series):** - -| LR | val_bpb | Notes | -|----|---------|-------| -| 0.08 | 1.2991 | Overshoots — ternary STE amplifies gradient noise | -| **0.02** | **1.2608** | **Optimal** | -| 0.01 | 1.2853 | Too slow | - ---- - -## Scaling Runs (S prefix) - -**Hardware:** 8×H100 SXM 80GB | **Steps:** 1500 | **Timer:** disabled (MAX_WALLCLOCK_SECONDS=0) -**Base config:** 26L 512d, EMBED_DIM=128, ROPE=5000, QK_GAIN=2.25, SOFTCAP=10, LR=0.02 all, VOCAB=8192, SwiGLU, SEED=1337 - ---- - -### Warmdown Sweep - -| Run | Fraction | val_bpb | -|-----|----------|---------| -| S3 | 10% | 1.3467 | -| **S1** | **20%** | **1.3438** | -| S2 | 30% | 1.3443 | -| S4 | 30% repeat | 1.3458 | -| S5 | 40% | 1.3501 | - -S2 vs S4 (identical config): 0.0015 bpb spread — confirmed seed variance floor. - -### Muon Backend Steps - -| Run | Steps | ms/step | val_bpb | -|-----|-------|---------|---------| -| S8 | 3 | 144.87 | 1.3491 | -| S9 | 4 | 146.61 | 1.3448 | -| **S1** | **5** | **149.19** | **1.3438** | -| S7 | 8 | 164.31 | 1.3441 | -| S6 | 10 | 157.95 | 1.3456 | - -At full convergence (F6 vs F1): 3 steps matches 5 due to +190 extra training steps. Locked at 3. - -### Muon Momentum - -| Run | Momentum | val_bpb | zero_frac | Artifact | -|-----|----------|---------|-----------|---------| -| S11 | 0.90 | 1.3680 | 0.179 | 15.39MB | -| **S1** | **0.95** | **1.3438** | **0.205** | **15.56MB** | -| S10 | 0.99 | 1.3505 | 0.259 | 15.78MB | - -Higher momentum increases zero_frac, inflating artifact size. - -### Architecture Experiments - -| Run | Config | ms/step | val_bpb | Notes | -|-----|--------|---------|---------|-------| -| S12 | 20L 640d (80M params) | 160.58 | 1.6676 | 17.75MB — over budget | -| **S1** | **26L 512d baseline** | **149.19** | **1.3438** | **Reference** | -| S13 | 26L, TRAINING_DR=2 | 281.63 | 1.3727 | ~795 effective steps, OOM at DR=3 | - -### Eval Depth Recurrence Sweep - -| Run | EVAL_DR | val_bpb | -|-----|---------|---------| -| S15 | 0/1 | 1.3685–1.3690 | -| S16 | 2 | 1.3688 | -| S17 | 3 | 1.3681 | -| S18 | 4 | 1.3690 | -| S19 | 5 | 1.3683 | - -Total range: 0.0009 bpb — pure noise. - -### Weight Decay (1500 steps) - -| Run | MUON_WD | val_bpb | zero_frac | Artifact | -|-----|---------|---------|-----------|---------| -| **S15** | **0.00** | **1.3685** | **0.179** | **15.39MB** | -| S20 | 0.04 | 1.3722 | 0.145 | 15.12MB | - -WD hurts at 1500 steps but saves 0.27MB. Reversed at full convergence — see Final Ternary Record Runs. - -### BigramHash - -| Run | Config | Steps | val_bpb | Artifact | -|-----|--------|-------|---------|---------| -| S21 | 26L + BigramHash | 1500 | 1.3681 | 15.45MB | -| R9 | 26L + BigramHash | 4010 | 1.2804 | 15.81MB | - -At full convergence: 0.020 bpb worse than R7. The 2.1MB fp16 cost of the bigram table displaces ternary layer depth at convergence. **Not viable within budget.** - -### Tied Embedding / Correction Weight / Untie Investigation - -| TIE_EMBEDDINGS | UNTIE_AT_FRACTION | LM_HEAD_RANK | Behaviour | -|---------------|-------------------|--------------|-----------| -| 0 | any | any | Untied from start — unstable, loss = log(8192) = 9.01 | -| 1 | 0.0 | 0 | Always tied — current best | -| 1 | 0.66 | 0 | Tied → full-rank untie at 66% of wallclock | -| 2 | 0.0 | 0 | Tied + correction weight residual on tok_emb | -| 2 | 0.66 | 0 | Tied + correction → full-rank untie at 66% | -| 2 | 0.66 | r | Tied + correction → SVD rank-r untie at 66% | - -**1500-step results:** - -| Run | TIE | UNTIE | RANK | val_bpb | Artifact | -|-----|-----|-------|------|---------|---------| -| S15 | 1 | 0.00 | 0 | 1.3685 | 15.39MB | -| S30 | 2 | 0.00 | 0 | 1.3678 | 15.39MB | -| S36 | 1 | 0.66 | 0 | 1.3648 | 22.83MB | -| **S37** | **2** | **0.66** | **0** | **1.3642** | **22.84MB** | -| S38 | 1 | 0.66 | 0 | 1.3667 | 22.84MB | -| S39 | 0 | 0.66 | 0 | 3.4890 | 10.88MB | - -Untie gives +0.005 bpb gain but adds 7.3MB — over budget. **TIE=1, no untie locked.** - -### LM Head Factorization (SVD-at-Untie) - -| Run | RANK | val_bpb | Artifact | Delta vs baseline | -|-----|------|---------|---------|-------------------| -| S37 | 0 (full) | 1.3642 | 22.84MB | +0.004 — over budget | -| S43 | 32 | 1.4873 | 17.27MB | −0.119 | -| S41 | 64 | 1.4243 | 17.60MB | −0.056 | -| S42 | 128 | 1.3889 | 18.40MB | −0.020 | - -SVD factorization does not recover within the remaining 34% of training. The model requires full-rank lm_head for 8192-class separability in 512-dimensional space. - -### Tied Embed LR Sweep - -| Run | TIED_EMBED_LR | MATRIX_LR | SCALAR_LR | val_bpb | -|-----|--------------|-----------|-----------|---------| -| S33 | 0.01 | 0.02 | 0.02 | 1.3723 | -| **S15** | **0.02** | **0.02** | **0.02** | **1.3685** | -| S34 | 0.03 | 0.02 | 0.02 | 1.3742 | - -Symmetric degradation. **TIED_EMBED_LR=0.02 locked.** - -### TTT-LoRA Investigation - -Test-time training with per-document LoRA adapters. Confirmed working at dev scale (−0.0315 bpb). Incompatible at convergence across 6 diagnostic runs. - -| Run | Config | val_bpb | TTT bpb | Notes | -|-----|--------|---------|---------|-------| -| S22 | TTT_LR=0.01 | 1.3690 | 1.5065 | TTT hurts | -| S23 | No lm_head_lora | 1.3690 | 1.4993 | Still hurts | -| S24 | tanh softcap | 1.3693 | 1.4982 | No improvement | -| S25 | Q/V loras only | 1.3692 | 1.5193 | Worse | -| S26 | EMBED_DIM=1024 | 1.3473 | 1.4746 | Bottleneck not cause | -| S27 | 9L (original depth) | 1.4039 | 1.5189 | Still incompatible at 9L | - -**Root cause:** Every `TernaryLinear` applies RMSNorm to its input before the weight multiply. The LoRA adapter delta is computed on the pre-normalised representation, but injected into a forward pass where base weights operate on a differently-normalised space. At 100 steps the model is poorly calibrated and LoRA signal dominates. At convergence, the base model's representations are precisely calibrated to this normalised space, and any LoRA delta corrupts rather than adapts. This incompatibility is architectural. **TTT permanently disabled.** - -### MTP (Multi-Token Prediction) - -| Run | MTP_HEADS | ms/step | val_bpb | Notes | -|-----|-----------|---------|---------|-------| -| **S47** | **0** | **149** | **1.3693** | **Baseline** | -| S45 | 2 | 157 | 1.3704 | +0.0011 worse | -| S62 | 2 | 144 | 1.3727 | +0.0034 worse | - -Confirmed at both 1500 steps and full convergence (post-fix retest: 0.006 bpb worse at both MTP=1 and MTP=2). A 60M+ parameter, 1.58-bit model does not have the parameter bandwidth for auxiliary future-planning objectives. - -### Smear Module - -| Run | SMEAR | val_bpb | ms/step | -|-----|-------|---------|---------| -| **S48** | **0** | **1.3687** | **149** | -| S49 | 1 | 1.3675 | 182 | - -+22% slower, −0.0012 bpb at 1500 steps. At full 600s wallclock, smear costs ~740 fewer training steps. Not viable within the ternary 10-minute budget but explored further in the binary track. - -### Sequence Length Schedule - -| Run | Config | val_bpb | ms/step avg | -|-----|--------|---------|-------------| -| S48 | baseline | 1.3687 | 149 | -| S51 | smear + seq@33% | 1.3660 | ~240 | -| S52 | smear + seq@33% repeat | 1.3640 | ~221 | -| **S58** | **smear + seq@33% + YaRN** | **1.3628** | **~221** | - -Real gain at 1500 steps but severe step penalty at full 600s. **Disabled for final runs.** - -### Batch Size Schedule - -| Run | Config | val_bpb | -|-----|--------|---------| -| S48 | baseline | 1.3687 | -| S50 | smear + batch | 1.3698 | -| S53 | smear + seq + batch | 1.3667 | - -Noisier gradients interfere with ternary STE convergence. **Not viable.** - -### YaRN Positional Encoding - -| Run | Config | val_bpb | -|-----|--------|---------| -| S48 | RoPE baseline | 1.3687 | -| S54 | YaRN 4096 | 1.3705 | -| S55 | YaRN 2048 | 1.3679 | -| S56 | YaRN 2048 + seq@33% | 1.3672 | -| S57 | YaRN 2048 + seq@50% + smear | 1.3637 | -| **S58** | **YaRN 2048 + seq@33% + smear** | **1.3628** | - -YaRN 4096 hurts (scale=0.25 too aggressive). YaRN 2048 marginally better. **YaRN 2048 retained; seq schedule disabled.** - -ROPE_BASE with YaRN: S63 (10000) = 1.3692, **S61 (5000) = 1.3686**. ROPE_BASE=5000 locked. - -### Sliding Window Evaluation - -| Run | Stride | Sliding bpb | Eval time | -|-----|--------|-------------|-----------| -| S60 | 16 | 1.3452* | >600s | -| S67 | 24 | 1.3146 | 592s | -| **S61/S66** | **32** | **1.3139–1.3452*** | **~350s** | - -*S60/S61 used incorrect momentum=0.90. At full convergence (F1): stride=32 gives 1.2312 sliding bpb in 280s. - -### Temperature Scaling - -Grid search over T in [0.80, 1.20] on 65,536 training tokens. 5-point grid. Optimal T was consistently 1.00 at convergence for the 512d SwiGLU architecture. At the 768d relu² architecture, T=0.90 was consistently optimal (relu² logits slightly underconfident). **TEMP_SCALING=1 in all final runs.** - -### Group Size Sweep (S73–S76, 2000 steps, 27L) - -| Run | Group Size | Layers | val_bpb | Artifact | Total | -|-----|-----------|--------|---------|----------|-------| -| S76 | 32 | 27 | 1.2739 | 17.64MB | 17.73MB | -| S75 | 64 | 27 | 1.2683 | 16.22MB | 16.31MB | -| **S73** | **128** | **27** | **1.2677** | **15.53MB** | **15.62MB** | -| S74 | 256 | 27 | 1.2699 | 15.19MB | 15.28MB | - -128 wins on both quality and compression. - -### Skip Weights Init — Zero vs Ones (S77) - -| Run | Init | val_bpb | artifact | -|-----|------|---------|---------| -| S73 | ones | 1.2677 | 15.62MB | -| S77 | zeros | 1.2781 | 15.62MB | - -Zero-init is **0.0104 bpb worse**. Decoder needs skip signal from step 0. - -### FP8/FP4 Storage with QAT - -**FP8 sweep:** - -| Run | Config | val_bpb | RT bpb | RT gap | Sliding bpb | Artifact | -|-----|--------|---------|--------|--------|-------------|---------| -| S64 | 26L fp16 | 1.3390 | 1.3390 | 0.000 | 1.3150 | 15.58MB | -| S65 | 30L fp8, no QAT | 1.3346 | 1.3394 | 0.0048 | 1.3150 | 16.92MB | -| S66 | 30L fp8, QAT | 1.3351 | 1.3380 | 0.0029 | **1.3139** | 16.92MB | -| S71 | 27L fp8, QAT | 1.3380 | 1.3405 | 0.0025 | 1.3164 | 15.42MB | -| S72 | 28L fp8, QAT | 1.3377 | 1.3406 | 0.0029 | 1.3166 | 15.92MB | - -QAT reduces fp8 RT gap from 0.0048 to 0.0029 (40% improvement). However at full convergence (F3), 28L fp8 QAT (1.2353 sliding) loses to 26L fp16 (1.2312 sliding). - -**FP4 sweep:** - -| Run | Config | val_bpb | RT bpb | RT gap | Sliding bpb | Artifact | -|-----|--------|---------|--------|--------|-------------|---------| -| S68 | 30L fp4 QAT | 1.3377 | 1.3643 | **0.0266** | 1.3404 | 16.49MB | -| S69 | 26L fp4 Tversky QAT | 1.3543 | 1.3835 | **0.0292** | 1.3606 | 15.01MB | -| S70 | 28L fp4 QAT | 1.3405 | 1.3666 | **0.0261** | 1.3424 | 15.43MB | - -FP4 RT gap of ~0.026–0.029 even with QAT is unrecoverable. **FP4 not viable at any layer count.** - -### EMBED_DIM Sweep (Full Convergence, 25L) - -| Config | EMBED_DIM | Steps | val_bpb | sliding_bpb | artifact | Notes | -|--------|-----------|-------|---------|-------------|---------|-------| -| S80 | 0 (=512) | 4500 | 1.1902 | ~1.168 est | 19.78MB | OOM on sliding eval | -| **F22** | **256** | **4720** | **1.2012** | **1.1739 (s16)** | **16.21MB** | **Best 512d result** | -| F16-era | 128 | 4310 | 1.2245 | — | 16.19MB | Pre-fix baseline | - -**EMBED_DIM=256 locked.** Budget impact: fp_params ~4.85MB vs ~2.48MB at 128 (+2.37MB). - ---- - -## Final Ternary Record Runs (F prefix) - -**Hardware:** 8×H100 SXM 80GB | **FlashAttention-3 enabled** | **Time limit:** 600 seconds - -| Run | Config | Steps | val_bpb | RT bpb | Sliding bpb | Eval time | Artifact | -|-----|--------|-------|---------|--------|-------------|-----------|---------| -| **F1** | **26L fp16, no smear, no seq** | **4362** | **1.2560** | **1.2560** | **1.2312** | **280s** | **15.85MB** | -| F2 | 26L fp16, smear + seq@33% | 3044 | 1.2779 | 1.2778 | 1.2535 | 390s | 15.85MB | -| F3 | 28L fp8 QAT, no smear, no seq | 4019 | 1.2571 | 1.2601 | 1.2353 (s24) | 385s | 16.14MB | -| F4 | 26L fp16, EMA=1 | 4145 | 1.2589 | 2.3307 | — | — | 14.52MB | -| F5 | 26L fp16, EMA fix v1 (smoke) | 407 | 1.5483 | 2.3642 | — | — | 14.90MB | -| F6 | 26L fp16, MUON_BACKEND_STEPS=3 | 4552 | 1.2558 | 1.2558 | 1.2311 (s24) | 362s | 15.81MB | -| F7 | 26L fp16, WD=0.04, steps=3 | 4499 | 1.2552 | 1.2551 | 1.2302 (s24) | 362s | 15.60MB | -| F8 | 28L fp16, WD=0.04, steps=2, LR=0.02 | 4219 | 1.2799 | 1.2801 | 1.2558 (s16) | 577s | 15.92MB | -| F9 | 28L fp16, WD=0.04, steps=2, LR=0.03 | 4231 | 1.2673 | 1.2676 | 1.2431 (s16) | 577s | 16.00MB | -| F10 | 28L fp16, WD=0.04, steps=2, LR=0.04 | 4226 | 1.2636 | 1.2636 | 1.2391 (s16) | 578s | 16.01MB | -| F11 | 28L fp16, WD=0.04, steps=3, LR=0.04 | 4137 | 1.2489 | 1.2488 | — | — | 16.69MB | -| F12 | 28L fp16, WD=0.04, steps=4, LR=0.04 | 4047 | 1.2496 | 1.2500 | — | — | 16.71MB | -| F13 | 28L fp16, WD=0.04, steps=3, LR=0.05 | 4048 | 1.2512 | 1.2510 | — | — | 16.73MB | -| F14 | 28L fp16, WD=0.04, steps=3, LR=0.08 | 4036 | 1.2576 | 1.2574 | — | — | 16.75MB | -| F15 | 27L fp16, AdamW matrix, LR=0.01 | 4676 | 1.2943 | 1.2942 | — | — | 15.71MB | -| F16 | 27L fp16, Muon, LR=0.04, WD=0.04 | 4310 | 1.2245 | — | — | — | 16.19MB | -| **F22** | **25L fp16, EMBED=256, steps=3, WD=0.04** | **4720** | **1.2012** | **1.2011** | **1.1739 (s16)** | **493s** | **16.21MB** | - -**Key findings:** F22 with EMBED_DIM=256 and corrected optimizer achieves 0.055 bpb improvement over F1 (the best pre-fix config). 28L extensively attempted (F8–F14) but artifact always over budget at competitive LR. AdamW for matrix params (F15) is clearly worse than Muon. - ---- - -## Phase 2 — Post-Optimizer-Fix Experiments (25L 512d EMBED=256) - -### EMA (Exponential Moving Average) - -| Run | Config | Steps | val_bpb | RT bpb | Artifact | -|-----|--------|-------|---------|--------|----------| -| F4 | EMA=1, decay=0.999 | 4145 | 1.2589 | 2.3307 | 14.52MB | -| — | Full run with EMA | 4144 | 1.2584 | 1.3776 | 14.94MB | - -**EMA is fundamentally incompatible with ternary quantization.** EMA averaging in fp32 produces smoother, more zero-centered weights. More latent weights near zero → more round to 0 in ternary → scale factor mismatch → 0.13 bpb RT gap. **Permanently disabled.** - -### Muon Backend Steps — Full Convergence - -| Run | Steps | step_avg | val_bpb | sliding_bpb | artifact | -|-----|-------|----------|---------|-------------|---------| -| F1 (steps=5) | 4362 | 137ms | 1.2560 | 1.2312 | 15.85MB | -| F6 (steps=3) | 4552 | 131ms | 1.2558 | 1.2311 | 15.81MB | - -6ms/step saving → 190 extra steps → quality equivalent. **MUON_BACKEND_STEPS=3 locked.** - -### Weight Decay — Full Convergence - -| Run | WD | Steps | val_bpb | sliding_bpb | zero_frac | artifact | -|-----|-----|-------|---------|-------------|-----------|---------| -| F6 | 0.00 | 4552 | 1.2558 | 1.2311 | 0.294 | 15.81MB | -| F7 | 0.04 | 4499 | 1.2552 | 1.2302 | 0.221 | 15.60MB | - -WD=0.04 wins at full convergence on the 26L architecture. However at 10L 4×MLP (Phase 4), WD=0.00 was better — wider MLP needs full weight freedom. - -### MTP Retest (Post-Fix) - -| Run | MTP_HEADS | Steps | step_avg | val_bpb | artifact | -|-----|-----------|-------|----------|---------|---------| -| F22 baseline | 0 | 4720 | 127ms | 1.2012 | 16.29MB | -| Run 26 | 1 | 4560 | 131ms | 1.2074 | 16.30MB | -| Run 27 | 2 | 4420 | 135ms | 1.2074 | 16.29MB | - -**MTP confirmed not viable post-fix.** 0.006 bpb worse at both heads. **MTP_HEADS=0 permanently locked.** - -### Tversky Phase 2 (Post-Fix, 12L 768d, fp16 Prototypes) - -Comprehensive retest with corrected optimizer and fp16 prototype storage: - -| Run | Config | Features | Pools | val_bpb | RT gap | -|-----|--------|----------|-------|---------|--------| -| 49 | No Tversky | — | — | **1.1888** | 0.0002 | -| 50 | Attn proj only | 128 | 1 | 1.1893 | 0.0000 | -| 51 | Attn proj only | 256 | 1 | 1.1894 | 0.0001 | -| 52 | Attn proj only | 32 | 1 | 1.1898 | 0.0001 | -| 53 | Attn + head | 128 | 1 | 1.1892 | — | -| 54 | Attn + head | 128 | 0 (local) | 1.1897 | +0.0006 | - -All variants within 0.001–0.002 bpb of baseline — pure noise. Confirmed by synthetic-data analysis that Tversky's asymmetric similarity only helps on tasks with directional feature relationships, which next-token prediction on web text is not. - ---- - -## Phase 3 — Architecture Exploration (Post-Optimizer-Fix) - -### Width vs Depth - -The central Phase 3 finding: wider models with fewer layers beat deeper models. - -#### 768d Scaling Curve - -| Run | Layers | Steps | step_avg | val_bpb | Artifact | -|-----|--------|-------|----------|---------|----------| -| 34 | 8 | 8110 | 74ms | 1.2894 | 12.94MB | -| 30 | 12 | 5640 | 106ms | 1.1893 | 17.50MB | -| 38 | 14 | 4900 | 122ms | 1.1870 | 19.79MB | -| 33/37 | 16 | 4320 | 139ms | 1.1825–37 | 22.08MB | -| 39 | 18 | 3870 | 155ms | 1.1801 | 24.39MB | -| 36 | 20 | 3510 | 171ms | 1.1854 | 26.67MB | - -Peak at 18L, then step penalty dominates. 8L collapses (U-Net encoder too shallow). Seed variance: Run 33 vs 37 = 0.0012 bpb. - -#### Cross-Architecture Comparison - -| Config | Layers | Dim | Steps | val_bpb | -|--------|--------|-----|-------|---------| -| F22 | 25 | 512 | 4720 | 1.2012 | -| Run 30 | 12 | 768 | 5640 | 1.1893 | -| Run 40 | 8 | 1024 | 5870 | 1.1858 | -| Run 41 | 10 | 896 | 5400 | 1.1862 | -| Run 35 | 20 | 640 | 4170 | 1.1927 | -| Run 42 | 6 | 896 | 8510 | 1.2157 | - -Width beats depth: 12L 768d (1.1893) beats 25L 512d (1.2012). Minimum viable depth: 768d ~10–12L, 896d ~10L, 1024d ~8L. - -### FP8 at 768d - -| Run | Layers | Storage | val_bpb | RT bpb | RT gap | -|-----|--------|---------|---------|--------|--------| -| 49 | 12 | fp16 | 1.1888 | 1.1886 | 0.0002 | -| 42 | 13 | fp8 | 1.1879 | 1.1900 | 0.0021 | - -FP8 RT gap acceptable at 768d. Enables extra layers within budget. - -### LM_HEAD_RANK Investigation (Post-Fix, 768d) - -| Run | Config | val_bpb | RT bpb | Total | Notes | -|-----|--------|---------|--------|-------|-------| -| Run 49 | baseline | 1.1888 | 1.1886 | 17.50MB | Reference | -| Run 43 | TIE=2, rank=256, fp8 | 1.2021 | 1.2028 | 20.41MB | Artifact bloated | -| Run 44 | TIE=0, rank=512, untie=0.0 | 1.3196 | 1.3195 | 16.92MB | Random head, no learning | -| Run 45 | TIE=2, rank=512, fp16 | 1.2312 | 1.2317 | 26.87MB | Catastrophic artifact blowup | - -Root cause: the SVD factors U and V require fp16/fp8 precision to maintain approximation quality. At any viable compression level, the two new matrices cost more storage than the original tied embedding saves. **Not viable.** - ---- - -## Phase 4 — Final Architecture Search - -### Activation Sweep (12L 768d 3×MLP, 600s) - -| Run | Activation | MLP | ms/step | Steps | val_bpb | Artifact | -|-----|-----------|-----|---------|-------|---------|----------| -| F55 | relu | 2× | 88.7 | 6760 | 1.2284 | 14.49MB | -| **F56** | **relu²** | **2×** | **89.5** | **6700** | **1.2042** | **14.48MB** | -| F60 | leaky relu | 3× | 102.6 | 5840 | 1.2094 | 17.50MB | -| **F57** | **relu²** | **3×** | **101.5** | **5910** | **1.1878** | **17.51MB** | -| F58 | swiglu | 3× | 127.4 | 4700 | 1.1786 | 22.05MB | -| **F59** | **swiglu** | **3×** | **127.3** | **4710** | **1.1771** | **21.96MB** | - -relu² beats relu by 0.024 bpb at no cost — strictly dominant. relu² locked for budget-constrained path. - -### MLP Width Sweep (600s) - -| Run | Activation | MLP | Layers | ms/step | Steps | val_bpb | Artifact | -|-----|-----------|-----|--------|---------|-------|---------|----------| -| F56 | relu² | 2× | 12 | 89.5 | 6700 | 1.2042 | 14.48MB | -| F64 | relu² | 3× | 12 | 99.4 | 6030 | 1.1873 | 17.50MB | -| F75 | relu² | 4× | 12 | 91.6 | 6550 | 1.1795 | 20.54MB | -| F82 | relu² | 4× | 10 | 91.6 | 6550 | 1.1861 | 16.04MB | - -4× MLP at 10L beats 3× at 12L within similar budget. - -### Layer Count vs MLP Width (fp8, 600s) - -| Run | Config | Layers | ms/step | Steps | val_bpb | RT bpb | Artifact | -|-----|--------|--------|---------|-------|---------|--------|----------| -| F78 | relu² 3× fp8 | 12 | 99.3 | 6040 | 1.1884 | 1.1898 | 15.80MB | -| F77 | relu² 3× fp8 | 13 | 106.6 | 5630 | 1.2065 | 1.2077 | 16.96MB | -| F80 | relu² 2× fp8 | 15 | 106.9 | 5610 | 1.2120 | 1.2136 | 15.45MB | -| F81 | relu² 2× fp8 | 16 | 113.9 | 5270 | 1.1996 | 1.2009 | 16.33MB | -| F79 | relu² 3× fp8 | 11 | 91.5 | 6560 | 1.1920 | 1.1933 | 14.66MB | -| **F82** | **relu² 4× fp8** | **10** | **91.6** | **6550** | **1.1861** | **1.1877** | **16.04MB** | -| F83 | swiglu 3× fp8 | 10 | 105.5 | 5690 | 1.1842 | 1.1853 | 17.29MB | - -### Weight Decay at 10L 4×MLP fp8 - -| Run | WD | val_bpb | RT bpb | Artifact | -|-----|-----|---------|--------|----------| -| F82 | 0.04 | 1.1861 | 1.1877 | 16.04MB | -| F84 | 0.08 | 1.1983 | 1.1998 | 16.04MB | -| **F85** | **0.00** | **1.1828** | **1.1844** | **16.02MB** | -| S87 | 0.00 | 1.1831 | 1.1843 | 16.01MB | -| **F88** | **0.00 (EMBED=254)** | **1.1820** | **1.1839** | **16.00MB — FITS** | - -WD=0 optimal at 10L 4× — opposite to 26L result. Wider MLP needs full weight freedom. - ---- - -## Binary Quantisation Track - -### Motivation - -Binary quantisation constrains weights to {-1, +1} with no zero state. At 1 bit/param vs ternary's 1.6 bits/param, binary packs approximately 60% more parameters per MB. The hypothesis was that additional depth could compensate for the loss of the zero state. - -Starting point: the ternary best config (10L, 768d, 8h, 4kv, 4× relu², FP8, 524k batch, 599s) scoring 1.1578 sliding bpb. - -### Binary Scaling Runs - -| Run | Layers | MLP | FP | Other | Steps | ms/step | Sliding bpb | Artifact | Fits | -|-----|--------|-----|-----|-------|-------|---------|-------------|----------|------| -| F17 | 17 | 4× | FP8 | — | 4010 | 149 | 1.2022 | 17.45MB | No | -| **F1** | **14** | **4×** | **FP8** | **—** | **4820** | **124** | **1.1824** | **14.74MB** | **Yes** | -| F2 | 14 | 4× | FP8 | EMA | 4800 | 125 | 1.2110 | 14.56MB | Yes | -| S3 | 15 | 4× | FP8 | — | 1000 | 133 | 1.3114 | 15.65MB | Yes | -| S4 | 20 | 3× | FP8 | — | 1000 | 160 | 1.3077 | 16.90MB | No | -| S5 | 21 | 3× | FP4 | — | 1000 | 167 | 1.3676 | 16.64MB | No | -| S6 | 19 | 3× | FP8 | — | 1000 | 152 | 1.3130 | 16.16MB | No | -| S7 | 15 | 4× | FP8 | refiner | 1000 | 135 | 1.3123 | 15.89MB | Yes | -| S8 | 15 | 4× | FP8 | smear | 1000 | 155 | 1.3043 | 15.67MB | Yes | -| S9 | 15 | 4× | FP8 | tversky_attn | 1000 | 179 | 1.4016 | 15.74MB | Yes | - -### Key Decisions from Binary Scaling - -**MLP width (4× vs 3×):** 4× won even when 3× received 4–5 extra layers. S3 (15L 4×) outperformed S6 (19L 3×) at matched steps. Width matters more than depth past a minimum viable layer count. - -**FP storage (FP8 vs FP4):** FP4 added a 0.06 bpb roundtrip penalty and was immediately ruled out. FP8 used for all non-binary tensors. - -**Layer count:** 17L was the theoretical maximum at 4× FP8 but landed 1.45MB over budget. 15L at 15.65MB was the maximum that fit. 14L left 1.26MB headroom. - -**EMA:** Mathematically sound for binary (no zero bucket means `mean(|Q|)=1.0` always, clean roundtrip). In practice, 0.03 bpb worse — the smoothed weights apparently hurt binary's learning dynamics despite the clean quantisation math. - -**Smear:** 0.007 bpb gain at 1000 steps but added 22ms/step overhead (133→155ms). Retained for the extended binary run to test whether the gain survives the step penalty at longer training. - -**Refiner (causal conv):** Neutral at 1000 steps, added 2ms/step. Not justified. - -**Tversky attention projection:** 0.09 bpb worse. Completely incompatible with binary weights. - -**Activation:** relu² inherited from ternary sweeps, not retested for binary. SwiGLU would cost ~4MB extra across 15 layers, eliminating the layer budget advantage. - -### Extended Binary Run (Unconstrained Compute) - -To measure the binary architecture's convergence ceiling without the 10-minute wallclock constraint, a single extended run was conducted at 50,000 steps (~2 hours on 8×H100). - -**Configuration:** 15L 768d, 4× relu², FP8, smear, 524k batch tokens, seed=42, MUON_WD=0.0 - -``` -step:50000/50000 val_loss:2.9692 val_bpb:1.1497 train_time:7763s -artifact:15.60MB binary:97320960(13685760B) fp:2542200(2585072B) code:70399 -budget:15670651/16000000 (15.67/16.00MB) FITS -final_binary_roundtrip val_loss:2.9743 val_bpb:1.1516 -temp_scaling optimal_T:0.90 -final_sliding val_loss:2.9027 val_bpb:1.1239 (stride=16, T=0.90) -``` - -| Metric | Value | -|--------|-------| -| val_bpb | 1.1497 | -| RT bpb | 1.1516 | -| Sliding bpb | **1.1239** | -| Artifact | 15.60MB (15.67MB total) | -| Params | 97,320,960 | -| Steps | 50,000 | -| ms/step | 155.3 | -| Training time | ~2.15 hours | - -The 1.1239 sliding bpb demonstrates that with sufficient compute the binary architecture reaches strong quality. This validates the compression approach — nearly 100M parameters in 15.67MB via 1-bit quantisation — though the 50k steps required far exceeds the competition's 10-minute budget. - -### Binary vs Ternary at Equal Architecture (Dev Scale) - -| Metric | Binary | Ternary | Delta | -|--------|--------|---------|-------| -| val_bpb | 1.8609 | 1.8113 | Ternary wins by 0.050 | -| Artifact | 9.14MB | 11.56MB | Binary saves 2.42MB | -| ms/step | 918 | 924 | Identical | -| RT gap | 0.000 | 0.000 | Both clean | - -Ternary is better at equal architecture. Binary's only advantage is fitting more layers in the same budget. - -### Binary Conclusion - -Binary lost the depth-for-sparsity trade. The 5 extra layers (15L binary vs 10L ternary) could not overcome ternary's representational advantage from the zero state. The 0.0016 bpb gap measured at 500 dev steps significantly understated the true difference at convergence. Ternary at 1.1578 sliding bpb (10-minute budget) outperforms binary's best fitting run (F1: 1.1824 at 14L without smear) by 0.025 bpb. Even the over-budget 17L binary run (1.2022) could not match ternary. - -The extended 50k-step binary run reaching 1.1239 sliding bpb shows that binary has a competitive convergence ceiling, but it requires approximately 8× more training steps to approach competitive quality — well beyond the competition constraints. - ---- - -## Grouped MLP Investigation - -Tested GroupedTernaryLinear: splits MLP into independent groups for parameter/speed savings. - -### Real Model Results (relu² 3×, 768d, 600s) - -| Run | Config | Layers | ms/step | Steps | val_bpb | Artifact | -|-----|--------|--------|---------|-------|---------|----------| -| F64 | standard | 12 | 99.4 | 6030 | 1.1873 | 17.50MB | -| F72 | g=2 | 12 | 87.4 | 6870 | 1.2180 | 12.97MB | -| F71 | g=4 | 12 | 83.5 | 7190 | 1.2429 | 10.74MB | -| F73 | g=2 | 16 | 114.2 | 5260 | 1.2037 | 16.04MB | -| F74 | swiglu g=2 | 12 | 113.3 | 5300 | 1.2084 | 15.24MB | - -Cross-group isolation costs 0.031–0.056 bpb. Even with 4 extra layers (F73), only recovers 0.014 of the deficit. **Not viable for language modelling.** - ---- - -## Differential Attention - -Microsoft (2024): computes two attention maps from split Q/K and takes their difference. - -| Run | Config | ms/step | Steps | val_bpb | -|-----|--------|---------|-------|---------| -| F64 | standard | 99.4 | 6030 | 1.1873 | -| F68 | diff_attn | 109.3 | 5480 | 1.2094 | - -Splits 96-dim heads into 48-dim sub-heads — insufficient dimensionality for meaningful attention patterns at this model scale. - ---- - -## Sequence Refiner (CausalConvRefiner) - -| Run | Config | ms/step | Steps | val_bpb | Artifact | -|-----|--------|---------|-------|---------|----------| -| F64 | none | 99.4 | 6030 | 1.1873 | 17.50MB | -| F69 | k=3 | 102.2 | 5860 | 1.1885 | 19.92MB | -| F70 | k=5 | 103.0 | 5820 | 1.2018 | 18.13MB | - -Noise-level quality improvement with storage bloat. 12 attention layers already saturate local pattern capture. - ---- - -## ByteCNN Vocabulary Generator - -Replaces `nn.Embedding(8192, 256)` with a CNN that generates the embedding matrix from byte spellings. - -``` -step:500 loss:9.0471 — step:2000 loss:9.0471 (flat, no learning) -``` - -All 8192 CNN-generated embeddings converge to near-identical vectors at initialisation. The CNN's inductive bias (byte-similar tokens → similar embeddings) destroys the initial diversity needed for gradient signal. - ---- - -## Asymmetric Tokenizer Investigation - -8k BPE input with 256-byte output to eliminate large output projection. - -| Model | BPB | Notes | -|-------|-----|-------| -| Standard (tied, emb=256) | 3.10 | reference | -| Asymmetric parallel (emb=256) | 8.65 | byte independence assumption fails | -| Asymmetric autoregressive (emb=256) | 8.17 | tiny GRU insufficient capacity | - -Multi-byte parallel heads assume conditional independence between bytes within a token — mathematically incorrect. Sequence-length mismatch (7 BPE tokens → 70 bytes) also incompatible with the evaluation framework. - ---- - -## Linear Alternative Exploration - -Systematic notebook testing of linear layer alternatives at real model dimensions (768d). - -### Projection Benchmark (DIM → DIM, H100) - -| Model | Params | ms | vs Linear | -|-------|--------|-----|-----------| -| Linear | 589,824 | 0.07ms | 1.00× | -| LowRank r=64 | 98,304 | 0.03ms | 0.44× | -| BlockDiag b=4 | 147,456 | 0.03ms | 0.40× | -| Grouped g=4 | 147,456 | 0.03ms | 0.40× | -| BD4 + mix32 | 196,608 | 0.07ms | 0.97× | -| Hash 65536 | 65,536 | 0.08ms | 1.13× | - -BlockDiag/Grouped offer speed advantages but cross-group isolation degrades LM quality in practice. - ---- - -## H100 Microbenchmark Results - -Standalone kernel timing vs torch.compile behaviour (critical lesson: standalone microbenchmarks can mislead when torch.compile fuses operations). - -### STE Speed - -| Variant | ms/call | -|---------|---------| -| Current | 0.041 | -| Reciprocal | 0.043 | - -No gain — 48 STE calls/step = ~2ms overhead (unavoidable). - -### Contiguous Checks - -Q and K are contiguous after RoPE. V is non-contiguous (view into fused QKV). V's `.contiguous()` costs 0.065ms/call = 0.78ms/step (necessary for flash_attn). - -### RoPE Variants - -Current (half-split + cat) is fastest at 0.52ms/call. - -### Softcap: Poly5 vs Tanh - -| Variant | ms/call | -|---------|---------| -| Poly5 (current) | 8.43 | -| Poly3 | 5.98 | -| Tanh | 2.12 | -| Hardtanh | 0.71 | - -**Critical finding:** Tanh is 4× faster standalone due to H100 hardware transcendental units. However in the real training loop, torch.compile fuses poly5 with surrounding ops into a single kernel. **Switching to tanh broke fusion — F63 was 16ms/step slower.** Poly5 retained. - -### CE + Z-Loss Fusion - -| Variant | ms/call (fwd+bwd) | -|---------|-------------------| -| Separate (current) | 16.56 | -| Fused (shared LSE) | 12.33 | - -**Same lesson:** 4.2ms saving standalone, but torch.compile already optimises `F.cross_entropy`. Manual gather+logsumexp prevents optimisation. Current approach retained. - ---- - -## Efficiency Analysis - -### BPB Gained Per Component - -| Component | BPB gain | Source | -|-----------|----------|--------| -| relu → relu² | −0.024 | F55 vs F56 | -| MLP 2× → 3× (relu²) | −0.017 | F56 vs F64 | -| MLP 3× → 4× (relu²) | −0.008 | F64 vs F75 | -| relu² → swiglu (at 3×) | −0.010 | F64 vs F59 | -| +1 layer (average) | −0.0012 | scaling data | -| fp16 → fp8 (RT penalty) | +0.002 | run 42 vs 49 | -| Sliding eval stride=16 | −0.025 | F22 data | -| WD=0.04 vs WD=0 (at 26L) | −0.001 | F7 vs F6 | - -### MB Cost Per Component - -| Component | MB/layer | -|-----------|----------| -| relu² 2× layer | 0.767 | -| relu² 3× layer | 1.003 | -| relu² 4× layer | 1.220 | -| swiglu 3× layer | 1.357 | -| fp16 → fp8 (fixed saving) | −2.51 | - -### Efficiency Ratio (BPB Gained Per MB Spent) - -| Change | BPB gain | MB cost | BPB/MB | -|--------|----------|---------|--------| -| relu → relu² | −0.024 | 0.00 | infinite (free) | -| Sliding eval | −0.025 | 0.00 | infinite (free) | -| MLP 2× → 3× | −0.017 | +2.83 (12L) | −0.0060/MB | -| MLP 3× → 4× | −0.008 | +2.83 (12L) | −0.0028/MB | -| relu² → swiglu | −0.010 | +4.25 (12L) | −0.0024/MB | -| +1 layer (relu² 2×) | −0.0012 | +0.767 | −0.0016/MB | -| +1 layer (relu² 3×) | −0.0012 | +1.003 | −0.0012/MB | - -MLP 2×→3× is the most efficient paid upgrade. relu² and sliding eval are free wins. - -### Layer Budget at 768d - -| Config | Max Layers | Est ms/step | -|--------|-----------|-------------| -| relu² 2× fp16 | 14L | ~95ms | -| relu² 2× fp8 | 17L | ~97ms | -| relu² 3× fp16 | 10L | ~99ms | -| relu² 3× fp8 | 13L | ~106ms | -| relu² 4× fp8 | 10L | ~92ms | -| swiglu 3× fp8 | 9L | ~105ms | - ---- - -## Ternary-Incompatible Techniques - -These are not merely unhelpful but structurally incompatible with 1.58-bit quantisation: - -| Technique | Mechanism of failure | -|-----------|---------------------| -| **EMA** | Weight averaging → values cluster near zero → ternary rounds most to 0 → 0.12 bpb RT gap | -| **TTT-LoRA** | LoRA delta computed outside RMSNorm space that TernaryLinear normalises into. Corrupts calibrated representations at convergence | -| **Ternary prototypes + sigmoid** | Sigmoid membership needs continuous values. Ternary {-1,0,+1} collapses membership patterns → 0.077 RT gap | -| **LM head rank factorisation** | SVD factors U,V need fp16 precision. Storage exceeds original tied embedding | - ---- - -## Software Optimisations - -| Optimisation | Saving | Notes | -|---|---|---| -| Fused QKV (c_q+c_k+c_v → single matmul) | ~2ms/step | Safe: in_features divisible by all group sizes | -| Fused SwiGLU/relu² (gate+up → single wide matmul) | ~2-4ms/step | Same params, fewer kernel launches | -| Z-loss regularisation (1e-4 x logsumexp²) | quality | Anchors logits, keeps STE gradients sharp | -| DataLoader int16 transfer (pin then cast on GPU) | ~1ms/step | 4× less PCIe bandwidth | -| FlashAttention-3 | ~13ms/step | ~9% speedup, ~380 free training steps | -| TernaryLinear bf16 weights, cleaner STE | ~1ms/step | Eliminates fp32 roundtrip | -| DDP static_graph + gradient_as_bucket_view | ~1ms/step | Free when find_unused=False | -| Fused optimizer loop (LR set + step in one pass) | ~0.5ms/step | Fewer Python-level iterations | -| Removed CUBLAS determinism tax | ~1ms/step | Not required for competition | -| Temperature grid: 5 points instead of 21 | ~1s total | T=0.90 consistently with relu² | -| Temp scaling moved to eval phase | ~3 steps gained | No longer steals training time | -| `_e()` helper for Hyperparameters | -1.8KB code | Eliminates env var boilerplate | -| 3D tensor ternary quantisation | storage fix | Conv1d weights reshaped to 2D for ternary | - ---- - -## Rejected Techniques (Summary) - -| Technique | Reason | -|-----------|--------| -| Tversky (all variants) | Quality-neutral on FineWeb LM — confirmed via synthetic data analysis; speed penalty with relu² | -| Differential attention | Halved head_dim (96→48) degrades quality at this model scale | -| Grouped MLP (g=2, g=4) | Cross-group isolation costs 0.031–0.056 bpb; not recoverable with extra layers | -| CausalConvRefiner | Noise-level quality; storage bloat from Conv1d weights | -| ByteCNN vocabulary generator | Embedding collapse — CNN inductive bias destroys initial diversity | -| Asymmetric tokenizer | Byte independence assumption incorrect; sequence mismatch with eval framework | -| EMA | Incompatible with ternary — weight averaging causes 0.12 bpb RT gap | -| TTT-LoRA | Architectural incompatibility with RMSNorm space in TernaryLinear | -| LM head factorisation | SVD factors bloat artifact beyond budget; unrecoverable quality loss | -| MTP | 0.006 bpb worse — model capacity too limited for auxiliary objectives | -| BigramHash | 0.020 bpb worse at convergence; fp16 table displaces ternary layers | -| Seq/batch schedule | Recompile and step penalties dominate at 600s wallclock | -| SmearModule | +22% step cost for −0.001 gain within ternary 10-minute budget | -| Depth recurrence | Halves effective steps; OOM at DR=3 | -| AdamW for matrix params | Clearly inferior to Muon for ternary weights | -| FP4 storage | 0.026–0.029 RT gap even with QAT — unrecoverable | -| Tanh softcap | Faster standalone but breaks torch.compile kernel fusion | -| Fused CE+Z-loss | Same — breaks compile optimisation | -| 16 heads at 768d | 48-dim head_dim insufficient for meaningful attention | -| relu (plain) | Strictly dominated by relu² | -| leaky relu | Strictly dominated by relu² | -| Distillation (in-run) | Train-from-scratch teacher always worse than supervised | -| reduce-overhead compile | Rotary + embed_proj_rev incompatible with CUDA graphs | -| max-autotune compile | 30+ minute kernel search prohibitive for 600s runs | -| Skip weights zero-init | 0.010 bpb worse — decoder needs skip signal from step 0 | -| EMBED_DIM=0 (full 512) | 19.78MB artifact — 3.78MB over budget | -| Untie lm_head full-rank | 7.3MB budget overrun not justified by 0.005 bpb gain | - ---- - -## Decision Log - -| Decision | Rationale | -|----------|-----------| -| 8k vocabulary | −0.42 bpb, largest single win | -| relu² activation | −0.024 bpb vs relu, free (no cost) | -| 4×MLP width | Best BPB within budget at 10L; 0.008 better than 3× | -| 10L 768d | Minimum viable depth at 768d with maximum MLP width | -| WD=0.0 at 10L 4× | Opposite to deep models — wider MLP needs full weight freedom | -| fp8 storage | Halves fp_params (5MB→2.5MB), enables wider MLP within budget | -| EMBED_DIM=254 | 256-2 dims to fit artifact+code under 16,000,000 byte budget; ~0.0004 bpb cost | -| BITNET_GROUP_SIZE=128 | Same quality as 64; saves 0.69MB | -| 8 heads, 4 KV, 96-dim head_dim | 16h at 48-dim insufficient; MHA only +0.0012 at +1.5MB | -| Poly softcap | Fuses with torch.compile; tanh breaks fusion | -| ROPE_BASE=5000 + YaRN 2048 | Best frequency calibration | -| Muon optimizer | Newton-Schulz normalisation compensates for ternary STE gradient attenuation | -| MUON_BACKEND_STEPS=3 | Equivalent to 5 at convergence; +190 extra steps | -| MUON_MOMENTUM=0.95 | Both directions degrade; affects artifact via zero_frac | -| WARMDOWN=20% | Asymmetric — too little hurts more than too much | -| MATRIX_LR=0.04 | Higher LR compensates for ternary STE gradient attenuation | -| SCALAR_LR=0.02 | Optimal — scalars do not pass through STE | -| TIED_EMBED_LR=0.02 | Optimal | -| TRAIN_BATCH_TOKENS=524k | Optimal tradeoff between gradient quality and step count | -| Base-3 + LZMA | 39% reduction over int8+zlib | -| Shrinkage fix | Eliminates all RT gaps universally | -| Skip weights ones-init | Decoder needs skip signal from step 0; zeros costs 0.010 bpb | -| Tied embeddings | Untie costs 7.3MB; not justified | -| Standard attn projection | Tversky quality-neutral; grouped destroys quality | -| No EMA | Fundamentally incompatible with ternary | -| No TTT | RMSNorm space incompatibility confirmed across 6 runs | -| No MTP | Confirmed post-fix: 0.006 bpb worse | -| Temperature scaling T=0.90 | relu² logits slightly underconfident; auto-calibrated | -| Fused QKV + relu² | ~130-180 free training steps per run | -| Z-loss regularisation | Anchors logits; keeps STE gradients sharp | -| FlashAttention-3 | Free ~380 extra training steps per 600s run | -| Sliding eval stride=16 | Best quality when eval budget unconstrained | -| Optimizer coverage fix | embed_proj/embed_proj_rev now train; +0.055 bpb improvement | -| MAX_WALLCLOCK_SECONDS=599 | 1s leeway for safety margin | -| Binary 15L 768d 4× fp8 | 97M params in 15.67MB — maximum parameter density; convergence ceiling validated at 50k steps | diff --git a/records/track_non_record_16mb/2026-03-24_106M_Binary_Asymmetric_UNet_FP8_15L_8192BPE_YaRN_NeoMuon_Smear/binary_log.txt b/records/track_non_record_16mb/2026-03-24_106M_Binary_Asymmetric_UNet_FP8_15L_8192BPE_YaRN_NeoMuon_Smear/binary_log.txt deleted file mode 100644 index f75377dcdf..0000000000 --- a/records/track_non_record_16mb/2026-03-24_106M_Binary_Asymmetric_UNet_FP8_15L_8192BPE_YaRN_NeoMuon_Smear/binary_log.txt +++ /dev/null @@ -1,1518 +0,0 @@ -"""Binary training script for OpenAI's Parameter Golf Challenge. Ciprian-Florin Ifrim - 24 March 2026""" - -import copy -import glob -import io -import math -import os -import random -import sys -import time -import lzma -from pathlib import Path -import numpy as np -import sentencepiece as spm -import torch -import torch.distributed as dist -import torch.nn.functional as F -from torch import Tensor, nn -from torch.nn.parallel import DistributedDataParallel as DDP -from flash_attn_interface import flash_attn_func - -# --------------------------------------------------------------------------- -# Hyperparameters (all configurable via environment variables) -# --------------------------------------------------------------------------- -def _e(k, d, t=str): - v = os.environ.get(k, str(d)) - if t == bool: return bool(int(v)) - return t(v) - -class Hyperparameters: - data_path = _e("DATA_PATH", "./data/datasets/fineweb10B_sp1024") - train_files = os.path.join(data_path, "fineweb_train_*.bin") - val_files = os.path.join(data_path, "fineweb_val_*.bin") - tokenizer_path = _e("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") - run_id = os.environ.get("RUN_ID", f"run_{int(time.time())}") - seed = _e("SEED", 1337, int) - compile_mode = _e("COMPILE_MODE", "default") - val_batch_size = _e("VAL_BATCH_SIZE", 524288, int) - val_loss_every = _e("VAL_LOSS_EVERY", 500, int) - train_log_every = _e("TRAIN_LOG_EVERY", 10, int) - iterations = _e("ITERATIONS", 2000, int) - warmdown_fraction = _e("WARMDOWN_FRACTION", 0.2, float) - warmup_steps = _e("WARMUP_STEPS", 20, int) - train_batch_tokens = _e("TRAIN_BATCH_TOKENS", 524288, int) - train_seq_len = _e("TRAIN_SEQ_LEN", 1024, int) - max_wallclock_seconds = _e("MAX_WALLCLOCK_SECONDS", 0.0, float) - vocab_size = _e("VOCAB_SIZE", 1024, int) - num_layers = _e("NUM_LAYERS", 16, int) - num_kv_heads = _e("NUM_KV_HEADS", 4, int) - model_dim = _e("MODEL_DIM", 512, int) - num_heads = _e("NUM_HEADS", 8, int) - mlp_mult = _e("MLP_MULT", 2, int) - tie_embeddings = _e("TIE_EMBEDDINGS", 1, int) - rope_base = _e("ROPE_BASE", 10000.0, float) - rope_type = _e("ROPE_TYPE", "rope") - yarn_max_len = _e("YARN_MAX_LEN", 4096, int) - logit_softcap = _e("LOGIT_SOFTCAP", 30.0, float) - softcap_type = _e("SOFTCAP_TYPE", "poly") - tied_embed_init_std = _e("TIED_EMBED_INIT_STD", 0.005, float) - qk_gain_init = _e("QK_GAIN_INIT", 1.5, float) - activation_type = _e("ACTIVATION", "swiglu") - embed_dim = _e("EMBED_DIM", 0, int) - bigram_hash = _e("BIGRAM_HASH", 0, bool) - mtp_heads_count = _e("MTP_HEADS", 0, int) - training_depth_recurrence = _e("TRAINING_DEPTH_RECURRENCE", 1, int) - eval_depth_recurrence = _e("EVAL_DEPTH_RECURRENCE", 1, int) - attn_proj_type = _e("ATTN_PROJ_TYPE", "standard") - logit_head_type = _e("LOGIT_HEAD_TYPE", "standard") - tversky_num_features = _e("TVERSKY_NUM_FEATURES", 16, int) - tversky_feature_pools = _e("TVERSKY_FEATURE_POOLS", 0, int) - tversky_membership = _e("TVERSKY_MEMBERSHIP", "sigmoid") - diff_attn = _e("DIFF_ATTN", 0, bool) - refiner = _e("REFINER", 0, bool) - refiner_kernel = _e("REFINER_KERNEL", 3, int) - mlp_groups = _e("MLP_GROUPS", 0, int) - embed_lr = _e("EMBED_LR", 0.6, float) - head_lr = _e("HEAD_LR", 0.008, float) - adam_lr = _e("ADAM_LR", 1e-3, float) - adam_wd = _e("ADAM_WD", 0.05, float) - untie_at_fraction = _e("UNTIE_AT_FRACTION", 0.0, float) - tied_embed_lr = _e("TIED_EMBED_LR", 0.05, float) - corr_weight_lr = _e("CORR_WEIGHT_LR", 0.05, float) - smear = _e("SMEAR", 0, bool) - seq_len_start = _e("SEQ_LEN_START", 0, int) - seq_schedule_fraction = _e("SEQ_SCHEDULE_FRACTION", 0.33, float) - batch_tokens_start = _e("BATCH_TOKENS_START", 0, int) - batch_schedule_fraction = _e("BATCH_SCHEDULE_FRACTION", 0.33, float) - churn_log_every = _e("CHURN_LOG_EVERY", 500, int) - matrix_lr = _e("MATRIX_LR", 0.04, float) - scalar_lr = _e("SCALAR_LR", 0.04, float) - muon_momentum = _e("MUON_MOMENTUM", 0.95, float) - muon_backend_steps = _e("MUON_BACKEND_STEPS", 5, int) - muon_wd = _e("MUON_WD", 0.0, float) - matrix_optimizer = _e("MATRIX_OPTIMIZER", "muon") - muon_momentum_warmup_start = _e("MUON_MOMENTUM_WARMUP_START", 0.85, float) - muon_momentum_warmup_steps = _e("MUON_MOMENTUM_WARMUP_STEPS", 500, int) - beta1 = _e("BETA1", 0.9, float) - beta2 = _e("BETA2", 0.95, float) - adam_eps = _e("ADAM_EPS", 1e-8, float) - grad_clip_norm = _e("GRAD_CLIP_NORM", 0.0, float) - bitnet_group_size = _e("BITNET_GROUP_SIZE", 64, int) - sliding_eval = _e("SLIDING_EVAL", 0, bool) - sliding_eval_stride = _e("SLIDING_EVAL_STRIDE", 64, int) - sliding_batch_size = _e("SLIDING_BATCH_SIZE", 64, int) - temp_scaling = _e("TEMP_SCALING", 0, bool) - _fp_raw = os.environ.get("FP_STORAGE", "0") - fp_storage = True if _fp_raw == "FP8" else ("fp4" if _fp_raw == "FP4" else False) - ema = _e("EMA", 0, bool) - ema_decay = _e("EMA_DECAY", 0.995, float) - ema_start_fraction = _e("EMA_START_FRACTION", 0.5, float) - -CTP = ("attn_scale","attn_scales","mlp_scale","mlp_scales","resid_mix","resid_mixes","q_gain","diff_lambda","skip_weight","skip_weights","vocab_bias","refiner.gate") - -# --------------------------------------------------------------------------- -# Binary packing — bitpacking (8 weights/byte = 1 bit/param, lossless) -# --------------------------------------------------------------------------- -def pack_binary(q: Tensor) -> tuple[bytes, int]: - bits = ((q.reshape(-1).to(torch.int8) + 1) // 2).numpy().astype(np.uint8) - n = len(bits) - pad = (8 - n % 8) % 8 - if pad: - bits = np.concatenate([bits, np.zeros(pad, dtype=np.uint8)]) - groups = bits.reshape(-1, 8) - packed = np.zeros(len(groups), dtype=np.uint8) - for i in range(8): - packed |= groups[:, i] << i - return packed.tobytes(), n - -def unpack_binary(data: bytes, n: int) -> Tensor: - packed = np.frombuffer(data, dtype=np.uint8) - bits = np.zeros((len(packed), 8), dtype=np.int8) - for i in range(8): - bits[:, i] = (packed >> i) & 1 - flat = bits.reshape(-1)[:n] - return torch.from_numpy(flat.astype(np.int8) * 2 - 1) - -# --------------------------------------------------------------------------- -# FP4 quantization (per-row absmax, 2 values packed per byte) -# --------------------------------------------------------------------------- -def quantize_to_int4(t: Tensor) -> tuple[Tensor, Tensor, list]: - t32 = t.float() - orig_shape = t32.shape - if t32.ndim < 2: - t32 = t32.unsqueeze(0) - absmax = t32.abs().amax(dim=-1, keepdim=True).clamp(min=1e-8) - scale = absmax / 7.0 - q = torch.clamp(torch.round(t32 / scale), -7, 7).to(torch.int8) - flat = q.reshape(-1) - if flat.numel() % 2 != 0: - flat = F.pad(flat, (0, 1)) - low = (flat[0::2] + 8).to(torch.uint8) - high = (flat[1::2] + 8).to(torch.uint8) - return low | (high << 4), scale.half().squeeze(-1), list(orig_shape) - -def dequantize_from_int4(packed: Tensor, scale: Tensor, shape: list) -> Tensor: - low = (packed & 0x0F).to(torch.int8) - 8 - high = ((packed >> 4) & 0x0F).to(torch.int8) - 8 - flat = torch.zeros(packed.numel() * 2, dtype=torch.int8) - flat[0::2] = low - flat[1::2] = high - numel = 1 - for s in shape: - numel *= s - flat = flat[:numel].float() - if len(shape) <= 1: - return (flat * scale.float().squeeze()).reshape(shape) - return (flat.reshape(-1, shape[-1]) * scale.float().unsqueeze(-1)).reshape(shape) - -# --------------------------------------------------------------------------- -# State dict serialization (binary + fp16/fp8/fp4) -# --------------------------------------------------------------------------- -def q_sd(state_dict: dict, group_size: int = 64, fp_storage=False, binary_override_names: set | None = None) -> tuple[dict, dict]: - "Binary for large 2D weight matrices, fp16/fp8/fp4 for everything else." - quantized = {} - stats = {"binary_params": 0, "binary_bytes": 0, "fp_params": 0, "fp_bytes": 0} - for name, tensor in state_dict.items(): - if "mtp_heads" in name: - continue - t = tensor.detach().cpu().float().contiguous() - t_orig_shape = list(t.shape) - if t.ndim == 3: - t = t.reshape(t.shape[0], -1) - is_binary_candidate = ( - t.ndim == 2 and t.numel() > 65_536 - and "tok_emb" not in name and "lm_head" not in name and "embed_proj" not in name and "bigram_emb" not in name and "lm_head_correction" not in name and "lm_head_U" not in name and "lm_head_V" not in name - and "prototypes" not in name and "tversky" not in name - ) or (binary_override_names is not None and name in binary_override_names) - if is_binary_candidate: - pad = (group_size - t.shape[1] % group_size) % group_size - t_padded = F.pad(t, (0, pad)) if pad > 0 else t - t_grouped = t_padded.reshape(-1, group_size) - scale = t_grouped.abs().mean(-1, keepdim=True).clamp(min=1e-8).half().float() - q = torch.where(t_grouped >= 0, - torch.ones_like(t_grouped, dtype=torch.int8), - -torch.ones_like(t_grouped, dtype=torch.int8)) - packed_bytes, n_bits = pack_binary(q) - quantized[name] = { - "type": "binary", "packed": packed_bytes, - "scale": scale.half().squeeze(-1), - "shape": list(t.shape), "padded_cols": t_padded.shape[1], - "group_size": group_size, "n_bits": n_bits, - "orig_shape": t_orig_shape, - } - stats["binary_params"] += t.numel() - stats["binary_bytes"] += len(packed_bytes) + scale.numel() * 2 - elif fp_storage == "fp4" and t.ndim == 2: - packed, scale, orig_shape = quantize_to_int4(t) - quantized[name] = {"type": "fp4", "packed": packed, "scale": scale, "shape": orig_shape} - stats["fp_params"] += t.numel() - stats["fp_bytes"] += packed.numel() + scale.numel() * 2 - elif fp_storage and t.ndim == 2: - quantized[name] = {"type": "fp8", "data": t.to(torch.float8_e4m3fn)} - stats["fp_params"] += t.numel() - stats["fp_bytes"] += t.numel() - else: - quantized[name] = {"type": "fp16", "data": t.half()} - stats["fp_params"] += t.numel() - stats["fp_bytes"] += t.numel() * 2 - return quantized, stats - -def deq_sd(quantized: dict, target_dtype=torch.bfloat16): - "Reconstruct full-precision state dict from quantized representation." - out = {} - for name, entry in quantized.items(): - if entry["type"] == "binary": - q = unpack_binary(entry["packed"], entry["n_bits"]) - q = q.float().reshape(-1, entry["group_size"]) - scale = entry["scale"].float().unsqueeze(-1) - # No shrinkage correction needed: binary has no zeros, q.abs().mean() == 1.0 always - t = (q * scale).reshape(-1, entry["padded_cols"]) - shape = entry["shape"] - result = t[:shape[0], :shape[1]].to(target_dtype) - orig = entry.get("orig_shape") - out[name] = result.reshape(orig).contiguous() if orig and orig != shape else result.contiguous() - elif entry["type"] == "fp8": - out[name] = entry["data"].to(torch.float32).to(target_dtype).contiguous() - elif entry["type"] == "fp4": - out[name] = dequantize_from_int4(entry["packed"], entry["scale"], entry["shape"]).to(target_dtype).contiguous() - else: - out[name] = entry["data"].to(target_dtype).contiguous() - return out - -# --------------------------------------------------------------------------- -# Binary diagnostics (logged during training) -# --------------------------------------------------------------------------- -_prev_committed: dict = {} -def churn_fn(model: nn.Module, group_size: int = 64): - global _prev_committed - total = flipped = 0 - with torch.no_grad(): - for name, p in model.named_parameters(): - if p.ndim == 2 and ("weight" in name or "prototypes" in name) and p.shape[0] > 1: - w = p.detach().float().reshape(-1, group_size) - q = torch.where(w >= 0, torch.ones_like(w), -torch.ones_like(w)).cpu().numpy() - if name in _prev_committed: - flipped += int(np.sum(q != _prev_committed[name])) - total += q.size - _prev_committed[name] = q - return flipped / max(total, 1) - -# --------------------------------------------------------------------------- -# Muon optimizer (Newton-Schulz orthogonalized momentum) -# --------------------------------------------------------------------------- -def ns_orth(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: - a, b, c = (3.4445, -4.7750, 2.0315) - X = G.bfloat16() - X /= X.norm() + eps - transposed = G.size(0) > G.size(1) - if transposed: - X = X.T - for _ in range(steps): - A = X @ X.T - B = b * A + c * A @ A - X = a * X + B @ X - return X.T if transposed else X - -class Muon(torch.optim.Optimizer): - def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True, wd: float = 0.0): - super().__init__(params, dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov, wd=wd)) - @torch.no_grad() - def step(self, closure=None): - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - distributed = dist.is_available() and dist.is_initialized() - world_size = dist.get_world_size() if distributed else 1 - rank = dist.get_rank() if distributed else 0 - for group in self.param_groups: - params = group["params"] - if not params: - continue - lr, momentum = group["lr"], group["momentum"] - backend_steps, nesterov = group["backend_steps"], group["nesterov"] - total_params = sum(int(p.numel()) for p in params) - updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) - curr = 0 - for i, p in enumerate(params): - if i % world_size == rank and p.grad is not None: - g = p.grad - state = self.state[p] - if "momentum_buffer" not in state: - state["momentum_buffer"] = torch.zeros_like(g) - buf = state["momentum_buffer"] - buf.mul_(momentum).add_(g) - if nesterov: - g = g.add(buf, alpha=momentum) - g = F.rms_norm(g.float(), (g.size(-1),)).bfloat16() - g = ns_orth(g, steps=backend_steps) - g *= max(1, g.size(0) / g.size(1)) ** 0.5 - updates_flat[curr:curr + p.numel()] = g.reshape(-1) - curr += p.numel() - if distributed: - dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) - wd = group.get("wd", 0.0) - curr = 0 - for p in params: - g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) - if wd > 0: - p.mul_(1 - lr * wd) - p.add_(g, alpha=-lr) - curr += p.numel() - return loss - -# --------------------------------------------------------------------------- -# Data loading -# --------------------------------------------------------------------------- -def ld_shard(file: Path) -> Tensor: - header_bytes = 256 * np.dtype(" Tensor: - chunks = [] - remaining = n - while remaining > 0: - avail = self.tokens.numel() - self.pos - if avail <= 0: - self._advance_file() - continue - k = min(remaining, avail) - chunks.append(self.tokens[self.pos:self.pos + k]) - self.pos += k - remaining -= k - return chunks[0] if len(chunks) == 1 else torch.cat(chunks) - -class DistributedTokenLoader: - def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): - self.rank, self.world_size, self.device = rank, world_size, device - self.stream = TokenStream(pattern) - def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: - local_tokens = global_tokens // (self.world_size * grad_accum_steps) - per_rank_span = local_tokens + 1 - chunk = self.stream.take(per_rank_span * self.world_size) - start = self.rank * per_rank_span - local = chunk[start:start + per_rank_span].pin_memory().to(self.device, non_blocking=True).to(torch.int64) - x = local[:-1].reshape(-1, seq_len) - y = local[1:].reshape(-1, seq_len) - return x, y -# --------------------------------------------------------------------------- -# Model -# --------------------------------------------------------------------------- -class RMSNorm(nn.Module): - def __init__(self, eps: float | None = None): - super().__init__() - self.eps = eps - def forward(self, x: Tensor) -> Tensor: - return F.rms_norm(x, (x.size(-1),), eps=self.eps) - -def apply_qat_ste(w: Tensor, fp_storage: str | bool) -> Tensor: - """Applies Straight-Through Estimator (STE) for FP4 or FP8 simulated quantization.""" - if not fp_storage: - return w - if fp_storage == "fp4": - absmax = w.abs().amax(dim=-1, keepdim=True).clamp(min=1e-8) - scale = absmax / 7.0 - q = torch.clamp(torch.round(w / scale), -7.0, 7.0) - w_sim = q * scale - return (w_sim - w).detach() + w - elif fp_storage is True or fp_storage == "fp8": - w_sim = w.to(torch.float8_e4m3fn).to(w.dtype) - return (w_sim - w).detach() + w - return w - -class QATLinear(nn.Linear): - def __init__(self, in_features: int, out_features: int, bias: bool = False, fp_storage: str | bool = False): - super().__init__(in_features, out_features, bias=bias) - self.fp_storage = fp_storage - def forward(self, x: Tensor) -> Tensor: - w_qat = apply_qat_ste(self.weight, self.fp_storage) - return F.linear(x, w_qat.to(x.dtype), self.bias.to(x.dtype) if self.bias is not None else None) - -class QATEmbedding(nn.Embedding): - def __init__(self, num_embeddings: int, embedding_dim: int, fp_storage: str | bool = False): - super().__init__(num_embeddings, embedding_dim) - self.fp_storage = fp_storage - def forward(self, input: Tensor) -> Tensor: - w_qat = apply_qat_ste(self.weight, self.fp_storage) - return F.embedding(input, w_qat, self.padding_idx, self.max_norm, - self.norm_type, self.scale_grad_by_freq, self.sparse) - -class BinaryLinear(nn.Linear): - def __init__(self, in_features, out_features, bias=False, group_size=64): - super().__init__(in_features, out_features, bias=bias) - self.group_size = group_size - def forward(self, x: Tensor) -> Tensor: - w = self.weight.bfloat16() - g = self.group_size - w_g = w.reshape(-1, g) - scale = w_g.abs().mean(-1, keepdim=True).clamp(min=1e-8) - q = torch.where(w_g >= 0, torch.ones_like(w_g), -torch.ones_like(w_g)) - w_binary = w + ((q * scale).reshape(w.shape) - w).detach() - return F.linear(x, w_binary, - self.bias.to(x.dtype) if self.bias is not None else None) - -class NormedBinaryLinear(BinaryLinear): - "Binary linear with RMSNorm on input — for output projections receiving un-normalized activations." - def forward(self, x: Tensor) -> Tensor: - return super().forward(F.rms_norm(x, (x.size(-1),))) - -class GroupedBinaryLinear(nn.Module): - "Grouped linear with binary STE. Weight stored as 2D [groups*group_out, group_in] for binary quantization compatibility." - def __init__(self, in_features, out_features, groups=4, group_size=64, normed=False): - super().__init__() - assert in_features % groups == 0 and out_features % groups == 0 - self.groups = groups - self.group_in = in_features // groups - self.group_out = out_features // groups - self.group_size = group_size - self.normed = normed - self.weight = nn.Parameter(torch.randn(groups * self.group_out, self.group_in) * 0.02) - def forward(self, x: Tensor) -> Tensor: - if self.normed: - x = F.rms_norm(x, (x.size(-1),)) - w = self.weight.bfloat16() - g = self.group_size - w_g = w.reshape(-1, g) - scale = w_g.abs().mean(-1, keepdim=True).clamp(min=1e-8) - q = torch.where(w_g >= 0, torch.ones_like(w_g), -torch.ones_like(w_g)) - w_binary = w + ((q * scale).reshape(w.shape) - w).detach() - w_grouped = w_binary.reshape(self.groups, self.group_out, self.group_in) - bsz = x.shape[:-1] - x_g = x.reshape(*bsz, self.groups, self.group_in) - out = torch.einsum('...gi,goi->...go', x_g, w_grouped) - return out.reshape(*bsz, self.groups * self.group_out) - -class TverskyProjection(nn.Module): - "Tversky similarity: S = θ·f(A∩B) - α·f(A\\B) - β·f(B\\A). Three modes." - def __init__(self, in_features: int, out_features: int, num_features: int = 16, - group_size: int = 64, use_shared_features: bool = False, - membership: str = "sigmoid"): - super().__init__() - self.group_size = group_size - self.num_features = num_features - self.membership_type = membership - self.no_features_mode = (num_features == 0) - if not self.no_features_mode and not use_shared_features: - self.features = nn.Parameter(torch.empty(num_features, in_features).uniform_(-0.02, 0.02)) - else: - self.register_parameter('features', None) - self.prototypes = nn.Parameter(torch.empty(out_features, in_features).uniform_(-0.02, 0.02)) - self.theta = nn.Parameter(torch.tensor(1.0)) - self.alpha = nn.Parameter(torch.tensor(0.5)) - self.beta = nn.Parameter(torch.tensor(0.5)) - - def _binary_ste(self, w: Tensor) -> Tensor: - w_bf16 = w.bfloat16() - g = self.group_size - w_grouped = w_bf16.reshape(-1, g) - scale = w_grouped.abs().mean(-1, keepdim=True).clamp(min=1e-8) - q = torch.where(w_grouped >= 0, torch.ones_like(w_grouped), -torch.ones_like(w_grouped)) - w_binary = w_bf16 + ((q * scale).reshape(w_bf16.shape) - w_bf16).detach() - return w_binary.reshape(w.shape) - - def _membership(self, t: Tensor) -> Tensor: - if self.membership_type == "poly": - return torch.clamp(t * 5.0 / 4.0 + 0.5, 0.0, 1.0) - elif self.membership_type == "tanh": - return (torch.tanh(t * 5.0) + 1.0) * 0.5 - else: - return torch.sigmoid(t * 5.0) - - def forward(self, x: Tensor, shared_features: Tensor | None = None) -> Tensor: - proto = self._binary_ste(self.prototypes) - if self.no_features_mode: - x_f = x @ proto.t() - p_norm = F.normalize(proto, dim=-1) - p_f = p_norm @ p_norm.t() - else: - feat = (shared_features if shared_features is not None else self.features).float() - x_f = x @ feat.t() - p_f = proto @ feat.t() - x_s = self._membership(x_f) - p_s = self._membership(p_f) - x_a = x_f * x_s - p_a = p_f * p_s - t, a, b = self.theta.abs(), self.alpha.abs(), self.beta.abs() - return t * (x_a @ p_a.t()) - a * (x_a @ (1 - p_s).t()) - b * ((1 - x_s) @ p_a.t()) - -def restore_low_dim_params_to_fp32(module: nn.Module) -> None: - with torch.no_grad(): - for name, param in module.named_parameters(): - if (param.ndim < 2 or any(p in name for p in CTP)) and param.dtype != torch.float32: - param.data = param.data.float() - -class Rotary(nn.Module): - def __init__(self, dim: int, base: float = 10000.0, no_cache: bool = False, - rope_type: str = "rope", yarn_max_len: int = 4096, train_seq_len: int = 1024): - super().__init__() - self.no_cache = no_cache - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) - if rope_type == "yarn": - scale = train_seq_len / yarn_max_len - freq_idx = torch.arange(0, dim, 2, dtype=torch.float32) - ramp = torch.clamp((freq_idx / dim - 0.25) / 0.75, 0.0, 1.0) - inv_freq = inv_freq / (ramp * (1.0 / scale - 1.0) + 1.0) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self._seq_len_cached = 0 - self._cos_cached: Tensor | None = None - self._sin_cached: Tensor | None = None - def forward(self, seq_len, device, dtype): - if self.no_cache: - t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) - freqs = torch.outer(t, self.inv_freq.to(device)) - return freqs.cos()[None, :, None, :].to(dtype=dtype), freqs.sin()[None, :, None, :].to(dtype=dtype) - if ( - self._cos_cached is None - or self._sin_cached is None - or self._seq_len_cached != seq_len - or self._cos_cached.device != device - ): - t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) - freqs = torch.outer(t, self.inv_freq.to(device)) - self._cos_cached = freqs.cos()[None, :, None, :] - self._sin_cached = freqs.sin()[None, :, None, :] - self._seq_len_cached = seq_len - return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) - -def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: - half = x.size(-1) // 2 - x1, x2 = x[..., :half], x[..., half:] - return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) - -class CausalSelfAttention(nn.Module): - def __init__(self, dim, num_heads, num_kv_heads, rope_base, qk_gain_init, - group_size=64, attn_proj_type="standard", tversky_num_features=16, - tversky_feature_pools=0, no_cache=False, rope_type="rope", - yarn_max_len=4096, train_seq_len=1024, tversky_membership="sigmoid", - diff_attn=False): - super().__init__() - self.num_heads, self.num_kv_heads = num_heads, num_kv_heads - self.head_dim = dim // num_heads - self.diff_attn = diff_attn - self.q_size = self.num_heads * self.head_dim - self.kv_size = self.num_kv_heads * self.head_dim - self.c_qkv = BinaryLinear(dim, self.q_size + 2 * self.kv_size, bias=False, group_size=group_size) - self.proj = NormedBinaryLinear(dim, dim, bias=False, group_size=group_size) if attn_proj_type != "tversky" else None - if self.proj is not None: - self.proj._zero_init = True - self.tversky_proj = TverskyProjection( - dim, dim, num_features=tversky_num_features, group_size=group_size, - use_shared_features=(tversky_feature_pools > 0), - membership=tversky_membership, - ) if attn_proj_type == "tversky" else None - self.shared_features = None - self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) - if diff_attn: - self.diff_lambda = nn.Parameter(torch.full((num_heads,), 0.5, dtype=torch.float32)) - self.rotary = Rotary(self.head_dim, base=rope_base, no_cache=no_cache, - rope_type=rope_type, yarn_max_len=yarn_max_len, - train_seq_len=train_seq_len) - def forward(self, x: Tensor) -> Tensor: - bsz, seqlen, dim = x.shape - qkv_out = self.c_qkv(x) - q_out, k_out, v_out = qkv_out.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - q = q_out.reshape(bsz, seqlen, self.num_heads, self.head_dim) - k = k_out.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) - v = v_out.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) - q, k = F.rms_norm(q, (q.size(-1),)), F.rms_norm(k, (k.size(-1),)) - cos, sin = self.rotary(seqlen, x.device, q.dtype) - q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin) - q = q * self.q_gain.to(dtype=q.dtype)[None, None, :, None] - if self.diff_attn: - half = self.head_dim // 2 - q1, q2 = q[..., :half], q[..., half:] - k1, k2 = k[..., :half], k[..., half:] - v1, v2 = v[..., :half], v[..., half:] - y1 = flash_attn_func(q1.contiguous(), k1.contiguous(), v1.contiguous(), causal=True) - y2 = flash_attn_func(q2.contiguous(), k2.contiguous(), v2.contiguous(), causal=True) - lam = self.diff_lambda.to(dtype=y1.dtype)[None, None, :, None] - y = torch.cat([y1 - lam * y2, y1 + lam * y2], dim=-1) - else: - y = flash_attn_func( - q.contiguous(), - k.contiguous(), - v.contiguous(), - causal=True - ) - y = y.reshape(bsz, seqlen, dim) - return self.tversky_proj(y, self.shared_features) if self.tversky_proj is not None else self.proj(y) - -class MLP(nn.Module): - def __init__(self, dim, mlp_mult, group_size=64, activation="swiglu", mlp_groups=0): - super().__init__() - hidden = mlp_mult * dim - self.activation = activation - if mlp_groups > 0: - if activation == "swiglu": - self.gate_up = GroupedBinaryLinear(dim, hidden * 2, groups=mlp_groups, group_size=group_size) - else: - self.fc = GroupedBinaryLinear(dim, hidden, groups=mlp_groups, group_size=group_size) - self.proj = GroupedBinaryLinear(hidden, dim, groups=mlp_groups, group_size=group_size, normed=True) - else: - if activation == "swiglu": - self.gate_up = BinaryLinear(dim, hidden * 2, bias=False, group_size=group_size) - else: - self.fc = BinaryLinear(dim, hidden, bias=False, group_size=group_size) - self.proj = NormedBinaryLinear(hidden, dim, bias=False, group_size=group_size) - self.proj._zero_init = True - def forward(self, x: Tensor) -> Tensor: - if self.activation == "swiglu": - gu = self.gate_up(x) - gate, up = gu.chunk(2, dim=-1) - return self.proj(F.silu(gate) * up) - elif self.activation == "relu": - return self.proj(torch.relu(self.fc(x))) - elif self.activation == "leaky_relu": - return self.proj(F.leaky_relu(self.fc(x), negative_slope=0.01)) - else: # relu2 - return self.proj(torch.relu(self.fc(x)).square()) - -class SmearModule(nn.Module): - def __init__(self, dim: int): - super().__init__() - self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) - def forward(self, x: Tensor) -> Tensor: - cumsum = x.cumsum(dim=1) - counts = torch.arange(1, x.size(1) + 1, device=x.device, dtype=x.dtype).view(1, -1, 1) - smeared = cumsum / counts - gate = torch.tanh(self.gate.to(dtype=x.dtype)) - return x + gate * (smeared - x) - -class CausalConvRefiner(nn.Module): - "Causal Conv1d that refines hidden states using local n-gram context." - def __init__(self, dim: int, kernel_size: int = 3): - super().__init__() - self.kernel_size = kernel_size - self.conv = nn.Conv1d(dim, dim, kernel_size, padding=0, bias=False) - self.gate = nn.Parameter(torch.zeros(1, dtype=torch.float32)) - def forward(self, x: Tensor) -> Tensor: - h = x.permute(0, 2, 1) - h = F.pad(h, (self.kernel_size - 1, 0)) - h = self.conv(h) - h = h.permute(0, 2, 1) - return x + torch.tanh(self.gate.to(dtype=x.dtype)) * F.rms_norm(h, (h.size(-1),)) - -class Block(nn.Module): - def __init__(self, dim: int, num_heads: int, num_kv_heads: int, mlp_mult: int, - rope_base: float, qk_gain_init: float, group_size: int=64, - activation: str="swiglu", attn_proj_type: str="standard", - tversky_num_features: int=16, tversky_feature_pools: int=0, no_cache: bool=False, - smear: bool=False, rope_type: str="rope", yarn_max_len: int=4096, - train_seq_len: int=1024, tversky_membership: str="sigmoid", - diff_attn: bool=False, mlp_groups: int=0): - super().__init__() - self.attn_norm = RMSNorm() - self.mlp_norm = RMSNorm() - self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init, - group_size, attn_proj_type, tversky_num_features, - tversky_feature_pools, no_cache, rope_type, yarn_max_len, - train_seq_len, tversky_membership, diff_attn) - self.mlp = MLP(dim, mlp_mult, group_size, activation, mlp_groups) - self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) - self.smear = SmearModule(dim) if smear else None - def forward(self, x: Tensor, x0: Tensor) -> Tensor: - mix = self.resid_mix.to(dtype=x.dtype) - x = mix[0] * x + mix[1] * x0 - n = self.attn_norm(x) - x = x + self.attn_scale.to(dtype=x.dtype) * self.attn(n) - x = x + self.mlp_scale.to(dtype=x.dtype) * self.mlp(self.mlp_norm(x)) - if self.smear is not None: - x = self.smear(x) - return x - -class GPT(nn.Module): - def __init__(self, vocab_size, num_layers, model_dim, num_heads, num_kv_heads, mlp_mult, - tie_embeddings, tied_embed_init_std, logit_softcap, rope_base, qk_gain_init, - group_size: int = 64, activation: str = "swiglu", mtp_heads_count: int = 0, - embed_dim: int = 0, attn_proj_type: str = "standard", logit_head_type: str = "standard", - tversky_num_features: int = 16, tversky_feature_pools: int = 0, - training_depth_recurrence: int=1, fp_storage=False, bigram_hash: bool=False, - softcap_type: str="poly", no_cache: bool=False, - smear: bool=False, rope_type: str="rope", yarn_max_len: int=4096, - train_seq_len: int=1024, tversky_membership: str="sigmoid", - diff_attn=False, mlp_groups=0, refiner=False, refiner_kernel=3): - super().__init__() - self.training_depth_recurrence = training_depth_recurrence - self.fp_storage = fp_storage - self.tie_embeddings = tie_embeddings - self.logit_softcap = logit_softcap - self.softcap_type = softcap_type - self.embed_dim = embed_dim if embed_dim > 0 else model_dim - self.tok_emb = QATEmbedding(vocab_size, self.embed_dim, fp_storage=fp_storage) - self.bigram_emb = QATEmbedding(vocab_size, self.embed_dim, fp_storage=fp_storage) if bigram_hash else None - if self.bigram_emb is not None: - nn.init.zeros_(self.bigram_emb.weight) - self.lm_head_correction = nn.Parameter( - torch.zeros(vocab_size, self.embed_dim)) if tie_embeddings == 2 else None - self.embed_proj = QATLinear(self.embed_dim, model_dim, bias=False, fp_storage=fp_storage) if self.embed_dim != model_dim else None - self.embed_proj_rev = QATLinear(model_dim, self.embed_dim, bias=False, fp_storage=fp_storage) if ( - self.embed_dim != model_dim and logit_head_type != "tversky") else None - self.num_encoder_layers = num_layers // 2 - self.num_decoder_layers = num_layers - self.num_encoder_layers - self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) - self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) - # Shared Tversky feature pools (if enabled and num_features > 0) - if attn_proj_type == "tversky" and tversky_feature_pools > 0 and tversky_num_features > 0: - self.tversky_feature_pools_list = nn.ParameterList([ - nn.Parameter(torch.empty(tversky_num_features, model_dim).uniform_(-0.02, 0.02)) - for _ in range(tversky_feature_pools) - ]) - else: - self.tversky_feature_pools_list = None - self.blocks = nn.ModuleList([ - Block(model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init, - group_size, activation, attn_proj_type, tversky_num_features, tversky_feature_pools, - no_cache, smear, rope_type, yarn_max_len, train_seq_len, tversky_membership, - diff_attn, mlp_groups) - for _ in range(num_layers) - ]) - # Inject shared feature pool references into attention layers - if self.tversky_feature_pools_list is not None: - for i, block in enumerate(self.blocks): - pool_idx = (i * tversky_feature_pools) // num_layers - block.attn.shared_features = self.tversky_feature_pools_list[pool_idx] - self.final_norm = RMSNorm() - self.refiner = CausalConvRefiner(model_dim, kernel_size=refiner_kernel) if refiner else None - self.mtp_heads = nn.ModuleList([ - nn.Linear(model_dim, vocab_size, bias=False) for _ in range(mtp_heads_count) - ]) - for h in self.mtp_heads: - nn.init.zeros_(h.weight) - self.logit_head_type = logit_head_type - if logit_head_type == "tversky" and tversky_num_features == 0 and vocab_size > 1024: - raise ValueError( - f"Tversky logit head with no-features mode creates O(V^2) = {vocab_size}x{vocab_size} " - f"matrix per forward pass. Use tversky_num_features > 0 or a smaller vocab." - ) - self.tversky_head = TverskyProjection( - model_dim, vocab_size, num_features=tversky_num_features, - membership=tversky_membership, - ) if logit_head_type == "tversky" else None - self.lm_head = QATLinear(model_dim, vocab_size, bias=False, fp_storage=fp_storage) - self.lm_head._zero_init = True - if self.lm_head is not None and (tie_embeddings or logit_head_type == "tversky"): - self.lm_head.weight.requires_grad_(False) - self.vocab_bias = nn.Parameter(torch.zeros(vocab_size, dtype=torch.float32)) - self._init_weights(tied_embed_init_std) - def _init_weights(self, tied_embed_init_std: float) -> None: - if self.tie_embeddings: - nn.init.normal_(self.tok_emb.weight, mean=0.0, std=tied_embed_init_std) - for module in self.modules(): - if isinstance(module, BinaryLinear) and not getattr(module, "_zero_init", False): - nn.init.normal_(module.weight, mean=0.0, std=0.02) - elif isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): - nn.init.zeros_(module.weight) - def _compute_logits(self, x: Tensor) -> Tensor: - if self.tversky_head is not None: - logits_raw = self.tversky_head(x) - elif self.tie_embeddings: - if self.embed_proj_rev is not None: - proj = self.embed_proj_rev(x) - else: - proj = x - weight = self.tok_emb.weight - if self.lm_head_correction is not None: - weight = weight + self.lm_head_correction - logits_raw = F.linear(proj, weight.to(x.dtype)) - else: - logits_raw = self.lm_head(x) - return logits_raw + self.vocab_bias.to(x.dtype) - def _softcap(self, logits: Tensor) -> Tensor: - s = self.logit_softcap - if self.softcap_type == "tanh": - return s * torch.tanh(logits / s) - x_sc = torch.clamp(logits / s, -2.0, 2.0) - x2 = x_sc * x_sc - return s * torch.clamp(x_sc * (1.0 - x2 / 3.0 + x2 * x2 / 15.0), -1.0, 1.0) - def forward(self, input_ids: Tensor, target_ids: Tensor, reduction: str = "mean", temperature: float = 1.0) -> Tensor: - x = self.tok_emb(input_ids).float() - if self.bigram_emb is not None: - prev = F.pad(input_ids[:, :-1], (1, 0), value=0) - x = x + self.bigram_emb(prev).float() - if self.embed_proj is not None: - x = self.embed_proj(x) - x = F.rms_norm(x, (x.size(-1),)) - x0 = x - # U-Net style encoder/decoder with skip connections - skips = [] - for i in range(self.num_encoder_layers): - for _ in range(max(1, self.training_depth_recurrence)): - x = self.blocks[i](x, x0) - skips.append(x) - for i in range(self.num_decoder_layers): - bi = self.num_encoder_layers + i - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype) * skips.pop() - for _ in range(max(1, self.training_depth_recurrence)): - x = self.blocks[bi](x, x0) - x_normed = self.final_norm(x) - if self.refiner is not None: - x_normed = self.refiner(x_normed) - # Standard training/eval path - x_flat = x_normed.reshape(-1, x_normed.size(-1)) - targets = target_ids.reshape(-1) - logits = self._softcap(self._compute_logits(x_flat)) - if reduction == "none": - return F.cross_entropy(logits.float(), targets, reduction="none").reshape(input_ids.shape) - # Fused CE + Z-loss: single logsumexp computation - logits_f = logits.float() - lse = torch.logsumexp(logits_f, dim=-1) - target_logits = logits_f.gather(1, targets.unsqueeze(1)).squeeze(1) - main_loss = (lse - target_logits).mean() + 1e-4 * (lse ** 2).mean() - # Multi-token prediction auxiliary loss (training only) - if self.training and len(self.mtp_heads) > 0: - mtp_loss = torch.zeros((), device=main_loss.device) - for k, head in enumerate(self.mtp_heads): - shift = k + 2 - if target_ids.shape[1] > shift: - mtp_tgt = target_ids[:, shift:].reshape(-1) - mtp_in = x_normed[:, :target_ids.shape[1] - shift, :].reshape(-1, x_normed.shape[-1]) - mtp_loss = mtp_loss + F.cross_entropy(head(mtp_in).float(), mtp_tgt, reduction="mean") - main_loss = main_loss + 0.1 * mtp_loss / len(self.mtp_heads) - return main_loss - -# --------------------------------------------------------------------------- -# Validation -# --------------------------------------------------------------------------- -def build_luts(sp, vocab_size: int, device: torch.device): - sp_vocab_size = int(sp.vocab_size()) - table_size = max(sp_vocab_size, vocab_size) - base_bytes_np = np.zeros((table_size,), dtype=np.int16) - has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) - is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) - for token_id in range(sp_vocab_size): - if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): - continue - is_boundary_token_np[token_id] = False - if sp.is_byte(token_id): - base_bytes_np[token_id] = 1 - continue - piece = sp.id_to_piece(token_id) - if piece.startswith("\u2581"): - has_leading_space_np[token_id] = True - piece = piece[1:] - base_bytes_np[token_id] = len(piece.encode("utf-8")) - return ( - torch.tensor(base_bytes_np, dtype=torch.int16, device=device), - torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), - torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), - ) - -def ld_val(pattern, seq_len, max_tok=int(os.environ.get("VAL_MAX_TOKENS", 500000))): - files = sorted(glob.glob(pattern)) - assert files, f"No files: {pattern}" - tok = torch.cat([ld_shard(Path(p)) for p in files]).contiguous() - if max_tok > 0: tok = tok[:max_tok + 1] - u = ((tok.numel() - 1) // seq_len) * seq_len - return tok[:u + 1] - -def eval_val(args, model, rank, world_size, device, grad_accum_steps, val_tokens, - base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, temperature: float = 1.0): - local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) - local_batch_seqs = max(1, local_batch_tokens // args.train_seq_len) - total_seqs = (val_tokens.numel() - 1) // args.train_seq_len - seq_start = (total_seqs * rank) // world_size - seq_end = (total_seqs * (rank + 1)) // world_size - loss_sum = torch.zeros((), device=device, dtype=torch.float64) - token_count = torch.zeros((), device=device, dtype=torch.float64) - byte_count = torch.zeros((), device=device, dtype=torch.float64) - model.eval() - with torch.inference_mode(): - for batch_start in range(seq_start, seq_end, local_batch_seqs): - batch_end = min(batch_start + local_batch_seqs, seq_end) - raw_start = batch_start * args.train_seq_len - raw_end = batch_end * args.train_seq_len + 1 - local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64) - x, y = local[:-1].reshape(-1, args.train_seq_len), local[1:].reshape(-1, args.train_seq_len) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16): - batch_loss = model(x, y, temperature=temperature).detach() - n = float(y.numel()) - loss_sum += batch_loss.to(torch.float64) * n - token_count += n - prev_ids, tgt_ids = x.reshape(-1), y.reshape(-1) - tok_bytes = base_bytes_lut[tgt_ids].to(torch.int16) - tok_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(torch.int16) - byte_count += tok_bytes.to(torch.float64).sum() - if dist.is_available() and dist.is_initialized(): - for t in (loss_sum, token_count, byte_count): - dist.all_reduce(t, op=dist.ReduceOp.SUM) - val_loss = loss_sum / token_count - bpb = (val_loss.item() / math.log(2.0)) * (token_count.item() / byte_count.item()) - model.train() - return float(val_loss.item()), float(bpb) - -def eval_val_sliding(args, model, rank, world_size, device, grad_accum_steps, val_tokens, - base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - stride: int = 64, temperature: float = 1.0): - seq_len = args.train_seq_len - batch_size = args.sliding_batch_size - total_tokens = val_tokens.numel() - 1 - loss_sum = torch.zeros((), device=device, dtype=torch.float64) - token_count = torch.zeros((), device=device, dtype=torch.float64) - byte_count = torch.zeros((), device=device, dtype=torch.float64) - all_starts = list(range(0, total_tokens - seq_len, stride)) - my_starts = all_starts[rank::world_size] - model.eval() - with torch.inference_mode(): - for i in range(0, len(my_starts), batch_size): - batch_starts = my_starts[i:i + batch_size] - starts_t = torch.tensor(batch_starts, dtype=torch.int64) - offsets = torch.arange(seq_len + 1, dtype=torch.int64) - indices = starts_t.unsqueeze(1) + offsets.unsqueeze(0) - local_batch = val_tokens[indices].to(device=device, dtype=torch.int64, non_blocking=True) - x = local_batch[:, :-1] - y = local_batch[:, 1:] - with torch.autocast(device_type="cuda", dtype=torch.bfloat16): - per_token_loss = model(x, y, reduction="none", temperature=temperature).detach() - for b, start in enumerate(batch_starts): - score_from = 0 if start == 0 else seq_len - stride - scored = per_token_loss[b, score_from:] - sx, sy = x[b, score_from:], y[b, score_from:] - loss_sum += scored.to(torch.float64).sum() - token_count += scored.numel() - tok_bytes = base_bytes_lut[sy].to(torch.int16) - tok_bytes += (has_leading_space_lut[sy] & ~is_boundary_token_lut[sx]).to(torch.int16) - byte_count += tok_bytes.to(torch.float64).sum() - if dist.is_available() and dist.is_initialized(): - for t in (loss_sum, token_count, byte_count): - dist.all_reduce(t, op=dist.ReduceOp.SUM) - val_loss = loss_sum / token_count - bpb = (val_loss.item() / math.log(2.0)) * (token_count.item() / byte_count.item()) - model.train() - return float(val_loss.item()), float(bpb) - -# --------------------------------------------------------------------------- -# Temperature scaling -# --------------------------------------------------------------------------- -def find_temp(args, base_model, rank, world_size, device, grad_accum_steps, - calibration_tokens, base_bytes_lut, has_leading_space_lut, - is_boundary_token_lut): - best_t, best_loss = 1.0, float("inf") - for t in [0.90, 0.95, 1.00, 1.05, 1.10]: - loss, _ = eval_val(args, base_model, rank, world_size, device, grad_accum_steps, - calibration_tokens, base_bytes_lut, has_leading_space_lut, - is_boundary_token_lut, temperature=t) - if loss < best_loss: - best_loss = loss - best_t = t - return best_t - -# --------------------------------------------------------------------------- -# Training -# --------------------------------------------------------------------------- -def main() -> None: - args = Hyperparameters() - code = Path(__file__).read_text(encoding="utf-8") - if args.matrix_optimizer != "adamw": - global ns_orth - ns_orth = torch.compile(ns_orth) - distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ - rank = int(os.environ.get("RANK", "0")) - world_size = int(os.environ.get("WORLD_SIZE", "1")) - local_rank = int(os.environ.get("LOCAL_RANK", "0")) - grad_accum_steps = max(1, 8 // world_size) - grad_scale = 1.0 / grad_accum_steps - if not torch.cuda.is_available(): - raise RuntimeError("CUDA is required") - device = torch.device("cuda", local_rank) - torch.cuda.set_device(device) - if distributed: - dist.init_process_group(backend="nccl", device_id=device) - dist.barrier() - master_process = rank == 0 - torch.backends.cuda.matmul.allow_tf32 = True - torch.backends.cudnn.allow_tf32 = True - os.makedirs("logs/cuda/", exist_ok=True) - logfile = f"logs/cuda/{args.run_id}.txt" if master_process else None - if master_process: - print(logfile) - def log0(msg: str, console: bool = True) -> None: - if not master_process: - return - if console: - print(msg) - if logfile: - with open(logfile, "a", encoding="utf-8") as f: - print(msg, file=f) - log0(code, console=False) - log0("=" * 100, console=False) - log0(f"Python {sys.version}", console=False) - log0(f"PyTorch {torch.__version__}", console=False) - random.seed(args.seed) - np.random.seed(args.seed) - torch.manual_seed(args.seed) - torch.cuda.manual_seed_all(args.seed) - sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) - val_tokens = ld_val(args.val_files, args.train_seq_len) - base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_luts( - sp, args.vocab_size, device) - - # --- Model --- - base_model = GPT( - vocab_size=args.vocab_size, num_layers=args.num_layers, model_dim=args.model_dim, - num_heads=args.num_heads, num_kv_heads=args.num_kv_heads, mlp_mult=args.mlp_mult, - tie_embeddings=args.tie_embeddings, tied_embed_init_std=args.tied_embed_init_std, - logit_softcap=args.logit_softcap, rope_base=args.rope_base, qk_gain_init=args.qk_gain_init, - group_size=args.bitnet_group_size, activation=args.activation_type, mtp_heads_count=args.mtp_heads_count, - embed_dim=args.embed_dim, attn_proj_type=args.attn_proj_type, logit_head_type=args.logit_head_type, - tversky_num_features=args.tversky_num_features, tversky_feature_pools=args.tversky_feature_pools, - training_depth_recurrence=args.training_depth_recurrence, fp_storage=args.fp_storage, - bigram_hash=args.bigram_hash, softcap_type=args.softcap_type, no_cache=(args.compile_mode == "reduce-overhead"), - smear=args.smear, rope_type=args.rope_type, yarn_max_len=args.yarn_max_len, train_seq_len=args.train_seq_len, - tversky_membership=args.tversky_membership, diff_attn=args.diff_attn, - refiner=args.refiner, refiner_kernel=args.refiner_kernel, mlp_groups=args.mlp_groups, - ).to(device).bfloat16() - for module in base_model.modules(): - if isinstance(module, nn.Linear): - module.float() - restore_low_dim_params_to_fp32(base_model) - if base_model.lm_head is not None and (args.tie_embeddings or args.logit_head_type == "tversky"): - base_model.lm_head.weight.requires_grad_(False) - torch._dynamo.config.optimize_ddp = False - compiled_model = torch.compile(base_model, mode=args.compile_mode if args.compile_mode != "default" else None) - use_find_unused = args.untie_at_fraction > 0 or args.mtp_heads_count > 0 or not args.tie_embeddings - model = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False, - find_unused_parameters=use_find_unused, - static_graph=not use_find_unused, - gradient_as_bucket_view=True) if distributed else compiled_model - - # --- Optimizers --- - _excl = {"tok_emb.weight", "lm_head.weight", "lm_head_correction"} - all_other_params = [(n, p) for n, p in base_model.named_parameters() - if not any(eh in n for eh in _excl)] - matrix_params = [p for n, p in all_other_params - if p.ndim == 2 and not any(pat in n for pat in CTP)] - scalar_params = [p for n, p in all_other_params - if p.ndim < 2 or any(pat in n for pat in CTP)] - token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr - opt_tok = torch.optim.Adam( - [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], - betas=(args.beta1, args.beta2), eps=args.adam_eps, fused=True) - if args.matrix_optimizer == "adamw": - opt_muon = torch.optim.AdamW( - [{"params": matrix_params, "lr": args.adam_lr, "base_lr": args.adam_lr}], - betas=(args.beta1, args.beta2), eps=args.adam_eps, weight_decay=args.adam_wd, fused=True) - else: - opt_muon = Muon(matrix_params, lr=args.matrix_lr, momentum=args.muon_momentum, - backend_steps=args.muon_backend_steps, wd=args.muon_wd) - for g in opt_muon.param_groups: - g["base_lr"] = args.matrix_lr - opt_scalar = torch.optim.Adam( - [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], - betas=(args.beta1, args.beta2), eps=args.adam_eps, fused=True) - opt_head = torch.optim.Adam( - [{"params": [base_model.lm_head.weight], "lr": 0.0, "base_lr": 0.0}], - betas=(args.beta1, args.beta2), eps=args.adam_eps, fused=True) - optimizers = [opt for opt in [opt_tok, opt_muon, opt_scalar, opt_head] if opt is not None] - if base_model.lm_head_correction is not None: - opt_corr = torch.optim.Adam( - [{"params": [base_model.lm_head_correction], - "lr": args.corr_weight_lr, "base_lr": args.corr_weight_lr}], - betas=(args.beta1, args.beta2), eps=args.adam_eps, fused=True) - optimizers.append(opt_corr) - - # --- Log all hyperparameters --- - log0("--- Hyperparameters ---", console=False) - log0(" ".join(f"{a}={getattr(args,a)}" for a in sorted(dir(args)) if not a.startswith("_") and a not in ("train_files","val_files") and not callable(getattr(args,a))), console=False) - n_params = sum(p.numel() for p in base_model.parameters()) - log0(f"params:{n_params} L:{args.num_layers} d:{args.model_dim} h:{args.num_heads} kv:{args.num_kv_heads} ws:{world_size} ga:{grad_accum_steps} s:{args.seed}") - # --- Data loader & helpers --- - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - def zero_grad_all(): - for opt in optimizers: - opt.zero_grad(set_to_none=True) - max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None - def lr_mul(step: int, elapsed_ms: float): - if args.warmdown_fraction <= 0: - return 1.0 - if max_wallclock_ms is None: - warmdown_start = int(args.iterations * (1.0 - args.warmdown_fraction)) - return max((args.iterations - step) / max(args.iterations * args.warmdown_fraction, 1), 0.0) if step >= warmdown_start else 1.0 - warmdown_ms = max_wallclock_ms * args.warmdown_fraction - remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) - return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 - _seq_switched = False - _batch_switched = False - active_seq_len = args.seq_len_start if args.seq_len_start > 0 else args.train_seq_len - active_batch_tokens = args.batch_tokens_start if args.batch_tokens_start > 0 else args.train_batch_tokens - # --- Compiler warmup --- - if args.warmup_steps > 0: - _ms = {n: t.detach().cpu().clone() for n, t in base_model.state_dict().items()} - _os = [copy.deepcopy(o.state_dict()) for o in optimizers] - model.train() - for ws in range(args.warmup_steps): - zero_grad_all() - for mi in range(grad_accum_steps): - if distributed: model.require_backward_grad_sync = mi == grad_accum_steps - 1 - x, y = train_loader.next_batch(active_batch_tokens, active_seq_len, grad_accum_steps) - torch.compiler.cudagraph_mark_step_begin() - with torch.autocast(device_type="cuda", dtype=torch.bfloat16): loss = model(x, y) - (loss * grad_scale).backward() - for o in optimizers: o.step() - zero_grad_all() - log0(f"warmup:{ws+1}/{args.warmup_steps}") - base_model.load_state_dict(_ms, strict=True) - for o, s in zip(optimizers, _os): o.load_state_dict(s) - zero_grad_all() - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - # --- EMA model --- - ema_model = None - _ema_started = False - _ema_steps = 0 - if args.ema: - ema_model = copy.deepcopy(base_model) - for p in ema_model.parameters(): - p.requires_grad_(False) - - # --- Main training loop --- - training_time_ms = 0.0 - stop_after_step: int | None = None - _untied = False - train_loss = torch.zeros((), device=device) - torch.cuda.synchronize() - t0 = time.perf_counter() - step = 0 - while True: - last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) - if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): - torch.cuda.synchronize() - training_time_ms += 1000.0 * (time.perf_counter() - t0) - val_loss, val_bpb = eval_val(args, model, rank, world_size, device, grad_accum_steps, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut) - log0(f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " - f"train_time:{training_time_ms:.0f}ms") - torch.cuda.synchronize() - t0 = time.perf_counter() - if last_step: - if stop_after_step is not None and step < args.iterations: - log0(f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms step:{step}/{args.iterations}") - break - elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - scale = lr_mul(step, elapsed_ms) - # Sequence length schedule - if args.seq_len_start > 0 and not _seq_switched: - if max_wallclock_ms is not None: - should_switch_seq = elapsed_ms >= args.seq_schedule_fraction * max_wallclock_ms - else: - should_switch_seq = step >= int(args.iterations * args.seq_schedule_fraction) - if should_switch_seq: - active_seq_len = args.train_seq_len - _seq_switched = True - torch._dynamo.reset() - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - log0(f"step:{step} seq_len_switch:{args.seq_len_start}->{active_seq_len}") - - # Batch size schedule - if args.batch_tokens_start > 0 and not _batch_switched: - if max_wallclock_ms is not None: - should_switch_batch = elapsed_ms >= args.batch_schedule_fraction * max_wallclock_ms - else: - should_switch_batch = step >= int(args.iterations * args.batch_schedule_fraction) - if should_switch_batch: - active_batch_tokens = args.train_batch_tokens - _batch_switched = True - log0(f"step:{step} batch_switch:{args.batch_tokens_start}->{active_batch_tokens}") - zero_grad_all() - train_loss.zero_() - for micro in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro == grad_accum_steps - 1 - x, y = train_loader.next_batch(active_batch_tokens, active_seq_len, grad_accum_steps) - torch.compiler.cudagraph_mark_step_begin() - with torch.autocast(device_type="cuda", dtype=torch.bfloat16): - loss = model(x, y) - train_loss.add_(loss.detach()) - (loss * grad_scale).backward() - train_loss /= grad_accum_steps - - # Untie lm_head at configured fraction of training - if args.untie_at_fraction > 0: - if max_wallclock_ms is not None: - should_untie = not _untied and elapsed_ms >= args.untie_at_fraction * max_wallclock_ms - else: - should_untie = not _untied and step >= int(args.iterations * args.untie_at_fraction) - if should_untie and base_model.tie_embeddings: - with torch.no_grad(): - base_weight = base_model.tok_emb.weight.float() - if base_model.lm_head_correction is not None: - base_weight = base_weight + base_model.lm_head_correction.float() - if base_model.embed_proj_rev is not None: - full_weight = base_weight @ base_model.embed_proj_rev.weight.float() - else: - full_weight = base_weight - base_model.lm_head.weight.copy_(full_weight) - base_model.tie_embeddings = False - base_model.lm_head.weight.requires_grad_(True) - for g in opt_head.param_groups: - g["lr"] = g["base_lr"] = args.head_lr - _untied = True - torch._dynamo.reset() - log0(f"step:{step} untied lm_head (head_lr={args.head_lr})") - - # Muon momentum warmup - if args.matrix_optimizer != "adam": - frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 - for g in opt_muon.param_groups: - g["momentum"] = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum - - # LR scheduling - for opt in optimizers: - for g in opt.param_groups: - g["lr"] = g["base_lr"] * scale - opt.step() - zero_grad_all() - # EMA update - if ema_model is not None: - if not _ema_started: - if max_wallclock_ms is not None: - should_start_ema = elapsed_ms >= args.ema_start_fraction * max_wallclock_ms - else: - should_start_ema = step >= int(args.iterations * args.ema_start_fraction) - if should_start_ema: - _ema_started = True - _ema_steps = 0 - with torch.no_grad(): - for ep, bp in zip(ema_model.parameters(), base_model.parameters()): - ep.data.copy_(bp.data) - log0(f"step:{step} ema_started") - if _ema_started: - _ema_steps += 1 - decay = min(args.ema_decay, (1.0 + _ema_steps) / (10.0 + _ema_steps)) - with torch.no_grad(): - for ep, bp in zip(ema_model.parameters(), base_model.parameters()): - ep.data.mul_(decay).add_(bp.data, alpha=1.0 - decay) - step += 1 - approx_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - - if args.train_log_every > 0 and step % args.train_log_every == 0: - log0(f"step:{step}/{args.iterations} loss:{train_loss.item():.4f} t:{approx_ms:.0f}ms avg:{approx_ms/step:.1f}ms") - if args.churn_log_every > 0 and step % args.churn_log_every == 0: - log0(f"step:{step} churn:{churn_fn(base_model, args.bitnet_group_size):.4f}") - # Wallclock cap sync - if stop_after_step is None and max_wallclock_ms is not None and step % 10 == 0: - reached_cap = approx_ms >= max_wallclock_ms - if distributed: - cap_t = torch.tensor(int(reached_cap), device=device) - dist.all_reduce(cap_t, op=dist.ReduceOp.MAX) - reached_cap = bool(cap_t.item()) - if reached_cap: - stop_after_step = step - - # --- Serialization --- - if master_process: - sd = (ema_model if ema_model is not None and _ema_started else base_model).state_dict() - if base_model.tie_embeddings or args.logit_head_type == "tversky": - sd.pop("lm_head.weight", None) - - # Compute binary overrides for no-features Tversky prototypes - binary_overrides = set() - for n, m in base_model.named_modules(): - if isinstance(m, TverskyProjection) and m.no_features_mode: - binary_overrides.add(n + ".prototypes") - binary_overrides = binary_overrides or None - q_obj, q_stats = q_sd(sd, group_size=args.bitnet_group_size, fp_storage=args.fp_storage, binary_override_names=binary_overrides) - buf = io.BytesIO() - torch.save(q_obj, buf) - final_blob = lzma.compress(buf.getvalue(), preset=9) - with open("final_model.binary.ptz", "wb") as f: - f.write(final_blob) - artifact_bytes = len(final_blob) - code_bytes = len(code.encode("utf-8")) - total = artifact_bytes + code_bytes - log0(f"artifact:{artifact_bytes/1e6:.2f}MB binary:{q_stats['binary_params']}({q_stats['binary_bytes']}B) fp:{q_stats['fp_params']}({q_stats['fp_bytes']}B) code:{code_bytes}") - log0(f"budget:{total}/{16000000} ({total/1e6:.2f}/{16.00:.2f}MB) {'FITS' if total <= 16000000 else 'OVER'}") - if args.eval_depth_recurrence > 0: - base_model.training_depth_recurrence = args.eval_depth_recurrence - log0(f"eval_depth_recurrence:{args.eval_depth_recurrence}") - - # --- All ranks load roundtrip weights and evaluate --- - if distributed: - dist.barrier() - with open("final_model.binary.ptz", "rb") as f: - loaded = torch.load(io.BytesIO(lzma.decompress(f.read())), map_location="cpu", weights_only=False) - base_model.load_state_dict(deq_sd(loaded), strict=False) - if ema_model is not None: - ema_model.load_state_dict(deq_sd(loaded), strict=False) - torch._dynamo.reset() - q_val_loss, q_val_bpb = eval_val(args, model, rank, world_size, device, grad_accum_steps, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut) - log0(f"final_binary_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f}") - - opt_temp = 1.0 - if args.temp_scaling: - torch.cuda.synchronize() - t_temp = time.perf_counter() - calibration_tokens = train_loader.stream.take(65536).to(device) - opt_temp = find_temp(args, base_model, rank, world_size, device, grad_accum_steps, - calibration_tokens, base_bytes_lut, has_leading_space_lut, - is_boundary_token_lut) - torch.cuda.synchronize() - temp_time_ms = 1000.0 * (time.perf_counter() - t_temp) - log0(f"temp_scaling optimal_T:{opt_temp:.2f} eval_time:{temp_time_ms:.0f}ms") - - if args.sliding_eval: - torch.cuda.synchronize() - t_sliding = time.perf_counter() - sw_loss, sw_bpb = eval_val_sliding(args, base_model, rank, world_size, device, grad_accum_steps, - val_tokens, base_bytes_lut, has_leading_space_lut, - is_boundary_token_lut, stride=args.sliding_eval_stride, - temperature=opt_temp) - torch.cuda.synchronize() - sliding_time_ms = 1000.0 * (time.perf_counter() - t_sliding) - log0(f"final_sliding val_loss:{sw_loss:.4f} val_bpb:{sw_bpb:.4f} " - f"(stride={args.sliding_eval_stride}, T={opt_temp:.2f}) eval_time:{sliding_time_ms:.0f}ms") - - if distributed: - dist.destroy_process_group() - - -if __name__ == "__main__": - main() - -==================================================================================================== -Python 3.13.12 | packaged by Anaconda, Inc. | (main, Feb 24 2026, 16:13:31) [GCC 14.3.0] -PyTorch 2.10.0+cu128 ---- Hyperparameters --- -activation_type=relu2 adam_eps=1e-08 adam_lr=0.05 adam_wd=0.05 attn_proj_type=standard batch_schedule_fraction=0.33 batch_tokens_start=0 beta1=0.9 beta2=0.95 bigram_hash=False bitnet_group_size=128 churn_log_every=1000 compile_mode=default corr_weight_lr=0.02 data_path=./data/datasets/fineweb10B_sp8192 diff_attn=False ema=False ema_decay=0.995 ema_start_fraction=0.5 embed_dim=254 embed_lr=0.6 eval_depth_recurrence=0 fp_storage=True grad_clip_norm=0.0 head_lr=0.02 iterations=50000 logit_head_type=standard logit_softcap=10.0 matrix_lr=0.04 matrix_optimizer=muon max_wallclock_seconds=0.0 mlp_groups=0 mlp_mult=4 model_dim=768 mtp_heads_count=0 muon_backend_steps=3 muon_momentum=0.95 muon_momentum_warmup_start=0.85 muon_momentum_warmup_steps=500 muon_wd=0.0 num_heads=8 num_kv_heads=4 num_layers=15 qk_gain_init=2.25 refiner=False refiner_kernel=3 rope_base=5000.0 rope_type=yarn run_id=pushing_run_binary_1 scalar_lr=0.02 seed=42 seq_len_start=0 seq_schedule_fraction=0.0 sliding_batch_size=256 sliding_eval=True sliding_eval_stride=16 smear=True softcap_type=poly temp_scaling=True tie_embeddings=1 tied_embed_init_std=0.005 tied_embed_lr=0.02 tokenizer_path=./data/tokenizers/fineweb_8192_bpe.model train_batch_tokens=524288 train_log_every=500 train_seq_len=1024 training_depth_recurrence=0 tversky_feature_pools=0 tversky_membership=sigmoid tversky_num_features=0 untie_at_fraction=0.0 val_batch_size=524288 val_loss_every=0 vocab_size=8192 warmdown_fraction=0.2 warmup_steps=5 yarn_max_len=2048 -params:106154616 L:15 d:768 h:8 kv:4 ws:8 ga:1 s:42 -warmup:1/5 -warmup:2/5 -warmup:3/5 -warmup:4/5 -warmup:5/5 -step:500/50000 loss:3.6805 t:77540ms avg:155.1ms -step:1000/50000 loss:3.3485 t:155075ms avg:155.1ms -step:1000 churn:0.0000 -step:1500/50000 loss:3.3714 t:232880ms avg:155.3ms -step:2000/50000 loss:3.3187 t:310516ms avg:155.3ms -step:2000 churn:0.1984 -step:2500/50000 loss:3.2573 t:388417ms avg:155.4ms -step:3000/50000 loss:3.1844 t:465980ms avg:155.3ms -step:3000 churn:0.1457 -step:3500/50000 loss:3.3885 t:543772ms avg:155.4ms -step:4000/50000 loss:3.3496 t:621381ms avg:155.3ms -step:4000 churn:0.1252 -step:4500/50000 loss:3.3527 t:699211ms avg:155.4ms -step:5000/50000 loss:3.2171 t:776797ms avg:155.4ms -step:5000 churn:0.1151 -step:5500/50000 loss:3.0536 t:854512ms avg:155.4ms -step:6000/50000 loss:3.1355 t:932007ms avg:155.3ms -step:6000 churn:0.1087 -step:6500/50000 loss:3.1928 t:1009731ms avg:155.3ms -step:7000/50000 loss:3.2378 t:1087253ms avg:155.3ms -step:7000 churn:0.1041 -step:7500/50000 loss:3.1585 t:1164994ms avg:155.3ms -step:8000/50000 loss:3.1436 t:1242513ms avg:155.3ms -step:8000 churn:0.1009 -step:8500/50000 loss:3.0573 t:1320248ms avg:155.3ms -step:9000/50000 loss:3.0523 t:1397837ms avg:155.3ms -step:9000 churn:0.0982 -step:9500/50000 loss:3.3082 t:1475596ms avg:155.3ms -step:10000/50000 loss:3.3521 t:1553112ms avg:155.3ms -step:10000 churn:0.0964 -step:10500/50000 loss:3.1877 t:1630835ms avg:155.3ms -step:11000/50000 loss:2.7388 t:1708388ms avg:155.3ms -step:11000 churn:0.0948 -step:11500/50000 loss:3.2052 t:1786100ms avg:155.3ms -step:12000/50000 loss:3.2859 t:1863613ms avg:155.3ms -step:12000 churn:0.0935 -step:12500/50000 loss:3.0326 t:1941282ms avg:155.3ms -step:13000/50000 loss:3.2551 t:2018764ms avg:155.3ms -step:13000 churn:0.0924 -step:13500/50000 loss:3.1339 t:2096463ms avg:155.3ms -step:14000/50000 loss:3.0606 t:2173965ms avg:155.3ms -step:14000 churn:0.0915 -step:14500/50000 loss:3.1752 t:2251634ms avg:155.3ms -step:15000/50000 loss:3.0206 t:2329140ms avg:155.3ms -step:15000 churn:0.0907 -step:15500/50000 loss:3.2017 t:2406858ms avg:155.3ms -step:16000/50000 loss:3.1705 t:2484387ms avg:155.3ms -step:16000 churn:0.0900 -step:16500/50000 loss:3.0774 t:2562139ms avg:155.3ms -step:17000/50000 loss:3.2494 t:2639671ms avg:155.3ms -step:17000 churn:0.0894 -step:17500/50000 loss:3.2024 t:2717393ms avg:155.3ms -step:18000/50000 loss:3.1627 t:2794977ms avg:155.3ms -step:18000 churn:0.0888 -step:18500/50000 loss:3.1733 t:2872744ms avg:155.3ms -step:19000/50000 loss:3.2055 t:2950389ms avg:155.3ms -step:19000 churn:0.0885 -step:19500/50000 loss:3.2026 t:3028137ms avg:155.3ms -step:20000/50000 loss:2.9144 t:3105704ms avg:155.3ms -step:20000 churn:0.0880 -step:20500/50000 loss:3.2154 t:3183466ms avg:155.3ms -step:21000/50000 loss:3.1016 t:3261044ms avg:155.3ms -step:21000 churn:0.0878 -step:21500/50000 loss:3.2065 t:3338791ms avg:155.3ms -step:22000/50000 loss:3.1611 t:3416326ms avg:155.3ms -step:22000 churn:0.0875 -step:22500/50000 loss:3.2578 t:3494047ms avg:155.3ms -step:23000/50000 loss:3.0689 t:3571604ms avg:155.3ms -step:23000 churn:0.0871 -step:23500/50000 loss:3.2047 t:3649319ms avg:155.3ms -step:24000/50000 loss:3.0689 t:3726856ms avg:155.3ms -step:24000 churn:0.0868 -step:24500/50000 loss:3.2355 t:3804562ms avg:155.3ms -step:25000/50000 loss:3.2085 t:3882065ms avg:155.3ms -step:25000 churn:0.0865 -step:25500/50000 loss:3.2235 t:3959778ms avg:155.3ms -step:26000/50000 loss:3.2484 t:4037303ms avg:155.3ms -step:26000 churn:0.0863 -step:26500/50000 loss:3.2419 t:4114994ms avg:155.3ms -step:27000/50000 loss:3.1215 t:4192502ms avg:155.3ms -step:27000 churn:0.0861 -step:27500/50000 loss:3.1305 t:4270187ms avg:155.3ms -step:28000/50000 loss:3.2679 t:4347697ms avg:155.3ms -step:28000 churn:0.0858 -step:28500/50000 loss:3.1768 t:4425383ms avg:155.3ms -step:29000/50000 loss:3.1519 t:4502876ms avg:155.3ms -step:29000 churn:0.0857 -step:29500/50000 loss:3.1614 t:4580510ms avg:155.3ms -step:30000/50000 loss:3.2341 t:4658001ms avg:155.3ms -step:30000 churn:0.0855 -step:30500/50000 loss:3.1673 t:4735648ms avg:155.3ms -step:31000/50000 loss:3.0884 t:4813158ms avg:155.3ms -step:31000 churn:0.0854 -step:31500/50000 loss:3.0147 t:4890803ms avg:155.3ms -step:32000/50000 loss:3.1793 t:4968281ms avg:155.3ms -step:32000 churn:0.0853 -step:32500/50000 loss:3.1626 t:5045990ms avg:155.3ms -step:33000/50000 loss:3.3086 t:5123506ms avg:155.3ms -step:33000 churn:0.0851 -step:33500/50000 loss:2.9607 t:5201190ms avg:155.3ms -step:34000/50000 loss:3.1584 t:5278703ms avg:155.3ms -step:34000 churn:0.0850 -step:34500/50000 loss:3.2311 t:5356349ms avg:155.3ms -step:35000/50000 loss:3.0574 t:5433881ms avg:155.3ms -step:35000 churn:0.0848 -step:35500/50000 loss:3.1880 t:5511613ms avg:155.3ms -step:36000/50000 loss:3.0474 t:5589157ms avg:155.3ms -step:36000 churn:0.0848 -step:36500/50000 loss:3.1925 t:5666894ms avg:155.3ms -step:37000/50000 loss:3.0935 t:5744417ms avg:155.3ms -step:37000 churn:0.0847 -step:37500/50000 loss:3.1454 t:5822114ms avg:155.3ms -step:38000/50000 loss:2.9914 t:5899675ms avg:155.3ms -step:38000 churn:0.0846 -step:38500/50000 loss:3.1192 t:5977449ms avg:155.3ms -step:39000/50000 loss:3.1994 t:6055002ms avg:155.3ms -step:39000 churn:0.0845 -step:39500/50000 loss:3.1586 t:6132704ms avg:155.3ms -step:40000/50000 loss:3.1402 t:6210265ms avg:155.3ms -step:40000 churn:0.0845 -step:40500/50000 loss:3.2176 t:6287989ms avg:155.3ms -step:41000/50000 loss:3.1743 t:6365543ms avg:155.3ms -step:41000 churn:0.0831 -step:41500/50000 loss:3.1811 t:6443269ms avg:155.3ms -step:42000/50000 loss:3.0934 t:6520796ms avg:155.3ms -step:42000 churn:0.0810 -step:42500/50000 loss:3.0804 t:6598538ms avg:155.3ms -step:43000/50000 loss:3.1341 t:6676105ms avg:155.3ms -step:43000 churn:0.0788 -step:43500/50000 loss:3.0942 t:6753855ms avg:155.3ms -step:44000/50000 loss:3.0144 t:6831414ms avg:155.3ms -step:44000 churn:0.0769 -step:44500/50000 loss:2.8582 t:6909098ms avg:155.3ms -step:45000/50000 loss:3.3925 t:6986654ms avg:155.3ms -step:45000 churn:0.0745 -step:45500/50000 loss:3.0488 t:7064379ms avg:155.3ms -step:46000/50000 loss:2.9942 t:7141950ms avg:155.3ms -step:46000 churn:0.0721 -step:46500/50000 loss:3.0737 t:7219653ms avg:155.3ms -step:47000/50000 loss:3.1052 t:7297260ms avg:155.3ms -step:47000 churn:0.0688 -step:47500/50000 loss:3.1031 t:7375013ms avg:155.3ms -step:48000/50000 loss:3.0978 t:7452604ms avg:155.3ms -step:48000 churn:0.0648 -step:48500/50000 loss:3.0704 t:7530338ms avg:155.3ms -step:49000/50000 loss:3.0631 t:7607877ms avg:155.3ms -step:49000 churn:0.0586 -step:49500/50000 loss:2.9547 t:7685573ms avg:155.3ms -step:50000/50000 loss:3.0994 t:7763153ms avg:155.3ms -step:50000 churn:0.0453 -step:50000/50000 val_loss:2.9692 val_bpb:1.1497 train_time:7763355ms -artifact:15.60MB binary:97320960(13685760B) fp:2542200(2585072B) code:70399 -budget:15670651/16000000 (15.67/16.00MB) FITS -final_binary_roundtrip val_loss:2.9743 val_bpb:1.1516 -temp_scaling optimal_T:0.90 eval_time:245ms -final_sliding val_loss:2.9027 val_bpb:1.1239 (stride=16, T=0.90) eval_time:768782ms diff --git a/records/track_non_record_16mb/2026-03-24_106M_Binary_Asymmetric_UNet_FP8_15L_8192BPE_YaRN_NeoMuon_Smear/fineweb_8192_bpe.model b/records/track_non_record_16mb/2026-03-24_106M_Binary_Asymmetric_UNet_FP8_15L_8192BPE_YaRN_NeoMuon_Smear/fineweb_8192_bpe.model deleted file mode 100644 index 6574784f5f..0000000000 Binary files a/records/track_non_record_16mb/2026-03-24_106M_Binary_Asymmetric_UNet_FP8_15L_8192BPE_YaRN_NeoMuon_Smear/fineweb_8192_bpe.model and /dev/null differ diff --git a/records/track_non_record_16mb/2026-03-24_106M_Binary_Asymmetric_UNet_FP8_15L_8192BPE_YaRN_NeoMuon_Smear/fineweb_8192_bpe.vocab b/records/track_non_record_16mb/2026-03-24_106M_Binary_Asymmetric_UNet_FP8_15L_8192BPE_YaRN_NeoMuon_Smear/fineweb_8192_bpe.vocab deleted file mode 100644 index 6e194bf03c..0000000000 --- a/records/track_non_record_16mb/2026-03-24_106M_Binary_Asymmetric_UNet_FP8_15L_8192BPE_YaRN_NeoMuon_Smear/fineweb_8192_bpe.vocab +++ /dev/null @@ -1,8192 +0,0 @@ - 0 - 0 - 0 - 0 -<0x00> 0 -<0x01> 0 -<0x02> 0 -<0x03> 0 -<0x04> 0 -<0x05> 0 -<0x06> 0 -<0x07> 0 -<0x08> 0 -<0x09> 0 -<0x0A> 0 -<0x0B> 0 -<0x0C> 0 -<0x0D> 0 -<0x0E> 0 -<0x0F> 0 -<0x10> 0 -<0x11> 0 -<0x12> 0 -<0x13> 0 -<0x14> 0 -<0x15> 0 -<0x16> 0 -<0x17> 0 -<0x18> 0 -<0x19> 0 -<0x1A> 0 -<0x1B> 0 -<0x1C> 0 -<0x1D> 0 -<0x1E> 0 -<0x1F> 0 -<0x20> 0 -<0x21> 0 -<0x22> 0 -<0x23> 0 -<0x24> 0 -<0x25> 0 -<0x26> 0 -<0x27> 0 -<0x28> 0 -<0x29> 0 -<0x2A> 0 -<0x2B> 0 -<0x2C> 0 -<0x2D> 0 -<0x2E> 0 -<0x2F> 0 -<0x30> 0 -<0x31> 0 -<0x32> 0 -<0x33> 0 -<0x34> 0 -<0x35> 0 -<0x36> 0 -<0x37> 0 -<0x38> 0 -<0x39> 0 -<0x3A> 0 -<0x3B> 0 -<0x3C> 0 -<0x3D> 0 -<0x3E> 0 -<0x3F> 0 -<0x40> 0 -<0x41> 0 -<0x42> 0 -<0x43> 0 -<0x44> 0 -<0x45> 0 -<0x46> 0 -<0x47> 0 -<0x48> 0 -<0x49> 0 -<0x4A> 0 -<0x4B> 0 -<0x4C> 0 -<0x4D> 0 -<0x4E> 0 -<0x4F> 0 -<0x50> 0 -<0x51> 0 -<0x52> 0 -<0x53> 0 -<0x54> 0 -<0x55> 0 -<0x56> 0 -<0x57> 0 -<0x58> 0 -<0x59> 0 -<0x5A> 0 -<0x5B> 0 -<0x5C> 0 -<0x5D> 0 -<0x5E> 0 -<0x5F> 0 -<0x60> 0 -<0x61> 0 -<0x62> 0 -<0x63> 0 -<0x64> 0 -<0x65> 0 -<0x66> 0 -<0x67> 0 -<0x68> 0 -<0x69> 0 -<0x6A> 0 -<0x6B> 0 -<0x6C> 0 -<0x6D> 0 -<0x6E> 0 -<0x6F> 0 -<0x70> 0 -<0x71> 0 -<0x72> 0 -<0x73> 0 -<0x74> 0 -<0x75> 0 -<0x76> 0 -<0x77> 0 -<0x78> 0 -<0x79> 0 -<0x7A> 0 -<0x7B> 0 -<0x7C> 0 -<0x7D> 0 -<0x7E> 0 -<0x7F> 0 -<0x80> 0 -<0x81> 0 -<0x82> 0 -<0x83> 0 -<0x84> 0 -<0x85> 0 -<0x86> 0 -<0x87> 0 -<0x88> 0 -<0x89> 0 -<0x8A> 0 -<0x8B> 0 -<0x8C> 0 -<0x8D> 0 -<0x8E> 0 -<0x8F> 0 -<0x90> 0 -<0x91> 0 -<0x92> 0 -<0x93> 0 -<0x94> 0 -<0x95> 0 -<0x96> 0 -<0x97> 0 -<0x98> 0 -<0x99> 0 -<0x9A> 0 -<0x9B> 0 -<0x9C> 0 -<0x9D> 0 -<0x9E> 0 -<0x9F> 0 -<0xA0> 0 -<0xA1> 0 -<0xA2> 0 -<0xA3> 0 -<0xA4> 0 -<0xA5> 0 -<0xA6> 0 -<0xA7> 0 -<0xA8> 0 -<0xA9> 0 -<0xAA> 0 -<0xAB> 0 -<0xAC> 0 -<0xAD> 0 -<0xAE> 0 -<0xAF> 0 -<0xB0> 0 -<0xB1> 0 -<0xB2> 0 -<0xB3> 0 -<0xB4> 0 -<0xB5> 0 -<0xB6> 0 -<0xB7> 0 -<0xB8> 0 -<0xB9> 0 -<0xBA> 0 -<0xBB> 0 -<0xBC> 0 -<0xBD> 0 -<0xBE> 0 -<0xBF> 0 -<0xC0> 0 -<0xC1> 0 -<0xC2> 0 -<0xC3> 0 -<0xC4> 0 -<0xC5> 0 -<0xC6> 0 -<0xC7> 0 -<0xC8> 0 -<0xC9> 0 -<0xCA> 0 -<0xCB> 0 -<0xCC> 0 -<0xCD> 0 -<0xCE> 0 -<0xCF> 0 -<0xD0> 0 -<0xD1> 0 -<0xD2> 0 -<0xD3> 0 -<0xD4> 0 -<0xD5> 0 -<0xD6> 0 -<0xD7> 0 -<0xD8> 0 -<0xD9> 0 -<0xDA> 0 -<0xDB> 0 -<0xDC> 0 -<0xDD> 0 -<0xDE> 0 -<0xDF> 0 -<0xE0> 0 -<0xE1> 0 -<0xE2> 0 -<0xE3> 0 -<0xE4> 0 -<0xE5> 0 -<0xE6> 0 -<0xE7> 0 -<0xE8> 0 -<0xE9> 0 -<0xEA> 0 -<0xEB> 0 -<0xEC> 0 -<0xED> 0 -<0xEE> 0 -<0xEF> 0 -<0xF0> 0 -<0xF1> 0 -<0xF2> 0 -<0xF3> 0 -<0xF4> 0 -<0xF5> 0 -<0xF6> 0 -<0xF7> 0 -<0xF8> 0 -<0xF9> 0 -<0xFA> 0 -<0xFB> 0 -<0xFC> 0 -<0xFD> 0 -<0xFE> 0 -<0xFF> 0 -▁t -0 -▁a -1 -in -2 -he -3 -re -4 -on -5 -er -6 -▁the -7 -▁s -8 -▁w -9 -or -10 -at -11 -nd -12 -ou -13 -▁c -14 -it -15 -es -16 -▁f -17 -is -18 -ing -19 -en -20 -▁b -21 -▁p -22 -▁o -23 -an -24 -ed -25 -▁to -26 -al -27 -▁m -28 -ar -29 -▁and -30 -▁in -31 -▁of -32 -▁d -33 -le -34 -ic -35 -as -36 -▁h -37 -om -38 -ion -39 -▁th -40 -il -41 -▁T -42 -▁l -43 -ent -44 -ve -45 -▁I -46 -ro -47 -st -48 -▁y -49 -▁e -50 -▁re -51 -▁n -52 -▁S -53 -▁g -54 -et -55 -ct -56 -▁A -57 -▁C -58 -▁you -59 -ly -60 -ay -61 -id -62 -▁for -63 -▁on -64 -▁is -65 -ot -66 -▁be -67 -ow -68 -ol -69 -am -70 -ac -71 -ig -72 -us -73 -ad -74 -el -75 -▁M -76 -im -77 -ver -78 -ith -79 -ut -80 -▁st -81 -▁P -82 -ation -83 -▁with -84 -ur -85 -▁B -86 -▁that -87 -ir -88 -▁W -89 -ch -90 -▁he -91 -▁it -92 -▁The -93 -ce -94 -ill -95 -ers -96 -un -97 -▁al -98 -▁D -99 -ul -100 -▁an -101 -▁H -102 -▁F -103 -out -104 -ra -105 -ke -106 -▁pro -107 -▁wh -108 -▁as -109 -▁are -110 -se -111 -ter -112 -▁we -113 -▁ha -114 -▁R -115 -oo -116 -if -117 -ge -118 -our -119 -pp -120 -▁at -121 -ate -122 -ess -123 -▁com -124 -▁or -125 -▁con -126 -▁L -127 -her -128 -ore -129 -est -130 -▁fr -131 -ment -132 -igh -133 -▁- -134 -ab -135 -▁N -136 -▁se -137 -▁ne -138 -ld -139 -ort -140 -▁G -141 -▁E -142 -ri -143 -ist -144 -▁( -145 -▁your -146 -op -147 -▁O -148 -▁ex -149 -em -150 -ure -151 -ity -152 -▁r -153 -ant -154 -qu -155 -▁v -156 -▁was -157 -art -158 -ust -159 -▁have -160 -ive -161 -um -162 -▁this -163 -▁from -164 -pe -165 -▁de -166 -oc -167 -▁sh -168 -th -169 -ain -170 -up -171 -ies -172 -▁will -173 -▁by -174 -ight -175 -▁ch -176 -and -177 -os -178 -▁can -179 -ie -180 -nt -181 -all -182 -▁us -183 -ome -184 -▁not -185 -ard -186 -ud -187 -▁le -188 -res -189 -▁J -190 -ast -191 -.. -192 -ost -193 -▁pl -194 -ear -195 -▁ab -196 -ack -197 -▁su -198 -iv -199 -▁wor -200 -gh -201 -▁all -202 -rou -203 -ide -204 -ould -205 -▁j -206 -ell -207 -ial -208 -te -209 -ak -210 -ine -211 -od -212 -ag -213 -are -214 -▁has -215 -ice -216 -▁U -217 -▁Th -218 -▁do -219 -age -220 -▁k -221 -ook -222 -fe -223 -▁ad -224 -▁me -225 -ip -226 -▁In -227 -▁comp -228 -▁but -229 -▁up -230 -▁out -231 -ake -232 -per -233 -red -234 -▁whe -235 -ions -236 -ally -237 -pt -238 -ry -239 -og -240 -one -241 -▁more -242 -ail -243 -able -244 -ind -245 -▁my -246 -ite -247 -▁our -248 -ther -249 -▁en -250 -▁“ -251 -very -252 -▁Y -253 -▁sa -254 -▁so -255 -ich -256 -ime -257 -cc -258 -▁cl -259 -ong -260 -▁their -261 -▁K -262 -ated -263 -ood -264 -ame -265 -orm -266 -▁St -267 -▁they -268 -▁one -269 -▁te -270 -ber -271 -ace -272 -ike -273 -iz -274 -▁about -275 -so -276 -ous -277 -du -278 -ick -279 -ase -280 -ans -281 -▁" -282 -▁V -283 -pl -284 -▁cont -285 -act -286 -ia -287 -▁im -288 -▁work -289 -▁un -290 -▁who -291 -ree -292 -cl -293 -ire -294 -▁fe -295 -ign -296 -▁off -297 -▁his -298 -▁man -299 -ue -300 -ff -301 -ance -302 -▁go -303 -ll -304 -ach -305 -▁year -306 -▁new -307 -▁tr -308 -ays -309 -ne -310 -reat -311 -▁It -312 -ction -313 -ub -314 -ib -315 -ult -316 -▁app -317 -erv -318 -und -319 -▁We -320 -ap -321 -▁Ch -322 -ass -323 -▁qu -324 -ep -325 -▁res -326 -ary -327 -ark -328 -▁sp -329 -▁per -330 -ations -331 -ile -332 -ove -333 -form -334 -▁int -335 -▁get -336 -▁also -337 -▁time -338 -▁which -339 -ount -340 -ven -341 -▁like -342 -own -343 -▁other -344 -ents -345 -▁some -346 -ond -347 -ord -348 -▁any -349 -ings -350 -vel -351 -av -352 -▁been -353 -ical -354 -▁over -355 -▁part -356 -ress -357 -▁This -358 -▁dis -359 -ks -360 -▁He -361 -ors -362 -ence -363 -▁said -364 -▁sc -365 -▁rec -366 -▁ar -367 -ition -368 -▁them -369 -▁ag -370 -▁when -371 -▁pe -372 -ild -373 -port -374 -▁her -375 -ound -376 -ough -377 -▁kn -378 -ose -379 -ob -380 -irst -381 -low -382 -▁just -383 -mer -384 -int -385 -▁ro -386 -ov -387 -ck -388 -ish -389 -▁what -390 -oy -391 -▁pr -392 -ru -393 -▁spe -394 -▁pre -395 -▁there -396 -ens -397 -wn -398 -▁acc -399 -day -400 -▁if -401 -ren -402 -▁than -403 -▁would -404 -▁need -405 -▁Re -406 -▁had -407 -vers -408 -▁its -409 -▁were -410 -ink -411 -fter -412 -ning -413 -▁am -414 -ater -415 -... -416 -▁des -417 -old -418 -itt -419 -clud -420 -ade -421 -rough -422 -▁tw -423 -▁into -424 -lp -425 -ory -426 -use -427 -ople -428 -ool -429 -ang -430 -▁first -431 -▁how -432 -▁bec -433 -▁help -434 -lic -435 -hed -436 -ons -437 -▁add -438 -anc -439 -ft -440 -▁make -441 -amp -442 -gr -443 -▁bl -444 -▁look -445 -▁– -446 -▁Wh -447 -▁prov -448 -▁col -449 -▁includ -450 -▁people -451 -▁comm -452 -▁produ -453 -▁You -454 -▁Ne -455 -ual -456 -▁know -457 -ful -458 -▁she -459 -ian -460 -ments -461 -ates -462 -iew -463 -round -464 -▁em -465 -▁every -466 -▁back -467 -▁only -468 -▁serv -469 -tern -470 -les -471 -ious -472 -▁no -473 -▁may -474 -rent -475 -▁through -476 -▁bu -477 -ict -478 -▁most -479 -cts -480 -ating -481 -▁see -482 -▁want -483 -▁two -484 -▁ph -485 -com -486 -pport -487 -▁As -488 -xt -489 -we -490 -ities -491 -ices -492 -iss -493 -▁use -494 -▁well -495 -ont -496 -▁bet -497 -▁after -498 -▁If -499 -ise -500 -hing -501 -▁ind -502 -ause -503 -▁play -504 -▁Se -505 -ph -506 -▁und -507 -je -508 -▁& -509 -▁co -510 -ife -511 -▁| -512 -ock -513 -ily -514 -▁stud -515 -lect -516 -row -517 -▁act -518 -ting -519 -iness -520 -▁fl -521 -hen -522 -▁years -523 -▁Com -524 -▁Un -525 -urn -526 -ts -527 -▁$ -528 -enc -529 -aw -530 -▁these -531 -▁tra -532 -▁An -533 -fore -534 -▁cons -535 -▁under -536 -als -537 -cial -538 -ange -539 -▁exper -540 -bs -541 -aking -542 -▁ke -543 -oth -544 -▁now -545 -ures -546 -ational -547 -▁very -548 -▁Pro -549 -▁wee -550 -▁bus -551 -▁good -552 -▁gu -553 -ased -554 -vent -555 -▁And -556 -formation -557 -▁many -558 -▁sm -559 -get -560 -▁way -561 -any -562 -▁reg -563 -erson -564 -oint -565 -ific -566 -ward -567 -▁De -568 -ert -569 -ility -570 -▁start -571 -▁fin -572 -▁dif -573 -▁could -574 -rit -575 -lease -576 -▁great -577 -▁imp -578 -ork -579 -uch -580 -▁day -581 -fect -582 -▁rem -583 -▁Sh -584 -yst -585 -▁rel -586 -ience -587 -ible -588 -▁even -589 -▁For -590 -uring -591 -ty -592 -▁show -593 -▁high -594 -oss -595 -ics -596 -▁sec -597 -ull -598 -▁own -599 -nds -600 -velop -601 -▁inv -602 -▁where -603 -▁here -604 -▁don -605 -▁inc -606 -▁down -607 -). -608 -▁ent -609 -ident -610 -hes -611 -olog -612 -cess -613 -▁loc -614 -arch -615 -▁right -616 -ble -617 -▁then -618 -chool -619 -▁home -620 -▁should -621 -▁Al -622 -▁New -623 -elf -624 -alth -625 -The -626 -▁ass -627 -ied -628 -▁br -629 -its -630 -ited -631 -▁find -632 -ath -633 -air -634 -ular -635 -▁read -636 -▁too -637 -▁ac -638 -hip -639 -▁av -640 -▁set -641 -ix -642 -▁car -643 -▁fam -644 -ner -645 -▁information -646 -▁mon -647 -gan -648 -line -649 -▁best -650 -▁last -651 -ys -652 -▁min -653 -gram -654 -▁take -655 -io -656 -▁design -657 -▁Cl -658 -pect -659 -ract -660 -▁long -661 -ason -662 -▁did -663 -▁inst -664 -▁much -665 -omet -666 -▁che -667 -|| -668 -erm -669 -▁Be -670 -▁business -671 -ystem -672 -▁because -673 -▁before -674 -other -675 -ank -676 -▁dec -677 -ues -678 -▁But -679 -▁att -680 -▁ins -681 -▁Fr -682 -.” -683 -▁made -684 -▁team -685 -ative -686 -▁call -687 -▁Le -688 -▁him -689 -pr -690 -▁sur -691 -pen -692 -atch -693 -▁cre -694 -rib -695 -me -696 -▁think -697 -ject -698 -ollow -699 -az -700 -▁again -701 -▁world -702 -way -703 -ax -704 -ale -705 -ug -706 -▁Ad -707 -▁art -708 -▁mem -709 -▁does -710 -alk -711 -), -712 -▁vis -713 -arket -714 -▁being -715 -▁pres -716 -ave -717 -▁develop -718 -▁person -719 -oun -720 -▁requ -721 -arn -722 -ustom -723 -ower -724 -chn -725 -rest -726 -▁inte -727 -arm -728 -ient -729 -▁life -730 -▁those -731 -ener -732 -▁diffe -733 -▁such -734 -ins -735 -▁med -736 -ng -737 -ivers -738 -ince -739 -ouse -740 -▁support -741 -ving -742 -▁while -743 -ash -744 -irect -745 -▁Ar -746 -▁pol -747 -view -748 -land -749 -▁sk -750 -▁provid -751 -ss -752 -unity -753 -ier -754 -▁lead -755 -▁ra -756 -▁Te -757 -▁each -758 -▁around -759 -▁book -760 -der -761 -▁love -762 -▁free -763 -▁used -764 -ced -765 -akes -766 -▁care -767 -▁end -768 -read -769 -▁mod -770 -ailable -771 -▁ser -772 -▁comple -773 -▁post -774 -▁run -775 -▁gr -776 -ather -777 -▁disc -778 -▁sim -779 -ric -780 -▁program -781 -ality -782 -▁ret -783 -▁pub -784 -ces -785 -ional -786 -ages -787 -ually -788 -▁bo -789 -▁cur -790 -▁ed -791 -ines -792 -imes -793 -ton -794 -ives -795 -▁All -796 -▁det -797 -▁really -798 -roup -799 -ple -800 -oad -801 -ars -802 -▁eas -803 -ets -804 -▁On -805 -▁child -806 -▁system -807 -▁There -808 -▁So -809 -▁num -810 -iel -811 -au -812 -ize -813 -▁follow -814 -▁trans -815 -." -816 -led -817 -ene -818 -▁count -819 -▁going -820 -▁found -821 -,” -822 -▁top -823 -ah -824 -▁form -825 -▁char -826 -▁somet -827 -iet -828 -▁three -829 -ittle -830 -▁inter -831 -▁list -832 -▁cour -833 -ames -834 -man -835 -▁still -836 -▁Bl -837 -▁fun -838 -▁How -839 -▁month -840 -▁available -841 -▁place -842 -▁del -843 -ature -844 -▁Pl -845 -▁custom -846 -ute -847 -ness -848 -▁though -849 -▁They -850 -▁feel -851 -ways -852 -▁prof -853 -▁cle -854 -▁both -855 -▁To -856 -▁few -857 -▁sub -858 -cept -859 -▁aut -860 -orn -861 -meric -862 -▁str -863 -▁happ -864 -▁week -865 -▁sign -866 -▁open -867 -▁hand -868 -ved -869 -▁gl -870 -▁pur -871 -▁say -872 -uc -873 -▁report -874 -▁health -875 -▁game -876 -▁adv -877 -att -878 -▁rep -879 -▁market -880 -ital -881 -▁different -882 -oot -883 -ired -884 -orth -885 -▁frie -886 -bers -887 -▁keep -888 -▁same -889 -ering -890 -tt -891 -▁lot -892 -▁Ex -893 -▁She -894 -▁point -895 -▁Col -896 -ween -897 -▁techn -898 -▁family -899 -▁ev -900 -▁i -901 -ology -902 -▁exp -903 -iqu -904 -▁ext -905 -▁school -906 -ining -907 -▁little -908 -▁using -909 -," -910 -▁process -911 -ished -912 -atur -913 -▁company -914 -▁lar -915 -ata -916 -▁including -917 -▁Sc -918 -ross -919 -iving -920 -oh -921 -ants -922 -▁next -923 -▁plan -924 -▁win -925 -▁Americ -926 -ott -927 -▁fil -928 -▁real -929 -▁during -930 -▁Tr -931 -▁between -932 -thing -933 -ized -934 -▁water -935 -ger -936 -▁sol -937 -▁Ph -938 -▁import -939 -▁Q -940 -ody -941 -cent -942 -▁state -943 -▁What -944 -gg -945 -ield -946 -▁things -947 -ik -948 -ves -949 -▁met -950 -arly -951 -els -952 -▁come -953 -aut -954 -ists -955 -be -956 -▁allow -957 -▁big -958 -less -959 -aint -960 -reen -961 -▁mus -962 -▁put -963 -▁contin -964 -uss -965 -▁Or -966 -▁rece -967 -▁experience -968 -ware -969 -▁service -970 -▁opt -971 -▁build -972 -cer -973 -self -974 -▁small -975 -▁dri -976 -▁days -977 -▁appro -978 -ined -979 -iversity -980 -ex -981 -▁organ -982 -▁full -983 -ling -984 -▁since -985 -▁cent -986 -▁always -987 -▁rest -988 -▁try -989 -▁phot -990 -▁better -991 -▁cr -992 -▁sure -993 -▁When -994 -ution -995 -▁pat -996 -▁online -997 -▁pri -998 -▁quest -999 -▁ref -1000 -▁Ind -1001 -▁second -1002 -▁pass -1003 -▁something -1004 -▁var -1005 -illion -1006 -▁bel -1007 -▁interest -1008 -rand -1009 -ever -1010 -over -1011 -▁iss -1012 -▁partic -1013 -▁class -1014 -▁poss -1015 -▁gener -1016 -▁def -1017 -▁group -1018 -▁tri -1019 -▁mov -1020 -ffect -1021 -▁perform -1022 -▁hard -1023 -▁direct -1024 -▁Z -1025 -▁pay -1026 -pping -1027 -ours -1028 -▁With -1029 -▁result -1030 -▁bro -1031 -▁today -1032 -▁head -1033 -▁special -1034 -gy -1035 -▁— -1036 -▁sl -1037 -ps -1038 -▁ty -1039 -▁ve -1040 -ploy -1041 -ER -1042 -▁At -1043 -joy -1044 -▁stand -1045 -ms -1046 -work -1047 -ared -1048 -outh -1049 -▁another -1050 -▁ide -1051 -▁give -1052 -br -1053 -▁ann -1054 -▁Con -1055 -▁wom -1056 -▁provide -1057 -uck -1058 -▁got -1059 -▁cor -1060 -ccess -1061 -ior -1062 -▁Chr -1063 -ote -1064 -oor -1065 -▁Res -1066 -oney -1067 -▁meet -1068 -▁students -1069 -▁resp -1070 -istr -1071 -▁current -1072 -ense -1073 -ately -1074 -▁wr -1075 -▁without -1076 -ision -1077 -▁conf -1078 -▁Our -1079 -ients -1080 -rence -1081 -ok -1082 -ium -1083 -▁old -1084 -▁area -1085 -ley -1086 -ope -1087 -ards -1088 -▁number -1089 -▁four -1090 -▁bre -1091 -▁cost -1092 -aj -1093 -ems -1094 -ered -1095 -▁able -1096 -ically -1097 -▁soc -1098 -▁val -1099 -▁Sp -1100 -▁invest -1101 -▁must -1102 -con -1103 -▁access -1104 -▁services -1105 -▁unt -1106 -raph -1107 -ats -1108 -ird -1109 -▁ask -1110 -▁working -1111 -▁never -1112 -▁US -1113 -▁Cent -1114 -iver -1115 -▁No -1116 -stand -1117 -ww -1118 -▁webs -1119 -▁proble -1120 -▁public -1121 -▁vide -1122 -ission -1123 -▁visit -1124 -▁important -1125 -ann -1126 -▁light -1127 -pped -1128 -▁fact -1129 -let -1130 -▁sal -1131 -▁level -1132 -▁order -1133 -▁fac -1134 -ged -1135 -▁Comm -1136 -▁My -1137 -▁test -1138 -▁might -1139 -▁exc -1140 -ral -1141 -▁rese -1142 -▁product -1143 -▁local -1144 -▁night -1145 -▁season -1146 -inal -1147 -▁el -1148 -▁incre -1149 -ember -1150 -▁site -1151 -rol -1152 -▁That -1153 -▁sing -1154 -ruct -1155 -ample -1156 -▁expl -1157 -▁Mar -1158 -▁spec -1159 -▁grow -1160 -▁let -1161 -▁ca -1162 -▁proper -1163 -▁less -1164 -ording -1165 -▁enjoy -1166 -▁ob -1167 -▁past -1168 -▁event -1169 -▁products -1170 -▁Man -1171 -▁' -1172 -▁inf -1173 -▁May -1174 -▁looking -1175 -▁food -1176 -here -1177 -lection -1178 -▁within -1179 -▁profess -1180 -▁Fe -1181 -▁Is -1182 -▁data -1183 -▁making -1184 -▁pop -1185 -ertain -1186 -▁until -1187 -ases -1188 -ories -1189 -ffic -1190 -enn -1191 -ency -1192 -▁children -1193 -ently -1194 -▁University -1195 -We -1196 -gin -1197 -sh -1198 -▁job -1199 -▁offer -1200 -▁law -1201 -ery -1202 -ains -1203 -ney -1204 -urs -1205 -▁pos -1206 -eng -1207 -utes -1208 -▁power -1209 -▁view -1210 -▁turn -1211 -▁eng -1212 -▁email -1213 -ential -1214 -tend -1215 -▁oper -1216 -▁sit -1217 -▁check -1218 -▁against -1219 -ieve -1220 -▁est -1221 -▁Pr -1222 -ream -1223 -ised -1224 -▁Br -1225 -ina -1226 -▁prote -1227 -ids -1228 -ode -1229 -▁room -1230 -▁contact -1231 -IN -1232 -▁community -1233 -med -1234 -to -1235 -▁addition -1236 -▁prom -1237 -▁says -1238 -▁intern -1239 -load -1240 -▁toget -1241 -▁together -1242 -▁Fl -1243 -▁away -1244 -ivid -1245 -▁impro -1246 -▁quality -1247 -▁leg -1248 -ator -1249 -▁dist -1250 -▁creat -1251 -ills -1252 -irl -1253 -hor -1254 -▁indust -1255 -▁complete -1256 -▁news -1257 -aring -1258 -iron -1259 -ique -1260 -ret -1261 -▁App -1262 -icle -1263 -iday -1264 -agement -1265 -ified -1266 -oci -1267 -▁supp -1268 -osed -1269 -ability -1270 -▁project -1271 -▁website -1272 -▁Car -1273 -iety -1274 -ane -1275 -por -1276 -!! -1277 -▁change -1278 -co -1279 -▁success -1280 -▁dep -1281 -bo -1282 -▁learn -1283 -▁include -1284 -▁Co -1285 -pend -1286 -▁fav -1287 -▁chang -1288 -ym -1289 -▁Ste -1290 -▁detail -1291 -ism -1292 -▁offic -1293 -▁Can -1294 -▁members -1295 -▁dr -1296 -arent -1297 -son -1298 -▁buy -1299 -▁easy -1300 -▁please -1301 -rap -1302 -▁Me -1303 -aster -1304 -▁applic -1305 -ising -1306 -ury -1307 -▁name -1308 -▁pract -1309 -▁times -1310 -atures -1311 -▁along -1312 -▁equ -1313 -▁present -1314 -▁One -1315 -▁large -1316 -▁money -1317 -▁beaut -1318 -atter -1319 -augh -1320 -▁Am -1321 -aterial -1322 -the -1323 -▁Cont -1324 -iting -1325 -▁activ -1326 -vern -1327 -RE -1328 -▁employ -1329 -▁la -1330 -aff -1331 -une -1332 -▁house -1333 -ready -1334 -Th -1335 -▁course -1336 -▁expect -1337 -▁. -1338 -▁needs -1339 -ored -1340 -▁air -1341 -▁left -1342 -▁Christ -1343 -▁thing -1344 -itions -1345 -ift -1346 -sc -1347 -ably -1348 -▁cap -1349 -ider -1350 -ived -1351 -lish -1352 -▁music -1353 -▁dra -1354 -min -1355 -▁why -1356 -▁En -1357 -yle -1358 -ohn -1359 -ump -1360 -ify -1361 -▁hist -1362 -ec -1363 -ron -1364 -by -1365 -▁bas -1366 -ern -1367 -▁hum -1368 -▁video -1369 -rie -1370 -▁sw -1371 -▁account -1372 -ON -1373 -ffe -1374 -alf -1375 -ocus -1376 -veral -1377 -▁below -1378 -▁soft -1379 -▁hot -1380 -▁These -1381 -▁short -1382 -ries -1383 -▁Eng -1384 -▁line -1385 -▁live -1386 -pecial -1387 -▁opport -1388 -enef -1389 -▁create -1390 -book -1391 -▁cond -1392 -▁beh -1393 -▁... -1394 -▁perfect -1395 -uly -1396 -▁ce -1397 -▁page -1398 -▁word -1399 -▁/ -1400 -▁writ -1401 -AT -1402 -▁dem -1403 -ots -1404 -▁Med -1405 -▁mar -1406 -▁Please -1407 -fort -1408 -side -1409 -ows -1410 -mber -1411 -▁govern -1412 -▁pa -1413 -artment -1414 -▁already -1415 -▁Che -1416 -▁kind -1417 -▁After -1418 -▁enough -1419 -▁ever -1420 -▁research -1421 -ured -1422 -▁makes -1423 -▁following -1424 -▁million -1425 -▁Do -1426 -▁review -1427 -▁getting -1428 -▁dev -1429 -ten -1430 -itive -1431 -ush -1432 -▁friends -1433 -▁cut -1434 -▁conne -1435 -▁trad -1436 -ee -1437 -., -1438 -▁record -1439 -room -1440 -▁treat -1441 -▁side -1442 -▁const -1443 -vious -1444 -▁Ass -1445 -▁case -1446 -▁having -1447 -ajor -1448 -▁tell -1449 -▁Count -1450 -▁personal -1451 -▁move -1452 -▁based -1453 -▁story -1454 -viron -1455 -ention -1456 -▁John -1457 -rop -1458 -▁Your -1459 -▁Serv -1460 -▁won -1461 -unch -1462 -ips -1463 -▁Des -1464 -▁minutes -1465 -uper -1466 -▁become -1467 -uture -1468 -▁possible -1469 -osp -1470 -oice -1471 -iam -1472 -▁talk -1473 -▁city -1474 -ights -1475 -▁across -1476 -▁vers -1477 -▁share -1478 -ization -1479 -▁done -1480 -▁bit -1481 -▁camp -1482 -▁pack -1483 -▁didn -1484 -▁comes -1485 -▁men -1486 -▁understand -1487 -ead -1488 -▁several -1489 -▁-- -1490 -yn -1491 -▁: -1492 -▁country -1493 -▁Tw -1494 -▁hours -1495 -▁effect -1496 -▁cou -1497 -▁purch -1498 -iven -1499 -▁benef -1500 -ES -1501 -▁mil -1502 -▁women -1503 -uff -1504 -▁net -1505 -ividual -1506 -app -1507 -aces -1508 -▁percent -1509 -▁Comp -1510 -▁educ -1511 -wards -1512 -▁focus -1513 -▁often -1514 -▁material -1515 -ball -1516 -▁social -1517 -aim -1518 -▁elect -1519 -▁Wor -1520 -idd -1521 -ances -1522 -ination -1523 -uro -1524 -ides -1525 -ober -1526 -▁quick -1527 -▁Not -1528 -▁development -1529 -▁es -1530 -▁bring -1531 -▁return -1532 -orts -1533 -▁American -1534 -ister -1535 -ienc -1536 -▁doing -1537 -▁Bro -1538 -▁School -1539 -ript -1540 -▁pie -1541 -▁X -1542 -▁far -1543 -▁hold -1544 -arl -1545 -▁mult -1546 -ted -1547 -▁body -1548 -arr -1549 -err -1550 -▁Gr -1551 -of -1552 -mend -1553 -▁pot -1554 -ference -1555 -iful -1556 -ones -1557 -AN -1558 -▁wa -1559 -ners -1560 -▁fund -1561 -▁took -1562 -ograph -1563 -▁Here -1564 -▁tre -1565 -ource -1566 -lished -1567 -▁blog -1568 -oose -1569 -itc -1570 -AR -1571 -▁State -1572 -▁doesn -1573 -reet -1574 -conom -1575 -▁jo -1576 -vironment -1577 -▁deal -1578 -lement -1579 -▁others -1580 -▁City -1581 -▁Rep -1582 -▁came -1583 -▁called -1584 -▁started -1585 -▁sum -1586 -▁rele -1587 -org -1588 -▁Inst -1589 -nder -1590 -▁least -1591 -▁months -1592 -▁Intern -1593 -▁space -1594 -acy -1595 -▁Gu -1596 -▁mom -1597 -▁future -1598 -▁orig -1599 -▁compet -1600 -▁individual -1601 -oon -1602 -lege -1603 -▁went -1604 -▁occ -1605 -▁yet -1606 -▁young -1607 -rodu -1608 -▁clean -1609 -▁non -1610 -▁mind -1611 -▁told -1612 -ai -1613 -▁five -1614 -▁early -1615 -▁series -1616 -▁control -1617 -af -1618 -utions -1619 -▁term -1620 -▁major -1621 -oll -1622 -hers -1623 -ille -1624 -ape -1625 -▁games -1626 -ained -1627 -▁comb -1628 -▁means -1629 -▁pict -1630 -▁industry -1631 -▁chall -1632 -yl -1633 -▁tool -1634 -anks -1635 -▁Min -1636 -▁ens -1637 -▁lim -1638 -▁cover -1639 -ctor -1640 -▁fore -1641 -▁ago -1642 -AS -1643 -▁low -1644 -sw -1645 -▁key -1646 -fer -1647 -ama -1648 -▁x -1649 -▁heart -1650 -▁features -1651 -▁Ed -1652 -ilt -1653 -▁tem -1654 -rew -1655 -▁price -1656 -unic -1657 -▁store -1658 -fact -1659 -jects -1660 -▁offers -1661 -▁Ab -1662 -itor -1663 -back -1664 -▁once -1665 -▁specific -1666 -come -1667 -▁range -1668 -▁thought -1669 -ges -1670 -urity -1671 -ither -1672 -ateg -1673 -▁Bo -1674 -▁Jan -1675 -sel -1676 -▁pick -1677 -illed -1678 -▁Now -1679 -eral -1680 -▁God -1681 -▁Dr -1682 -▁favor -1683 -▁appear -1684 -year -1685 -▁More -1686 -▁York -1687 -ilities -1688 -▁Ke -1689 -▁Im -1690 -▁hope -1691 -▁redu -1692 -▁discuss -1693 -OR -1694 -ibr -1695 -▁happen -1696 -▁require -1697 -yr -1698 -▁Pe -1699 -▁However -1700 -atic -1701 -It -1702 -▁mean -1703 -▁single -1704 -nes -1705 -▁step -1706 -▁close -1707 -▁upd -1708 -▁land -1709 -▁break -1710 -▁ey -1711 -▁main -1712 -▁invol -1713 -most -1714 -anies -1715 -▁Pres -1716 -ourn -1717 -▁stay -1718 -▁government -1719 -▁Em -1720 -isk -1721 -isc -1722 -// -1723 -▁Sm -1724 -ony -1725 -▁field -1726 -de -1727 -▁priv -1728 -▁United -1729 -▁beautiful -1730 -resh -1731 -cle -1732 -▁Per -1733 -▁friend -1734 -▁everything -1735 -▁Qu -1736 -▁walk -1737 -ched -1738 -▁questions -1739 -▁added -1740 -▁hig -1741 -▁Cal -1742 -▁tax -1743 -aken -1744 -▁customers -1745 -▁strong -1746 -now -1747 -▁taking -1748 -▁install -1749 -for -1750 -:// -1751 -aps -1752 -ging -1753 -▁Pol -1754 -▁charact -1755 -▁wond -1756 -▁South -1757 -▁begin -1758 -▁study -1759 -ources -1760 -▁North -1761 -▁Just -1762 -▁announ -1763 -ief -1764 -ensive -1765 -▁miss -1766 -▁recom -1767 -▁travel -1768 -▁certain -1769 -▁Park -1770 -▁address -1771 -▁problem -1772 -▁By -1773 -▁County -1774 -▁actually -1775 -play -1776 -▁staff -1777 -▁tot -1778 -▁half -1779 -▁mess -1780 -▁z -1781 -aur -1782 -ew -1783 -inc -1784 -ians -1785 -▁search -1786 -▁technology -1787 -▁girl -1788 -▁media -1789 -urther -1790 -time -1791 -▁watch -1792 -▁typ -1793 -▁known -1794 -▁official -1795 -▁manag -1796 -▁National -1797 -▁six -1798 -irm -1799 -▁Pre -1800 -▁wind -1801 -▁enc -1802 -gle -1803 -atural -1804 -ural -1805 -▁front -1806 -ublic -1807 -▁Add -1808 -▁sound -1809 -▁improve -1810 -▁Post -1811 -wh -1812 -▁dig -1813 -irt -1814 -▁lat -1815 -▁content -1816 -▁Su -1817 -▁Stud -1818 -▁anal -1819 -▁track -1820 -itted -1821 -▁Mc -1822 -▁face -1823 -▁training -1824 -▁link -1825 -▁click -1826 -icy -1827 -▁ste -1828 -▁web -1829 -▁someone -1830 -ison -1831 -▁Oct -1832 -arning -1833 -▁works -1834 -▁author -1835 -▁later -1836 -▁building -1837 -not -1838 -lebr -1839 -▁host -1840 -ocu -1841 -▁Gl -1842 -▁environment -1843 -abor -1844 -cted -1845 -▁Center -1846 -▁mor -1847 -▁log -1848 -▁unique -1849 -▁everyone -1850 -▁Reg -1851 -raft -1852 -▁port -1853 -▁provides -1854 -IS -1855 -gest -1856 -▁ener -1857 -▁fall -1858 -▁cred -1859 -▁seen -1860 -▁Dep -1861 -▁film -1862 -ask -1863 -▁Day -1864 -▁prep -1865 -▁oil -1866 -▁particular -1867 -▁professional -1868 -▁aud -1869 -fully -1870 -▁Aug -1871 -▁Euro -1872 -ests -1873 -▁particip -1874 -lex -1875 -ided -1876 -unities -1877 -▁bar -1878 -ibility -1879 -▁results -1880 -▁ident -1881 -▁recommend -1882 -roll -1883 -▁press -1884 -ED -1885 -▁card -1886 -▁While -1887 -▁Will -1888 -▁whole -1889 -▁Don -1890 -aturday -1891 -▁World -1892 -rain -1893 -▁companies -1894 -ino -1895 -▁Ge -1896 -▁High -1897 -urch -1898 -▁Friday -1899 -▁office -1900 -IT -1901 -pper -1902 -▁Bar -1903 -▁March -1904 -▁color -1905 -▁events -1906 -▁anything -1907 -▁issues -1908 -EN -1909 -ancial -1910 -▁mot -1911 -▁eff -1912 -▁prob -1913 -▁mag -1914 -▁areas -1915 -▁pret -1916 -resent -1917 -▁vol -1918 -▁Some -1919 -▁comput -1920 -▁respons -1921 -ops -1922 -▁points -1923 -▁Acc -1924 -▁performance -1925 -▁near -1926 -▁pain -1927 -ster -1928 -obile -1929 -▁red -1930 -▁print -1931 -▁cook -1932 -▁Apr -1933 -itch -1934 -umb -1935 -▁given -1936 -▁history -1937 -▁econom -1938 -pecially -1939 -crib -1940 -obal -1941 -.... -1942 -▁feature -1943 -go -1944 -ili -1945 -ands -1946 -▁sell -1947 -▁designed -1948 -▁above -1949 -ches -1950 -▁maint -1951 -▁skin -1952 -▁text -1953 -▁aff -1954 -▁simple -1955 -eth -1956 -▁assist -1957 -IC -1958 -my -1959 -ued -1960 -▁age -1961 -icult -1962 -▁reason -1963 -inks -1964 -In -1965 -▁size -1966 -▁question -1967 -▁dou -1968 -imate -1969 -▁according -1970 -▁repl -1971 -iod -1972 -ply -1973 -▁Sec -1974 -nding -1975 -▁black -1976 -▁Aust -1977 -head -1978 -▁htt -1979 -edd -1980 -▁pretty -1981 -▁foot -1982 -▁believe -1983 -▁Saturday -1984 -oved -1985 -ables -1986 -▁due -1987 -▁Part -1988 -▁among -1989 -▁select -1990 -AL -1991 -itter -1992 -▁Sund -1993 -▁fire -1994 -cript -1995 -▁phys -1996 -omes -1997 -ental -1998 -ledge -1999 -▁idea -2000 -ety -2001 -▁latest -2002 -▁details -2003 -▁ant -2004 -▁popular -2005 -ole -2006 -▁third -2007 -▁et -2008 -ators -2009 -▁Mr -2010 -pro -2011 -val -2012 -▁management -2013 -aining -2014 -itional -2015 -▁includes -2016 -ruction -2017 -asing -2018 -▁July -2019 -▁energy -2020 -▁items -2021 -ze -2022 -▁weeks -2023 -ouch -2024 -onday -2025 -▁sent -2026 -▁Feb -2027 -▁living -2028 -ites -2029 -▁cult -2030 -▁receive -2031 -▁fre -2032 -▁continue -2033 -▁bad -2034 -▁June -2035 -▁relations -2036 -▁Europe -2037 -vert -2038 -astic -2039 -idence -2040 -▁human -2041 -▁parent -2042 -ulation -2043 -▁Val -2044 -▁His -2045 -▁claim -2046 -aily -2047 -▁Sept -2048 -ufact -2049 -ctions -2050 -elt -2051 -▁Dav -2052 -▁sex -2053 -▁prop -2054 -▁soon -2055 -ung -2056 -▁property -2057 -▁hon -2058 -nov -2059 -▁currently -2060 -▁amount -2061 -▁entire -2062 -new -2063 -▁West -2064 -uation -2065 -▁coming -2066 -ese -2067 -though -2068 -ana -2069 -ogn -2070 -▁Off -2071 -▁kids -2072 -▁TH -2073 -▁Tra -2074 -▁From -2075 -itting -2076 -▁phone -2077 -This -2078 -cast -2079 -▁final -2080 -▁consum -2081 -▁ess -2082 -▁happy -2083 -▁taken -2084 -▁celebr -2085 -▁docu -2086 -▁member -2087 -icro -2088 -.) -2089 -▁answ -2090 -▁meas -2091 -AC -2092 -▁wanted -2093 -▁type -2094 -▁software -2095 -selves -2096 -▁experienc -2097 -▁forward -2098 -▁diff -2099 -eds -2100 -▁whether -2101 -▁Us -2102 -▁wide -2103 -▁Read -2104 -▁either -2105 -▁Bu -2106 -ires -2107 -▁El -2108 -▁value -2109 -▁concer -2110 -▁deb -2111 -▁further -2112 -ux -2113 -ilar -2114 -ival -2115 -▁isn -2116 -▁coll -2117 -used -2118 -ams -2119 -aced -2120 -▁par -2121 -▁almost -2122 -▁required -2123 -▁crit -2124 -▁held -2125 -▁white -2126 -arter -2127 -▁date -2128 -▁comfort -2129 -▁quite -2130 -▁trying -2131 -▁provided -2132 -▁summer -2133 -▁Sw -2134 -▁fit -2135 -▁Pa -2136 -▁sugg -2137 -▁needed -2138 -▁favorite -2139 -▁tit -2140 -St -2141 -ees -2142 -▁Sunday -2143 -▁opportunity -2144 -▁Jo -2145 -▁ach -2146 -aching -2147 -uary -2148 -ek -2149 -▁Cor -2150 -▁via -2151 -▁extra -2152 -▁players -2153 -▁April -2154 -▁books -2155 -▁Monday -2156 -▁network -2157 -▁cop -2158 -amer -2159 -ler -2160 -▁example -2161 -▁box -2162 -▁users -2163 -▁, -2164 -itten -2165 -▁seem -2166 -▁period -2167 -▁various -2168 -▁Health -2169 -▁options -2170 -where -2171 -▁running -2172 -gress -2173 -▁style -2174 -▁especially -2175 -▁consider -2176 -▁yourself -2177 -▁Art -2178 -▁dam -2179 -▁safe -2180 -▁previous -2181 -▁swe -2182 -▁ways -2183 -▁version -2184 -▁created -2185 -▁sle -2186 -▁Mon -2187 -▁recently -2188 -▁potential -2189 -OU -2190 -▁issue -2191 -▁common -2192 -ises -2193 -▁di -2194 -▁Inc -2195 -▁stri -2196 -▁ready -2197 -▁attend -2198 -▁morning -2199 -▁regular -2200 -▁insp -2201 -▁else -2202 -▁road -2203 -▁nice -2204 -▁throughout -2205 -▁probably -2206 -▁ensure -2207 --- -2208 -▁veh -2209 -▁received -2210 -earch -2211 -▁ball -2212 -▁Associ -2213 -▁President -2214 -▁clear -2215 -▁download -2216 -par -2217 -icles -2218 -▁engine -2219 -▁sho -2220 -erc -2221 -▁song -2222 -azing -2223 -▁lo -2224 -▁brand -2225 -▁relationship -2226 -▁takes -2227 -▁reading -2228 -mit -2229 -▁natural -2230 -▁Aut -2231 -▁States -2232 -ades -2233 -amed -2234 -▁park -2235 -▁House -2236 -ively -2237 -▁shows -2238 -▁asked -2239 -▁medical -2240 -istration -2241 -ague -2242 -▁inj -2243 -▁hit -2244 -▁choose -2245 -▁collect -2246 -▁Direct -2247 -▁Mich -2248 -▁original -2249 -▁cool -2250 -▁spr -2251 -▁couple -2252 -angu -2253 -reme -2254 -ipping -2255 -▁represent -2256 -▁bott -2257 -▁init -2258 -▁release -2259 -▁goal -2260 -▁behind -2261 -ny -2262 -apt -2263 -oid -2264 -▁Face -2265 -▁wonder -2266 -▁Soc -2267 -▁recent -2268 -▁sales -2269 -eter -2270 -▁clients -2271 -▁financial -2272 -aging -2273 -overed -2274 -▁accom -2275 -▁fresh -2276 -▁fast -2277 -▁super -2278 -▁leave -2279 -▁problems -2280 -▁anyone -2281 -▁role -2282 -face -2283 -▁Get -2284 -gs -2285 -hib -2286 -▁Ser -2287 -▁career -2288 -uge -2289 -▁Fin -2290 -bor -2291 -▁Black -2292 -ume -2293 -▁cup -2294 -ried -2295 -ville -2296 -▁model -2297 -▁article -2298 -oura -2299 -▁ful -2300 -uesday -2301 -▁meth -2302 -arth -2303 -▁ground -2304 -▁programs -2305 -▁Up -2306 -▁hol -2307 -▁fail -2308 -na -2309 -▁sun -2310 -aving -2311 -▁weeke -2312 -▁accept -2313 -▁flow -2314 -ada -2315 -ursday -2316 -▁base -2317 -medi -2318 -▁customer -2319 -▁difficult -2320 -OT -2321 -atform -2322 -▁writing -2323 -anced -2324 -urance -2325 -▁looks -2326 -▁PM -2327 -▁tour -2328 -▁polit -2329 -▁likely -2330 -ox -2331 -hel -2332 -oogle -2333 -▁paper -2334 -▁ap -2335 -▁abs -2336 -▁simply -2337 -cing -2338 -name -2339 -verage -2340 -▁inside -2341 -▁manufact -2342 -▁TV -2343 -clus -2344 -▁etc -2345 -▁mix -2346 -▁total -2347 -▁included -2348 -▁po -2349 -idge -2350 -ming -2351 -▁Int -2352 -▁risk -2353 -▁Wed -2354 -adem -2355 -aker -2356 -▁increase -2357 -▁party -2358 -▁changes -2359 -▁ele -2360 -ashing -2361 -▁board -2362 -▁education -2363 -oud -2364 -▁Her -2365 -▁October -2366 -▁action -2367 -▁former -2368 -▁meeting -2369 -Wh -2370 -▁however -2371 -▁News -2372 -▁outside -2373 -ification -2374 -uit -2375 -iple -2376 -▁match -2377 -▁Ac -2378 -▁America -2379 -▁Act -2380 -▁nothing -2381 -▁security -2382 -▁self -2383 -ground -2384 -▁contrib -2385 -▁stop -2386 -ester -2387 -▁town -2388 -▁August -2389 -▁matter -2390 -▁position -2391 -▁Af -2392 -▁ple -2393 -▁bed -2394 -▁late -2395 -istrict -2396 -▁Ob -2397 -▁systems -2398 -▁Every -2399 -icated -2400 -adu -2401 -ules -2402 -▁Bus -2403 -▁words -2404 -▁playing -2405 -▁cir -2406 -▁pan -2407 -ST -2408 -▁UK -2409 -wood -2410 -▁sat -2411 -▁impact -2412 -▁anim -2413 -▁mark -2414 -▁private -2415 -▁application -2416 -▁police -2417 -▁knowledge -2418 -▁exist -2419 -▁photos -2420 -▁method -2421 -▁longer -2422 -▁coun -2423 -▁worked -2424 -iddle -2425 -▁national -2426 -▁projects -2427 -ederal -2428 -▁ord -2429 -▁Are -2430 -▁necess -2431 -ude -2432 -▁table -2433 -▁stra -2434 -off -2435 -▁Ag -2436 -empt -2437 -elcome -2438 -▁September -2439 -ecut -2440 -▁activities -2441 -▁worth -2442 -▁recogn -2443 -▁production -2444 -str -2445 -nesday -2446 -▁Department -2447 -based -2448 -aby -2449 -iff -2450 -▁comment -2451 -▁compl -2452 -▁skills -2453 -▁true -2454 -▁general -2455 -▁Austral -2456 -▁January -2457 -iol -2458 -▁round -2459 -▁lives -2460 -▁learning -2461 -▁Tuesday -2462 -▁Thursday -2463 -ID -2464 -che -2465 -▁Then -2466 -▁introdu -2467 -ky -2468 -arden -2469 -▁signific -2470 -ING -2471 -oom -2472 -▁Sal -2473 -▁ill -2474 -▁student -2475 -▁Pat -2476 -▁lay -2477 -▁hair -2478 -▁Free -2479 -▁Nove -2480 -▁computer -2481 -▁squ -2482 -▁purchase -2483 -▁tal -2484 -ham -2485 -▁Also -2486 -ession -2487 -ett -2488 -▁Mus -2489 -▁death -2490 -▁defin -2491 -▁seems -2492 -▁Of -2493 -ci -2494 -▁hands -2495 -izing -2496 -▁communic -2497 -mon -2498 -▁rad -2499 -▁choice -2500 -▁screen -2501 -AM -2502 -▁draw -2503 -▁concern -2504 -▁leading -2505 -▁additional -2506 -▁First -2507 -▁rights -2508 -attle -2509 -▁cell -2510 -▁credit -2511 -▁located -2512 -▁variety -2513 -▁leaders -2514 -▁Facebook -2515 -▁stat -2516 -▁tick -2517 -▁drive -2518 -▁movie -2519 -▁San -2520 -arget -2521 -oring -2522 -▁file -2523 -▁fig -2524 -ipment -2525 -▁hy -2526 -▁bud -2527 -▁image -2528 -▁determ -2529 -▁amazing -2530 -aign -2531 -▁Sim -2532 -▁suggest -2533 -mercial -2534 -▁chance -2535 -▁Red -2536 -▁associ -2537 -▁rather -2538 -▁practice -2539 -▁built -2540 -▁plans -2541 -▁function -2542 -oph -2543 -▁Har -2544 -▁providing -2545 -iter -2546 -▁cal -2547 -ached -2548 -airs -2549 -light -2550 -ought -2551 -urg -2552 -pm -2553 -▁War -2554 -▁vict -2555 -▁court -2556 -▁aw -2557 -▁saf -2558 -▁cand -2559 -example -2560 -▁Out -2561 -▁touch -2562 -▁Air -2563 -▁teac -2564 -cil -2565 -▁exam -2566 -▁autom -2567 -▁Street -2568 -▁international -2569 -▁loss -2570 -▁weekend -2571 -▁Wind -2572 -▁infl -2573 -▁prior -2574 -▁prevent -2575 -▁allows -2576 -▁arri -2577 -▁Calif -2578 -▁Click -2579 -irth -2580 -ibrary -2581 -▁character -2582 -▁piece -2583 -▁treatment -2584 -cember -2585 -itchen -2586 -olution -2587 -▁http -2588 -ma -2589 -▁similar -2590 -▁Most -2591 -▁moment -2592 -gar -2593 -oke -2594 -ruary -2595 -▁clos -2596 -▁Design -2597 -▁investig -2598 -▁rate -2599 -▁AM -2600 -reg -2601 -▁commit -2602 -▁growth -2603 -imum -2604 -▁norm -2605 -OM -2606 -iber -2607 -▁Dis -2608 -ivery -2609 -▁estab -2610 -▁cause -2611 -▁user -2612 -sp -2613 -▁deg -2614 -▁lost -2615 -▁display -2616 -▁collection -2617 -▁myself -2618 -▁Cr -2619 -▁op -2620 -▁enter -2621 -▁Wednesday -2622 -unt -2623 -▁rout -2624 -ault -2625 -▁decided -2626 -▁decision -2627 -▁sil -2628 -▁inde -2629 -▁Any -2630 -▁higher -2631 -cy -2632 -▁bal -2633 -▁daily -2634 -ha -2635 -ournal -2636 -▁digital -2637 -▁November -2638 -▁purp -2639 -▁Group -2640 -▁released -2641 -▁significant -2642 -▁reported -2643 -LE -2644 -▁Home -2645 -▁woman -2646 -▁Cour -2647 -▁easily -2648 -▁cannot -2649 -▁goes -2650 -▁International -2651 -▁excell -2652 -lin -2653 -▁wall -2654 -▁Thanks -2655 -▁quickly -2656 -▁College -2657 -▁usually -2658 -amb -2659 -▁bag -2660 -▁apply -2661 -▁floor -2662 -▁expected -2663 -iant -2664 -▁involved -2665 -▁Law -2666 -▁dom -2667 -▁attack -2668 -just -2669 -▁boy -2670 -illing -2671 -▁regard -2672 -▁platform -2673 -▁capt -2674 -▁iP -2675 -▁Net -2676 -▁encoura -2677 -▁protect -2678 -ondon -2679 -▁Cons -2680 -▁agree -2681 -ael -2682 -▁serious -2683 -▁December -2684 -▁safety -2685 -▁roll -2686 -▁saw -2687 -▁dress -2688 -▁Google -2689 -▁gen -2690 -▁parents -2691 -▁mach -2692 -idents -2693 -▁played -2694 -▁Service -2695 -▁immedi -2696 -▁surpr -2697 -mas -2698 -▁warm -2699 -zz -2700 -▁integr -2701 -▁mobile -2702 -▁tast -2703 -ica -2704 -▁February -2705 -▁sn -2706 -▁club -2707 -▁langu -2708 -▁president -2709 -▁sche -2710 -▁related -2711 -hern -2712 -▁shoot -2713 -▁finish -2714 -▁ideas -2715 -▁global -2716 -▁marketing -2717 -▁tools -2718 -▁ep -2719 -▁expert -2720 -band -2721 -▁code -2722 -▁exact -2723 -ospital -2724 -asons -2725 -▁mass -2726 -▁note -2727 -avy -2728 -▁photo -2729 -izes -2730 -▁save -2731 -▁source -2732 -▁ut -2733 -▁option -2734 -▁respect -2735 -▁Brit -2736 -▁Let -2737 -▁feed -2738 -enge -2739 -iding -2740 -▁arch -2741 -▁deep -2742 -▁corre -2743 -▁Ang -2744 -▁announced -2745 -ilies -2746 -▁appe -2747 -edding -2748 -▁Well -2749 -cription -2750 -▁La -2751 -www -2752 -hood -2753 -reng -2754 -▁stock -2755 -▁sens -2756 -▁admin -2757 -▁location -2758 -▁ri -2759 -ellow -2760 -▁gets -2761 -▁David -2762 -▁costs -2763 -▁helps -2764 -▁Av -2765 -ples -2766 -▁materials -2767 -ength -2768 -▁Je -2769 -ipe -2770 -rab -2771 -▁Tex -2772 -▁huge -2773 -▁published -2774 -agn -2775 -like -2776 -AP -2777 -▁send -2778 -▁mother -2779 -▁benefits -2780 -▁English -2781 -enior -2782 -mission -2783 -ography -2784 -▁lab -2785 -oday -2786 -▁Play -2787 -▁fight -2788 -▁Over -2789 -▁hear -2790 -▁weight -2791 -rown -2792 -▁Spr -2793 -ornia -2794 -uel -2795 -vey -2796 -iction -2797 -▁images -2798 -rought -2799 -▁restaur -2800 -key -2801 -▁gar -2802 -▁Book -2803 -▁earn -2804 -ald -2805 -▁ability -2806 -▁interview -2807 -add -2808 -▁Check -2809 -▁Business -2810 -atory -2811 -▁London -2812 -ructure -2813 -▁written -2814 -akers -2815 -▁challeng -2816 -▁standard -2817 -▁gives -2818 -▁giving -2819 -▁ones -2820 -▁legal -2821 -▁sense -2822 -▁campaign -2823 -▁Sch -2824 -▁dest -2825 -▁innov -2826 -erved -2827 -▁door -2828 -▁patients -2829 -rom -2830 -▁mid -2831 -▁trust -2832 -urt -2833 -▁sus -2834 -▁wasn -2835 -▁Services -2836 -▁center -2837 -▁instead -2838 -aged -2839 -▁Produ -2840 -▁fab -2841 -▁Coun -2842 -▁heat -2843 -▁neg -2844 -▁fine -2845 -▁item -2846 -▁Great -2847 -▁target -2848 -erous -2849 -▁prem -2850 -erve -2851 -▁sold -2852 -▁White -2853 -aught -2854 -▁wish -2855 -▁Trans -2856 -▁parts -2857 -▁write -2858 -▁levels -2859 -▁lic -2860 -▁award -2861 -iring -2862 -arant -2863 -aves -2864 -▁cases -2865 -▁describ -2866 -▁picture -2867 -▁pers -2868 -▁partners -2869 -▁Web -2870 -▁dry -2871 -▁neigh -2872 -irit -2873 -▁Mod -2874 -▁Prof -2875 -▁stuff -2876 -ashington -2877 -ida -2878 -▁pull -2879 -▁conditions -2880 -▁ded -2881 -atives -2882 -▁green -2883 -▁California -2884 -▁broad -2885 -▁effic -2886 -▁Hol -2887 -board -2888 -▁Hall -2889 -put -2890 -rows -2891 -▁Program -2892 -ivity -2893 -▁began -2894 -▁sale -2895 -▁upon -2896 -istic -2897 -▁highly -2898 -▁interesting -2899 -TM -2900 -bit -2901 -OS -2902 -▁vot -2903 -▁fans -2904 -▁stories -2905 -inner -2906 -▁request -2907 -▁contract -2908 -▁remember -2909 -▁slow -2910 -▁Cle -2911 -▁emer -2912 -▁subs -2913 -▁answer -2914 -▁Techn -2915 -anch -2916 -▁comments -2917 -acing -2918 -ocol -2919 -▁bra -2920 -▁Phot -2921 -▁wood -2922 -▁Other -2923 -▁lower -2924 -▁sym -2925 -▁dead -2926 -orge -2927 -▁prim -2928 -orage -2929 -▁modern -2930 -▁player -2931 -▁cat -2932 -coming -2933 -bum -2934 -▁interested -2935 -ooth -2936 -▁reports -2937 -aches -2938 -▁except -2939 -ara -2940 -lev -2941 -▁dise -2942 -▁trip -2943 -▁teams -2944 -▁Jack -2945 -▁Texas -2946 -▁attention -2947 -▁equipment -2948 -▁paint -2949 -sy -2950 -▁fully -2951 -▁wrong -2952 -▁directly -2953 -▁starting -2954 -▁completely -2955 -▁organization -2956 -▁types -2957 -uk -2958 -wide -2959 -▁Green -2960 -mm -2961 -▁resources -2962 -▁Last -2963 -▁www -2964 -ET -2965 -urb -2966 -ager -2967 -▁document -2968 -▁themselves -2969 -apan -2970 -▁dru -2971 -▁solutions -2972 -▁stru -2973 -▁viol -2974 -ashion -2975 -▁bank -2976 -▁Washington -2977 -▁Loc -2978 -▁Rem -2979 -ament -2980 -▁multiple -2981 -▁Association -2982 -▁band -2983 -▁achieve -2984 -▁condition -2985 -▁gold -2986 -▁businesses -2987 -▁Twitter -2988 -uses -2989 -▁wait -2990 -ule -2991 -▁Go -2992 -ening -2993 -udd -2994 -▁Each -2995 -▁affect -2996 -▁opportunities -2997 -▁vac -2998 -▁Gener -2999 -urer -3000 -▁hop -3001 -EC -3002 -▁sett -3003 -▁policy -3004 -▁Par -3005 -▁led -3006 -ension -3007 -▁thinking -3008 -▁dream -3009 -▁Once -3010 -raz -3011 -rel -3012 -▁groups -3013 -▁planning -3014 -▁commercial -3015 -EO -3016 -He -3017 -ffee -3018 -olf -3019 -▁Spe -3020 -▁separ -3021 -▁applications -3022 -▁qual -3023 -▁streng -3024 -▁approach -3025 -▁families -3026 -▁solution -3027 -▁Del -3028 -▁firm -3029 -▁Class -3030 -▁express -3031 -ores -3032 -▁gave -3033 -▁Found -3034 -enty -3035 -iles -3036 -▁offe -3037 -▁consult -3038 -▁Year -3039 -▁gift -3040 -▁subject -3041 -▁Mem -3042 -AD -3043 -▁Afric -3044 -▁prices -3045 -▁successful -3046 -ties -3047 -▁positive -3048 -▁employees -3049 -arlier -3050 -▁blood -3051 -▁AN -3052 -▁race -3053 -itute -3054 -▁deliver -3055 -oul -3056 -▁join -3057 -ares -3058 -▁itself -3059 -▁King -3060 -▁shot -3061 -▁advice -3062 -▁cert -3063 -▁THE -3064 -▁eye -3065 -riend -3066 -▁hour -3067 -▁defe -3068 -▁saying -3069 -▁healthy -3070 -▁glass -3071 -▁creating -3072 -▁Sub -3073 -▁According -3074 -▁dark -3075 -ration -3076 -▁spent -3077 -▁div -3078 -▁Even -3079 -▁Why -3080 -field -3081 -▁cy -3082 -itely -3083 -ford -3084 -▁Best -3085 -▁cancer -3086 -▁Christmas -3087 -▁effective -3088 -▁serve -3089 -omen -3090 -▁sites -3091 -▁budget -3092 -▁Whe -3093 -▁Road -3094 -▁lif -3095 -▁goals -3096 -▁message -3097 -king -3098 -▁Vis -3099 -▁reve -3100 -mb -3101 -down -3102 -▁Paul -3103 -▁fair -3104 -▁India -3105 -▁average -3106 -▁Dan -3107 -▁fix -3108 -▁circ -3109 -▁Office -3110 -▁Pri -3111 -▁condu -3112 -▁East -3113 -▁reach -3114 -elling -3115 -▁Since -3116 -▁cross -3117 -aughter -3118 -▁traditional -3119 -▁extreme -3120 -▁organiz -3121 -▁director -3122 -PS -3123 -▁Hot -3124 -▁implement -3125 -Ch -3126 -▁sometimes -3127 -▁physical -3128 -▁obs -3129 -ipped -3130 -▁camer -3131 -ords -3132 -vis -3133 -▁Oh -3134 -▁opp -3135 -▁adult -3136 -▁terms -3137 -iable -3138 -▁Germ -3139 -▁plant -3140 -▁wonderful -3141 -US -3142 -rote -3143 -▁hor -3144 -▁Many -3145 -▁Rec -3146 -▁aim -3147 -▁attempt -3148 -▁limited -3149 -▁pictures -3150 -tee -3151 -▁Japan -3152 -▁See -3153 -▁Develop -3154 -▁excellent -3155 -▁dro -3156 -urning -3157 -ysis -3158 -▁mount -3159 -BC -3160 -▁emb -3161 -▁Work -3162 -imately -3163 -onse -3164 -▁brought -3165 -uth -3166 -yond -3167 -▁Ann -3168 -▁quarter -3169 -hest -3170 -▁title -3171 -▁section -3172 -ecutive -3173 -▁block -3174 -▁delivery -3175 -▁Mor -3176 -▁became -3177 -▁farm -3178 -▁arr -3179 -▁carry -3180 -▁effort -3181 -▁IN -3182 -▁kitchen -3183 -▁mention -3184 -▁developed -3185 -▁imm -3186 -inary -3187 -▁Use -3188 -iance -3189 -yright -3190 -reci -3191 -▁jud -3192 -▁fish -3193 -▁China -3194 -▁Inter -3195 -▁countries -3196 -estern -3197 -▁progress -3198 -▁necessary -3199 -▁ge -3200 -▁suppl -3201 -▁sweet -3202 -pendent -3203 -▁complex -3204 -ocks -3205 -▁baby -3206 -vest -3207 -▁felt -3208 -mitted -3209 -▁feeling -3210 -▁System -3211 -▁nation -3212 -▁promot -3213 -▁Top -3214 -▁Make -3215 -▁Dem -3216 -▁Good -3217 -hold -3218 -iced -3219 -▁birth -3220 -▁sleep -3221 -▁growing -3222 -▁impress -3223 -porate -3224 -▁Public -3225 -▁places -3226 -ocr -3227 -▁seven -3228 -▁IT -3229 -▁Flor -3230 -ffects -3231 -venue -3232 -▁Mac -3233 -▁war -3234 -▁heard -3235 -itation -3236 -gu -3237 -pite -3238 -▁weather -3239 -▁Lear -3240 -▁Open -3241 -▁region -3242 -▁Michael -3243 -haps -3244 -▁billion -3245 -▁son -3246 -itary -3247 -▁star -3248 -▁Sur -3249 -duc -3250 -▁Today -3251 -▁hotel -3252 -▁wants -3253 -Re -3254 -▁Thank -3255 -▁stick -3256 -▁college -3257 -▁construction -3258 -IL -3259 -▁bi -3260 -▁album -3261 -▁spend -3262 -▁mat -3263 -▁cold -3264 -▁medic -3265 -▁stage -3266 -▁ver -3267 -▁Port -3268 -▁Director -3269 -▁individuals -3270 -▁double -3271 -nded -3272 -▁Canada -3273 -▁Market -3274 -): -3275 -EL -3276 -aries -3277 -▁Down -3278 -▁convers -3279 -▁Russ -3280 -▁profession -3281 -ying -3282 -▁ble -3283 -▁speed -3284 -▁distrib -3285 -pects -3286 -▁exerc -3287 -rup -3288 -▁ST -3289 -aled -3290 -▁finished -3291 -fl -3292 -▁gas -3293 -istry -3294 -▁suit -3295 -ils -3296 -▁pages -3297 -▁statement -3298 -pre -3299 -ancy -3300 -▁charge -3301 -▁ing -3302 -▁spot -3303 -▁ult -3304 -▁requirements -3305 -▁finally -3306 -▁schools -3307 -▁vehicle -3308 -▁smart -3309 -▁annual -3310 -▁Windows -3311 -". -3312 -ado -3313 -wor -3314 -▁eat -3315 -useum -3316 -▁feet -3317 -▁Board -3318 -▁advant -3319 -ibly -3320 -▁blue -3321 -▁load -3322 -▁aware -3323 -unk -3324 -▁Gold -3325 -▁Research -3326 -▁straight -3327 -▁appl -3328 -arc -3329 -▁Mark -3330 -▁nearly -3331 -ato -3332 -▁Bel -3333 -▁Tom -3334 -▁tried -3335 -▁hous -3336 -▁avoid -3337 -aling -3338 -ports -3339 -▁difference -3340 -▁wrote -3341 -▁William -3342 -▁Sol -3343 -▁pattern -3344 -owl -3345 -ened -3346 -▁James -3347 -▁respond -3348 -▁challenge -3349 -▁Bre -3350 -▁dog -3351 -▁beginning -3352 -ION -3353 -▁Educ -3354 -▁About -3355 -▁helping -3356 -:|| -3357 -▁benefit -3358 -▁insurance -3359 -▁situation -3360 -iment -3361 -▁essential -3362 -▁imag -3363 -ancing -3364 -unte -3365 -▁device -3366 -ceed -3367 -▁Obama -3368 -rast -3369 -▁shop -3370 -ological -3371 -▁Care -3372 -▁Indian -3373 -▁political -3374 -box -3375 -uted -3376 -▁Time -3377 -▁loved -3378 -▁Review -3379 -ube -3380 -▁nut -3381 -▁pow -3382 -overn -3383 -▁wear -3384 -▁Apple -3385 -▁Sl -3386 -▁Mag -3387 -olute -3388 -▁Find -3389 -▁activity -3390 -▁devices -3391 -▁moving -3392 -▁Met -3393 -▁lik -3394 -▁paid -3395 -▁enh -3396 -▁Club -3397 -▁Hel -3398 -▁uses -3399 -▁eight -3400 -▁exhib -3401 -▁Court -3402 -▁turned -3403 -oms -3404 -oses -3405 -▁posted -3406 -▁towards -3407 -”. -3408 -▁nature -3409 -▁Sk -3410 -▁partner -3411 -asy -3412 -▁investment -3413 -ourney -3414 -▁appreci -3415 -▁offering -3416 -▁temper -3417 -▁contain -3418 -▁largest -3419 -ivil -3420 -▁knew -3421 -▁ahead -3422 -oves -3423 -rench -3424 -idered -3425 -▁retail -3426 -▁hus -3427 -▁eyes -3428 -▁owners -3429 -▁language -3430 -▁Ant -3431 -inger -3432 -▁expand -3433 -house -3434 -ey -3435 -rences -3436 -ios -3437 -▁rent -3438 -ned -3439 -▁cas -3440 -▁connect -3441 -▁wife -3442 -ampions -3443 -▁advert -3444 -▁Rel -3445 -▁Rich -3446 -▁reduce -3447 -▁European -3448 -▁guarant -3449 -ago -3450 -cause -3451 -▁Look -3452 -▁sports -3453 -▁correct -3454 -aly -3455 -anta -3456 -▁categ -3457 -▁client -3458 -▁states -3459 -▁consist -3460 -pri -3461 -▁maybe -3462 -▁named -3463 -▁definitely -3464 -hips -3465 -▁influ -3466 -▁entertain -3467 -erry -3468 -hens -3469 -▁accur -3470 -▁concept -3471 -osing -3472 -ounds -3473 -▁runs -3474 -▁grand -3475 -▁stress -3476 -IP -3477 -change -3478 -▁Super -3479 -▁guide -3480 -▁homes -3481 -▁Have -3482 -▁thous -3483 -last -3484 -▁jobs -3485 -▁offered -3486 -estival -3487 -▁earlier -3488 -▁immediately -3489 -▁doll -3490 -▁numbers -3491 -sych -3492 -▁conc -3493 -iers -3494 -▁decl -3495 -▁Fam -3496 -esome -3497 -▁Rob -3498 -▁rates -3499 -▁Council -3500 -azine -3501 -▁rev -3502 -▁Community -3503 -▁path -3504 -▁collabor -3505 -lying -3506 -roud -3507 -▁Cop -3508 -You -3509 -alt -3510 -orrow -3511 -▁candid -3512 -▁interact -3513 -ails -3514 -▁remain -3515 -▁II -3516 -more -3517 -▁bottom -3518 -sec -3519 -dule -3520 -▁Sum -3521 -▁Cong -3522 -▁belie -3523 -▁drink -3524 -▁pieces -3525 -▁exactly -3526 -asc -3527 -lim -3528 -▁tips -3529 -▁Micro -3530 -▁View -3531 -iation -3532 -▁overall -3533 -▁max -3534 -▁federal -3535 -▁storage -3536 -vin -3537 -icious -3538 -▁Custom -3539 -▁opening -3540 -▁demand -3541 -▁Two -3542 -place -3543 -▁surround -3544 -▁Cur -3545 -▁histor -3546 -▁Bay -3547 -orial -3548 -▁Rober -3549 -▁adjust -3550 -ulations -3551 -▁shipping -3552 -▁strateg -3553 -▁Internet -3554 -▁active -3555 -▁threat -3556 -ram -3557 -▁Win -3558 -▁looked -3559 -oma -3560 -▁ten -3561 -▁occas -3562 -▁length -3563 -inated -3564 -▁served -3565 -▁conference -3566 -ico -3567 -iny -3568 -▁IS -3569 -▁guys -3570 -▁rock -3571 -▁button -3572 -▁garden -3573 -▁Florida -3574 -▁acqu -3575 -▁Police -3576 -▁easier -3577 -▁Angel -3578 -yd -3579 -order -3580 -undred -3581 -▁Island -3582 -▁father -3583 -oly -3584 -▁bath -3585 -▁speak -3586 -▁attract -3587 -If -3588 -▁normal -3589 -▁thanks -3590 -dom -3591 -umn -3592 -▁Love -3593 -▁thank -3594 -▁bill -3595 -▁People -3596 -▁background -3597 -illa -3598 -rial -3599 -▁born -3600 -arily -3601 -▁girls -3602 -rig -3603 -▁Ev -3604 -▁Det -3605 -▁wedding -3606 -care -3607 -▁lots -3608 -▁damage -3609 -roid -3610 -▁Big -3611 -▁fat -3612 -▁pet -3613 -bl -3614 -ses -3615 -▁Ty -3616 -▁culture -3617 -▁replace -3618 -▁creative -3619 -▁internet -3620 -▁completed -3621 -▁assess -3622 -OL -3623 -▁Call -3624 -▁prec -3625 -aduate -3626 -atever -3627 -mod -3628 -que -3629 -▁Life -3630 -▁Team -3631 -▁wine -3632 -▁Company -3633 -▁husband -3634 -ij -3635 -▁coach -3636 -▁beyond -3637 -aith -3638 -▁cards -3639 -ipp -3640 -▁cash -3641 -▁Child -3642 -▁haven -3643 -▁altern -3644 -ota -3645 -▁Matt -3646 -▁guy -3647 -phone -3648 -▁depend -3649 -▁setting -3650 -leg -3651 -▁bul -3652 -▁Back -3653 -▁Show -3654 -▁miles -3655 -▁er -3656 -antly -3657 -force -3658 -▁transport -3659 -▁Management -3660 -ustain -3661 -body -3662 -ston -3663 -wise -3664 -▁emot -3665 -▁behav -3666 -▁driving -3667 -▁cream -3668 -▁response -3669 -iling -3670 -▁pred -3671 -▁estate -3672 -ously -3673 -het -3674 -▁USA -3675 -oving -3676 -isions -3677 -▁owner -3678 -▁Australia -3679 -friend -3680 -▁Pet -3681 -▁Sun -3682 -▁cho -3683 -error -3684 -▁Contact -3685 -izz -3686 -▁excited -3687 -▁selection -3688 -▁Ir -3689 -ales -3690 -anging -3691 -▁Ret -3692 -▁middle -3693 -▁efforts -3694 -▁particularly -3695 -▁Plan -3696 -▁Pal -3697 -itect -3698 -icks -3699 -▁Dri -3700 -▁helped -3701 -door -3702 -ustr -3703 -▁Lake -3704 -▁doub -3705 -▁colors -3706 -▁inform -3707 -▁Ve -3708 -aper -3709 -▁files -3710 -▁allowed -3711 -▁lines -3712 -▁existing -3713 -▁Bank -3714 -▁satis -3715 -▁patient -3716 -▁comfortable -3717 -istered -3718 -▁welcome -3719 -▁considered -3720 -▁responsible -3721 -▁clot -3722 -▁drop -3723 -▁truly -3724 -▁coffee -3725 -▁understanding -3726 -DA -3727 -▁plus -3728 -▁Govern -3729 -▁Thom -3730 -▁measure -3731 -set -3732 -▁economic -3733 -▁Yes -3734 -oming -3735 -▁frame -3736 -▁slight -3737 -▁journey -3738 -isl -3739 -▁Dec -3740 -▁indic -3741 -▁degree -3742 -▁ingred -3743 -▁himself -3744 -bon -3745 -▁purpose -3746 -▁tom -3747 -▁surv -3748 -▁changed -3749 -▁liter -3750 -▁mission -3751 -free -3752 -nown -3753 -ences -3754 -onstr -3755 -ona -3756 -▁Although -3757 -EM -3758 -▁pen -3759 -ologies -3760 -▁models -3761 -reed -3762 -▁train -3763 -▁winter -3764 -▁prot -3765 -▁stream -3766 -▁highest -3767 -ads -3768 -see -3769 -encies -3770 -▁prefer -3771 -▁seeing -3772 -▁strugg -3773 -▁evening -3774 -press -3775 -▁Take -3776 -▁artist -3777 -▁talking -3778 -OW -3779 -▁Camp -3780 -▁Phil -3781 -▁afford -3782 -▁Information -3783 -▁Str -3784 -▁sty -3785 -▁Smith -3786 -▁fashion -3787 -▁Republic -3788 -▁gun -3789 -▁disease -3790 -▁pool -3791 -▁absolute -3792 -OV -3793 -▁Sen -3794 -▁shopping -3795 -raw -3796 -oman -3797 -apter -3798 -▁River -3799 -▁Church -3800 -met -3801 -soft -3802 -▁Mart -3803 -▁lack -3804 -▁appoint -3805 -▁heavy -3806 -▁letter -3807 -rem -3808 -▁Color -3809 -▁British -3810 -▁daughter -3811 -▁fem -3812 -▁Rock -3813 -▁cast -3814 -▁brother -3815 -rey -3816 -▁Sing -3817 -▁flav -3818 -porary -3819 -▁occur -3820 -▁smooth -3821 -▁opin -3822 -▁increased -3823 -▁Jes -3824 -▁Music -3825 -▁moved -3826 -▁proud -3827 -▁couldn -3828 -▁launch -3829 -▁analysis -3830 -▁organizations -3831 -dd -3832 -▁PC -3833 -tion -3834 -▁mer -3835 -fit -3836 -▁links -3837 -gery -3838 -▁obt -3839 -▁Water -3840 -▁craft -3841 -▁church -3842 -▁compon -3843 -▁Blue -3844 -▁fill -3845 -▁rules -3846 -▁shared -3847 -▁spring -3848 -eria -3849 -uled -3850 -▁mail -3851 -▁Under -3852 -▁sched -3853 -▁Because -3854 -ronic -3855 -chan -3856 -▁Special -3857 -▁reviews -3858 -▁senior -3859 -▁hundred -3860 -IM -3861 -▁onto -3862 -▁whose -3863 -bed -3864 -▁Brown -3865 -net -3866 -▁fan -3867 -icing -3868 -▁Power -3869 -▁decor -3870 -▁secure -3871 -▁machine -3872 -imal -3873 -▁spread -3874 -▁u -3875 -▁frequ -3876 -▁score -3877 -ocolate -3878 -▁spirit -3879 -▁residents -3880 -amic -3881 -▁Hum -3882 -▁trade -3883 -▁science -3884 -vant -3885 -▁fra -3886 -▁Wood -3887 -▁appropri -3888 -▁officials -3889 -▁Sam -3890 -▁unit -3891 -▁died -3892 -hone -3893 -▁gone -3894 -▁manager -3895 -▁pressure -3896 -▁Like -3897 -▁challenges -3898 -TS -3899 -ady -3900 -▁clin -3901 -▁extend -3902 -▁instruct -3903 -▁dedicated -3904 -▁competition -3905 -▁Mount -3906 -▁Char -3907 -▁session -3908 -▁fant -3909 -▁Follow -3910 -▁happened -3911 -rian -3912 -▁Food -3913 -▁Mary -3914 -▁sort -3915 -ulated -3916 -▁initial -3917 -▁Fire -3918 -▁trou -3919 -▁Media -3920 -▁District -3921 -BA -3922 -icon -3923 -▁characters -3924 -▁basic -3925 -▁camera -3926 -▁holiday -3927 -azon -3928 -ategy -3929 -▁Enter -3930 -▁powerful -3931 -▁Institute -3932 -▁produce -3933 -▁beg -3934 -istics -3935 -▁Press -3936 -osition -3937 -▁dating -3938 -ette -3939 -asp -3940 -▁Hist -3941 -▁reasons -3942 -▁increasing -3943 -icken -3944 -▁shown -3945 -▁sugar -3946 -▁incred -3947 -▁extremely -3948 -▁rob -3949 -▁chem -3950 -▁Education -3951 -oos -3952 -▁AC -3953 -inese -3954 -▁volunte -3955 -▁disp -3956 -▁package -3957 -▁payment -3958 -RA -3959 -▁eval -3960 -▁guests -3961 -▁aren -3962 -▁snow -3963 -▁leader -3964 -▁biggest -3965 -▁TO -3966 -▁alone -3967 -▁object -3968 -▁proced -3969 -▁Sa -3970 -rowd -3971 -▁basis -3972 -▁disapp -3973 -▁supply -3974 -▁General -3975 -orney -3976 -▁Star -3977 -ifying -3978 -olic -3979 -▁laws -3980 -▁breat -3981 -▁graph -3982 -▁solid -3983 -▁forget -3984 -▁continues -3985 -LC -3986 -▁cars -3987 -▁guid -3988 -▁voice -3989 -▁experienced -3990 -▁Lou -3991 -▁mis -3992 -▁brows -3993 -rapy -3994 -▁arrest -3995 -▁passed -3996 -▁schedule -3997 -ken -3998 -omb -3999 -uing -4000 -▁egg -4001 -▁passion -4002 -▁dang -4003 -▁fear -4004 -▁guess -4005 -▁scene -4006 -esterday -4007 -BS -4008 -▁bur -4009 -▁steps -4010 -cel -4011 -▁Mal -4012 -▁beat -4013 -▁military -4014 -Sh -4015 -▁PR -4016 -▁Miss -4017 -gal -4018 -▁gra -4019 -▁names -4020 -▁approx -4021 -▁update -4022 -▁subst -4023 -▁During -4024 -▁protection -4025 -▁Att -4026 -▁Franc -4027 -▁French -4028 -annel -4029 -▁peace -4030 -▁conven -4031 -term -4032 -▁Who -4033 -▁ton -4034 -▁advantage -4035 -state -4036 -▁placed -4037 -▁Commission -4038 -▁pair -4039 -▁notice -4040 -▁strength -4041 -ero -4042 -What -4043 -incip -4044 -using -4045 -▁academ -4046 -▁Arch -4047 -▁epis -4048 -▁adding -4049 -▁waiting -4050 -▁although -4051 -ags -4052 -ideo -4053 -▁League -4054 -IV -4055 -▁Ben -4056 -clusive -4057 -▁Mot -4058 -▁reb -4059 -▁Alex -4060 -▁beauty -4061 -▁scient -4062 -ula -4063 -▁Dig -4064 -▁calls -4065 -▁relax -4066 -▁demonstr -4067 -▁regarding -4068 -amin -4069 -mark -4070 -ovel -4071 -▁income -4072 -▁covered -4073 -▁effects -4074 -ari -4075 -ixt -4076 -▁Sign -4077 -▁Online -4078 -uty -4079 -imin -4080 -▁copy -4081 -iverse -4082 -▁initi -4083 -▁experts -4084 -▁standards -4085 -▁technical -4086 -ros -4087 -okes -4088 -▁Atl -4089 -▁Vol -4090 -ading -4091 -▁manage -4092 -▁Chic -4093 -▁knows -4094 -▁winning -4095 -▁hospital -4096 -▁certainly -4097 -▁Real -4098 -▁batter -4099 -▁workers -4100 -▁connection -4101 -osh -4102 -▁compared -4103 -As -4104 -oe -4105 -▁RE -4106 -▁hom -4107 -ga -4108 -oop -4109 -▁Ins -4110 -▁Form -4111 -▁Development -4112 -▁wild -4113 -▁dinner -4114 -▁fabric -4115 -▁associated -4116 -▁experiences -4117 -▁Pay -4118 -▁doctor -4119 -▁master -4120 -▁cit -4121 -▁cru -4122 -▁wat -4123 -ograp -4124 -▁vote -4125 -▁posts -4126 -▁finding -4127 -▁Foundation -4128 -▁opened -4129 -▁Profess -4130 -▁reflect -4131 -IG -4132 -▁Carol -4133 -amm -4134 -▁audience -4135 -▁friendly -4136 -cell -4137 -unning -4138 -atically -4139 -mail -4140 -ctors -4141 -▁surface -4142 -▁den -4143 -▁Science -4144 -▁pm -4145 -▁Cap -4146 -itude -4147 -▁trail -4148 -▁artists -4149 -▁traffic -4150 -▁critical -4151 -▁communities -4152 -AA -4153 -uce -4154 -▁NY -4155 -▁Valley -4156 -works -4157 -▁remind -4158 -▁victim -4159 -▁Step -4160 -▁salt -4161 -▁followed -4162 -la -4163 -well -4164 -▁Rad -4165 -iques -4166 -▁Elect -4167 -▁football -4168 -tr -4169 -aming -4170 -▁electric -4171 -aven -4172 -▁Beach -4173 -▁facility -4174 -▁cry -4175 -gency -4176 -▁Disc -4177 -▁keeping -4178 -▁meaning -4179 -▁luck -4180 -▁pros -4181 -▁figure -4182 -▁learned -4183 -yer -4184 -ander -4185 -ulate -4186 -▁tickets -4187 -▁professionals -4188 -antic -4189 -▁laun -4190 -▁taste -4191 -▁instit -4192 -gen -4193 -▁bright -4194 -ech -4195 -arge -4196 -▁produced -4197 -▁watching -4198 -▁flex -4199 -▁catch -4200 -▁monitor -4201 -▁contains -4202 -lor -4203 -▁ter -4204 -There -4205 -ooper -4206 -▁entry -4207 -▁Project -4208 -▁Society -4209 -▁classic -4210 -▁department -4211 -edy -4212 -itar -4213 -▁diagn -4214 -▁lock -4215 -▁classes -4216 -rees -4217 -▁closed -4218 -▁starts -4219 -▁continued -4220 -▁dire -4221 -▁jump -4222 -▁awesome -4223 -▁kept -4224 -▁bought -4225 -▁listed -4226 -▁Christian -4227 -▁Wil -4228 -osure -4229 -▁Whether -4230 -▁neighbor -4231 -▁selected -4232 -▁Town -4233 -▁explore -4234 -▁testing -4235 -▁harm -4236 -▁Date -4237 -▁larger -4238 -▁videos -4239 -▁Another -4240 -▁presented -4241 -fast -4242 -▁Ber -4243 -▁ice -4244 -▁Times -4245 -▁transfer -4246 -▁thousands -4247 -▁developing -4248 -fin -4249 -▁capital -4250 -▁OF -4251 -iller -4252 -▁teaching -4253 -▁Mel -4254 -▁Nov -4255 -▁Long -4256 -▁force -4257 -▁grant -4258 -▁minute -4259 -▁talent -4260 -▁established -4261 -▁fol -4262 -▁Hill -4263 -▁desk -4264 -standing -4265 -▁England -4266 -▁AP -4267 -enses -4268 -▁announce -4269 -▁exciting -4270 -end -4271 -▁Vir -4272 -acity -4273 -▁Family -4274 -▁street -4275 -▁furn -4276 -▁facilities -4277 -▁Jim -4278 -▁brings -4279 -▁Tim -4280 -▁buying -4281 -▁records -4282 -▁articles -4283 -gn -4284 -▁sto -4285 -▁drug -4286 -▁ideal -4287 -▁library -4288 -▁requires -4289 -noon -4290 -itors -4291 -enance -4292 -▁Scott -4293 -▁micro -4294 -▁Chicago -4295 -win -4296 -rief -4297 -▁sup -4298 -▁rich -4299 -▁virt -4300 -▁novel -4301 -▁Chinese -4302 -▁sharing -4303 -▁updated -4304 -▁mo -4305 -part -4306 -sequ -4307 -▁Start -4308 -▁butter -4309 -▁driver -4310 -▁greater -4311 -riage -4312 -▁Sand -4313 -▁ship -4314 -▁crowd -4315 -▁wouldn -4316 -▁restaurant -4317 -imb -4318 -▁ir -4319 -lands -4320 -▁vision -4321 -▁Note -4322 -▁Exper -4323 -▁ingredients -4324 -ray -4325 -unately -4326 -▁List -4327 -▁poor -4328 -▁Stand -4329 -▁studies -4330 -▁Cup -4331 -overy -4332 -▁loan -4333 -▁Build -4334 -▁Grand -4335 -▁handle -4336 -▁plenty -4337 -▁resident -4338 -outs -4339 -▁bird -4340 -illage -4341 -ka -4342 -▁tree -4343 -▁economy -4344 -▁Central -4345 -▁leaving -4346 -▁serving -4347 -▁Div -4348 -▁sem -4349 -▁Support -4350 -SP -4351 -word -4352 -▁Mex -4353 -iture -4354 -▁beach -4355 -▁famous -4356 -ini -4357 -inn -4358 -▁Mil -4359 -lastname -4360 -▁manufacturer -4361 -▁faith -4362 -▁rooms -4363 -▁shall -4364 -▁recipe -4365 -▁Congress -4366 -CH -4367 -▁station -4368 -UR -4369 -▁react -4370 -▁shape -4371 -pective -4372 -▁origin -4373 -night -4374 -▁Amazon -4375 -▁injury -4376 -▁missing -4377 -reek -4378 -semb -4379 -▁Sil -4380 -▁upgr -4381 -▁Social -4382 -do -4383 -▁Pub -4384 -isher -4385 -▁motor -4386 -▁claims -4387 -▁medium -4388 -▁Bill -4389 -▁Posted -4390 -▁orders -4391 -▁maintain -4392 -rd -4393 -▁Fun -4394 -asure -4395 -▁brain -4396 -▁notes -4397 -▁views -4398 -▁Download -4399 -▁appropriate -4400 -▁boo -4401 -ishes -4402 -point -4403 -▁Offic -4404 -▁meant -4405 -▁older -4406 -▁spons -4407 -▁window -4408 -▁sustain -4409 -atab -4410 -▁Jesus -4411 -▁signed -4412 -berg -4413 -▁remove -4414 -cks -4415 -▁ended -4416 -▁changing -4417 -▁strategy -4418 -fr -4419 -cles -4420 -look -4421 -▁map -4422 -▁Union -4423 -outhern -4424 -▁happens -4425 -▁efficient -4426 -▁uns -4427 -going -4428 -▁advance -4429 -▁journal -4430 -ervation -4431 -▁plastic -4432 -▁Fore -4433 -▁stores -4434 -▁independent -4435 -▁iPhone -4436 -iest -4437 -▁useful -4438 -top -4439 -▁CD -4440 -umber -4441 -▁Organ -4442 -▁forms -4443 -▁leaves -4444 -▁Jul -4445 -craft -4446 -▁Light -4447 -▁Academ -4448 -acks -4449 -▁Award -4450 -▁advent -4451 -no -4452 -▁sand -4453 -▁shut -4454 -rehens -4455 -▁agency -4456 -▁repair -4457 -▁evidence -4458 -▁spending -4459 -▁afternoon -4460 -▁tim -4461 -apers -4462 -odes -4463 -rooms -4464 -▁throw -4465 -▁AND -4466 -▁menu -4467 -essions -4468 -▁secret -4469 -▁whatever -4470 -▁Fil -4471 -▁fee -4472 -estic -4473 -iliar -4474 -▁core -4475 -▁pray -4476 -▁sport -4477 -▁operations -4478 -▁combination -4479 -allery -4480 -▁Chris -4481 -▁Before -4482 -▁helpful -4483 -▁reality -4484 -atively -4485 -▁Where -4486 -▁multi -4487 -▁district -4488 -▁prepared -4489 -men -4490 -oyal -4491 -eless -4492 -icted -4493 -▁Week -4494 -▁cris -4495 -▁cab -4496 -ption -4497 -▁adop -4498 -▁tend -4499 -▁Democr -4500 -▁Series -4501 -▁status -4502 -▁balance -4503 -▁Mad -4504 -▁YOU -4505 -▁scen -4506 -▁estim -4507 -alls -4508 -▁flu -4509 -▁Both -4510 -▁flat -4511 -▁Author -4512 -▁joined -4513 -▁designs -4514 -▁remains -4515 -▁ID -4516 -▁Los -4517 -▁ride -4518 -▁corner -4519 -▁rank -4520 -▁eating -4521 -▁memory -4522 -Cl -4523 -mp -4524 -itz -4525 -▁Bet -4526 -▁Mont -4527 -▁caused -4528 -▁operating -4529 -▁Ma -4530 -aser -4531 -▁mist -4532 -▁George -4533 -▁discount -4534 -▁slightly -4535 -▁teachers -4536 -eed -4537 -▁IP -4538 -▁Women -4539 -▁esc -4540 -▁perhaps -4541 -▁primary -4542 -▁numerous -4543 -hem -4544 -▁funds -4545 -▁worry -4546 -▁survey -4547 -▁winner -4548 -▁enjoyed -4549 -▁showing -4550 -▁exercise -4551 -een -4552 -▁unc -4553 -▁Card -4554 -▁fourth -4555 -▁showed -4556 -▁spl -4557 -uries -4558 -▁anti -4559 -▁Francis -4560 -▁surgery -4561 -▁becoming -4562 -▁properties -4563 -pan -4564 -▁gain -4565 -▁recip -4566 -▁veget -4567 -▁Engine -4568 -▁markets -4569 -▁obvious -4570 -▁committed -4571 -▁suff -4572 -▁theme -4573 -▁focused -4574 -vere -4575 -▁plants -4576 -▁direction -4577 -ius -4578 -▁Tor -4579 -▁listen -4580 -▁managed -4581 -▁kick -4582 -iences -4583 -▁forum -4584 -▁chocolate -4585 -▁shel -4586 -▁limit -4587 -gers -4588 -lets -4589 -iency -4590 -▁legisl -4591 -aked -4592 -▁Its -4593 -▁Jun -4594 -▁busy -4595 -▁rain -4596 -issions -4597 -▁mechan -4598 -▁movement -4599 -▁encourage -4600 -▁rap -4601 -▁cloud -4602 -▁resist -4603 -▁putting -4604 -▁communication -4605 -OP -4606 -cher -4607 -▁bon -4608 -▁Their -4609 -▁raised -4610 -▁animals -4611 -▁assistance -4612 -?? -4613 -obe -4614 -oles -4615 -▁Bob -4616 -▁CEO -4617 -▁Full -4618 -▁Frank -4619 -▁lunch -4620 -▁defense -4621 -ita -4622 -▁analy -4623 -▁relig -4624 -life -4625 -rael -4626 -▁poll -4627 -▁corporate -4628 -▁practices -4629 -▁Technology -4630 -”, -4631 -itness -4632 -▁discover -4633 -▁Microsoft -4634 -", -4635 -gl -4636 -!!! -4637 -▁Mike -4638 -▁civil -4639 -▁reached -4640 -▁sources -4641 -bert -4642 -▁util -4643 -igation -4644 -vention -4645 -▁society -4646 -▁yesterday -4647 -orter -4648 -▁mill -4649 -▁chair -4650 -▁Wr -4651 -▁scr -4652 -▁youth -4653 -▁central -4654 -abilities -4655 -▁advanced -4656 -▁Ham -4657 -▁cart -4658 -▁architect -4659 -▁determine -4660 -REE -4661 -▁Fort -4662 -arrant -4663 -▁cleaning -4664 -▁vehicles -4665 -▁firstname -4666 -ena -4667 -ror -4668 -west -4669 -▁Tri -4670 -▁tea -4671 -▁dete -4672 -▁rare -4673 -▁AS -4674 -▁NOT -4675 -▁Mass -4676 -▁actual -4677 -yan -4678 -▁psych -4679 -▁Robert -4680 -▁tables -4681 -▁worksh -4682 -▁methods -4683 -▁leadership -4684 -▁Bur -4685 -▁ath -4686 -▁structure -4687 -kin -4688 -▁vs -4689 -▁pock -4690 -aturing -4691 -▁Commit -4692 -CC -4693 -MS -4694 -iled -4695 -▁Log -4696 -▁Set -4697 -▁fell -4698 -▁register -4699 -?” -4700 -▁repe -4701 -▁battle -4702 -▁format -4703 -▁becomes -4704 -▁willing -4705 -bre -4706 -ifts -4707 -▁colle -4708 -▁charges -4709 -▁funding -4710 -▁updates -4711 -▁thoughts -4712 -▁ju -4713 -▁Tre -4714 -ordin -4715 -▁toward -4716 -▁appears -4717 -▁visitors -4718 -▁fees -4719 -▁incor -4720 -▁sector -4721 -▁Copyright -4722 -▁absolutely -4723 -▁temperature -4724 -▁lose -4725 -▁locations -4726 -▁Keep -4727 -▁Next -4728 -▁colour -4729 -▁filled -4730 -▁songs -4731 -▁Network -4732 -▁Old -4733 -▁instru -4734 -levision -4735 -▁Wall -4736 -▁Trump -4737 -▁brown -4738 -▁Spring -4739 -▁century -4740 -▁extensive -4741 -▁Conference -4742 -kins -4743 -▁Land -4744 -▁Learn -4745 -▁Louis -4746 -▁asking -4747 -▁environmental -4748 -ola -4749 -ship -4750 -▁Way -4751 -▁topic -4752 -▁favour -4753 -▁transl -4754 -▁courses -4755 -▁profile -4756 -▁AL -4757 -▁Ol -4758 -while -4759 -▁Test -4760 -▁south -4761 -▁dur -4762 -▁Medic -4763 -▁Report -4764 -▁documents -4765 -▁previously -4766 -coh -4767 -▁Dou -4768 -▁Oper -4769 -▁adapt -4770 -▁north -4771 -ception -4772 -ipl -4773 -▁Plus -4774 -▁bowl -4775 -▁swim -4776 -ivered -4777 -▁guest -4778 -▁refer -4779 -▁visual -4780 -▁readers -4781 -▁anywhere -4782 -▁kid -4783 -▁registered -4784 -otton -4785 -▁Jeff -4786 -▁France -4787 -For -4788 -▁Cre -4789 -▁Lim -4790 -▁lux -4791 -▁sch -4792 -▁polic -4793 -▁charged -4794 -▁expertise -4795 -New -4796 -water -4797 -▁task -4798 -iration -4799 -▁upcoming -4800 -▁UN -4801 -▁wire -4802 -▁allowing -4803 -FL -4804 -▁Ok -4805 -▁selling -4806 -po -4807 -bour -4808 -▁bask -4809 -▁recommended -4810 -▁stre -4811 -▁Hotel -4812 -▁plays -4813 -▁Android -4814 -▁coverage -4815 -icip -4816 -▁Lat -4817 -▁fuel -4818 -▁neck -4819 -▁audio -4820 -▁sounds -4821 -▁Library -4822 -▁population -4823 -list -4824 -umin -4825 -▁Only -4826 -▁Conne -4827 -▁featured -4828 -▁Saf -4829 -▁pal -4830 -▁joint -4831 -▁Medical -4832 -▁princip -4833 -▁smaller -4834 -▁walking -4835 -▁ur -4836 -ulty -4837 -▁thr -4838 -▁Prov -4839 -▁seat -4840 -▁mental -4841 -▁establish -4842 -▁discussion -4843 -▁Jew -4844 -▁tun -4845 -▁apart -4846 -▁trial -4847 -▁parties -4848 -▁NE -4849 -istan -4850 -▁dance -4851 -ferences -4852 -IA -4853 -azz -4854 -ora -4855 -osis -4856 -▁Somet -4857 -▁Watch -4858 -igan -4859 -prise -4860 -▁Main -4861 -▁dogs -4862 -▁radio -4863 -▁despite -4864 -On -4865 -▁Lord -4866 -▁Walk -4867 -▁fold -4868 -▁truck -4869 -▁Africa -4870 -▁Virgin -4871 -▁scheduled -4872 -▁maintenance -4873 -▁Head -4874 -▁inspired -4875 -▁ON -4876 -▁diet -4877 -▁nine -4878 -▁restr -4879 -SA -4880 -▁writer -4881 -▁outdoor -4882 -▁Security -4883 -▁accommod -4884 -▁combined -4885 -▁van -4886 -ki -4887 -▁CA -4888 -▁har -4889 -▁citiz -4890 -▁scored -4891 -aks -4892 -alog -4893 -▁Western -4894 -rehensive -4895 -▁techniques -4896 -OO -4897 -▁Game -4898 -▁Admin -4899 -▁decide -4900 -▁seconds -4901 -▁Soft -4902 -▁Museum -4903 -▁values -4904 -▁removed -4905 -▁provider -4906 -▁sav -4907 -▁earth -4908 -▁raise -4909 -▁accompl -4910 -ownt -4911 -▁metal -4912 -▁stret -4913 -▁researc -4914 -eal -4915 -▁Place -4916 -▁spect -4917 -▁elements -4918 -▁purchased -4919 -▁joy -4920 -▁calc -4921 -▁purs -4922 -▁trees -4923 -▁launched -4924 -zen -4925 -▁Hy -4926 -▁Mer -4927 -▁sea -4928 -▁honest -4929 -▁movies -4930 -▁innovative -4931 -An -4932 -IF -4933 -▁panel -4934 -idering -4935 -▁counter -4936 -▁shooting -4937 -▁delicious -4938 -▁approximately -4939 -▁sitting -4940 -gment -4941 -▁killed -4942 -▁separate -4943 -▁edge -4944 -▁Video -4945 -▁Digital -4946 -▁teacher -4947 -▁relevant -4948 -ano -4949 -▁matt -4950 -▁approved -4951 -gage -4952 -▁lovely -4953 -▁parking -4954 -▁consumers -4955 -▁executive -4956 -My -4957 -nel -4958 -van -4959 -▁steel -4960 -▁Israel -4961 -▁Angeles -4962 -▁Manager -4963 -▁magazine -4964 -rs -4965 -ye -4966 -orry -4967 -▁hearing -4968 -▁concerns -4969 -bu -4970 -appy -4971 -igned -4972 -ushed -4973 -▁Charl -4974 -▁Person -4975 -pet -4976 -ellig -4977 -known -4978 -▁chat -4979 -▁conv -4980 -▁Georg -4981 -▁Peter -4982 -ensions -4983 -▁mostly -4984 -▁agreement -4985 -ears -4986 -▁eth -4987 -▁milk -4988 -▁rise -4989 -▁occasion -4990 -ups -4991 -▁Aud -4992 -▁tow -4993 -olars -4994 -▁Cook -4995 -▁Data -4996 -▁Join -4997 -isation -4998 -▁cheese -4999 -▁highlight -5000 -▁generation -5001 -VD -5002 -▁Ext -5003 -▁Ill -5004 -▁Penn -5005 -▁Word -5006 -▁Const -5007 -osit -5008 -▁mur -5009 -▁rid -5010 -▁Room -5011 -▁Thomas -5012 -▁identify -5013 -▁Gal -5014 -▁Pac -5015 -▁Centre -5016 -▁connected -5017 -▁intended -5018 -▁appearance -5019 -TV -5020 -fol -5021 -ring -5022 -orthern -5023 -▁controll -5024 -PA -5025 -ris -5026 -apes -5027 -▁sets -5028 -▁Prote -5029 -▁feels -5030 -▁waste -5031 -▁described -5032 -▁operation -5033 -▁commitment -5034 -▁Mo -5035 -▁Ver -5036 -irmed -5037 -▁truth -5038 -▁Master -5039 -▁academic -5040 -▁delivered -5041 -▁participate -5042 -cm -5043 -▁sympt -5044 -▁Through -5045 -ournament -5046 -!) -5047 -ENT -5048 -▁Men -5049 -oston -5050 -▁Lead -5051 -▁push -5052 -▁stars -5053 -▁Indust -5054 -▁Invest -5055 -▁server -5056 -▁Children -5057 -▁familiar -5058 -▁marriage -5059 -osen -5060 -▁Bas -5061 -▁nom -5062 -▁Arts -5063 -▁tough -5064 -▁enhance -5065 -▁capacity -5066 -▁relationships -5067 -UT -5068 -ycl -5069 -▁Upd -5070 -reens -5071 -▁cooking -5072 -▁promote -5073 -den -5074 -elines -5075 -▁landsc -5076 -ker -5077 -alend -5078 -nergy -5079 -▁cells -5080 -▁campus -5081 -▁editor -5082 -mond -5083 -▁mort -5084 -▁optim -5085 -▁cities -5086 -▁Journal -5087 -▁decisions -5088 -▁generally -5089 -▁Fair -5090 -▁signs -5091 -▁Access -5092 -▁wearing -5093 -▁therefore -5094 -▁introduced -5095 -arsh -5096 -berry -5097 -▁Vict -5098 -▁breast -5099 -▁accident -5100 -▁properly -5101 -▁processes -5102 -▁Er -5103 -prene -5104 -▁educational -5105 -▁Ul -5106 -▁Cam -5107 -cohol -5108 -eline -5109 -▁situ -5110 -▁majority -5111 -▁investigation -5112 -anda -5113 -inch -5114 -▁jew -5115 -▁minor -5116 -ya -5117 -burg -5118 -▁arm -5119 -ishing -5120 -▁opinion -5121 -▁detailed -5122 -▁Government -5123 -▁Dev -5124 -▁fly -5125 -▁Hand -5126 -▁Rest -5127 -reprene -5128 -▁technologies -5129 -▁teen -5130 -▁Chief -5131 -▁Earth -5132 -atabase -5133 -▁Global -5134 -▁minimum -5135 -▁category -5136 -▁presence -5137 -IR -5138 -▁Lab -5139 -▁ban -5140 -▁Live -5141 -▁label -5142 -▁calling -5143 -▁returned -5144 -▁emergency -5145 -▁expensive -5146 -▁mentioned -5147 -ef -5148 -▁Tur -5149 -▁feedback -5150 -fortunately -5151 -▁responsibility -5152 -▁Ari -5153 -▁Fund -5154 -▁Ohio -5155 -▁Wild -5156 -ression -5157 -▁Committee -5158 -▁installed -5159 -DF -5160 -▁Mur -5161 -▁ring -5162 -▁square -5163 -▁Johnson -5164 -▁foreign -5165 -▁bringing -5166 -▁hundreds -5167 -▁websites -5168 -▁Americans -5169 -▁installation -5170 -col -5171 -▁Que -5172 -▁plug -5173 -▁female -5174 -▁ourselves -5175 -rag -5176 -razy -5177 -▁Boston -5178 -▁entertainment -5179 -otten -5180 -ternal -5181 -▁invent -5182 -▁arrange -5183 -▁behavior -5184 -▁exchange -5185 -▁performed -5186 -▁episode -5187 -▁factors -5188 -▁consumer -5189 -▁advertising -5190 -ien -5191 -▁Pack -5192 -▁sizes -5193 -▁begins -5194 -▁satisf -5195 -hab -5196 -text -5197 -▁appeared -5198 -▁Di -5199 -▁Kn -5200 -aded -5201 -▁brief -5202 -▁sides -5203 -▁veter -5204 -▁Squ -5205 -▁flo -5206 -▁teach -5207 -▁units -5208 -▁studio -5209 -uts -5210 -▁Den -5211 -▁coast -5212 -ictions -5213 -emporary -5214 -▁MP -5215 -rist -5216 -▁Adv -5217 -▁Sup -5218 -▁Human -5219 -▁Federal -5220 -AY -5221 -▁elig -5222 -▁icon -5223 -▁tight -5224 -▁caught -5225 -▁transform -5226 -▁confidence -5227 -icians -5228 -▁chief -5229 -▁sauce -5230 -▁thick -5231 -ae -5232 -When -5233 -iser -5234 -▁Tour -5235 -▁fruit -5236 -▁Colorado -5237 -▁honor -5238 -▁holding -5239 -▁reserved -5240 -lock -5241 -▁Wal -5242 -▁Those -5243 -▁adults -5244 -▁topics -5245 -▁policies -5246 -▁supporting -5247 -spe -5248 -uke -5249 -▁https -5250 -▁Contin -5251 -▁ven -5252 -OC -5253 -hew -5254 -cean -5255 -▁alle -5256 -▁meat -5257 -▁ment -5258 -▁achie -5259 -▁chicken -5260 -▁windows -5261 -▁confident -5262 -▁HD -5263 -acle -5264 -▁vary -5265 -▁Price -5266 -rastructure -5267 -▁administration -5268 -▁Pan -5269 -▁motiv -5270 -▁animal -5271 -ifications -5272 -▁supported -5273 -with -5274 -▁Jud -5275 -▁cro -5276 -▁fantastic -5277 -ushing -5278 -▁mouth -5279 -▁sexual -5280 -▁seeking -5281 -SS -5282 -▁meal -5283 -▁Creat -5284 -▁alternative -5285 -arp -5286 -iat -5287 -arks -5288 -oted -5289 -▁Maybe -5290 -▁victory -5291 -ait -5292 -how -5293 -▁Bi -5294 -▁Search -5295 -▁Carolina -5296 -▁Australian -5297 -kes -5298 -ancer -5299 -▁Germany -5300 -▁components -5301 -▁importance -5302 -▁competitive -5303 -vy -5304 -▁sy -5305 -▁Prem -5306 -▁quiet -5307 -▁basket -5308 -▁edition -5309 -paper -5310 -▁tele -5311 -▁sister -5312 -▁dollars -5313 -rier -5314 -▁cheap -5315 -▁leads -5316 -▁thread -5317 -▁apparent -5318 -ste -5319 -▁Jon -5320 -▁rom -5321 -▁rub -5322 -unting -5323 -▁Canad -5324 -▁Sports -5325 -▁switch -5326 -▁guarantee -5327 -▁Academy -5328 -▁conduct -5329 -▁confirm -5330 -▁transact -5331 -▁conversation -5332 -inct -5333 -▁Lin -5334 -ighter -5335 -▁distance -5336 -▁Tit -5337 -▁Young -5338 -▁recru -5339 -▁centre -5340 -▁measures -5341 -▁worldwide -5342 -Com -5343 -▁Gar -5344 -▁Gen -5345 -▁info -5346 -▁Festival -5347 -▁Students -5348 -.| -5349 -etic -5350 -▁Bal -5351 -▁fif -5352 -▁picked -5353 -iability -5354 -▁remaining -5355 -▁photograph -5356 -weet -5357 -▁Jose -5358 -weight -5359 -▁bread -5360 -▁license -5361 -away -5362 -ucks -5363 -▁impl -5364 -▁flight -5365 -▁totally -5366 -▁Nor -5367 -▁rat -5368 -▁Meet -5369 -▁doubt -5370 -▁prison -5371 -▁unless -5372 -▁tack -5373 -▁Martin -5374 -inations -5375 -NA -5376 -atre -5377 -▁Sar -5378 -▁ang -5379 -▁vir -5380 -achel -5381 -uable -5382 -▁species -5383 -How -5384 -elly -5385 -ersey -5386 -▁restaurants -5387 -▁comprehensive -5388 -asks -5389 -▁seek -5390 -▁doors -5391 -▁contest -5392 -▁agencies -5393 -ailability -5394 -▁Champions -5395 -iano -5396 -verse -5397 -▁Quest -5398 -▁tests -5399 -▁faster -5400 -▁delight -5401 -▁maximum -5402 -▁celebrate -5403 -uzz -5404 -eries -5405 -▁league -5406 -▁clearly -5407 -▁musical -5408 -▁visiting -5409 -▁photograp -5410 -RC -5411 -TH -5412 -Our -5413 -▁Type -5414 -▁forg -5415 -itable -5416 -▁depart -5417 -▁painting -5418 -▁eventually -5419 -pass -5420 -▁Did -5421 -▁dyn -5422 -▁wel -5423 -estyle -5424 -▁noted -5425 -▁planned -5426 -▁election -5427 -▁revealed -5428 -▁considering -5429 -TC -5430 -otic -5431 -▁Inte -5432 -▁propos -5433 -▁prepare -5434 -▁depending -5435 -▁Cred -5436 -▁Using -5437 -▁Energy -5438 -▁arrived -5439 -▁housing -5440 -▁married -5441 -▁university -5442 -igr -5443 -▁Ro -5444 -usion -5445 -▁burn -5446 -▁lived -5447 -▁ticket -5448 -▁Hospital -5449 -▁bike -5450 -▁mine -5451 -▁Jackson -5452 -▁sessions -5453 -erg -5454 -▁Ce -5455 -▁inn -5456 -iminal -5457 -ixture -5458 -orough -5459 -▁scale -5460 -▁Assist -5461 -▁SP -5462 -wing -5463 -▁McC -5464 -▁ign -5465 -▁ris -5466 -ulous -5467 -▁FREE -5468 -▁apps -5469 -▁otherwise -5470 -▁discovered -5471 -▁Mid -5472 -▁Cost -5473 -▁compar -5474 -▁gather -5475 -▁officer -5476 -mes -5477 -▁Secret -5478 -▁climate -5479 -▁monthly -5480 -▁Japanese -5481 -▁chemical -5482 -▁neighborhood -5483 -▁boys -5484 -▁ends -5485 -▁liqu -5486 -▁evalu -5487 -▁turns -5488 -▁inches -5489 -▁spokes -5490 -▁struct -5491 -▁commission -5492 -▁Kore -5493 -▁weap -5494 -▁symptoms -5495 -ht -5496 -▁Bul -5497 -▁Cat -5498 -agram -5499 -▁freed -5500 -▁missed -5501 -▁cutting -5502 -▁accounts -5503 -▁internal -5504 -▁reliable -5505 -ias -5506 -▁ran -5507 -tered -5508 -▁pump -5509 -▁surf -5510 -related -5511 -▁brands -5512 -▁lights -5513 -▁seemed -5514 -▁appreciate -5515 -▁participants -5516 -otes -5517 -alian -5518 -▁Know -5519 -▁battery -5520 -▁organic -5521 -▁affordable -5522 -edia -5523 -▁hyd -5524 -▁Cert -5525 -▁corn -5526 -▁twice -5527 -▁Applic -5528 -▁Columb -5529 -▁Georgia -5530 -▁cultural -5531 -▁resource -5532 -▁featuring -5533 -hi -5534 -▁Second -5535 -▁automatically -5536 -They -5537 -ician -5538 -▁valid -5539 -▁athlet -5540 -▁paying -5541 -▁submit -5542 -▁African -5543 -▁meetings -5544 -iors -5545 -▁Code -5546 -▁Jones -5547 -▁Andrew -5548 -EE -5549 -▁emp -5550 -▁Share -5551 -▁bigger -5552 -▁regularly -5553 -); -5554 -Ex -5555 -but -5556 -▁Hard -5557 -▁Qual -5558 -▁debt -5559 -▁Middle -5560 -▁failed -5561 -▁supposed -5562 -▁Ep -5563 -▁Help -5564 -▁Steve -5565 -▁storm -5566 -▁accurate -5567 -▁possibly -5568 -GB -5569 -ua -5570 -ban -5571 -▁mel -5572 -▁pod -5573 -▁boost -5574 -▁deals -5575 -▁labor -5576 -▁volume -5577 -▁television -5578 -▁presentation -5579 -cont -5580 -▁fro -5581 -▁draft -5582 -▁fellow -5583 -▁realize -5584 -▁manufacturing -5585 -Pro -5586 -▁Ut -5587 -▁fle -5588 -▁Daniel -5589 -▁concent -5590 -▁Virginia -5591 -▁messages -5592 -?" -5593 -▁SH -5594 -ennis -5595 -idden -5596 -pected -5597 -▁fields -5598 -▁revenue -5599 -▁affected -5600 -▁recovery -5601 -EST -5602 -rupt -5603 -▁Boy -5604 -▁Blog -5605 -▁German -5606 -▁covers -5607 -▁shares -5608 -▁proposed -5609 -▁researchers -5610 -No -5611 -roy -5612 -eper -5613 -mosp -5614 -▁die -5615 -rical -5616 -▁Page -5617 -iamond -5618 -alendar -5619 -oration -5620 -▁Rights -5621 -ployment -5622 -▁returns -5623 -▁engineering -5624 -▁Lee -5625 -▁Tem -5626 -▁Farm -5627 -▁Travel -5628 -▁birthday -5629 -▁AD -5630 -case -5631 -▁Rom -5632 -▁aid -5633 -▁ages -5634 -▁Little -5635 -▁confirmed -5636 -▁instructions -5637 -▁amb -5638 -cious -5639 -▁Cast -5640 -▁Trust -5641 -▁dates -5642 -▁tells -5643 -▁answers -5644 -▁creation -5645 -▁interior -5646 -▁protected -5647 -ca -5648 -ters -5649 -▁Tech -5650 -▁breakfast -5651 -▁sad -5652 -▁wal -5653 -▁dish -5654 -▁chart -5655 -▁warrant -5656 -▁industrial -5657 -▁infrastructure -5658 -iner -5659 -▁nor -5660 -which -5661 -▁Orig -5662 -▁Games -5663 -▁Visit -5664 -▁loves -5665 -▁Mexico -5666 -▁county -5667 -▁applied -5668 -▁browser -5669 -▁employee -5670 -ario -5671 -▁nurs -5672 -▁agent -5673 -▁pregn -5674 -▁specifically -5675 -▁Opt -5676 -▁mir -5677 -▁poly -5678 -▁route -5679 -▁desire -5680 -▁issued -5681 -▁choices -5682 -▁decades -5683 -▁drivers -5684 -▁NC -5685 -▁Hen -5686 -▁hook -5687 -▁rapid -5688 -▁furniture -5689 -▁chain -5690 -▁foods -5691 -fection -5692 -▁flowers -5693 -▁reference -5694 -▁twe -5695 -▁hero -5696 -▁jack -5697 -▁affili -5698 -▁element -5699 -▁perfectly -5700 -▁WH -5701 -gend -5702 -▁Joe -5703 -erves -5704 -▁thus -5705 -lights -5706 -▁attorney -5707 -▁standing -5708 -▁exclusive -5709 -ansas -5710 -▁tail -5711 -▁plate -5712 -▁chosen -5713 -▁earned -5714 -▁supports -5715 -upp -5716 -▁CH -5717 -▁anc -5718 -▁yes -5719 -anger -5720 -odies -5721 -▁Made -5722 -▁bond -5723 -▁Broad -5724 -▁talks -5725 -▁Control -5726 -▁Francisco -5727 -▁employment -5728 -hand -5729 -rick -5730 -▁Ken -5731 -hetic -5732 -oking -5733 -▁mode -5734 -▁vent -5735 -▁Brand -5736 -▁remote -5737 -ibilities -5738 -▁Executive -5739 -anna -5740 -irms -5741 -▁Dom -5742 -▁End -5743 -ospit -5744 -▁Enjoy -5745 -▁agreed -5746 -▁purposes -5747 -▁apartment -5748 -▁incredible -5749 -Al -5750 -▁AT -5751 -▁Lo -5752 -lymp -5753 -▁Bon -5754 -▁wid -5755 -▁Expl -5756 -▁broken -5757 -▁improved -5758 -▁strategies -5759 -UN -5760 -can -5761 -▁DVD -5762 -▁nav -5763 -▁Does -5764 -▁logo -5765 -▁Store -5766 -▁Williams -5767 -▁processing -5768 -▁Hope -5769 -▁Pass -5770 -▁Sher -5771 -▁Current -5772 -▁illustr -5773 -▁hardware -5774 -▁surrounding -5775 -▁Sy -5776 -anges -5777 -▁cake -5778 -▁cute -5779 -▁whom -5780 -▁advis -5781 -▁Product -5782 -▁recorded -5783 -▁disappoint -5784 -BI -5785 -MA -5786 -▁Id -5787 -ench -5788 -hent -5789 -▁Equ -5790 -▁Haw -5791 -▁lit -5792 -▁Coast -5793 -▁quant -5794 -▁reput -5795 -▁rough -5796 -▁premium -5797 -aped -5798 -▁Mic -5799 -adium -5800 -▁golf -5801 -ampion -5802 -▁holds -5803 -▁judge -5804 -▁pleased -5805 -▁accepted -5806 -▁suitable -5807 -umes -5808 -idays -5809 -▁boat -5810 -▁Point -5811 -▁downt -5812 -▁losing -5813 -▁Instead -5814 -▁male -5815 -▁pure -5816 -▁grade -5817 -▁trouble -5818 -uous -5819 -▁rule -5820 -▁Three -5821 -▁wheel -5822 -▁administr -5823 -▁buildings -5824 -lyn -5825 -oga -5826 -uits -5827 -▁usual -5828 -▁History -5829 -▁explain -5830 -▁domestic -5831 -▁concerned -5832 -!” -5833 -xy -5834 -itage -5835 -▁telling -5836 -▁Minister -5837 -▁violence -5838 -▁candidates -5839 -gas -5840 -ums -5841 -▁moist -5842 -▁licens -5843 -▁aspects -5844 -▁Communic -5845 -▁injuries -5846 -▁favourite -5847 -tra -5848 -▁ok -5849 -what -5850 -▁Girl -5851 -person -5852 -▁moments -5853 -▁typically -5854 -otal -5855 -▁pun -5856 -▁tur -5857 -▁Party -5858 -▁error -5859 -▁causes -5860 -▁styles -5861 -▁Italian -5862 -▁awareness -5863 -▁registration -5864 -▁vit -5865 -▁arts -5866 -▁phil -5867 -▁Night -5868 -▁Print -5869 -▁Perform -5870 -rim -5871 -road -5872 -lines -5873 -▁oven -5874 -▁grown -5875 -▁enable -5876 -▁island -5877 -▁greatest -5878 -vell -5879 -▁Harr -5880 -▁rand -5881 -orable -5882 -▁abuse -5883 -▁shoes -5884 -▁forces -5885 -▁stated -5886 -fficient -5887 -▁surprise -5888 -va -5889 -▁FOR -5890 -▁Key -5891 -▁tag -5892 -▁taxes -5893 -▁photography -5894 -ERS -5895 -hors -5896 -▁jun -5897 -anish -5898 -cluding -5899 -▁closer -5900 -▁citizens -5901 -▁negative -5902 -▁influence -5903 -CA -5904 -bur -5905 -writ -5906 -▁Four -5907 -▁circum -5908 -▁actions -5909 -ria -5910 -▁Def -5911 -▁Dog -5912 -tters -5913 -ulture -5914 -▁retire -5915 -▁script -5916 -▁stopped -5917 -▁stretch -5918 -▁broadcast -5919 -▁Wi -5920 -pond -5921 -▁Drive -5922 -▁Local -5923 -▁gradu -5924 -▁resol -5925 -▁Division -5926 -▁wet -5927 -▁crew -5928 -▁powder -5929 -▁database -5930 -▁tomorrow -5931 -▁sam -5932 -astern -5933 -▁Olymp -5934 -▁leather -5935 -▁practical -5936 -ribe -5937 -▁Bra -5938 -▁Ell -5939 -▁Max -5940 -▁adm -5941 -▁argu -5942 -Un -5943 -▁serves -5944 -▁weekly -5945 -▁alleged -5946 -iami -5947 -udden -5948 -▁shock -5949 -▁Pacific -5950 -▁payments -5951 -▁functions -5952 -▁inspiration -5953 -DS -5954 -▁Gra -5955 -stone -5956 -▁acid -5957 -▁bound -5958 -▁faculty -5959 -And -5960 -yers -5961 -▁tro -5962 -alled -5963 -▁mini -5964 -▁funny -5965 -▁Awards -5966 -▁speech -5967 -▁receiving -5968 -▁authorities -5969 -ava -5970 -hus -5971 -▁Mat -5972 -merce -5973 -▁Ryan -5974 -▁sequ -5975 -▁thin -5976 -lywood -5977 -▁column -5978 -▁designer -5979 -ucle -5980 -▁hits -5981 -▁cable -5982 -forcement -5983 -▁supplies -5984 -▁Available -5985 -▁electronic -5986 -TA -5987 -ERE -5988 -▁rot -5989 -atholic -5990 -▁config -5991 -▁pepper -5992 -▁village -5993 -▁identified -5994 -▁tut -5995 -▁gear -5996 -▁Cross -5997 -▁random -5998 -poration -5999 -▁everyday -6000 -▁committee -6001 -GE -6002 -bol -6003 -oup -6004 -irty -6005 -▁Hor -6006 -▁Oil -6007 -under -6008 -profit -6009 -▁Econom -6010 -▁perman -6011 -▁recognized -6012 -ache -6013 -▁Aff -6014 -itate -6015 -never -6016 -right -6017 -▁Coll -6018 -▁Need -6019 -▁grab -6020 -▁atmosp -6021 -▁degrees -6022 -▁printed -6023 -▁convenient -6024 -▁healthcare -6025 -▁impressive -6026 -PM -6027 -mar -6028 -inet -6029 -▁crime -6030 -▁keeps -6031 -▁lessons -6032 -▁Michigan -6033 -Pl -6034 -So -6035 -rip -6036 -▁tab -6037 -▁Bell -6038 -▁Cond -6039 -isters -6040 -▁essay -6041 -▁flour -6042 -▁crisis -6043 -▁height -6044 -▁emotional -6045 -▁determined -6046 -▁Cas -6047 -▁Ref -6048 -▁Tay -6049 -▁voc -6050 -atoes -6051 -etime -6052 -▁Ariz -6053 -▁films -6054 -▁imagine -6055 -▁treated -6056 -▁Sometimes -6057 -▁dangerous -6058 -▁happening -6059 -▁Lt -6060 -▁PS -6061 -aren -6062 -phas -6063 -▁Dun -6064 -▁Try -6065 -▁Small -6066 -▁crazy -6067 -▁Comple -6068 -▁ongoing -6069 -▁champions -6070 -▁explained -6071 -iate -6072 -hered -6073 -inter -6074 -▁Jenn -6075 -▁Mean -6076 -uction -6077 -▁Santa -6078 -▁fixed -6079 -▁sheet -6080 -▁entreprene -6081 -Ar -6082 -▁Run -6083 -▁Sus -6084 -urban -6085 -▁Safety -6086 -▁dropped -6087 -▁Marketing -6088 -cue -6089 -rum -6090 -▁Fed -6091 -▁patterns -6092 -▁resolution -6093 -▁du -6094 -pret -6095 -▁Mach -6096 -▁Canadian -6097 -▁investors -6098 -LS -6099 -All -6100 -aid -6101 -eler -6102 -made -6103 -▁row -6104 -▁worse -6105 -▁Victor -6106 -▁dining -6107 -iversary -6108 -▁subscrib -6109 -▁gro -6110 -anged -6111 -arian -6112 -▁Writ -6113 -▁rear -6114 -▁Guide -6115 -▁command -6116 -▁trading -6117 -▁conducted -6118 -▁tradition -6119 -LA -6120 -mary -6121 -anche -6122 -osoph -6123 -▁Rose -6124 -▁soul -6125 -▁taught -6126 -▁arrested -6127 -▁attended -6128 -▁officers -6129 -▁appointment -6130 -▁collaboration -6131 -Bl -6132 -Con -6133 -▁GM -6134 -▁Kh -6135 -enced -6136 -▁lift -6137 -▁simpl -6138 -▁extended -6139 -lete -6140 -▁der -6141 -▁Priv -6142 -▁cock -6143 -▁grad -6144 -▁roof -6145 -▁Chair -6146 -▁hoping -6147 -▁alcohol -6148 -▁positions -6149 -▁Environment -6150 -▁successfully -6151 -ppers -6152 -oosing -6153 -▁native -6154 -▁tournament -6155 -Don -6156 -inson -6157 -▁grew -6158 -▁wash -6159 -▁depth -6160 -▁flood -6161 -▁Account -6162 -▁freedom -6163 -▁ordered -6164 -▁eligible -6165 -▁incident -6166 -▁sick -6167 -▁folks -6168 -▁Senate -6169 -▁versions -6170 -iana -6171 -▁Inf -6172 -▁kne -6173 -▁Mult -6174 -▁spin -6175 -▁Richard -6176 -ello -6177 -rate -6178 -▁obtain -6179 -▁severe -6180 -▁Sat -6181 -aints -6182 -▁Turn -6183 -▁Photo -6184 -▁cycle -6185 -▁guard -6186 -▁teeth -6187 -▁noticed -6188 -iki -6189 -▁bat -6190 -▁Area -6191 -▁Paris -6192 -▁advoc -6193 -▁belong -6194 -▁forced -6195 -▁massive -6196 -▁graduate -6197 -▁construct -6198 -Be -6199 -ala -6200 -cers -6201 -essed -6202 -racts -6203 -▁adds -6204 -▁dram -6205 -▁none -6206 -▁houses -6207 -▁improvement -6208 -hire -6209 -real -6210 -rics -6211 -▁Daily -6212 -▁trend -6213 -iveness -6214 -▁Summer -6215 -▁tested -6216 -▁failure -6217 -▁Building -6218 -▁valuable -6219 -▁innovation -6220 -tle -6221 -▁ol -6222 -▁Kent -6223 -▁Which -6224 -▁mixed -6225 -▁shots -6226 -▁yards -6227 -▁cotton -6228 -▁regional -6229 -ayer -6230 -utch -6231 -▁Ash -6232 -▁Die -6233 -rease -6234 -▁Carl -6235 -▁Clean -6236 -▁Right -6237 -▁council -6238 -Is -6239 -▁MS -6240 -▁Box -6241 -▁Rev -6242 -▁thorough -6243 -▁integrated -6244 -▁DC -6245 -▁syn -6246 -▁Size -6247 -▁tiny -6248 -hentic -6249 -▁output -6250 -za -6251 -▁ec -6252 -inem -6253 -▁tank -6254 -▁owned -6255 -▁concert -6256 -▁knowing -6257 -▁routine -6258 -▁turning -6259 -▁efficiency -6260 -erse -6261 -▁drugs -6262 -▁Avenue -6263 -▁facing -6264 -▁guitar -6265 -▁diverse -6266 -▁therapy -6267 -▁clothing -6268 -▁providers -6269 -▁MO -6270 -▁Sn -6271 -▁Ent -6272 -▁Tool -6273 -acking -6274 -▁Select -6275 -▁publish -6276 -▁reduced -6277 -▁interface -6278 -CE -6279 -▁fo -6280 -▁Hon -6281 -osite -6282 -secut -6283 -▁Asia -6284 -▁Though -6285 -▁yellow -6286 -▁follows -6287 -▁description -6288 -▁distribution -6289 -illy -6290 -▁LLC -6291 -▁ped -6292 -abled -6293 -ansion -6294 -▁Training -6295 -▁settings -6296 -▁surprised -6297 -▁effectively -6298 -▁EU -6299 -print -6300 -▁auto -6301 -▁dial -6302 -sembly -6303 -▁Miami -6304 -▁silver -6305 -▁mixture -6306 -▁contemporary -6307 -▁expectations -6308 -▁:) -6309 -abet -6310 -▁Ball -6311 -intage -6312 -▁baking -6313 -▁enthus -6314 -▁unable -6315 -▁carried -6316 -▁circumst -6317 -▁intellig -6318 -▁accessible -6319 -▁challenging -6320 -▁perspective -6321 -▁Ira -6322 -▁Low -6323 -▁Want -6324 -letter -6325 -▁bonus -6326 -▁risks -6327 -▁upper -6328 -quality -6329 -▁nearby -6330 -▁pulled -6331 -▁protein -6332 -▁stunning -6333 -▁candidate -6334 -CT -6335 -PR -6336 -▁af -6337 -iece -6338 -ATION -6339 -▁Phys -6340 -▁Italy -6341 -▁stands -6342 -ev -6343 -aze -6344 -claim -6345 -▁Lind -6346 -ington -6347 -▁Beaut -6348 -▁matters -6349 -▁tonight -6350 -▁significantly -6351 -rowse -6352 -▁Nick -6353 -▁laugh -6354 -▁Proper -6355 -▁excess -6356 -▁garlic -6357 -▁univers -6358 -▁witness -6359 -▁approval -6360 -▁medicine -6361 -▁carefully -6362 -sm -6363 -zy -6364 -▁hur -6365 -▁Shop -6366 -▁chapter -6367 -▁complic -6368 -▁joining -6369 -obs -6370 -flow -6371 -oral -6372 -▁Cir -6373 -oured -6374 -▁fulf -6375 -▁equal -6376 -▁kinds -6377 -▁awarded -6378 -▁bedroom -6379 -▁channel -6380 -▁hosting -6381 -▁guidance -6382 -▁vacation -6383 -▁adventure -6384 -▁increases -6385 -▁recording -6386 -▁availability -6387 -▁SU -6388 -▁Dub -6389 -▁Requ -6390 -▁sole -6391 -▁Never -6392 -▁Works -6393 -▁likes -6394 -▁emphas -6395 -▁festival -6396 -▁accessories -6397 -bal -6398 -zer -6399 -▁glad -6400 -▁iron -6401 -▁tall -6402 -▁Heart -6403 -▁loans -6404 -▁Spanish -6405 -UL -6406 -rete -6407 -▁ease -6408 -riends -6409 -▁filed -6410 -▁renew -6411 -clusion -6412 -▁cooper -6413 -▁Republican -6414 -▁exhibition -6415 -▁partnership -6416 -stal -6417 -▁hopes -6418 -▁Credit -6419 -▁Mobile -6420 -▁SE -6421 -▁Rub -6422 -acked -6423 -ether -6424 -folio -6425 -▁bags -6426 -nesota -6427 -orgeous -6428 -▁creates -6429 -▁speaking -6430 -▁lifestyle -6431 -HA -6432 -sen -6433 -you -6434 -▁diss -6435 -▁hang -6436 -▁vend -6437 -▁Connect -6438 -▁Student -6439 -To -6440 -▁) -6441 -▁AR -6442 -adow -6443 -▁unf -6444 -▁legs -6445 -▁occup -6446 -▁Disney -6447 -▁appeal -6448 -▁assets -6449 -▁motion -6450 -▁trends -6451 -▁clothes -6452 -▁context -6453 -▁reporting -6454 -▁replacement -6455 -FC -6456 -yth -6457 -onto -6458 -yard -6459 -agues -6460 -▁Email -6461 -▁spaces -6462 -▁entirely -6463 -▁scholars -6464 -▁constantly -6465 -!" -6466 -anny -6467 -ican -6468 -long -6469 -▁arms -6470 -orders -6471 -▁shift -6472 -▁stamp -6473 -▁forest -6474 -▁Members -6475 -▁certific -6476 -▁searching -6477 -▁sustainable -6478 -▁OS -6479 -irts -6480 -onym -6481 -rition -6482 -▁spark -6483 -▁Number -6484 -▁Taylor -6485 -▁engage -6486 -▁manner -6487 -▁conflic -6488 -▁believes -6489 -▁submitted -6490 -II -6491 -bi -6492 -▁LED -6493 -comes -6494 -eding -6495 -▁kill -6496 -▁luxury -6497 -▁Studies -6498 -▁streets -6499 -▁procedures -6500 -ml -6501 -▁pil -6502 -▁fort -6503 -▁Still -6504 -▁sudden -6505 -▁outstanding -6506 -rid -6507 -▁Rh -6508 -foot -6509 -▁odd -6510 -▁cuts -6511 -▁Field -6512 -▁goods -6513 -▁negot -6514 -▁awards -6515 -▁criminal -6516 -▁monitoring -6517 -▁originally -6518 -▁SC -6519 -▁Kim -6520 -ially -6521 -▁Russian -6522 -▁invited -6523 -▁trained -6524 -▁Southern -6525 -▁millions -6526 -▁seriously -6527 -▁performing -6528 -▁transition -6529 -erts -6530 -ikes -6531 -▁Pot -6532 -▁eleg -6533 -▁weak -6534 -▁walls -6535 -▁recycl -6536 -▁refund -6537 -▁unlike -6538 -▁Arizona -6539 -▁capture -6540 -osc -6541 -asts -6542 -emic -6543 -izer -6544 -▁Pop -6545 -▁dim -6546 -▁rac -6547 -athan -6548 -ented -6549 -▁ille -6550 -▁zone -6551 -▁factor -6552 -▁prompt -6553 -▁reward -6554 -friendly -6555 -PC -6556 -ih -6557 -pat -6558 -bing -6559 -▁mal -6560 -▁Very -6561 -▁entr -6562 -▁horse -6563 -▁quote -6564 -▁museum -6565 -▁Mountain -6566 -Le -6567 -Ph -6568 -ba -6569 -▁Ra -6570 -▁Far -6571 -▁anx -6572 -▁vul -6573 -▁Jersey -6574 -▁conver -6575 -▁relief -6576 -▁illness -6577 -▁fighting -6578 -ATE -6579 -icket -6580 -▁blow -6581 -▁remov -6582 -▁Despite -6583 -▁Seattle -6584 -▁Standard -6585 -▁interests -6586 -▁foundation -6587 -▁cm -6588 -izza -6589 -front -6590 -▁Braz -6591 -▁Kenn -6592 -▁Pract -6593 -▁Should -6594 -▁herself -6595 -▁virtual -6596 -▁younger -6597 -HS -6598 -born -6599 -elry -6600 -▁tip -6601 -▁Easy -6602 -▁Ford -6603 -▁Iraq -6604 -▁moves -6605 -▁pocket -6606 -▁involve -6607 -▁examples -6608 -ani -6609 -rell -6610 -▁rose -6611 -▁smile -6612 -▁pounds -6613 -▁wealth -6614 -▁offices -6615 -▁flexible -6616 -▁Minnesota -6617 -▁transportation -6618 -▁Fre -6619 -▁Ire -6620 -▁Fall -6621 -▁gifts -6622 -▁input -6623 -▁Senior -6624 -▁upload -6625 -▁bathroom -6626 -▁assessment -6627 -▁capabilities -6628 -▁Jr -6629 -▁Ray -6630 -▁Rod -6631 -▁Stat -6632 -▁eggs -6633 -▁hole -6634 -▁pink -6635 -▁directed -6636 -▁identity -6637 -anes -6638 -ifer -6639 -iler -6640 -uter -6641 -▁Luc -6642 -▁Sav -6643 -▁beer -6644 -▁rein -6645 -▁bottle -6646 -▁Finally -6647 -▁airport -6648 -▁founded -6649 -▁clinical -6650 -▁ultimate -6651 -RS -6652 -sey -6653 -▁Army -6654 -▁debut -6655 -aturally -6656 -▁scientific -6657 -At -6658 -▁Ha -6659 -aron -6660 -▁Ask -6661 -▁Jac -6662 -▁sac -6663 -▁Bible -6664 -▁Royal -6665 -▁worst -6666 -illiant -6667 -▁distinct -6668 -▁improving -6669 -car -6670 -ilst -6671 -quir -6672 -▁Est -6673 -▁Kat -6674 -▁Vers -6675 -▁Event -6676 -▁elimin -6677 -▁figures -6678 -▁fishing -6679 -▁forever -6680 -▁copyright -6681 -da -6682 -▁Put -6683 -▁bab -6684 -ashed -6685 -▁Supp -6686 -▁faces -6687 -▁hospit -6688 -▁Country -6689 -▁Software -6690 -▁? -6691 -▁Non -6692 -ingly -6693 -▁garage -6694 -▁Instagram -6695 -▁tie -6696 -arrow -6697 -icate -6698 -▁Come -6699 -▁Site -6700 -▁Again -6701 -▁spoke -6702 -▁rating -6703 -▁Charles -6704 -▁visited -6705 -▁residential -6706 -▁Cab -6707 -ylvan -6708 -▁Arab -6709 -▁Fact -6710 -▁hasn -6711 -▁blank -6712 -▁stone -6713 -aration -6714 -▁entered -6715 -▁objects -6716 -▁rig -6717 -▁split -6718 -▁contribute -6719 -▁Unfortunately -6720 -RI -6721 -awn -6722 -uine -6723 -▁Bed -6724 -▁Dist -6725 -season -6726 -▁liked -6727 -▁spots -6728 -▁murder -6729 -▁Atlanta -6730 -▁developers -6731 -▁implementation -6732 -eah -6733 -With -6734 -▁coc -6735 -▁san -6736 -▁sky -6737 -▁Term -6738 -▁pitc -6739 -cluded -6740 -▁Radio -6741 -▁shower -6742 -▁Looking -6743 -▁Systems -6744 -▁baseball -6745 -▁calendar -6746 -▁Professor -6747 -▁procedure -6748 -oes -6749 -▁Ms -6750 -That -6751 -▁Save -6752 -▁cups -6753 -▁vital -6754 -resents -6755 -▁Member -6756 -▁linked -6757 -▁historical -6758 -▁possibility -6759 -Se -6760 -omy -6761 -umps -6762 -▁Mom -6763 -▁Foot -6764 -▁vibr -6765 -▁pitch -6766 -▁flavor -6767 -▁liquid -6768 -▁drawing -6769 -▁fitness -6770 -▁password -6771 -▁household -6772 -▁programme -6773 -▁atmosphere -6774 -▁reputation -6775 -andy -6776 -hell -6777 -ossible -6778 -▁enroll -6779 -▁papers -6780 -▁recipes -6781 -▁attached -6782 -▁mountain -6783 -▁organized -6784 -▁LA -6785 -▁Pow -6786 -▁hall -6787 -▁soph -6788 -▁tiss -6789 -asters -6790 -▁liber -6791 -▁Having -6792 -▁critic -6793 -▁muscle -6794 -▁talked -6795 -▁Administration -6796 -LY -6797 -One -6798 -host -6799 -▁Sem -6800 -▁Van -6801 -▁empt -6802 -▁seed -6803 -Americ -6804 -▁Brazil -6805 -▁Russia -6806 -▁carbon -6807 -▁passing -6808 -▁privacy -6809 -▁seasons -6810 -▁victims -6811 -▁frequently -6812 -▁institutions -6813 -.' -6814 -MP -6815 -But -6816 -rad -6817 -▁CO -6818 -▁PA -6819 -▁Space -6820 -▁chose -6821 -▁Living -6822 -▁theory -6823 -▁Shipping -6824 -▁MA -6825 -Read -6826 -▁ads -6827 -enger -6828 -ordan -6829 -▁rail -6830 -▁tech -6831 -▁regul -6832 -▁profit -6833 -▁managing -6834 -▁circumstances -6835 -ras -6836 -adel -6837 -tain -6838 -▁Son -6839 -▁Barb -6840 -▁hurt -6841 -▁proven -6842 -▁Justice -6843 -▁historic -6844 -▁networks -6845 -▁permission -6846 -▁legislation -6847 -▁publication -6848 -phy -6849 -▁Ba -6850 -bury -6851 -▁Cru -6852 -▁Cut -6853 -rible -6854 -▁butt -6855 -▁inch -6856 -▁Image -6857 -▁Express -6858 -▁regulations -6859 -dy -6860 -neys -6861 -ucky -6862 -▁err -6863 -uling -6864 -▁counsel -6865 -ta -6866 -ura -6867 -▁BE -6868 -▁Ur -6869 -olis -6870 -▁Fac -6871 -worth -6872 -▁Prom -6873 -▁skill -6874 -unction -6875 -▁Source -6876 -▁debate -6877 -▁Further -6878 -▁exposure -6879 -ubs -6880 -▁($ -6881 -▁Mir -6882 -▁Nic -6883 -▁Tax -6884 -▁cos -6885 -▁west -6886 -▁Garden -6887 -▁tracks -6888 -▁operate -6889 -RL -6890 -nders -6891 -▁Link -6892 -▁Name -6893 -▁lets -6894 -ffered -6895 -▁breath -6896 -▁qualified -6897 -▁represents -6898 -▁Leg -6899 -▁Oak -6900 -▁Brad -6901 -▁delay -6902 -▁finds -6903 -▁Season -6904 -▁walked -6905 -▁technique -6906 -▁NAS -6907 -▁bow -6908 -▁obl -6909 -▁tou -6910 -▁Anth -6911 -uclear -6912 -▁Choose -6913 -▁saving -6914 -▁authors -6915 -▁Learning -6916 -▁contrast -6917 -ella -6918 -ione -6919 -pons -6920 -▁Ltd -6921 -▁lad -6922 -icial -6923 -▁Scot -6924 -▁Brian -6925 -▁normally -6926 -▁realized -6927 -▁authentic -6928 -zes -6929 -urse -6930 -▁Rog -6931 -eller -6932 -▁fifth -6933 -▁merch -6934 -▁sight -6935 -▁tasks -6936 -▁hosted -6937 -▁reader -6938 -▁causing -6939 -▁savings -6940 -▁downtown -6941 -▁instance -6942 -By -6943 -odd -6944 -▁OR -6945 -▁Tony -6946 -▁mold -6947 -▁casual -6948 -▁execut -6949 -igration -6950 -ographic -6951 -▁anticip -6952 -▁justice -6953 -▁promise -6954 -▁somewhere -6955 -▁Professional -6956 -▁architecture -6957 -ingu -6958 -stra -6959 -entle -6960 -▁coat -6961 -▁smell -6962 -▁templ -6963 -ultural -6964 -▁sample -6965 -▁consequ -6966 -▁portion -6967 -▁estimated -6968 -Sc -6969 -idi -6970 -▁Pict -6971 -▁trib -6972 -remony -6973 -▁Labor -6974 -▁agric -6975 -▁trick -6976 -▁coordin -6977 -▁default -6978 -▁sending -6979 -▁upgrade -6980 -▁priority -6981 -▁interpret -6982 -▁surprising -6983 -▁volunteers -6984 -ults -6985 -cknow -6986 -▁batt -6987 -▁soil -6988 -▁mainly -6989 -▁manual -6990 -▁matches -6991 -▁gorgeous -6992 -▁shoulder -6993 -▁certified -6994 -▁apparently -6995 -▁continuing -6996 -▁situations -6997 -law -6998 -▁Es -6999 -▁exec -7000 -▁warn -7001 -arters -7002 -▁Stock -7003 -▁banks -7004 -▁bench -7005 -▁facil -7006 -▁lucky -7007 -ylvania -7008 -▁Golden -7009 -▁planet -7010 -▁posting -7011 -▁immediate -7012 -▁guidelines -7013 -bel -7014 -▁PH -7015 -star -7016 -▁Buy -7017 -▁Hou -7018 -words -7019 -▁Wilson -7020 -▁blocks -7021 -▁Financial -7022 -▁discussed -7023 -owa -7024 -ulf -7025 -ulpt -7026 -▁Mix -7027 -▁Mrs -7028 -▁USB -7029 -class -7030 -▁bear -7031 -▁hate -7032 -earing -7033 -▁firms -7034 -▁shops -7035 -▁Policy -7036 -▁Spirit -7037 -▁drinks -7038 -▁scheme -7039 -▁Customer -7040 -▁Medicine -7041 -▁Lar -7042 -anned -7043 -▁fasc -7044 -ealand -7045 -▁charm -7046 -ogether -7047 -respond -7048 -▁ending -7049 -▁terror -7050 -▁attacks -7051 -▁singles -7052 -▁workshop -7053 -▁Engineering -7054 -▁FA -7055 -iger -7056 -▁Ron -7057 -uster -7058 -▁Stay -7059 -▁magn -7060 -▁Sales -7061 -▁layer -7062 -▁prove -7063 -▁teasp -7064 -▁fairly -7065 -▁vulner -7066 -▁Ireland -7067 -▁external -7068 -nam -7069 -▁Yet -7070 -▁hat -7071 -▁vice -7072 -ingers -7073 -▁aspect -7074 -▁capable -7075 -▁Catholic -7076 -▁retirement -7077 -from -7078 -icit -7079 -unes -7080 -▁Cro -7081 -inder -7082 -▁scan -7083 -bridge -7084 -▁Motor -7085 -▁Order -7086 -▁Phone -7087 -▁stuck -7088 -eration -7089 -▁loving -7090 -▁Toronto -7091 -▁closely -7092 -▁injured -7093 -▁listing -7094 -▁Memorial -7095 -▁clicking -7096 -▁programming -7097 -aping -7098 -▁bare -7099 -▁Linux -7100 -▁climb -7101 -▁saved -7102 -▁orange -7103 -▁Zealand -7104 -▁proceed -7105 -▁believed -7106 -▁listening -7107 -▁industries -7108 -▁destination -7109 -▁Cy -7110 -▁EV -7111 -rich -7112 -▁Exp -7113 -▁wra -7114 -uting -7115 -▁Conf -7116 -▁Eric -7117 -▁juice -7118 -▁casino -7119 -▁breaking -7120 -▁memories -7121 -▁collected -7122 -▁landscape -7123 -SE -7124 -lo -7125 -▁Ca -7126 -▁FL -7127 -alle -7128 -aska -7129 -▁Ram -7130 -otted -7131 -▁Band -7132 -▁Tenn -7133 -▁terr -7134 -angers -7135 -▁reform -7136 -▁strike -7137 -▁Welcome -7138 -▁doctors -7139 -▁Material -7140 -▁enjoying -7141 -▁religious -7142 -▁spiritual -7143 -▁suggested -7144 -ati -7145 -▁MD -7146 -▁OK -7147 -Tube -7148 -aste -7149 -odge -7150 -▁hell -7151 -▁Roman -7152 -▁blend -7153 -▁forth -7154 -▁meets -7155 -▁assign -7156 -▁winners -7157 -▁machines -7158 -▁alongside -7159 -▁relatively -7160 -equ -7161 -ghan -7162 -▁Fox -7163 -▁Ide -7164 -oster -7165 -cludes -7166 -▁index -7167 -faction -7168 -▁riding -7169 -▁choosing -7170 -▁pleasure -7171 -▁strategic -7172 -▁anniversary -7173 -Ad -7174 -gypt -7175 -▁Dur -7176 -▁gym -7177 -child -7178 -imize -7179 -▁Line -7180 -▁yard -7181 -▁Smart -7182 -▁Think -7183 -▁aside -7184 -▁boxes -7185 -▁newly -7186 -▁prize -7187 -▁treatments -7188 -▁celebration -7189 -▁Subsc -7190 -▁bodies -7191 -▁writers -7192 -▁requests -7193 -▁designers -7194 -▁engagement -7195 -bro -7196 -inte -7197 -amber -7198 -▁Dave -7199 -▁east -7200 -▁Davis -7201 -▁Happy -7202 -▁bunch -7203 -▁pharm -7204 -▁belief -7205 -▁covering -7206 -▁extension -7207 -▁performances -7208 -▁WW -7209 -days -7210 -▁Sky -7211 -▁arg -7212 -▁Bang -7213 -▁elev -7214 -▁Camer -7215 -▁buyers -7216 -▁Meanwhile -7217 -▁brilliant -7218 -De -7219 -ls -7220 -agon -7221 -obby -7222 -▁Dar -7223 -▁NFL -7224 -▁Sep -7225 -ormal -7226 -▁enem -7227 -ensity -7228 -giving -7229 -▁birds -7230 -▁broke -7231 -▁giant -7232 -▁proof -7233 -▁franch -7234 -▁division -7235 -nic -7236 -inos -7237 -▁Pak -7238 -ashes -7239 -osophy -7240 -▁Asian -7241 -▁Kevin -7242 -lements -7243 -▁acknow -7244 -▁symbol -7245 -▁titles -7246 -sylvania -7247 -▁packaging -7248 -▁platforms -7249 -▁instrument -7250 -▁differences -7251 -oty -7252 -▁raw -7253 -▁unw -7254 -iders -7255 -ureau -7256 -▁Adam -7257 -▁iPad -7258 -esides -7259 -▁meals -7260 -▁river -7261 -▁compat -7262 -▁enables -7263 -▁drinking -7264 -▁volunteer -7265 -’. -7266 -▁PDF -7267 -inton -7268 -▁mile -7269 -▁slic -7270 -▁solo -7271 -▁superv -7272 -▁letters -7273 -▁authority -7274 -.’ -7275 -wan -7276 -▁PL -7277 -alse -7278 -rage -7279 -wart -7280 -▁pip -7281 -▁Bush -7282 -▁Iran -7283 -lisher -7284 -parent -7285 -▁Story -7286 -▁urban -7287 -ainless -7288 -▁consistent -7289 -pes -7290 -▁Uk -7291 -▁|| -7292 -bles -7293 -wich -7294 -▁kit -7295 -ronics -7296 -▁Chall -7297 -▁Model -7298 -▁centers -7299 -▁charity -7300 -▁typical -7301 -▁explains -7302 -▁replaced -7303 -▁newspaper -7304 -▁communications -7305 -GA -7306 -OVID -7307 -▁rug -7308 -▁acts -7309 -▁lapt -7310 -▁vacc -7311 -▁vast -7312 -ateful -7313 -jection -7314 -▁infect -7315 -▁YouTube -7316 -▁mortgage -7317 -▁CN -7318 -leep -7319 -oker -7320 -▁Jay -7321 -▁stim -7322 -▁tape -7323 -▁trim -7324 -▁tooth -7325 -▁dreams -7326 -▁falling -7327 -▁handling -7328 -▁holidays -7329 -▁swimming -7330 -cons -7331 -iley -7332 -page -7333 -▁stir -7334 -▁Return -7335 -▁decade -7336 -▁domain -7337 -▁singer -7338 -▁Perhaps -7339 -▁destroy -7340 -▁dynamic -7341 -▁lighting -7342 -▁proposal -7343 -▁categories -7344 -▁encouraged -7345 -▁membership -7346 -▁personally -7347 -Fi -7348 -acious -7349 -▁Jason -7350 -▁Jordan -7351 -▁Columbia -7352 -▁forecast -7353 -▁informed -7354 -▁wireless -7355 -▁classroom -7356 -▁accomplish -7357 -▁initiative -7358 -▁suggestions -7359 -▁Po -7360 -▁mut -7361 -erman -7362 -▁Bird -7363 -▁Mill -7364 -▁Swed -7365 -▁slee -7366 -▁susp -7367 -▁Egypt -7368 -▁Staff -7369 -▁Treat -7370 -▁recre -7371 -▁solve -7372 -▁agents -7373 -▁combine -7374 -▁founder -7375 -▁percentage -7376 -▁Advis -7377 -▁Cancer -7378 -▁arrive -7379 -▁headed -7380 -▁expansion -7381 -▁sensitive -7382 -▁manufacturers -7383 -TER -7384 -uis -7385 -athy -7386 -▁Bad -7387 -▁Ess -7388 -▁magic -7389 -▁penal -7390 -▁Agency -7391 -▁Miller -7392 -▁Gallery -7393 -ounce -7394 -▁bars -7395 -▁embr -7396 -▁tied -7397 -▁Being -7398 -▁crash -7399 -▁flash -7400 -▁filter -7401 -▁Classic -7402 -▁Houston -7403 -▁shouldn -7404 -▁Remember -7405 -▁Transport -7406 -▁participating -7407 -▁ast -7408 -▁Talk -7409 -▁dust -7410 -▁Annual -7411 -▁Recent -7412 -▁slowly -7413 -▁Airport -7414 -▁Kingdom -7415 -▁pricing -7416 -▁travell -7417 -▁Northern -7418 -▁enterprise -7419 -ko -7420 -▁Josh -7421 -▁evol -7422 -▁mood -7423 -▁unus -7424 -▁facts -7425 -▁phones -7426 -▁Consult -7427 -▁ancient -7428 -▁presents -7429 -▁printing -7430 -▁Secretary -7431 -▁permanent -7432 -wis -7433 -onna -7434 -level -7435 -▁hire -7436 -amsung -7437 -rovers -7438 -▁Brook -7439 -▁venue -7440 -▁Joseph -7441 -▁gender -7442 -▁extract -7443 -▁intense -7444 -ervations -7445 -▁Pennsylvania -7446 -▁DI -7447 -..... -7448 -abeth -7449 -▁Base -7450 -▁assum -7451 -▁dealing -7452 -▁gallery -7453 -▁genuine -7454 -▁portfolio -7455 -▁enforcement -7456 -FA -7457 -esy -7458 -site -7459 -▁suc -7460 -igate -7461 -uties -7462 -▁Film -7463 -▁gall -7464 -ership -7465 -▁Level -7466 -▁roles -7467 -ologist -7468 -▁Create -7469 -▁watched -7470 -▁producing -7471 -▁IC -7472 -lers -7473 -wear -7474 -▁Dam -7475 -asted -7476 -mates -7477 -▁fest -7478 -making -7479 -▁scenes -7480 -▁constit -7481 -▁carrying -7482 -▁suffered -7483 -▁traveling -7484 -▁attractive -7485 -OD -7486 -Tr -7487 -▁Own -7488 -▁Sea -7489 -iking -7490 -oices -7491 -▁Webs -7492 -▁vari -7493 -ardens -7494 -▁Grant -7495 -ulating -7496 -▁Silver -7497 -▁border -7498 -▁assault -7499 -▁Continue -7500 -▁generate -7501 -▁assistant -7502 -▁Collection -7503 -▁guaranteed -7504 -▁recommendations -7505 -Do -7506 -axy -7507 -bar -7508 -pir -7509 -Book -7510 -▁Sym -7511 -▁Stan -7512 -▁trig -7513 -▁wins -7514 -▁Books -7515 -▁absor -7516 -▁stake -7517 -▁Studio -7518 -▁Quality -7519 -▁chances -7520 -▁Personal -7521 -▁equipped -7522 -▁Ter -7523 -Press -7524 -books -7525 -active -7526 -▁grass -7527 -▁opens -7528 -▁solar -7529 -inating -7530 -▁compens -7531 -▁heading -7532 -▁Everyone -7533 -▁diseases -7534 -▁reducing -7535 -▁Hollywood -7536 -▁languages -7537 -▁professor -7538 -▁incredibly -7539 -boy -7540 -▁rh -7541 -aine -7542 -ilty -7543 -raid -7544 -burgh -7545 -▁Fred -7546 -▁actor -7547 -▁formed -7548 -▁Eastern -7549 -▁booking -7550 -▁podcast -7551 -▁speaker -7552 -▁Experience -7553 -▁interactive -7554 -SC -7555 -Te -7556 -rm -7557 -amel -7558 -▁hel -7559 -▁anyway -7560 -▁lawyer -7561 -▁neighb -7562 -▁cookies -7563 -▁Magazine -7564 -▁Therefore -7565 -acc -7566 -ila -7567 -▁CL -7568 -▁Deb -7569 -asant -7570 -ctive -7571 -▁Bern -7572 -▁lect -7573 -▁Force -7574 -▁Henry -7575 -▁Would -7576 -▁formal -7577 -▁string -7578 -▁filling -7579 -▁Products -7580 -▁purchasing -7581 -▁connections -7582 -alo -7583 -run -7584 -▁Gi -7585 -etch -7586 -game -7587 -phia -7588 -shire -7589 -▁narr -7590 -▁alive -7591 -▁pride -7592 -graduate -7593 -▁preferred -7594 -▁Hi -7595 -ials -7596 -▁Ath -7597 -▁Hun -7598 -▁Mov -7599 -stein -7600 -▁Clin -7601 -▁Emer -7602 -▁Guard -7603 -▁Major -7604 -▁phase -7605 -▁limits -7606 -▁marked -7607 -▁writes -7608 -▁defined -7609 -▁deposit -7610 -▁visible -7611 -▁suggests -7612 -oto -7613 -swe -7614 -roke -7615 -▁Tel -7616 -▁Kids -7617 -▁seats -7618 -▁shell -7619 -▁accused -7620 -▁aggress -7621 -▁expressed -7622 -▁basketball -7623 -Fr -7624 -▁EN -7625 -onic -7626 -allas -7627 -▁bact -7628 -lessly -7629 -▁empty -7630 -▁Estate -7631 -▁hotels -7632 -▁nights -7633 -▁racing -7634 -▁Comment -7635 -▁jewelry -7636 -▁substant -7637 -▁primarily -7638 -esh -7639 -imp -7640 -▁CP -7641 -bell -7642 -▁bid -7643 -▁gay -7644 -utter -7645 -▁Past -7646 -▁aims -7647 -▁lady -7648 -▁habit -7649 -▁Father -7650 -▁Histor -7651 -▁Mother -7652 -▁Things -7653 -▁rental -7654 -▁shapes -7655 -▁weapons -7656 -itionally -7657 -▁accuracy -7658 -▁resulting -7659 -▁creativity -7660 -▁specialist -7661 -▁vegetables -7662 -AV -7663 -▁oz -7664 -ogue -7665 -▁Has -7666 -▁lie -7667 -ifies -7668 -inity -7669 -▁cycl -7670 -intend -7671 -▁Based -7672 -▁bills -7673 -limited -7674 -▁remark -7675 -▁rising -7676 -▁engaged -7677 -▁instant -7678 -▁organis -7679 -▁politics -7680 -▁Published -7681 -▁recognition -7682 -ns -7683 -hour -7684 -▁Las -7685 -inois -7686 -uters -7687 -▁Give -7688 -▁Iowa -7689 -▁Marc -7690 -▁Tele -7691 -abetes -7692 -▁Vegas -7693 -▁criteria -7694 -▁suffering -7695 -▁compliance -7696 -essee -7697 -▁rice -7698 -▁marks -7699 -adelphia -7700 -▁Officer -7701 -▁compare -7702 -▁desired -7703 -▁component -7704 -▁highlights -7705 -▁TR -7706 -uana -7707 -▁tub -7708 -oween -7709 -▁dism -7710 -▁Prime -7711 -▁brush -7712 -▁Kansas -7713 -▁dollar -7714 -▁Britain -7715 -▁crucial -7716 -▁graphic -7717 -▁recover -7718 -▁achieved -7719 -▁literally -7720 -▁interviews -7721 -jo -7722 -igs -7723 -lee -7724 -▁Ap -7725 -greg -7726 -▁Map -7727 -▁tap -7728 -▁Fast -7729 -▁HERE -7730 -▁duty -7731 -makers -7732 -▁Among -7733 -▁Steel -7734 -▁knock -7735 -▁healing -7736 -▁illegal -7737 -▁admitted -7738 -▁describe -7739 -▁entering -7740 -▁releases -7741 -▁speakers -7742 -▁Solutions -7743 -▁functional -7744 -des -7745 -▁pra -7746 -▁Roll -7747 -▁Cover -7748 -▁Kelly -7749 -athered -7750 -▁intent -7751 -▁Edition -7752 -▁massage -7753 -▁packages -7754 -▁Following -7755 -▁attending -7756 -▁obviously -7757 -li -7758 -uan -7759 -▁EX -7760 -mers -7761 -▁Meth -7762 -▁keys -7763 -▁heads -7764 -holders -7765 -▁Change -7766 -▁Orange -7767 -▁matching -7768 -▁displayed -7769 -▁recognize -7770 -▁wondering -7771 -▁correspond -7772 -isa -7773 -▁CC -7774 -▁IM -7775 -Cont -7776 -orous -7777 -▁Diego -7778 -▁dough -7779 -▁trips -7780 -▁signal -7781 -▁developer -7782 -▁exceptional -7783 -▁increasingly -7784 -%. -7785 -ja -7786 -htt -7787 -▁Ros -7788 -athon -7789 -heast -7790 -▁Dead -7791 -▁puts -7792 -▁till -7793 -▁Nation -7794 -▁alumin -7795 -▁struck -7796 -novation -7797 -▁claimed -7798 -▁farmers -7799 -▁hitting -7800 -▁whenever -7801 -▁officially -7802 -▁introduction -7803 -pson -7804 -▁Isl -7805 -found -7806 -▁Auto -7807 -▁Body -7808 -▁king -7809 -▁mand -7810 -inding -7811 -▁Table -7812 -▁Forest -7813 -▁Valent -7814 -▁narrow -7815 -▁colours -7816 -▁Attorney -7817 -▁networking -7818 -▁necessarily -7819 -▁improvements -7820 -tail -7821 -▁bug -7822 -▁clar -7823 -▁Civil -7824 -utional -7825 -▁hidden -7826 -▁Theatre -7827 -▁texture -7828 -▁checking -7829 -▁constant -7830 -▁licensed -7831 -▁Cry -7832 -▁cust -7833 -▁root -7834 -ickets -7835 -terior -7836 -▁Youth -7837 -▁loose -7838 -▁setup -7839 -▁acting -7840 -▁Chapter -7841 -▁Reading -7842 -▁occurred -7843 -▁struggling -7844 -TP -7845 -tw -7846 -AND -7847 -▁ -7848 -e -7849 -t -7850 -a -7851 -o -7852 -i -7853 -n -7854 -s -7855 -r -7856 -h -7857 -l -7858 -d -7859 -c -7860 -u -7861 -m -7862 -p -7863 -g -7864 -f -7865 -y -7866 -w -7867 -b -7868 -. -7869 -v -7870 -, -7871 -k -7872 -T -7873 -I -7874 -S -7875 -A -7876 -- -7877 -C -7878 -0 -7879 -1 -7880 -M -7881 -P -7882 -B -7883 -x -7884 -2 -7885 -W -7886 -D -7887 -R -7888 -E -7889 -H -7890 -F -7891 -L -7892 -O -7893 -N -7894 -’ -7895 -' -7896 -: -7897 -G -7898 -j -7899 -) -7900 -3 -7901 -( -7902 -z -7903 -5 -7904 -q -7905 -" -7906 -U -7907 -4 -7908 -J -7909 -9 -7910 -6 -7911 -8 -7912 -V -7913 -Y -7914 -K -7915 -7 -7916 -! -7917 -| -7918 -/ -7919 -? -7920 -“ -7921 -” -7922 -; -7923 -– -7924 -& -7925 -$ -7926 -— -7927 -Q -7928 -X -7929 -% -7930 -Z -7931 diff --git a/records/track_non_record_16mb/2026-03-24_106M_Binary_Asymmetric_UNet_FP8_15L_8192BPE_YaRN_NeoMuon_Smear/requirements.txt b/records/track_non_record_16mb/2026-03-24_106M_Binary_Asymmetric_UNet_FP8_15L_8192BPE_YaRN_NeoMuon_Smear/requirements.txt deleted file mode 100644 index 0c5eedce7b..0000000000 --- a/records/track_non_record_16mb/2026-03-24_106M_Binary_Asymmetric_UNet_FP8_15L_8192BPE_YaRN_NeoMuon_Smear/requirements.txt +++ /dev/null @@ -1,10 +0,0 @@ -numpy -tqdm -torch==2.10 -huggingface-hub -kernels -setuptools -typing-extensions==4.15.0 -datasets -tiktoken -sentencepiece \ No newline at end of file diff --git a/records/track_non_record_16mb/2026-03-24_106M_Binary_Asymmetric_UNet_FP8_15L_8192BPE_YaRN_NeoMuon_Smear/run_cuda_binary.sh b/records/track_non_record_16mb/2026-03-24_106M_Binary_Asymmetric_UNet_FP8_15L_8192BPE_YaRN_NeoMuon_Smear/run_cuda_binary.sh deleted file mode 100644 index 473b3388e3..0000000000 --- a/records/track_non_record_16mb/2026-03-24_106M_Binary_Asymmetric_UNet_FP8_15L_8192BPE_YaRN_NeoMuon_Smear/run_cuda_binary.sh +++ /dev/null @@ -1,72 +0,0 @@ -RUN_ID=pushing_run_binary_1 \ -DATA_PATH=./data/datasets/fineweb10B_sp8192 \ -TOKENIZER_PATH=./data/tokenizers/fineweb_8192_bpe.model \ -ATTN_PROJ_TYPE=standard \ -LOGIT_HEAD_TYPE=standard \ -TVERSKY_MEMBERSHIP=sigmoid \ -TVERSKY_NUM_FEATURES=0 \ -TVERSKY_FEATURE_POOLS=0 \ -VOCAB_SIZE=8192 \ -BITNET_GROUP_SIZE=128 \ -BIGRAM_HASH=0 \ -EMBED_DIM=254 \ -TRAINING_DEPTH_RECURRENCE=0 \ -EVAL_DEPTH_RECURRENCE=0 \ -NUM_LAYERS=15 \ -MODEL_DIM=768 \ -NUM_KV_HEADS=4 \ -NUM_HEADS=8 \ -DIFF_ATTN=0 \ -MLP_MULT=4 \ -MLP_GROUPS=0 \ -MATRIX_OPTIMIZER=muon \ -ADAM_LR=0.05 \ -ADAM_WD=0.05 \ -MUON_BACKEND_STEPS=3 \ -MUON_MOMENTUM=0.95 \ -MUON_MOMENTUM_WARMUP_START=0.85 \ -MUON_MOMENTUM_WARMUP_STEPS=500 \ -MUON_WD=0.0 \ -MATRIX_LR=0.04 \ -SCALAR_LR=0.02 \ -TIED_EMBED_LR=0.02 \ -WARMDOWN_FRACTION=0.2 \ -LOGIT_SOFTCAP=10 \ -QK_GAIN_INIT=2.25 \ -ROPE_TYPE=yarn \ -YARN_MAX_LEN=2048 \ -ROPE_BASE=5000 \ -BATCH_TOKENS_START=0 \ -BATCH_SCHEDULE_FRACTION=0.33 \ -TRAIN_BATCH_TOKENS=524288 \ -SEQ_LEN_START=0 \ -SEQ_SCHEDULE_FRACTION=0.0 \ -TRAIN_SEQ_LEN=1024 \ -SMEAR=1 \ -ITERATIONS=50000 \ -WARMUP_STEPS=5 \ -MAX_WALLCLOCK_SECONDS=0 \ -VAL_LOSS_EVERY=0 \ -TRAIN_LOG_EVERY=500 \ -CHURN_LOG_EVERY=1000 \ -VAL_MAX_TOKENS=0 \ -TIE_EMBEDDINGS=1 \ -UNTIE_AT_FRACTION=0.00 \ -HEAD_LR=0.02 \ -CORR_WEIGHT_LR=0.02 \ -ACTIVATION=relu2 \ -SOFTCAP_TYPE=poly \ -MTP_HEADS=0 \ -REFINER=0 \ -REFINER_KERNEL=3 \ -SLIDING_EVAL=1 \ -SLIDING_EVAL_STRIDE=16 \ -SLIDING_BATCH_SIZE=256 \ -TEMP_SCALING=1 \ -FP_STORAGE=FP8 \ -EMA=0 \ -EMA_DECAY=0.995 \ -EMA_START_FRACTION=0.5 \ -SEED=42 \ -COMPILE_MODE=default \ -OMP_NUM_THREADS=1 torchrun --standalone --nproc_per_node=8 train_gpt_cuda_binary.py diff --git a/records/track_non_record_16mb/2026-03-24_106M_Binary_Asymmetric_UNet_FP8_15L_8192BPE_YaRN_NeoMuon_Smear/setup.sh b/records/track_non_record_16mb/2026-03-24_106M_Binary_Asymmetric_UNet_FP8_15L_8192BPE_YaRN_NeoMuon_Smear/setup.sh deleted file mode 100644 index 93f1c41fea..0000000000 --- a/records/track_non_record_16mb/2026-03-24_106M_Binary_Asymmetric_UNet_FP8_15L_8192BPE_YaRN_NeoMuon_Smear/setup.sh +++ /dev/null @@ -1,143 +0,0 @@ -#!/bin/bash -# ------------------------------------------------------------------------------- -# Parameter Golf -- Complete Environment Setup Script -# Drop this into the project root and run: bash setup.sh -# ------------------------------------------------------------------------------- - -set -e - -echo "----------------------------------------------" -echo " Parameter Golf -- Environment Setup" -echo "----------------------------------------------" - -# ------------------------------------------------------------------------------- -# 1. Miniconda -# ------------------------------------------------------------------------------- -echo "" -echo "[1/5] Miniconda..." - -if [ -d "$HOME/miniconda3" ]; then - echo " Already installed -- skipping." -else - wget -q https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O /tmp/miniconda.sh - bash /tmp/miniconda.sh -b - rm /tmp/miniconda.sh - ~/miniconda3/bin/conda init bash - echo " Installed." -fi - -export PATH="$HOME/miniconda3/bin:$PATH" -source ~/miniconda3/etc/profile.d/conda.sh - -echo " Accepting conda TOS..." -~/miniconda3/bin/conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/main -~/miniconda3/bin/conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/r -echo " TOS accepted." - -# ------------------------------------------------------------------------------- -# 2. Python Environment -# ------------------------------------------------------------------------------- -echo "" -echo "[2/5] Python 3.13 environment..." - -if conda env list | grep -q "^golf "; then - echo " Environment 'golf' already exists -- skipping." -else - conda create -n golf python=3.13 -y - echo " Created." -fi - -conda activate golf -echo " Activated." - -# ------------------------------------------------------------------------------- -# 3. Requirements -# ------------------------------------------------------------------------------- -echo "" -echo "[3/5] Requirements..." - -if python3 -c "import torch, sentencepiece, numpy" 2>/dev/null; then - echo " Core packages already installed -- skipping." -else - pip install --upgrade pip -q - pip install -r requirements.txt -q - echo " Installed." -fi - -# ------------------------------------------------------------------------------- -# 4. FlashAttention-3 -# ------------------------------------------------------------------------------- -echo "" -echo "[4/5] FlashAttention-3..." - -if python3 -c "import flash_attn" 2>/dev/null || python3 -c "import flash_attn_interface" 2>/dev/null; then - echo " Already installed -- skipping." -else - # abi3 wheel -- Python 3.9+ compatible, installs in seconds, no compilation - pip install --no-cache-dir "https://download.pytorch.org/whl/cu128/flash_attn_3-3.0.0-cp39-abi3-manylinux_2_28_x86_64.whl" - echo " Installed." -fi - -# ------------------------------------------------------------------------------- -# 5. Dataset -# ------------------------------------------------------------------------------- -echo "" -echo "[5/5] FineWeb dataset (sp8192, 10 shards)..." - -echo " Downloading... ($TRAIN_COUNT/10 train shards found)" -hf download sproos/parameter-golf-tokenizers --include "datasets/fineweb10B_sp8192/*" --local-dir ./data -echo " Downloaded." - -# ------------------------------------------------------------------------------- -# Verification -# ------------------------------------------------------------------------------- -echo "" -echo "----------------------------------------------" -echo " Verification" -echo "----------------------------------------------" - -python3 - << 'EOF' -import sys -import torch -import numpy as np -import glob - -print(f"Python : {sys.version.split()[0]}") -print(f"PyTorch : {torch.__version__}") -print(f"CUDA : {torch.cuda.is_available()}") -print(f"GPUs : {torch.cuda.device_count()}") - -if torch.cuda.is_available(): - for i in range(torch.cuda.device_count()): - props = torch.cuda.get_device_properties(i) - print(f" GPU {i} : {props.name} ({props.total_memory // 1024**3}GB)") - -try: - import flash_attn - print(f"FlashAttn : {flash_attn.__version__}") -except ImportError: - try: - import flash_attn_interface - print(f"FlashAttn3 : available") - except ImportError: - print(f"FlashAttn : NOT found") - -train_files = sorted(glob.glob("./data/datasets/fineweb10B_sp8192/fineweb_train_*.bin")) -val_files = sorted(glob.glob("./data/datasets/fineweb10B_sp8192/fineweb_val_*.bin")) -print(f"Train shards : {len(train_files)}") -print(f"Val shards : {len(val_files)}") - -if val_files: - total = sum( - int(np.fromfile(f, dtype=' tuple[bytes, int]: - bits = ((q.reshape(-1).to(torch.int8) + 1) // 2).numpy().astype(np.uint8) - n = len(bits) - pad = (8 - n % 8) % 8 - if pad: - bits = np.concatenate([bits, np.zeros(pad, dtype=np.uint8)]) - groups = bits.reshape(-1, 8) - packed = np.zeros(len(groups), dtype=np.uint8) - for i in range(8): - packed |= groups[:, i] << i - return packed.tobytes(), n - -def unpack_binary(data: bytes, n: int) -> Tensor: - packed = np.frombuffer(data, dtype=np.uint8) - bits = np.zeros((len(packed), 8), dtype=np.int8) - for i in range(8): - bits[:, i] = (packed >> i) & 1 - flat = bits.reshape(-1)[:n] - return torch.from_numpy(flat.astype(np.int8) * 2 - 1) - -# --------------------------------------------------------------------------- -# FP4 quantization (per-row absmax, 2 values packed per byte) -# --------------------------------------------------------------------------- -def quantize_to_int4(t: Tensor) -> tuple[Tensor, Tensor, list]: - t32 = t.float() - orig_shape = t32.shape - if t32.ndim < 2: - t32 = t32.unsqueeze(0) - absmax = t32.abs().amax(dim=-1, keepdim=True).clamp(min=1e-8) - scale = absmax / 7.0 - q = torch.clamp(torch.round(t32 / scale), -7, 7).to(torch.int8) - flat = q.reshape(-1) - if flat.numel() % 2 != 0: - flat = F.pad(flat, (0, 1)) - low = (flat[0::2] + 8).to(torch.uint8) - high = (flat[1::2] + 8).to(torch.uint8) - return low | (high << 4), scale.half().squeeze(-1), list(orig_shape) - -def dequantize_from_int4(packed: Tensor, scale: Tensor, shape: list) -> Tensor: - low = (packed & 0x0F).to(torch.int8) - 8 - high = ((packed >> 4) & 0x0F).to(torch.int8) - 8 - flat = torch.zeros(packed.numel() * 2, dtype=torch.int8) - flat[0::2] = low - flat[1::2] = high - numel = 1 - for s in shape: - numel *= s - flat = flat[:numel].float() - if len(shape) <= 1: - return (flat * scale.float().squeeze()).reshape(shape) - return (flat.reshape(-1, shape[-1]) * scale.float().unsqueeze(-1)).reshape(shape) - -# --------------------------------------------------------------------------- -# State dict serialization (binary + fp16/fp8/fp4) -# --------------------------------------------------------------------------- -def q_sd(state_dict: dict, group_size: int = 64, fp_storage=False, binary_override_names: set | None = None) -> tuple[dict, dict]: - "Binary for large 2D weight matrices, fp16/fp8/fp4 for everything else." - quantized = {} - stats = {"binary_params": 0, "binary_bytes": 0, "fp_params": 0, "fp_bytes": 0} - for name, tensor in state_dict.items(): - if "mtp_heads" in name: - continue - t = tensor.detach().cpu().float().contiguous() - t_orig_shape = list(t.shape) - if t.ndim == 3: - t = t.reshape(t.shape[0], -1) - is_binary_candidate = ( - t.ndim == 2 and t.numel() > 65_536 - and "tok_emb" not in name and "lm_head" not in name and "embed_proj" not in name and "bigram_emb" not in name and "lm_head_correction" not in name and "lm_head_U" not in name and "lm_head_V" not in name - and "prototypes" not in name and "tversky" not in name - ) or (binary_override_names is not None and name in binary_override_names) - if is_binary_candidate: - pad = (group_size - t.shape[1] % group_size) % group_size - t_padded = F.pad(t, (0, pad)) if pad > 0 else t - t_grouped = t_padded.reshape(-1, group_size) - scale = t_grouped.abs().mean(-1, keepdim=True).clamp(min=1e-8).half().float() - q = torch.where(t_grouped >= 0, - torch.ones_like(t_grouped, dtype=torch.int8), - -torch.ones_like(t_grouped, dtype=torch.int8)) - packed_bytes, n_bits = pack_binary(q) - quantized[name] = { - "type": "binary", "packed": packed_bytes, - "scale": scale.half().squeeze(-1), - "shape": list(t.shape), "padded_cols": t_padded.shape[1], - "group_size": group_size, "n_bits": n_bits, - "orig_shape": t_orig_shape, - } - stats["binary_params"] += t.numel() - stats["binary_bytes"] += len(packed_bytes) + scale.numel() * 2 - elif fp_storage == "fp4" and t.ndim == 2: - packed, scale, orig_shape = quantize_to_int4(t) - quantized[name] = {"type": "fp4", "packed": packed, "scale": scale, "shape": orig_shape} - stats["fp_params"] += t.numel() - stats["fp_bytes"] += packed.numel() + scale.numel() * 2 - elif fp_storage and t.ndim == 2: - quantized[name] = {"type": "fp8", "data": t.to(torch.float8_e4m3fn)} - stats["fp_params"] += t.numel() - stats["fp_bytes"] += t.numel() - else: - quantized[name] = {"type": "fp16", "data": t.half()} - stats["fp_params"] += t.numel() - stats["fp_bytes"] += t.numel() * 2 - return quantized, stats - -def deq_sd(quantized: dict, target_dtype=torch.bfloat16): - "Reconstruct full-precision state dict from quantized representation." - out = {} - for name, entry in quantized.items(): - if entry["type"] == "binary": - q = unpack_binary(entry["packed"], entry["n_bits"]) - q = q.float().reshape(-1, entry["group_size"]) - scale = entry["scale"].float().unsqueeze(-1) - # No shrinkage correction needed: binary has no zeros, q.abs().mean() == 1.0 always - t = (q * scale).reshape(-1, entry["padded_cols"]) - shape = entry["shape"] - result = t[:shape[0], :shape[1]].to(target_dtype) - orig = entry.get("orig_shape") - out[name] = result.reshape(orig).contiguous() if orig and orig != shape else result.contiguous() - elif entry["type"] == "fp8": - out[name] = entry["data"].to(torch.float32).to(target_dtype).contiguous() - elif entry["type"] == "fp4": - out[name] = dequantize_from_int4(entry["packed"], entry["scale"], entry["shape"]).to(target_dtype).contiguous() - else: - out[name] = entry["data"].to(target_dtype).contiguous() - return out - -# --------------------------------------------------------------------------- -# Binary diagnostics (logged during training) -# --------------------------------------------------------------------------- -_prev_committed: dict = {} -def churn_fn(model: nn.Module, group_size: int = 64): - global _prev_committed - total = flipped = 0 - with torch.no_grad(): - for name, p in model.named_parameters(): - if p.ndim == 2 and ("weight" in name or "prototypes" in name) and p.shape[0] > 1: - w = p.detach().float().reshape(-1, group_size) - q = torch.where(w >= 0, torch.ones_like(w), -torch.ones_like(w)).cpu().numpy() - if name in _prev_committed: - flipped += int(np.sum(q != _prev_committed[name])) - total += q.size - _prev_committed[name] = q - return flipped / max(total, 1) - -# --------------------------------------------------------------------------- -# Muon optimizer (Newton-Schulz orthogonalized momentum) -# --------------------------------------------------------------------------- -def ns_orth(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: - a, b, c = (3.4445, -4.7750, 2.0315) - X = G.bfloat16() - X /= X.norm() + eps - transposed = G.size(0) > G.size(1) - if transposed: - X = X.T - for _ in range(steps): - A = X @ X.T - B = b * A + c * A @ A - X = a * X + B @ X - return X.T if transposed else X - -class Muon(torch.optim.Optimizer): - def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True, wd: float = 0.0): - super().__init__(params, dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov, wd=wd)) - @torch.no_grad() - def step(self, closure=None): - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - distributed = dist.is_available() and dist.is_initialized() - world_size = dist.get_world_size() if distributed else 1 - rank = dist.get_rank() if distributed else 0 - for group in self.param_groups: - params = group["params"] - if not params: - continue - lr, momentum = group["lr"], group["momentum"] - backend_steps, nesterov = group["backend_steps"], group["nesterov"] - total_params = sum(int(p.numel()) for p in params) - updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) - curr = 0 - for i, p in enumerate(params): - if i % world_size == rank and p.grad is not None: - g = p.grad - state = self.state[p] - if "momentum_buffer" not in state: - state["momentum_buffer"] = torch.zeros_like(g) - buf = state["momentum_buffer"] - buf.mul_(momentum).add_(g) - if nesterov: - g = g.add(buf, alpha=momentum) - g = F.rms_norm(g.float(), (g.size(-1),)).bfloat16() - g = ns_orth(g, steps=backend_steps) - g *= max(1, g.size(0) / g.size(1)) ** 0.5 - updates_flat[curr:curr + p.numel()] = g.reshape(-1) - curr += p.numel() - if distributed: - dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) - wd = group.get("wd", 0.0) - curr = 0 - for p in params: - g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) - if wd > 0: - p.mul_(1 - lr * wd) - p.add_(g, alpha=-lr) - curr += p.numel() - return loss - -# --------------------------------------------------------------------------- -# Data loading -# --------------------------------------------------------------------------- -def ld_shard(file: Path) -> Tensor: - header_bytes = 256 * np.dtype(" Tensor: - chunks = [] - remaining = n - while remaining > 0: - avail = self.tokens.numel() - self.pos - if avail <= 0: - self._advance_file() - continue - k = min(remaining, avail) - chunks.append(self.tokens[self.pos:self.pos + k]) - self.pos += k - remaining -= k - return chunks[0] if len(chunks) == 1 else torch.cat(chunks) - -class DistributedTokenLoader: - def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): - self.rank, self.world_size, self.device = rank, world_size, device - self.stream = TokenStream(pattern) - def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: - local_tokens = global_tokens // (self.world_size * grad_accum_steps) - per_rank_span = local_tokens + 1 - chunk = self.stream.take(per_rank_span * self.world_size) - start = self.rank * per_rank_span - local = chunk[start:start + per_rank_span].pin_memory().to(self.device, non_blocking=True).to(torch.int64) - x = local[:-1].reshape(-1, seq_len) - y = local[1:].reshape(-1, seq_len) - return x, y -# --------------------------------------------------------------------------- -# Model -# --------------------------------------------------------------------------- -class RMSNorm(nn.Module): - def __init__(self, eps: float | None = None): - super().__init__() - self.eps = eps - def forward(self, x: Tensor) -> Tensor: - return F.rms_norm(x, (x.size(-1),), eps=self.eps) - -def apply_qat_ste(w: Tensor, fp_storage: str | bool) -> Tensor: - """Applies Straight-Through Estimator (STE) for FP4 or FP8 simulated quantization.""" - if not fp_storage: - return w - if fp_storage == "fp4": - absmax = w.abs().amax(dim=-1, keepdim=True).clamp(min=1e-8) - scale = absmax / 7.0 - q = torch.clamp(torch.round(w / scale), -7.0, 7.0) - w_sim = q * scale - return (w_sim - w).detach() + w - elif fp_storage is True or fp_storage == "fp8": - w_sim = w.to(torch.float8_e4m3fn).to(w.dtype) - return (w_sim - w).detach() + w - return w - -class QATLinear(nn.Linear): - def __init__(self, in_features: int, out_features: int, bias: bool = False, fp_storage: str | bool = False): - super().__init__(in_features, out_features, bias=bias) - self.fp_storage = fp_storage - def forward(self, x: Tensor) -> Tensor: - w_qat = apply_qat_ste(self.weight, self.fp_storage) - return F.linear(x, w_qat.to(x.dtype), self.bias.to(x.dtype) if self.bias is not None else None) - -class QATEmbedding(nn.Embedding): - def __init__(self, num_embeddings: int, embedding_dim: int, fp_storage: str | bool = False): - super().__init__(num_embeddings, embedding_dim) - self.fp_storage = fp_storage - def forward(self, input: Tensor) -> Tensor: - w_qat = apply_qat_ste(self.weight, self.fp_storage) - return F.embedding(input, w_qat, self.padding_idx, self.max_norm, - self.norm_type, self.scale_grad_by_freq, self.sparse) - -class BinaryLinear(nn.Linear): - def __init__(self, in_features, out_features, bias=False, group_size=64): - super().__init__(in_features, out_features, bias=bias) - self.group_size = group_size - def forward(self, x: Tensor) -> Tensor: - w = self.weight.bfloat16() - g = self.group_size - w_g = w.reshape(-1, g) - scale = w_g.abs().mean(-1, keepdim=True).clamp(min=1e-8) - q = torch.where(w_g >= 0, torch.ones_like(w_g), -torch.ones_like(w_g)) - w_binary = w + ((q * scale).reshape(w.shape) - w).detach() - return F.linear(x, w_binary, - self.bias.to(x.dtype) if self.bias is not None else None) - -class NormedBinaryLinear(BinaryLinear): - "Binary linear with RMSNorm on input — for output projections receiving un-normalized activations." - def forward(self, x: Tensor) -> Tensor: - return super().forward(F.rms_norm(x, (x.size(-1),))) - -class GroupedBinaryLinear(nn.Module): - "Grouped linear with binary STE. Weight stored as 2D [groups*group_out, group_in] for binary quantization compatibility." - def __init__(self, in_features, out_features, groups=4, group_size=64, normed=False): - super().__init__() - assert in_features % groups == 0 and out_features % groups == 0 - self.groups = groups - self.group_in = in_features // groups - self.group_out = out_features // groups - self.group_size = group_size - self.normed = normed - self.weight = nn.Parameter(torch.randn(groups * self.group_out, self.group_in) * 0.02) - def forward(self, x: Tensor) -> Tensor: - if self.normed: - x = F.rms_norm(x, (x.size(-1),)) - w = self.weight.bfloat16() - g = self.group_size - w_g = w.reshape(-1, g) - scale = w_g.abs().mean(-1, keepdim=True).clamp(min=1e-8) - q = torch.where(w_g >= 0, torch.ones_like(w_g), -torch.ones_like(w_g)) - w_binary = w + ((q * scale).reshape(w.shape) - w).detach() - w_grouped = w_binary.reshape(self.groups, self.group_out, self.group_in) - bsz = x.shape[:-1] - x_g = x.reshape(*bsz, self.groups, self.group_in) - out = torch.einsum('...gi,goi->...go', x_g, w_grouped) - return out.reshape(*bsz, self.groups * self.group_out) - -class TverskyProjection(nn.Module): - "Tversky similarity: S = θ·f(A∩B) - α·f(A\\B) - β·f(B\\A). Three modes." - def __init__(self, in_features: int, out_features: int, num_features: int = 16, - group_size: int = 64, use_shared_features: bool = False, - membership: str = "sigmoid"): - super().__init__() - self.group_size = group_size - self.num_features = num_features - self.membership_type = membership - self.no_features_mode = (num_features == 0) - if not self.no_features_mode and not use_shared_features: - self.features = nn.Parameter(torch.empty(num_features, in_features).uniform_(-0.02, 0.02)) - else: - self.register_parameter('features', None) - self.prototypes = nn.Parameter(torch.empty(out_features, in_features).uniform_(-0.02, 0.02)) - self.theta = nn.Parameter(torch.tensor(1.0)) - self.alpha = nn.Parameter(torch.tensor(0.5)) - self.beta = nn.Parameter(torch.tensor(0.5)) - - def _binary_ste(self, w: Tensor) -> Tensor: - w_bf16 = w.bfloat16() - g = self.group_size - w_grouped = w_bf16.reshape(-1, g) - scale = w_grouped.abs().mean(-1, keepdim=True).clamp(min=1e-8) - q = torch.where(w_grouped >= 0, torch.ones_like(w_grouped), -torch.ones_like(w_grouped)) - w_binary = w_bf16 + ((q * scale).reshape(w_bf16.shape) - w_bf16).detach() - return w_binary.reshape(w.shape) - - def _membership(self, t: Tensor) -> Tensor: - if self.membership_type == "poly": - return torch.clamp(t * 5.0 / 4.0 + 0.5, 0.0, 1.0) - elif self.membership_type == "tanh": - return (torch.tanh(t * 5.0) + 1.0) * 0.5 - else: - return torch.sigmoid(t * 5.0) - - def forward(self, x: Tensor, shared_features: Tensor | None = None) -> Tensor: - proto = self._binary_ste(self.prototypes) - if self.no_features_mode: - x_f = x @ proto.t() - p_norm = F.normalize(proto, dim=-1) - p_f = p_norm @ p_norm.t() - else: - feat = (shared_features if shared_features is not None else self.features).float() - x_f = x @ feat.t() - p_f = proto @ feat.t() - x_s = self._membership(x_f) - p_s = self._membership(p_f) - x_a = x_f * x_s - p_a = p_f * p_s - t, a, b = self.theta.abs(), self.alpha.abs(), self.beta.abs() - return t * (x_a @ p_a.t()) - a * (x_a @ (1 - p_s).t()) - b * ((1 - x_s) @ p_a.t()) - -def restore_low_dim_params_to_fp32(module: nn.Module) -> None: - with torch.no_grad(): - for name, param in module.named_parameters(): - if (param.ndim < 2 or any(p in name for p in CTP)) and param.dtype != torch.float32: - param.data = param.data.float() - -class Rotary(nn.Module): - def __init__(self, dim: int, base: float = 10000.0, no_cache: bool = False, - rope_type: str = "rope", yarn_max_len: int = 4096, train_seq_len: int = 1024): - super().__init__() - self.no_cache = no_cache - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) - if rope_type == "yarn": - scale = train_seq_len / yarn_max_len - freq_idx = torch.arange(0, dim, 2, dtype=torch.float32) - ramp = torch.clamp((freq_idx / dim - 0.25) / 0.75, 0.0, 1.0) - inv_freq = inv_freq / (ramp * (1.0 / scale - 1.0) + 1.0) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self._seq_len_cached = 0 - self._cos_cached: Tensor | None = None - self._sin_cached: Tensor | None = None - def forward(self, seq_len, device, dtype): - if self.no_cache: - t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) - freqs = torch.outer(t, self.inv_freq.to(device)) - return freqs.cos()[None, :, None, :].to(dtype=dtype), freqs.sin()[None, :, None, :].to(dtype=dtype) - if ( - self._cos_cached is None - or self._sin_cached is None - or self._seq_len_cached != seq_len - or self._cos_cached.device != device - ): - t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) - freqs = torch.outer(t, self.inv_freq.to(device)) - self._cos_cached = freqs.cos()[None, :, None, :] - self._sin_cached = freqs.sin()[None, :, None, :] - self._seq_len_cached = seq_len - return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) - -def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: - half = x.size(-1) // 2 - x1, x2 = x[..., :half], x[..., half:] - return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) - -class CausalSelfAttention(nn.Module): - def __init__(self, dim, num_heads, num_kv_heads, rope_base, qk_gain_init, - group_size=64, attn_proj_type="standard", tversky_num_features=16, - tversky_feature_pools=0, no_cache=False, rope_type="rope", - yarn_max_len=4096, train_seq_len=1024, tversky_membership="sigmoid", - diff_attn=False): - super().__init__() - self.num_heads, self.num_kv_heads = num_heads, num_kv_heads - self.head_dim = dim // num_heads - self.diff_attn = diff_attn - self.q_size = self.num_heads * self.head_dim - self.kv_size = self.num_kv_heads * self.head_dim - self.c_qkv = BinaryLinear(dim, self.q_size + 2 * self.kv_size, bias=False, group_size=group_size) - self.proj = NormedBinaryLinear(dim, dim, bias=False, group_size=group_size) if attn_proj_type != "tversky" else None - if self.proj is not None: - self.proj._zero_init = True - self.tversky_proj = TverskyProjection( - dim, dim, num_features=tversky_num_features, group_size=group_size, - use_shared_features=(tversky_feature_pools > 0), - membership=tversky_membership, - ) if attn_proj_type == "tversky" else None - self.shared_features = None - self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) - if diff_attn: - self.diff_lambda = nn.Parameter(torch.full((num_heads,), 0.5, dtype=torch.float32)) - self.rotary = Rotary(self.head_dim, base=rope_base, no_cache=no_cache, - rope_type=rope_type, yarn_max_len=yarn_max_len, - train_seq_len=train_seq_len) - def forward(self, x: Tensor) -> Tensor: - bsz, seqlen, dim = x.shape - qkv_out = self.c_qkv(x) - q_out, k_out, v_out = qkv_out.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - q = q_out.reshape(bsz, seqlen, self.num_heads, self.head_dim) - k = k_out.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) - v = v_out.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) - q, k = F.rms_norm(q, (q.size(-1),)), F.rms_norm(k, (k.size(-1),)) - cos, sin = self.rotary(seqlen, x.device, q.dtype) - q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin) - q = q * self.q_gain.to(dtype=q.dtype)[None, None, :, None] - if self.diff_attn: - half = self.head_dim // 2 - q1, q2 = q[..., :half], q[..., half:] - k1, k2 = k[..., :half], k[..., half:] - v1, v2 = v[..., :half], v[..., half:] - y1 = flash_attn_func(q1.contiguous(), k1.contiguous(), v1.contiguous(), causal=True) - y2 = flash_attn_func(q2.contiguous(), k2.contiguous(), v2.contiguous(), causal=True) - lam = self.diff_lambda.to(dtype=y1.dtype)[None, None, :, None] - y = torch.cat([y1 - lam * y2, y1 + lam * y2], dim=-1) - else: - y = flash_attn_func( - q.contiguous(), - k.contiguous(), - v.contiguous(), - causal=True - ) - y = y.reshape(bsz, seqlen, dim) - return self.tversky_proj(y, self.shared_features) if self.tversky_proj is not None else self.proj(y) - -class MLP(nn.Module): - def __init__(self, dim, mlp_mult, group_size=64, activation="swiglu", mlp_groups=0): - super().__init__() - hidden = mlp_mult * dim - self.activation = activation - if mlp_groups > 0: - if activation == "swiglu": - self.gate_up = GroupedBinaryLinear(dim, hidden * 2, groups=mlp_groups, group_size=group_size) - else: - self.fc = GroupedBinaryLinear(dim, hidden, groups=mlp_groups, group_size=group_size) - self.proj = GroupedBinaryLinear(hidden, dim, groups=mlp_groups, group_size=group_size, normed=True) - else: - if activation == "swiglu": - self.gate_up = BinaryLinear(dim, hidden * 2, bias=False, group_size=group_size) - else: - self.fc = BinaryLinear(dim, hidden, bias=False, group_size=group_size) - self.proj = NormedBinaryLinear(hidden, dim, bias=False, group_size=group_size) - self.proj._zero_init = True - def forward(self, x: Tensor) -> Tensor: - if self.activation == "swiglu": - gu = self.gate_up(x) - gate, up = gu.chunk(2, dim=-1) - return self.proj(F.silu(gate) * up) - elif self.activation == "relu": - return self.proj(torch.relu(self.fc(x))) - elif self.activation == "leaky_relu": - return self.proj(F.leaky_relu(self.fc(x), negative_slope=0.01)) - else: # relu2 - return self.proj(torch.relu(self.fc(x)).square()) - -class SmearModule(nn.Module): - def __init__(self, dim: int): - super().__init__() - self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) - def forward(self, x: Tensor) -> Tensor: - cumsum = x.cumsum(dim=1) - counts = torch.arange(1, x.size(1) + 1, device=x.device, dtype=x.dtype).view(1, -1, 1) - smeared = cumsum / counts - gate = torch.tanh(self.gate.to(dtype=x.dtype)) - return x + gate * (smeared - x) - -class CausalConvRefiner(nn.Module): - "Causal Conv1d that refines hidden states using local n-gram context." - def __init__(self, dim: int, kernel_size: int = 3): - super().__init__() - self.kernel_size = kernel_size - self.conv = nn.Conv1d(dim, dim, kernel_size, padding=0, bias=False) - self.gate = nn.Parameter(torch.zeros(1, dtype=torch.float32)) - def forward(self, x: Tensor) -> Tensor: - h = x.permute(0, 2, 1) - h = F.pad(h, (self.kernel_size - 1, 0)) - h = self.conv(h) - h = h.permute(0, 2, 1) - return x + torch.tanh(self.gate.to(dtype=x.dtype)) * F.rms_norm(h, (h.size(-1),)) - -class Block(nn.Module): - def __init__(self, dim: int, num_heads: int, num_kv_heads: int, mlp_mult: int, - rope_base: float, qk_gain_init: float, group_size: int=64, - activation: str="swiglu", attn_proj_type: str="standard", - tversky_num_features: int=16, tversky_feature_pools: int=0, no_cache: bool=False, - smear: bool=False, rope_type: str="rope", yarn_max_len: int=4096, - train_seq_len: int=1024, tversky_membership: str="sigmoid", - diff_attn: bool=False, mlp_groups: int=0): - super().__init__() - self.attn_norm = RMSNorm() - self.mlp_norm = RMSNorm() - self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init, - group_size, attn_proj_type, tversky_num_features, - tversky_feature_pools, no_cache, rope_type, yarn_max_len, - train_seq_len, tversky_membership, diff_attn) - self.mlp = MLP(dim, mlp_mult, group_size, activation, mlp_groups) - self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) - self.smear = SmearModule(dim) if smear else None - def forward(self, x: Tensor, x0: Tensor) -> Tensor: - mix = self.resid_mix.to(dtype=x.dtype) - x = mix[0] * x + mix[1] * x0 - n = self.attn_norm(x) - x = x + self.attn_scale.to(dtype=x.dtype) * self.attn(n) - x = x + self.mlp_scale.to(dtype=x.dtype) * self.mlp(self.mlp_norm(x)) - if self.smear is not None: - x = self.smear(x) - return x - -class GPT(nn.Module): - def __init__(self, vocab_size, num_layers, model_dim, num_heads, num_kv_heads, mlp_mult, - tie_embeddings, tied_embed_init_std, logit_softcap, rope_base, qk_gain_init, - group_size: int = 64, activation: str = "swiglu", mtp_heads_count: int = 0, - embed_dim: int = 0, attn_proj_type: str = "standard", logit_head_type: str = "standard", - tversky_num_features: int = 16, tversky_feature_pools: int = 0, - training_depth_recurrence: int=1, fp_storage=False, bigram_hash: bool=False, - softcap_type: str="poly", no_cache: bool=False, - smear: bool=False, rope_type: str="rope", yarn_max_len: int=4096, - train_seq_len: int=1024, tversky_membership: str="sigmoid", - diff_attn=False, mlp_groups=0, refiner=False, refiner_kernel=3): - super().__init__() - self.training_depth_recurrence = training_depth_recurrence - self.fp_storage = fp_storage - self.tie_embeddings = tie_embeddings - self.logit_softcap = logit_softcap - self.softcap_type = softcap_type - self.embed_dim = embed_dim if embed_dim > 0 else model_dim - self.tok_emb = QATEmbedding(vocab_size, self.embed_dim, fp_storage=fp_storage) - self.bigram_emb = QATEmbedding(vocab_size, self.embed_dim, fp_storage=fp_storage) if bigram_hash else None - if self.bigram_emb is not None: - nn.init.zeros_(self.bigram_emb.weight) - self.lm_head_correction = nn.Parameter( - torch.zeros(vocab_size, self.embed_dim)) if tie_embeddings == 2 else None - self.embed_proj = QATLinear(self.embed_dim, model_dim, bias=False, fp_storage=fp_storage) if self.embed_dim != model_dim else None - self.embed_proj_rev = QATLinear(model_dim, self.embed_dim, bias=False, fp_storage=fp_storage) if ( - self.embed_dim != model_dim and logit_head_type != "tversky") else None - self.num_encoder_layers = num_layers // 2 - self.num_decoder_layers = num_layers - self.num_encoder_layers - self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) - self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) - # Shared Tversky feature pools (if enabled and num_features > 0) - if attn_proj_type == "tversky" and tversky_feature_pools > 0 and tversky_num_features > 0: - self.tversky_feature_pools_list = nn.ParameterList([ - nn.Parameter(torch.empty(tversky_num_features, model_dim).uniform_(-0.02, 0.02)) - for _ in range(tversky_feature_pools) - ]) - else: - self.tversky_feature_pools_list = None - self.blocks = nn.ModuleList([ - Block(model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init, - group_size, activation, attn_proj_type, tversky_num_features, tversky_feature_pools, - no_cache, smear, rope_type, yarn_max_len, train_seq_len, tversky_membership, - diff_attn, mlp_groups) - for _ in range(num_layers) - ]) - # Inject shared feature pool references into attention layers - if self.tversky_feature_pools_list is not None: - for i, block in enumerate(self.blocks): - pool_idx = (i * tversky_feature_pools) // num_layers - block.attn.shared_features = self.tversky_feature_pools_list[pool_idx] - self.final_norm = RMSNorm() - self.refiner = CausalConvRefiner(model_dim, kernel_size=refiner_kernel) if refiner else None - self.mtp_heads = nn.ModuleList([ - nn.Linear(model_dim, vocab_size, bias=False) for _ in range(mtp_heads_count) - ]) - for h in self.mtp_heads: - nn.init.zeros_(h.weight) - self.logit_head_type = logit_head_type - if logit_head_type == "tversky" and tversky_num_features == 0 and vocab_size > 1024: - raise ValueError( - f"Tversky logit head with no-features mode creates O(V^2) = {vocab_size}x{vocab_size} " - f"matrix per forward pass. Use tversky_num_features > 0 or a smaller vocab." - ) - self.tversky_head = TverskyProjection( - model_dim, vocab_size, num_features=tversky_num_features, - membership=tversky_membership, - ) if logit_head_type == "tversky" else None - self.lm_head = QATLinear(model_dim, vocab_size, bias=False, fp_storage=fp_storage) - self.lm_head._zero_init = True - if self.lm_head is not None and (tie_embeddings or logit_head_type == "tversky"): - self.lm_head.weight.requires_grad_(False) - self.vocab_bias = nn.Parameter(torch.zeros(vocab_size, dtype=torch.float32)) - self._init_weights(tied_embed_init_std) - def _init_weights(self, tied_embed_init_std: float) -> None: - if self.tie_embeddings: - nn.init.normal_(self.tok_emb.weight, mean=0.0, std=tied_embed_init_std) - for module in self.modules(): - if isinstance(module, BinaryLinear) and not getattr(module, "_zero_init", False): - nn.init.normal_(module.weight, mean=0.0, std=0.02) - elif isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): - nn.init.zeros_(module.weight) - def _compute_logits(self, x: Tensor) -> Tensor: - if self.tversky_head is not None: - logits_raw = self.tversky_head(x) - elif self.tie_embeddings: - if self.embed_proj_rev is not None: - proj = self.embed_proj_rev(x) - else: - proj = x - weight = self.tok_emb.weight - if self.lm_head_correction is not None: - weight = weight + self.lm_head_correction - logits_raw = F.linear(proj, weight.to(x.dtype)) - else: - logits_raw = self.lm_head(x) - return logits_raw + self.vocab_bias.to(x.dtype) - def _softcap(self, logits: Tensor) -> Tensor: - s = self.logit_softcap - if self.softcap_type == "tanh": - return s * torch.tanh(logits / s) - x_sc = torch.clamp(logits / s, -2.0, 2.0) - x2 = x_sc * x_sc - return s * torch.clamp(x_sc * (1.0 - x2 / 3.0 + x2 * x2 / 15.0), -1.0, 1.0) - def forward(self, input_ids: Tensor, target_ids: Tensor, reduction: str = "mean", temperature: float = 1.0) -> Tensor: - x = self.tok_emb(input_ids).float() - if self.bigram_emb is not None: - prev = F.pad(input_ids[:, :-1], (1, 0), value=0) - x = x + self.bigram_emb(prev).float() - if self.embed_proj is not None: - x = self.embed_proj(x) - x = F.rms_norm(x, (x.size(-1),)) - x0 = x - # U-Net style encoder/decoder with skip connections - skips = [] - for i in range(self.num_encoder_layers): - for _ in range(max(1, self.training_depth_recurrence)): - x = self.blocks[i](x, x0) - skips.append(x) - for i in range(self.num_decoder_layers): - bi = self.num_encoder_layers + i - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype) * skips.pop() - for _ in range(max(1, self.training_depth_recurrence)): - x = self.blocks[bi](x, x0) - x_normed = self.final_norm(x) - if self.refiner is not None: - x_normed = self.refiner(x_normed) - # Standard training/eval path - x_flat = x_normed.reshape(-1, x_normed.size(-1)) - targets = target_ids.reshape(-1) - logits = self._softcap(self._compute_logits(x_flat)) - if reduction == "none": - return F.cross_entropy(logits.float(), targets, reduction="none").reshape(input_ids.shape) - # Fused CE + Z-loss: single logsumexp computation - logits_f = logits.float() - lse = torch.logsumexp(logits_f, dim=-1) - target_logits = logits_f.gather(1, targets.unsqueeze(1)).squeeze(1) - main_loss = (lse - target_logits).mean() + 1e-4 * (lse ** 2).mean() - # Multi-token prediction auxiliary loss (training only) - if self.training and len(self.mtp_heads) > 0: - mtp_loss = torch.zeros((), device=main_loss.device) - for k, head in enumerate(self.mtp_heads): - shift = k + 2 - if target_ids.shape[1] > shift: - mtp_tgt = target_ids[:, shift:].reshape(-1) - mtp_in = x_normed[:, :target_ids.shape[1] - shift, :].reshape(-1, x_normed.shape[-1]) - mtp_loss = mtp_loss + F.cross_entropy(head(mtp_in).float(), mtp_tgt, reduction="mean") - main_loss = main_loss + 0.1 * mtp_loss / len(self.mtp_heads) - return main_loss - -# --------------------------------------------------------------------------- -# Validation -# --------------------------------------------------------------------------- -def build_luts(sp, vocab_size: int, device: torch.device): - sp_vocab_size = int(sp.vocab_size()) - table_size = max(sp_vocab_size, vocab_size) - base_bytes_np = np.zeros((table_size,), dtype=np.int16) - has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) - is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) - for token_id in range(sp_vocab_size): - if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): - continue - is_boundary_token_np[token_id] = False - if sp.is_byte(token_id): - base_bytes_np[token_id] = 1 - continue - piece = sp.id_to_piece(token_id) - if piece.startswith("\u2581"): - has_leading_space_np[token_id] = True - piece = piece[1:] - base_bytes_np[token_id] = len(piece.encode("utf-8")) - return ( - torch.tensor(base_bytes_np, dtype=torch.int16, device=device), - torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), - torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), - ) - -def ld_val(pattern, seq_len, max_tok=int(os.environ.get("VAL_MAX_TOKENS", 500000))): - files = sorted(glob.glob(pattern)) - assert files, f"No files: {pattern}" - tok = torch.cat([ld_shard(Path(p)) for p in files]).contiguous() - if max_tok > 0: tok = tok[:max_tok + 1] - u = ((tok.numel() - 1) // seq_len) * seq_len - return tok[:u + 1] - -def eval_val(args, model, rank, world_size, device, grad_accum_steps, val_tokens, - base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, temperature: float = 1.0): - local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) - local_batch_seqs = max(1, local_batch_tokens // args.train_seq_len) - total_seqs = (val_tokens.numel() - 1) // args.train_seq_len - seq_start = (total_seqs * rank) // world_size - seq_end = (total_seqs * (rank + 1)) // world_size - loss_sum = torch.zeros((), device=device, dtype=torch.float64) - token_count = torch.zeros((), device=device, dtype=torch.float64) - byte_count = torch.zeros((), device=device, dtype=torch.float64) - model.eval() - with torch.inference_mode(): - for batch_start in range(seq_start, seq_end, local_batch_seqs): - batch_end = min(batch_start + local_batch_seqs, seq_end) - raw_start = batch_start * args.train_seq_len - raw_end = batch_end * args.train_seq_len + 1 - local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64) - x, y = local[:-1].reshape(-1, args.train_seq_len), local[1:].reshape(-1, args.train_seq_len) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16): - batch_loss = model(x, y, temperature=temperature).detach() - n = float(y.numel()) - loss_sum += batch_loss.to(torch.float64) * n - token_count += n - prev_ids, tgt_ids = x.reshape(-1), y.reshape(-1) - tok_bytes = base_bytes_lut[tgt_ids].to(torch.int16) - tok_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(torch.int16) - byte_count += tok_bytes.to(torch.float64).sum() - if dist.is_available() and dist.is_initialized(): - for t in (loss_sum, token_count, byte_count): - dist.all_reduce(t, op=dist.ReduceOp.SUM) - val_loss = loss_sum / token_count - bpb = (val_loss.item() / math.log(2.0)) * (token_count.item() / byte_count.item()) - model.train() - return float(val_loss.item()), float(bpb) - -def eval_val_sliding(args, model, rank, world_size, device, grad_accum_steps, val_tokens, - base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - stride: int = 64, temperature: float = 1.0): - seq_len = args.train_seq_len - batch_size = args.sliding_batch_size - total_tokens = val_tokens.numel() - 1 - loss_sum = torch.zeros((), device=device, dtype=torch.float64) - token_count = torch.zeros((), device=device, dtype=torch.float64) - byte_count = torch.zeros((), device=device, dtype=torch.float64) - all_starts = list(range(0, total_tokens - seq_len, stride)) - my_starts = all_starts[rank::world_size] - model.eval() - with torch.inference_mode(): - for i in range(0, len(my_starts), batch_size): - batch_starts = my_starts[i:i + batch_size] - starts_t = torch.tensor(batch_starts, dtype=torch.int64) - offsets = torch.arange(seq_len + 1, dtype=torch.int64) - indices = starts_t.unsqueeze(1) + offsets.unsqueeze(0) - local_batch = val_tokens[indices].to(device=device, dtype=torch.int64, non_blocking=True) - x = local_batch[:, :-1] - y = local_batch[:, 1:] - with torch.autocast(device_type="cuda", dtype=torch.bfloat16): - per_token_loss = model(x, y, reduction="none", temperature=temperature).detach() - for b, start in enumerate(batch_starts): - score_from = 0 if start == 0 else seq_len - stride - scored = per_token_loss[b, score_from:] - sx, sy = x[b, score_from:], y[b, score_from:] - loss_sum += scored.to(torch.float64).sum() - token_count += scored.numel() - tok_bytes = base_bytes_lut[sy].to(torch.int16) - tok_bytes += (has_leading_space_lut[sy] & ~is_boundary_token_lut[sx]).to(torch.int16) - byte_count += tok_bytes.to(torch.float64).sum() - if dist.is_available() and dist.is_initialized(): - for t in (loss_sum, token_count, byte_count): - dist.all_reduce(t, op=dist.ReduceOp.SUM) - val_loss = loss_sum / token_count - bpb = (val_loss.item() / math.log(2.0)) * (token_count.item() / byte_count.item()) - model.train() - return float(val_loss.item()), float(bpb) - -# --------------------------------------------------------------------------- -# Temperature scaling -# --------------------------------------------------------------------------- -def find_temp(args, base_model, rank, world_size, device, grad_accum_steps, - calibration_tokens, base_bytes_lut, has_leading_space_lut, - is_boundary_token_lut): - best_t, best_loss = 1.0, float("inf") - for t in [0.90, 0.95, 1.00, 1.05, 1.10]: - loss, _ = eval_val(args, base_model, rank, world_size, device, grad_accum_steps, - calibration_tokens, base_bytes_lut, has_leading_space_lut, - is_boundary_token_lut, temperature=t) - if loss < best_loss: - best_loss = loss - best_t = t - return best_t - -# --------------------------------------------------------------------------- -# Training -# --------------------------------------------------------------------------- -def main() -> None: - args = Hyperparameters() - code = Path(__file__).read_text(encoding="utf-8") - if args.matrix_optimizer != "adamw": - global ns_orth - ns_orth = torch.compile(ns_orth) - distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ - rank = int(os.environ.get("RANK", "0")) - world_size = int(os.environ.get("WORLD_SIZE", "1")) - local_rank = int(os.environ.get("LOCAL_RANK", "0")) - grad_accum_steps = max(1, 8 // world_size) - grad_scale = 1.0 / grad_accum_steps - if not torch.cuda.is_available(): - raise RuntimeError("CUDA is required") - device = torch.device("cuda", local_rank) - torch.cuda.set_device(device) - if distributed: - dist.init_process_group(backend="nccl", device_id=device) - dist.barrier() - master_process = rank == 0 - torch.backends.cuda.matmul.allow_tf32 = True - torch.backends.cudnn.allow_tf32 = True - os.makedirs("logs/cuda/", exist_ok=True) - logfile = f"logs/cuda/{args.run_id}.txt" if master_process else None - if master_process: - print(logfile) - def log0(msg: str, console: bool = True) -> None: - if not master_process: - return - if console: - print(msg) - if logfile: - with open(logfile, "a", encoding="utf-8") as f: - print(msg, file=f) - log0(code, console=False) - log0("=" * 100, console=False) - log0(f"Python {sys.version}", console=False) - log0(f"PyTorch {torch.__version__}", console=False) - random.seed(args.seed) - np.random.seed(args.seed) - torch.manual_seed(args.seed) - torch.cuda.manual_seed_all(args.seed) - sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) - val_tokens = ld_val(args.val_files, args.train_seq_len) - base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_luts( - sp, args.vocab_size, device) - - # --- Model --- - base_model = GPT( - vocab_size=args.vocab_size, num_layers=args.num_layers, model_dim=args.model_dim, - num_heads=args.num_heads, num_kv_heads=args.num_kv_heads, mlp_mult=args.mlp_mult, - tie_embeddings=args.tie_embeddings, tied_embed_init_std=args.tied_embed_init_std, - logit_softcap=args.logit_softcap, rope_base=args.rope_base, qk_gain_init=args.qk_gain_init, - group_size=args.bitnet_group_size, activation=args.activation_type, mtp_heads_count=args.mtp_heads_count, - embed_dim=args.embed_dim, attn_proj_type=args.attn_proj_type, logit_head_type=args.logit_head_type, - tversky_num_features=args.tversky_num_features, tversky_feature_pools=args.tversky_feature_pools, - training_depth_recurrence=args.training_depth_recurrence, fp_storage=args.fp_storage, - bigram_hash=args.bigram_hash, softcap_type=args.softcap_type, no_cache=(args.compile_mode == "reduce-overhead"), - smear=args.smear, rope_type=args.rope_type, yarn_max_len=args.yarn_max_len, train_seq_len=args.train_seq_len, - tversky_membership=args.tversky_membership, diff_attn=args.diff_attn, - refiner=args.refiner, refiner_kernel=args.refiner_kernel, mlp_groups=args.mlp_groups, - ).to(device).bfloat16() - for module in base_model.modules(): - if isinstance(module, nn.Linear): - module.float() - restore_low_dim_params_to_fp32(base_model) - if base_model.lm_head is not None and (args.tie_embeddings or args.logit_head_type == "tversky"): - base_model.lm_head.weight.requires_grad_(False) - torch._dynamo.config.optimize_ddp = False - compiled_model = torch.compile(base_model, mode=args.compile_mode if args.compile_mode != "default" else None) - use_find_unused = args.untie_at_fraction > 0 or args.mtp_heads_count > 0 or not args.tie_embeddings - model = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False, - find_unused_parameters=use_find_unused, - static_graph=not use_find_unused, - gradient_as_bucket_view=True) if distributed else compiled_model - - # --- Optimizers --- - _excl = {"tok_emb.weight", "lm_head.weight", "lm_head_correction"} - all_other_params = [(n, p) for n, p in base_model.named_parameters() - if not any(eh in n for eh in _excl)] - matrix_params = [p for n, p in all_other_params - if p.ndim == 2 and not any(pat in n for pat in CTP)] - scalar_params = [p for n, p in all_other_params - if p.ndim < 2 or any(pat in n for pat in CTP)] - token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr - opt_tok = torch.optim.Adam( - [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], - betas=(args.beta1, args.beta2), eps=args.adam_eps, fused=True) - if args.matrix_optimizer == "adamw": - opt_muon = torch.optim.AdamW( - [{"params": matrix_params, "lr": args.adam_lr, "base_lr": args.adam_lr}], - betas=(args.beta1, args.beta2), eps=args.adam_eps, weight_decay=args.adam_wd, fused=True) - else: - opt_muon = Muon(matrix_params, lr=args.matrix_lr, momentum=args.muon_momentum, - backend_steps=args.muon_backend_steps, wd=args.muon_wd) - for g in opt_muon.param_groups: - g["base_lr"] = args.matrix_lr - opt_scalar = torch.optim.Adam( - [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], - betas=(args.beta1, args.beta2), eps=args.adam_eps, fused=True) - opt_head = torch.optim.Adam( - [{"params": [base_model.lm_head.weight], "lr": 0.0, "base_lr": 0.0}], - betas=(args.beta1, args.beta2), eps=args.adam_eps, fused=True) - optimizers = [opt for opt in [opt_tok, opt_muon, opt_scalar, opt_head] if opt is not None] - if base_model.lm_head_correction is not None: - opt_corr = torch.optim.Adam( - [{"params": [base_model.lm_head_correction], - "lr": args.corr_weight_lr, "base_lr": args.corr_weight_lr}], - betas=(args.beta1, args.beta2), eps=args.adam_eps, fused=True) - optimizers.append(opt_corr) - - # --- Log all hyperparameters --- - log0("--- Hyperparameters ---", console=False) - log0(" ".join(f"{a}={getattr(args,a)}" for a in sorted(dir(args)) if not a.startswith("_") and a not in ("train_files","val_files") and not callable(getattr(args,a))), console=False) - n_params = sum(p.numel() for p in base_model.parameters()) - log0(f"params:{n_params} L:{args.num_layers} d:{args.model_dim} h:{args.num_heads} kv:{args.num_kv_heads} ws:{world_size} ga:{grad_accum_steps} s:{args.seed}") - # --- Data loader & helpers --- - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - def zero_grad_all(): - for opt in optimizers: - opt.zero_grad(set_to_none=True) - max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None - def lr_mul(step: int, elapsed_ms: float): - if args.warmdown_fraction <= 0: - return 1.0 - if max_wallclock_ms is None: - warmdown_start = int(args.iterations * (1.0 - args.warmdown_fraction)) - return max((args.iterations - step) / max(args.iterations * args.warmdown_fraction, 1), 0.0) if step >= warmdown_start else 1.0 - warmdown_ms = max_wallclock_ms * args.warmdown_fraction - remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) - return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 - _seq_switched = False - _batch_switched = False - active_seq_len = args.seq_len_start if args.seq_len_start > 0 else args.train_seq_len - active_batch_tokens = args.batch_tokens_start if args.batch_tokens_start > 0 else args.train_batch_tokens - # --- Compiler warmup --- - if args.warmup_steps > 0: - _ms = {n: t.detach().cpu().clone() for n, t in base_model.state_dict().items()} - _os = [copy.deepcopy(o.state_dict()) for o in optimizers] - model.train() - for ws in range(args.warmup_steps): - zero_grad_all() - for mi in range(grad_accum_steps): - if distributed: model.require_backward_grad_sync = mi == grad_accum_steps - 1 - x, y = train_loader.next_batch(active_batch_tokens, active_seq_len, grad_accum_steps) - torch.compiler.cudagraph_mark_step_begin() - with torch.autocast(device_type="cuda", dtype=torch.bfloat16): loss = model(x, y) - (loss * grad_scale).backward() - for o in optimizers: o.step() - zero_grad_all() - log0(f"warmup:{ws+1}/{args.warmup_steps}") - base_model.load_state_dict(_ms, strict=True) - for o, s in zip(optimizers, _os): o.load_state_dict(s) - zero_grad_all() - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - # --- EMA model --- - ema_model = None - _ema_started = False - _ema_steps = 0 - if args.ema: - ema_model = copy.deepcopy(base_model) - for p in ema_model.parameters(): - p.requires_grad_(False) - - # --- Main training loop --- - training_time_ms = 0.0 - stop_after_step: int | None = None - _untied = False - train_loss = torch.zeros((), device=device) - torch.cuda.synchronize() - t0 = time.perf_counter() - step = 0 - while True: - last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) - if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): - torch.cuda.synchronize() - training_time_ms += 1000.0 * (time.perf_counter() - t0) - val_loss, val_bpb = eval_val(args, model, rank, world_size, device, grad_accum_steps, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut) - log0(f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " - f"train_time:{training_time_ms:.0f}ms") - torch.cuda.synchronize() - t0 = time.perf_counter() - if last_step: - if stop_after_step is not None and step < args.iterations: - log0(f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms step:{step}/{args.iterations}") - break - elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - scale = lr_mul(step, elapsed_ms) - # Sequence length schedule - if args.seq_len_start > 0 and not _seq_switched: - if max_wallclock_ms is not None: - should_switch_seq = elapsed_ms >= args.seq_schedule_fraction * max_wallclock_ms - else: - should_switch_seq = step >= int(args.iterations * args.seq_schedule_fraction) - if should_switch_seq: - active_seq_len = args.train_seq_len - _seq_switched = True - torch._dynamo.reset() - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - log0(f"step:{step} seq_len_switch:{args.seq_len_start}->{active_seq_len}") - - # Batch size schedule - if args.batch_tokens_start > 0 and not _batch_switched: - if max_wallclock_ms is not None: - should_switch_batch = elapsed_ms >= args.batch_schedule_fraction * max_wallclock_ms - else: - should_switch_batch = step >= int(args.iterations * args.batch_schedule_fraction) - if should_switch_batch: - active_batch_tokens = args.train_batch_tokens - _batch_switched = True - log0(f"step:{step} batch_switch:{args.batch_tokens_start}->{active_batch_tokens}") - zero_grad_all() - train_loss.zero_() - for micro in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro == grad_accum_steps - 1 - x, y = train_loader.next_batch(active_batch_tokens, active_seq_len, grad_accum_steps) - torch.compiler.cudagraph_mark_step_begin() - with torch.autocast(device_type="cuda", dtype=torch.bfloat16): - loss = model(x, y) - train_loss.add_(loss.detach()) - (loss * grad_scale).backward() - train_loss /= grad_accum_steps - - # Untie lm_head at configured fraction of training - if args.untie_at_fraction > 0: - if max_wallclock_ms is not None: - should_untie = not _untied and elapsed_ms >= args.untie_at_fraction * max_wallclock_ms - else: - should_untie = not _untied and step >= int(args.iterations * args.untie_at_fraction) - if should_untie and base_model.tie_embeddings: - with torch.no_grad(): - base_weight = base_model.tok_emb.weight.float() - if base_model.lm_head_correction is not None: - base_weight = base_weight + base_model.lm_head_correction.float() - if base_model.embed_proj_rev is not None: - full_weight = base_weight @ base_model.embed_proj_rev.weight.float() - else: - full_weight = base_weight - base_model.lm_head.weight.copy_(full_weight) - base_model.tie_embeddings = False - base_model.lm_head.weight.requires_grad_(True) - for g in opt_head.param_groups: - g["lr"] = g["base_lr"] = args.head_lr - _untied = True - torch._dynamo.reset() - log0(f"step:{step} untied lm_head (head_lr={args.head_lr})") - - # Muon momentum warmup - if args.matrix_optimizer != "adam": - frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 - for g in opt_muon.param_groups: - g["momentum"] = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum - - # LR scheduling - for opt in optimizers: - for g in opt.param_groups: - g["lr"] = g["base_lr"] * scale - opt.step() - zero_grad_all() - # EMA update - if ema_model is not None: - if not _ema_started: - if max_wallclock_ms is not None: - should_start_ema = elapsed_ms >= args.ema_start_fraction * max_wallclock_ms - else: - should_start_ema = step >= int(args.iterations * args.ema_start_fraction) - if should_start_ema: - _ema_started = True - _ema_steps = 0 - with torch.no_grad(): - for ep, bp in zip(ema_model.parameters(), base_model.parameters()): - ep.data.copy_(bp.data) - log0(f"step:{step} ema_started") - if _ema_started: - _ema_steps += 1 - decay = min(args.ema_decay, (1.0 + _ema_steps) / (10.0 + _ema_steps)) - with torch.no_grad(): - for ep, bp in zip(ema_model.parameters(), base_model.parameters()): - ep.data.mul_(decay).add_(bp.data, alpha=1.0 - decay) - step += 1 - approx_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - - if args.train_log_every > 0 and step % args.train_log_every == 0: - log0(f"step:{step}/{args.iterations} loss:{train_loss.item():.4f} t:{approx_ms:.0f}ms avg:{approx_ms/step:.1f}ms") - if args.churn_log_every > 0 and step % args.churn_log_every == 0: - log0(f"step:{step} churn:{churn_fn(base_model, args.bitnet_group_size):.4f}") - # Wallclock cap sync - if stop_after_step is None and max_wallclock_ms is not None and step % 10 == 0: - reached_cap = approx_ms >= max_wallclock_ms - if distributed: - cap_t = torch.tensor(int(reached_cap), device=device) - dist.all_reduce(cap_t, op=dist.ReduceOp.MAX) - reached_cap = bool(cap_t.item()) - if reached_cap: - stop_after_step = step - - # --- Serialization --- - if master_process: - sd = (ema_model if ema_model is not None and _ema_started else base_model).state_dict() - if base_model.tie_embeddings or args.logit_head_type == "tversky": - sd.pop("lm_head.weight", None) - - # Compute binary overrides for no-features Tversky prototypes - binary_overrides = set() - for n, m in base_model.named_modules(): - if isinstance(m, TverskyProjection) and m.no_features_mode: - binary_overrides.add(n + ".prototypes") - binary_overrides = binary_overrides or None - q_obj, q_stats = q_sd(sd, group_size=args.bitnet_group_size, fp_storage=args.fp_storage, binary_override_names=binary_overrides) - buf = io.BytesIO() - torch.save(q_obj, buf) - final_blob = lzma.compress(buf.getvalue(), preset=9) - with open("final_model.binary.ptz", "wb") as f: - f.write(final_blob) - artifact_bytes = len(final_blob) - code_bytes = len(code.encode("utf-8")) - total = artifact_bytes + code_bytes - log0(f"artifact:{artifact_bytes/1e6:.2f}MB binary:{q_stats['binary_params']}({q_stats['binary_bytes']}B) fp:{q_stats['fp_params']}({q_stats['fp_bytes']}B) code:{code_bytes}") - log0(f"budget:{total}/{16000000} ({total/1e6:.2f}/{16.00:.2f}MB) {'FITS' if total <= 16000000 else 'OVER'}") - if args.eval_depth_recurrence > 0: - base_model.training_depth_recurrence = args.eval_depth_recurrence - log0(f"eval_depth_recurrence:{args.eval_depth_recurrence}") - - # --- All ranks load roundtrip weights and evaluate --- - if distributed: - dist.barrier() - with open("final_model.binary.ptz", "rb") as f: - loaded = torch.load(io.BytesIO(lzma.decompress(f.read())), map_location="cpu", weights_only=False) - base_model.load_state_dict(deq_sd(loaded), strict=False) - if ema_model is not None: - ema_model.load_state_dict(deq_sd(loaded), strict=False) - torch._dynamo.reset() - q_val_loss, q_val_bpb = eval_val(args, model, rank, world_size, device, grad_accum_steps, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut) - log0(f"final_binary_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f}") - - opt_temp = 1.0 - if args.temp_scaling: - torch.cuda.synchronize() - t_temp = time.perf_counter() - calibration_tokens = train_loader.stream.take(65536).to(device) - opt_temp = find_temp(args, base_model, rank, world_size, device, grad_accum_steps, - calibration_tokens, base_bytes_lut, has_leading_space_lut, - is_boundary_token_lut) - torch.cuda.synchronize() - temp_time_ms = 1000.0 * (time.perf_counter() - t_temp) - log0(f"temp_scaling optimal_T:{opt_temp:.2f} eval_time:{temp_time_ms:.0f}ms") - - if args.sliding_eval: - torch.cuda.synchronize() - t_sliding = time.perf_counter() - sw_loss, sw_bpb = eval_val_sliding(args, base_model, rank, world_size, device, grad_accum_steps, - val_tokens, base_bytes_lut, has_leading_space_lut, - is_boundary_token_lut, stride=args.sliding_eval_stride, - temperature=opt_temp) - torch.cuda.synchronize() - sliding_time_ms = 1000.0 * (time.perf_counter() - t_sliding) - log0(f"final_sliding val_loss:{sw_loss:.4f} val_bpb:{sw_bpb:.4f} " - f"(stride={args.sliding_eval_stride}, T={opt_temp:.2f}) eval_time:{sliding_time_ms:.0f}ms") - - if distributed: - dist.destroy_process_group() - - -if __name__ == "__main__": - main() diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index 911b0e52f0..0000000000 --- a/requirements.txt +++ /dev/null @@ -1,10 +0,0 @@ -numpy -tqdm -torch -huggingface-hub -kernels -setuptools -typing-extensions==4.15.0 -datasets -tiktoken -sentencepiece \ No newline at end of file diff --git a/submission.json b/submission.json new file mode 100644 index 0000000000..8c83cecbe1 --- /dev/null +++ b/submission.json @@ -0,0 +1,27 @@ +{ + "name": "pireylow", + "github_id": "pireylow", + "submission_type": "non-record (negative results)", + "val_bpb": 1.3696, + "val_bpb_note": "Best result: baseline + parallel residuals + TTT + QK 5.25, with sliding window eval and score-first TTT. Novel techniques (CAT, Sparsity, MoE, KAN) did not improve over baseline.", + "tokenizer": "sp8192", + "hardware": "1xH100 SXM (RunPod)", + "training_steps": 2000, + "training_mode": "medium (scaled-down from full 8xH100 config)", + "model_dim": 512, + "num_layers": 11, + "num_heads": 8, + "num_kv_heads": 4, + "total_params": "35.9M", + "artifact_size_bytes": 16076488, + "techniques_tested": [ + "CAT (Compressor-Aware Training)", + "2:4 Structured Sparsity", + "Hessian-Guided 2:4 Sparsity", + "Mixture of Experts (MoE)", + "KAN (Kolmogorov-Arnold Networks)", + "Parallel Residuals (GPT-J style)", + "Test-Time Training (Score-First SGD)", + "QK Gain tuning (4.0 vs 5.25)" + ] +} diff --git a/train_gpt.py b/train_gpt.py index 651beb2b89..04cce876de 100644 --- a/train_gpt.py +++ b/train_gpt.py @@ -1,1123 +1,1610 @@ -""" -The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. - -Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. -""" - -from __future__ import annotations - -import copy -import glob -import io -import math -import os -import random -import subprocess -import sys -import time -import uuid -import zlib +import collections, copy, glob, io, lzma, math, os from pathlib import Path +import random, re, subprocess, sys, time, uuid import numpy as np import sentencepiece as spm import torch import torch.distributed as dist import torch.nn.functional as F -from torch import Tensor, nn from torch.nn.parallel import DistributedDataParallel as DDP +from torch import Tensor, nn -# ----------------------------- -# HYPERPARAMETERS -# ----------------------------- -# Default Simple Baseline run: -# - 9 transformer blocks at width 512 -# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion -# - vocab size 1024, sequence length 1024, tied embeddings -# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap - -class Hyperparameters: - # Data paths are shard globs produced by the existing preprocessing pipeline. - data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") - train_files = os.path.join(data_path, "fineweb_train_*.bin") - val_files = os.path.join(data_path, "fineweb_val_*.bin") - tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") +try: + from flash_attn_interface import flash_attn_func as flash_attn_3_func + _HAS_FA3 = True +except ImportError: + flash_attn_3_func = None + _HAS_FA3 = False + + +def _sdpa_attn(q, k, v, causal=True): + q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) + if k.size(1) < q.size(1): + r = q.size(1) // k.size(1) + k = k.repeat_interleave(r, dim=1) + v = v.repeat_interleave(r, dim=1) + return F.scaled_dot_product_attention(q, k, v, is_causal=causal).transpose(1, 2) + + +# --------------------------------------------------------------------------- +# Test mode overrides +# --------------------------------------------------------------------------- +_TEST_MODE = os.environ.get("TEST_MODE", "full").lower() +_TEST_OVERRIDES = { + "smoke": { + "NUM_LAYERS": "5", "MODEL_DIM": "256", "NUM_HEADS": "4", "NUM_KV_HEADS": "2", + "EMBEDDING_DIM": "256", "ITERATIONS": "100", "WARMUP_STEPS": "5", + "WARMDOWN_FRAC": "0.667", "TRAIN_BATCH_TOKENS": str(512 * 16), + "VAL_BATCH_TOKENS": str(512 * 16), "TRAIN_SEQ_LEN": "512", "EVAL_SEQ_LEN": "512", + "VAL_LOSS_EVERY": "50", "TRAIN_LOG_EVERY": "10", "MAX_WALLCLOCK_SECONDS": "300", + "GPTQ_CALIBRATION_BATCHES": "4", "GPTQ_RESERVE_SECONDS": "5.0", "EMA_DECAY": "0.95", + "NUM_LOOPS": "0", "LOOP_START": "0", "LOOP_END": "0", "XSA_LAST_N": "5", + "SLIDING_WINDOW_ENABLED": "0", + }, + "small": { + "NUM_LAYERS": "7", "MODEL_DIM": "384", "NUM_HEADS": "6", "NUM_KV_HEADS": "3", + "EMBEDDING_DIM": "384", "ITERATIONS": "500", "WARMUP_STEPS": "10", + "WARMDOWN_FRAC": "0.667", "TRAIN_BATCH_TOKENS": str(1024 * 16), + "VAL_BATCH_TOKENS": str(1024 * 16), "TRAIN_SEQ_LEN": "1024", "EVAL_SEQ_LEN": "1024", + "VAL_LOSS_EVERY": "250", "TRAIN_LOG_EVERY": "50", "MAX_WALLCLOCK_SECONDS": "600", + "GPTQ_CALIBRATION_BATCHES": "8", "GPTQ_RESERVE_SECONDS": "8.0", "EMA_DECAY": "0.99", + "NUM_LOOPS": "0", "LOOP_START": "0", "LOOP_END": "0", "XSA_LAST_N": "7", + "SLIDING_WINDOW_ENABLED": "0", + }, + "medium": { + "NUM_LAYERS": "11", "MODEL_DIM": "512", "NUM_HEADS": "8", "NUM_KV_HEADS": "4", + "EMBEDDING_DIM": "512", "ITERATIONS": "2000", "WARMUP_STEPS": "15", + "WARMDOWN_FRAC": "0.667", "TRAIN_BATCH_TOKENS": str(2048 * 8), + "VAL_BATCH_TOKENS": str(2048 * 8), "TRAIN_SEQ_LEN": "2048", "EVAL_SEQ_LEN": "2048", + "VAL_LOSS_EVERY": "500", "TRAIN_LOG_EVERY": "100", "MAX_WALLCLOCK_SECONDS": "1200", + "GPTQ_CALIBRATION_BATCHES": "16", "GPTQ_RESERVE_SECONDS": "10.0", "EMA_DECAY": "0.99", + "NUM_LOOPS": "2", "LOOP_START": "3", "LOOP_END": "5", "ENABLE_LOOPING_AT": "0.35", + "XSA_LAST_N": "11", "SLIDING_WINDOW_ENABLED": "1", + }, + "full": {}, +} +if _TEST_MODE in _TEST_OVERRIDES: + for _k, _v in _TEST_OVERRIDES[_TEST_MODE].items(): + if _k not in os.environ: + os.environ[_k] = _v + +_USE_TORCH_COMPILE = _TEST_MODE not in ("smoke",) + +# --------------------------------------------------------------------------- +# Novel feature toggles +# --------------------------------------------------------------------------- +_CAT_ENABLED = bool(int(os.environ.get("CAT_ENABLED", "0"))) +_CAT_WEIGHT = float(os.environ.get("CAT_WEIGHT", "0.001")) +_CAT_BITS = int(os.environ.get("CAT_BITS", "6")) +_CAT_EVERY = int(os.environ.get("CAT_EVERY", "50")) + +_SPARSITY_ENABLED = bool(int(os.environ.get("SPARSITY_ENABLED", "0"))) +_SPARSITY_APPLY_TO = os.environ.get("SPARSITY_APPLY_TO", "mlp").split(",") + + +# --------------------------------------------------------------------------- +# Hyperparameters (defaults match top-1 submission where applicable) +# --------------------------------------------------------------------------- +class Hyperparameters(): + data_dir = os.environ.get('DATA_DIR', './data/') + seed = int(os.environ.get('SEED', 1337)) run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) - seed = int(os.environ.get("SEED", 1337)) - - # Validation cadence and batch size. Validation always uses the full fineweb_val split. - val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) - val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) - train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) - - # Training length. - iterations = int(os.environ.get("ITERATIONS", 20000)) - warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) - warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) - train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) - train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) - max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) - qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) - - # Model shape. - vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) - num_layers = int(os.environ.get("NUM_LAYERS", 9)) - num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) - model_dim = int(os.environ.get("MODEL_DIM", 512)) - num_heads = int(os.environ.get("NUM_HEADS", 8)) - mlp_mult = int(os.environ.get("MLP_MULT", 2)) - tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) - rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) - logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) - - # Optimizer hyperparameters. - embed_lr = float(os.environ.get("EMBED_LR", 0.6)) - head_lr = float(os.environ.get("HEAD_LR", 0.008)) - tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) - tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) - matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) - scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) - muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) - muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) - muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) - muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) - beta1 = float(os.environ.get("BETA1", 0.9)) - beta2 = float(os.environ.get("BETA2", 0.95)) - adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) - grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) - -# ----------------------------- -# MUON OPTIMIZER -# ----------------------------- -# -# As borrowed from modded-nanogpt -# Background on Muon: https://kellerjordan.github.io/posts/muon/ - -def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: - # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. - # Muon uses this to normalize matrix-shaped gradients before applying them. - a, b, c = (3.4445, -4.7750, 2.0315) - X = G.bfloat16() - X /= X.norm() + eps - transposed = G.size(0) > G.size(1) - if transposed: - X = X.T - for _ in range(steps): - A = X @ X.T - B = b * A + c * A @ A - X = a * X + B @ X - return X.T if transposed else X - - -class Muon(torch.optim.Optimizer): - def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): - super().__init__( - params, - dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), - ) - - @torch.no_grad() - def step(self, closure=None): - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - distributed = dist.is_available() and dist.is_initialized() - world_size = dist.get_world_size() if distributed else 1 - rank = dist.get_rank() if distributed else 0 - - for group in self.param_groups: - params = group["params"] - if not params: - continue - lr = group["lr"] - momentum = group["momentum"] - backend_steps = group["backend_steps"] - nesterov = group["nesterov"] - - total_params = sum(int(p.numel()) for p in params) - updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) - - curr = 0 - for i, p in enumerate(params): - if i % world_size == rank and p.grad is not None: - g = p.grad - state = self.state[p] - if "momentum_buffer" not in state: - state["momentum_buffer"] = torch.zeros_like(g) - buf = state["momentum_buffer"] - buf.mul_(momentum).add_(g) - if nesterov: - g = g.add(buf, alpha=momentum) - g = zeropower_via_newtonschulz5(g, steps=backend_steps) - # Scale correction from Muon reference implementations. - g *= max(1, g.size(0) / g.size(1)) ** 0.5 - updates_flat[curr : curr + p.numel()] = g.reshape(-1) - curr += p.numel() - - if distributed: - dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) - - curr = 0 - for p in params: - g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) - p.add_(g, alpha=-lr) - curr += p.numel() - - return loss + iterations = int(os.environ.get('ITERATIONS', 20000)) + warmdown_frac = float(os.environ.get('WARMDOWN_FRAC', 0.72)) + warmup_steps = int(os.environ.get('WARMUP_STEPS', 20)) + train_batch_tokens = int(os.environ.get('TRAIN_BATCH_TOKENS', 2048 * 48 * 8)) + train_seq_len = int(os.environ.get('TRAIN_SEQ_LEN', 2048)) + train_log_every = int(os.environ.get('TRAIN_LOG_EVERY', 500)) + max_wallclock_seconds = float(os.environ.get('MAX_WALLCLOCK_SECONDS', 600.0)) + val_batch_tokens = int(os.environ.get('VAL_BATCH_TOKENS', 2048 * 32 * 8)) + eval_seq_len = int(os.environ.get('EVAL_SEQ_LEN', 2048)) + val_loss_every = int(os.environ.get('VAL_LOSS_EVERY', 4000)) + sliding_window_enabled = bool(int(os.environ.get('SLIDING_WINDOW_ENABLED', '1'))) + + vocab_size = int(os.environ.get('VOCAB_SIZE', 8192)) + num_layers = int(os.environ.get('NUM_LAYERS', 11)) + xsa_last_n = int(os.environ.get('XSA_LAST_N', 11)) + model_dim = int(os.environ.get('MODEL_DIM', 512)) + embedding_dim = int(os.environ.get('EMBEDDING_DIM', 512)) + num_kv_heads = int(os.environ.get('NUM_KV_HEADS', 4)) + num_heads = int(os.environ.get('NUM_HEADS', 8)) + mlp_mult = float(os.environ.get('MLP_MULT', 4.0)) + skip_gates_enabled = bool(int(os.environ.get('SKIP_GATES_ENABLED', '1'))) + tie_embeddings = bool(int(os.environ.get('TIE_EMBEDDINGS', '1'))) + logit_softcap = float(os.environ.get('LOGIT_SOFTCAP', 30.0)) + rope_base = float(os.environ.get('ROPE_BASE', 10000.0)) + rope_dims = int(os.environ.get('ROPE_DIMS', 16)) + rope_train_seq_len = int(os.environ.get('ROPE_TRAIN_SEQ_LEN', 2048)) + ln_scale = bool(int(os.environ.get('LN_SCALE', '1'))) + + # QK gain bumped to 5.25 (top-1 uses 5.25, top-2/3 use 5.0) + qk_gain_init = float(os.environ.get('QK_GAIN_INIT', 5.25)) + + # Recurrence: loop layers 3-5 (top-1 uses 3-5, baseline used 4-5) + num_loops = int(os.environ.get('NUM_LOOPS', 2)) + loop_start = int(os.environ.get('LOOP_START', 3)) + loop_end = int(os.environ.get('LOOP_END', 5)) + enable_looping_at = float(os.environ.get('ENABLE_LOOPING_AT', 0.35)) + + # Parallel residuals from layer 7+ (GPT-J style, top-1/2/4 use this) + parallel_residual_start = int(os.environ.get('PARALLEL_RESIDUAL_START', 7)) + + min_lr = float(os.environ.get('MIN_LR', 0.0)) + embed_lr = float(os.environ.get('EMBED_LR', 0.6)) + head_lr = float(os.environ.get('HEAD_LR', 0.008)) + tied_embed_lr = float(os.environ.get('TIED_EMBED_LR', 0.03)) + tied_embed_init_std = float(os.environ.get('TIED_EMBED_INIT_STD', 0.005)) + + # Matrix LR and weight decay bumped to match top-1 + matrix_lr = float(os.environ.get('MATRIX_LR', 0.022)) + scalar_lr = float(os.environ.get('SCALAR_LR', 0.02)) + muon_momentum = float(os.environ.get('MUON_MOMENTUM', 0.99)) + muon_backend_steps = int(os.environ.get('MUON_BACKEND_STEPS', 5)) + muon_momentum_warmup_start = float(os.environ.get('MUON_MOMENTUM_WARMUP_START', 0.92)) + muon_momentum_warmup_steps = int(os.environ.get('MUON_MOMENTUM_WARMUP_STEPS', 1500)) + muon_row_normalize = bool(int(os.environ.get('MUON_ROW_NORMALIZE', '1'))) + beta1 = float(os.environ.get('BETA1', 0.9)) + beta2 = float(os.environ.get('BETA2', 0.95)) + adam_eps = float(os.environ.get('ADAM_EPS', 1e-8)) + grad_clip_norm = float(os.environ.get('GRAD_CLIP_NORM', 0.3)) + eval_stride = int(os.environ.get('EVAL_STRIDE', 64)) + muon_beta2 = float(os.environ.get('MUON_BETA2', 0.95)) + adam_wd = float(os.environ.get('ADAM_WD', 0.02)) + muon_wd = float(os.environ.get('MUON_WD', 0.095)) + embed_wd = float(os.environ.get('EMBED_WD', 0.085)) + ema_decay = float(os.environ.get('EMA_DECAY', 0.9965)) + compressor = os.environ.get('COMPRESSOR', 'brotli') + gptq_calibration_batches = int(os.environ.get('GPTQ_CALIBRATION_BATCHES', 64)) + gptq_reserve_seconds = float(os.environ.get('GPTQ_RESERVE_SECONDS', 12.0)) + matrix_bits = int(os.environ.get('MATRIX_BITS', 6)) + embed_bits = int(os.environ.get('EMBED_BITS', 8)) + matrix_clip_sigmas = float(os.environ.get('MATRIX_CLIP_SIGMAS', 12.85)) + embed_clip_sigmas = float(os.environ.get('EMBED_CLIP_SIGMAS', 20.0)) + + # TTT (Test-Time Training) - score-first SGD from top submissions + ttt_enabled = bool(int(os.environ.get('TTT_ENABLED', '1'))) + ttt_lr = float(os.environ.get('TTT_LR', 0.005)) + ttt_momentum = float(os.environ.get('TTT_MOMENTUM', 0.9)) + ttt_epochs = int(os.environ.get('TTT_EPOCHS', 3)) + ttt_chunk_tokens = int(os.environ.get('TTT_CHUNK_TOKENS', 32768)) + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + is_main_process = rank == 0 + grad_accum_steps = 8 // world_size + datasets_dir = os.path.join(data_dir, 'datasets', f'fineweb10B_sp{vocab_size}') + train_files = os.path.join(datasets_dir, 'fineweb_train_*.bin') + val_files = os.path.join(datasets_dir, 'fineweb_val_*.bin') + tokenizer_path = os.path.join(data_dir, 'tokenizers', f'fineweb_{vocab_size}_bpe.model') + logfile = f"logs/{run_id}.txt" + model_path = "final_model.pt" + quantized_model_path = "final_model.int6.ptz" -# ----------------------------- -# TOKENIZER-AGNOSTIC EVALUATION SETUP -# ----------------------------- -# -# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. -# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. -# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. -# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. - -def build_sentencepiece_luts( - sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device -) -> tuple[Tensor, Tensor, Tensor]: - sp_vocab_size = int(sp.vocab_size()) - table_size = max(sp_vocab_size, vocab_size) - base_bytes_np = np.zeros((table_size,), dtype=np.int16) - has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) - is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) - for token_id in range(sp_vocab_size): - if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): - continue - is_boundary_token_np[token_id] = False - if sp.is_byte(token_id): - base_bytes_np[token_id] = 1 - continue - piece = sp.id_to_piece(token_id) - if piece.startswith("▁"): - has_leading_space_np[token_id] = True - piece = piece[1:] - base_bytes_np[token_id] = len(piece.encode("utf-8")) - return ( - torch.tensor(base_bytes_np, dtype=torch.int16, device=device), - torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), - torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), - ) - - -def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: - files = [Path(p) for p in sorted(glob.glob(pattern))] - if not files: - raise FileNotFoundError(f"No files found for pattern: {pattern}") - # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. - tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() - usable = ((tokens.numel() - 1) // seq_len) * seq_len - if usable <= 0: - raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") - return tokens[: usable + 1] - - -def eval_val( - args: Hyperparameters, - model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - grad_accum_steps: int, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, -) -> tuple[float, float]: - # Validation computes two metrics: - # - val_loss: token cross-entropy (natural log) - # - val_bpb: tokenizer-agnostic compression metric used by the challenge - local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) - if local_batch_tokens < args.train_seq_len: - raise ValueError( - "VAL_BATCH_SIZE must provide at least one sequence per rank; " - f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " - f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" - ) - local_batch_seqs = local_batch_tokens // args.train_seq_len - total_seqs = (val_tokens.numel() - 1) // args.train_seq_len - seq_start = (total_seqs * rank) // world_size - seq_end = (total_seqs * (rank + 1)) // world_size - val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) - val_token_count = torch.zeros((), device=device, dtype=torch.float64) - val_byte_count = torch.zeros((), device=device, dtype=torch.float64) - model.eval() - with torch.inference_mode(): - for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): - batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) - raw_start = batch_seq_start * args.train_seq_len - raw_end = batch_seq_end * args.train_seq_len + 1 - local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) - x = local[:-1].reshape(-1, args.train_seq_len) - y = local[1:].reshape(-1, args.train_seq_len) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - batch_loss = model(x, y).detach() - batch_token_count = float(y.numel()) - val_loss_sum += batch_loss.to(torch.float64) * batch_token_count - val_token_count += batch_token_count - prev_ids = x.reshape(-1) - tgt_ids = y.reshape(-1) - token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) - token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) - val_byte_count += token_bytes.to(torch.float64).sum() +# --------------------------------------------------------------------------- +# Logging +# --------------------------------------------------------------------------- +_logger_hparams = None - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) - val_loss = val_loss_sum / val_token_count - bits_per_token = val_loss.item() / math.log(2.0) - tokens_per_byte = val_token_count.item() / val_byte_count.item() - model.train() - return float(val_loss.item()), float(bits_per_token * tokens_per_byte) +def set_logging_hparams(h): + global _logger_hparams + _logger_hparams = h -# ----------------------------- -# POST-TRAINING QUANTIZATION -# ----------------------------- -# -# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. -# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. -# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. -CONTROL_TENSOR_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "CONTROL_TENSOR_NAME_PATTERNS", - "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", - ",".join(CONTROL_TENSOR_NAME_PATTERNS), - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 -INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 -INT8_PER_ROW_SCALE_DTYPE = torch.float16 -INT8_CLIP_PERCENTILE = 99.99984 -INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 - -def tensor_nbytes(t: Tensor) -> int: - return int(t.numel()) * int(t.element_size()) - -def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: - if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): - return t.float().contiguous() - if t.dtype in {torch.float32, torch.bfloat16}: - passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") - return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() - return t - -def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: - t32 = t.float() - if t32.ndim == 2: - # Matrices get one scale per row, which usually tracks output-channel - # ranges much better than a single tensor-wide scale. - clip_abs = ( - torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) - if t32.numel() - else torch.empty((t32.shape[0],), dtype=torch.float32) - ) - clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) - scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) - q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() - return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() - - # Vectors / scalars use a simpler per-tensor scale. - clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 - scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) - q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() - return q, scale - -def quantize_state_dict_int8(state_dict: dict[str, Tensor]): - # Single supported clean-script export format: - # - per-row int8 for 2D float tensors - # - per-tensor int8 for other float tensors - # - exact passthrough for non-floats - # - passthrough for small float tensors, stored as fp16 to save bytes - quantized: dict[str, Tensor] = {} - scales: dict[str, Tensor] = {} - dtypes: dict[str, str] = {} - passthrough: dict[str, Tensor] = {} - passthrough_orig_dtypes: dict[str, str] = {} - qmeta: dict[str, dict[str, object]] = {} - stats = dict.fromkeys( - ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), - 0, - ) +def log(msg, console=False): + if _logger_hparams is None: + print(msg) + return + if _logger_hparams.is_main_process: + if console: + print(msg) + if _logger_hparams.logfile is not None: + with open(_logger_hparams.logfile, "a", encoding="utf-8") as f: + print(msg, file=f) - for name, tensor in state_dict.items(): - t = tensor.detach().to("cpu").contiguous() - stats["param_count"] += int(t.numel()) - stats["num_tensors"] += 1 - stats["baseline_tensor_bytes"] += tensor_nbytes(t) - - if not t.is_floating_point(): - stats["num_nonfloat_tensors"] += 1 - passthrough[name] = t - stats["int8_payload_bytes"] += tensor_nbytes(t) - continue - # Small float tensors are cheap enough to keep directly. We still downcast - # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. - if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: - kept = keep_float_tensor(name, t, passthrough_orig_dtypes) - passthrough[name] = kept - stats["int8_payload_bytes"] += tensor_nbytes(kept) - continue +# --------------------------------------------------------------------------- +# Data loading +# --------------------------------------------------------------------------- +class ValidationData: + def __init__(self, h, device): + self.sp = spm.SentencePieceProcessor(model_file=h.tokenizer_path) + assert int(self.sp.vocab_size()) == h.vocab_size + self.val_tokens = load_validation_tokens(h.val_files, h.eval_seq_len) + self.base_bytes_lut, self.has_leading_space_lut, self.is_boundary_token_lut = \ + build_sentencepiece_luts(self.sp, h.vocab_size, device) + + +def build_sentencepiece_luts(sp, vocab_size, device): + sv = int(sp.vocab_size()) + assert sp.piece_to_id("\u2581") != sp.unk_id() + base_bytes = torch.zeros(max(sv, vocab_size), dtype=torch.int32, device=device) + has_leading_space = torch.zeros(max(sv, vocab_size), dtype=torch.bool, device=device) + is_boundary = torch.zeros(max(sv, vocab_size), dtype=torch.bool, device=device) + for t in range(sv): + piece = sp.id_to_piece(t) + raw = piece.replace("\u2581", " ").encode("utf-8") + base_bytes[t] = len(raw) + if piece.startswith("\u2581"): + has_leading_space[t] = True + is_boundary[t] = True + elif sp.is_control(t) or sp.is_unknown(t): + is_boundary[t] = True + return base_bytes, has_leading_space, is_boundary + + +def load_validation_tokens(pattern, seq_len): + files = sorted(glob.glob(pattern)) + assert files, f"No files: {pattern}" + chunks = [] + for f in files: + chunks.append(load_data_shard(Path(f))) + tokens = torch.cat(chunks) + total = tokens.numel() + usable = (total - 1) // seq_len * seq_len + 1 + return tokens[:usable] + + +def load_data_shard(file): + hb = 256 * np.dtype(" 0: - qmeta[name] = {"scheme": "per_row", "axis": 0} - quantized[name] = q - scales[name] = s - dtypes[name] = str(t.dtype).removeprefix("torch.") - stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) - - obj: dict[str, object] = { - "__quant_format__": "int8_clean_per_row_v1", - "quantized": quantized, - "scales": scales, - "dtypes": dtypes, - "passthrough": passthrough, - } - if qmeta: - obj["qmeta"] = qmeta - if passthrough_orig_dtypes: - obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes - return obj, stats - -def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: - out: dict[str, Tensor] = {} - qmeta = obj.get("qmeta", {}) - passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) - for name, q in obj["quantized"].items(): - dtype = getattr(torch, obj["dtypes"][name]) - s = obj["scales"][name] - if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: - s = s.to(dtype=torch.float32) - # Broadcast the saved row scale back across trailing dimensions. - out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() - else: - scale = float(s.item()) - out[name] = (q.float() * scale).to(dtype=dtype).contiguous() - for name, t in obj["passthrough"].items(): - # Restore small tensors, undoing the temporary fp16 storage cast if needed. - out_t = t.detach().to("cpu").contiguous() - orig_dtype = passthrough_orig_dtypes.get(name) - if isinstance(orig_dtype, str): - out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() - out[name] = out_t - return out +_SHARD_HEADER_BYTES = 256 * np.dtype(" Tensor: - header_bytes = 256 * np.dtype(" None: - self.file_idx = (self.file_idx + 1) % len(self.files) - self.tokens = load_data_shard(self.files[self.file_idx]) - self.pos = 0 - - def take(self, n: int) -> Tensor: - chunks: list[Tensor] = [] - remaining = n - while remaining > 0: - avail = self.tokens.numel() - self.pos - if avail <= 0: - self._advance_file() - continue - k = min(remaining, avail) - chunks.append(self.tokens[self.pos : self.pos + k]) - self.pos += k - remaining -= k - return chunks[0] if len(chunks) == 1 else torch.cat(chunks) - - -class DistributedTokenLoader: - # Each call consumes a contiguous chunk from the shared token stream, then slices out - # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. - def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): - self.rank = rank - self.world_size = world_size + assert header.size == 256 and int(header[0]) == 20240520 and int(header[1]) == 1 + n = int(header[2]) + _SHARD_NTOKENS_CACHE[key] = n + return n + + +def _get_shard_memmap(file): + key = str(file) + mm = _MMAP_CACHE.get(key) + if mm is not None: + return mm + n = _read_num_tokens(file) + mm = np.memmap(file, mode="r", dtype=" tuple[Tensor, Tensor]: - local_tokens = global_tokens // (self.world_size * grad_accum_steps) - per_rank_span = local_tokens + 1 - chunk = self.stream.take(per_rank_span * self.world_size) - start = self.rank * per_rank_span - local = chunk[start : start + per_rank_span].to(dtype=torch.int64) - x = local[:-1].reshape(-1, seq_len) - y = local[1:].reshape(-1, seq_len) + all_files = [Path(p) for p in sorted(glob.glob(h.train_files))] + assert all_files, f"No files: {h.train_files}" + self.files = all_files[h.rank::h.world_size] + self.rng = np.random.Generator(np.random.PCG64(h.rank)) + self.num_tokens = [_read_num_tokens(f) for f in self.files] + self.start_inds = [[] for _ in self.files] + for si in range(len(self.files)): + self._reset_shard(si) + + def _reset_shard(self, si): + mp = min(self.seq_len - 1, max(0, self.num_tokens[si] - self.seq_len - 1)) + phase = int(self.rng.integers(mp + 1)) if mp > 0 else 0 + ns = (self.num_tokens[si] - 1 - phase) // self.seq_len + self.start_inds[si] = (phase + self.rng.permutation(ns) * self.seq_len).tolist() + + def next_batch(self, global_tokens, grad_accum_steps): + dt = global_tokens // (self.world_size * grad_accum_steps) + dbs = dt // self.seq_len + remaining = np.array([len(s) for s in self.start_inds], dtype=np.float64) + x = torch.empty((dbs, self.seq_len), dtype=torch.int64) + y = torch.empty((dbs, self.seq_len), dtype=torch.int64) + for bi in range(dbs): + total = remaining.sum() + if total <= 0: + for si in range(len(self.files)): + self._reset_shard(si) + remaining = np.array([len(s) for s in self.start_inds], dtype=np.float64) + total = remaining.sum() + si = int(self.rng.choice(len(self.files), p=remaining / total)) + start_ind = self.start_inds[si].pop() + remaining[si] -= 1 + mm = _get_shard_memmap(self.files[si]) + w = torch.as_tensor(np.array(mm[start_ind:start_ind + self.seq_len + 1], dtype=np.int64)) + x[bi] = w[:-1] + y[bi] = w[1:] return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) -# ----------------------------- -# TRANSFORMER MODULES -# ----------------------------- +# --------------------------------------------------------------------------- +# Model architecture +# --------------------------------------------------------------------------- class RMSNorm(nn.Module): - def __init__(self, eps: float | None = None): + def __init__(self, eps=None): super().__init__() self.eps = eps - def forward(self, x: Tensor) -> Tensor: + def forward(self, x): return F.rms_norm(x, (x.size(-1),), eps=self.eps) class CastedLinear(nn.Linear): - # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. - def forward(self, x: Tensor) -> Tensor: - bias = self.bias.to(x.dtype) if self.bias is not None else None - return F.linear(x, self.weight.to(x.dtype), bias) - - -def restore_low_dim_params_to_fp32(module: nn.Module) -> None: - # Keep small/control parameters in fp32 even when the model body runs in bf16. - with torch.no_grad(): - for name, param in module.named_parameters(): - if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: - param.data = param.data.float() + def forward(self, x): + w = self.weight.to(x.dtype) + b = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w, b) class Rotary(nn.Module): - # Caches cos/sin tables per sequence length on the current device. - def __init__(self, dim: int, base: float = 10000.0): + def __init__(self, dim, base=10000.0, train_seq_len=1024, rope_dims=0): super().__init__() - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) + self.dim = dim + self.base = base + self.train_seq_len = train_seq_len + self.rope_dims = rope_dims if rope_dims > 0 else dim + self.register_buffer( + "inv_freq", + 1.0 / (base ** (torch.arange(0, self.rope_dims, 2, dtype=torch.float32) / self.rope_dims)), + persistent=False, + ) self._seq_len_cached = 0 - self._cos_cached: Tensor | None = None - self._sin_cached: Tensor | None = None - - def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: - if ( - self._cos_cached is None - or self._sin_cached is None - or self._seq_len_cached != seq_len - or self._cos_cached.device != device - ): - t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) - freqs = torch.outer(t, self.inv_freq.to(device)) - self._cos_cached = freqs.cos()[None, None, :, :] - self._sin_cached = freqs.sin()[None, None, :, :] + self._cos_cached = None + self._sin_cached = None + + def forward(self, seq_len, device, dtype): + if (self._cos_cached is None or self._sin_cached is None + or self._seq_len_cached != seq_len or self._cos_cached.device != device): + rd = self.rope_dims + if seq_len > self.train_seq_len: + scale = seq_len / self.train_seq_len + new_base = self.base * (scale ** (rd / (rd - 2))) + inv_freq = 1.0 / (new_base ** (torch.arange(0, rd, 2, dtype=torch.float32, device=device) / rd)) + else: + inv_freq = self.inv_freq.to(device) + t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) + freqs = torch.outer(t, inv_freq) + self._cos_cached = freqs.cos()[None, :, None, :] + self._sin_cached = freqs.sin()[None, :, None, :] self._seq_len_cached = seq_len return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) -def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: +def apply_rotary_emb(x, cos, sin, rope_dims=0): + if rope_dims > 0 and rope_dims < x.size(-1): + xr, xp = x[..., :rope_dims], x[..., rope_dims:] + half = rope_dims // 2 + x1, x2 = xr[..., :half], xr[..., half:] + return torch.cat(( + torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1), + xp, + ), dim=-1) half = x.size(-1) // 2 x1, x2 = x[..., :half], x[..., half:] return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) class CausalSelfAttention(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - rope_base: float, - qk_gain_init: float, - ): + def __init__(self, dim, num_heads, num_kv_heads, rope_base, qk_gain_init, train_seq_len): super().__init__() - if dim % num_heads != 0: - raise ValueError("model_dim must be divisible by num_heads") - if num_heads % num_kv_heads != 0: - raise ValueError("num_heads must be divisible by num_kv_heads") self.num_heads = num_heads self.num_kv_heads = num_kv_heads self.head_dim = dim // num_heads - if self.head_dim % 2 != 0: - raise ValueError("head_dim must be even for RoPE") - kv_dim = self.num_kv_heads * self.head_dim + kv_dim = num_kv_heads * self.head_dim self.c_q = CastedLinear(dim, dim, bias=False) self.c_k = CastedLinear(dim, kv_dim, bias=False) self.c_v = CastedLinear(dim, kv_dim, bias=False) self.proj = CastedLinear(dim, dim, bias=False) self.proj._zero_init = True self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) - self.rotary = Rotary(self.head_dim, base=rope_base) - - def forward(self, x: Tensor) -> Tensor: + self.rope_dims = 0 + self.rotary = Rotary(self.head_dim, base=rope_base, train_seq_len=train_seq_len) + self.use_xsa = False + + def _xsa_efficient(self, y, v): + B, T, H, D = y.shape + Hkv = v.size(-2) + g = H // Hkv + y_g = y.reshape(B, T, Hkv, g, D) + vn = F.normalize(v, dim=-1).unsqueeze(-2) + return (y_g - (y_g * vn).sum(dim=-1, keepdim=True) * vn).reshape(B, T, H, D) + + def forward(self, x): bsz, seqlen, dim = x.shape - q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) - k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) - v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) q = F.rms_norm(q, (q.size(-1),)) k = F.rms_norm(k, (k.size(-1),)) cos, sin = self.rotary(seqlen, x.device, q.dtype) - q = apply_rotary_emb(q, cos, sin) - k = apply_rotary_emb(k, cos, sin) - q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] - y = F.scaled_dot_product_attention( - q, - k, - v, - attn_mask=None, - is_causal=True, - enable_gqa=(self.num_kv_heads != self.num_heads), - ) - y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) - return self.proj(y) + q = apply_rotary_emb(q, cos, sin, self.rope_dims) + k = apply_rotary_emb(k, cos, sin, self.rope_dims) + q = q * self.q_gain.to(dtype=q.dtype)[None, None, :, None] + y = flash_attn_3_func(q, k, v, causal=True) if _HAS_FA3 else _sdpa_attn(q, k, v, causal=True) + if self.use_xsa: + y = self._xsa_efficient(y, v) + return self.proj(y.reshape(bsz, seqlen, dim)) class MLP(nn.Module): - # relu^2 MLP from the original modded-nanogpt setup - def __init__(self, dim: int, mlp_mult: int): + def __init__(self, dim, mlp_mult): super().__init__() - hidden = mlp_mult * dim + hidden = int(mlp_mult * dim) self.fc = CastedLinear(dim, hidden, bias=False) self.proj = CastedLinear(hidden, dim, bias=False) self.proj._zero_init = True - def forward(self, x: Tensor) -> Tensor: - x = torch.relu(self.fc(x)) - return self.proj(x.square()) + def forward(self, x): + return self.proj(F.leaky_relu(self.fc(x), negative_slope=0.5).square()) class Block(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: int, - rope_base: float, - qk_gain_init: float, - ): + def __init__(self, dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init, + train_seq_len, layer_idx=0, ln_scale=False, parallel=False): super().__init__() self.attn_norm = RMSNorm() self.mlp_norm = RMSNorm() - self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init, train_seq_len) self.mlp = MLP(dim, mlp_mult) self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + self.ln_scale_factor = 1.0 / math.sqrt(layer_idx + 1) if ln_scale else 1.0 + self.parallel = parallel - def forward(self, x: Tensor, x0: Tensor) -> Tensor: + def forward(self, x, x0): mix = self.resid_mix.to(dtype=x.dtype) - x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 - attn_out = self.attn(self.attn_norm(x)) - x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out - x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) - return x + x_in = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + + attn_out = self.attn(self.attn_norm(x_in) * self.ln_scale_factor) + + if self.parallel: + # GPT-J style: attention and MLP both read from same pre-attention input + mlp_out = self.mlp(self.mlp_norm(x_in) * self.ln_scale_factor) + x_out = x_in + self.attn_scale.to(dtype=x_in.dtype)[None, None, :] * attn_out \ + + self.mlp_scale.to(dtype=x_in.dtype)[None, None, :] * mlp_out + else: + # Standard sequential: MLP reads post-attention output + x_out = x_in + self.attn_scale.to(dtype=x_in.dtype)[None, None, :] * attn_out + x_out = x_out + self.mlp_scale.to(dtype=x_out.dtype)[None, None, :] * \ + self.mlp(self.mlp_norm(x_out) * self.ln_scale_factor) + + return x_out class GPT(nn.Module): - def __init__( - self, - vocab_size: int, - num_layers: int, - model_dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: int, - tie_embeddings: bool, - tied_embed_init_std: float, - logit_softcap: float, - rope_base: float, - qk_gain_init: float, - ): + def __init__(self, h): super().__init__() - if logit_softcap <= 0.0: - raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") - self.tie_embeddings = tie_embeddings - self.tied_embed_init_std = tied_embed_init_std - self.logit_softcap = logit_softcap - self.tok_emb = nn.Embedding(vocab_size, model_dim) - self.num_encoder_layers = num_layers // 2 - self.num_decoder_layers = num_layers - self.num_encoder_layers - self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) - self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) - self.blocks = nn.ModuleList( - [ - Block( - model_dim, - num_heads, - num_kv_heads, - mlp_mult, - rope_base, - qk_gain_init, - ) - for i in range(num_layers) - ] - ) + self.tie_embeddings = h.tie_embeddings + self.tied_embed_init_std = h.tied_embed_init_std + self.logit_softcap = h.logit_softcap + self.tok_emb = nn.Embedding(h.vocab_size, h.embedding_dim) + + if h.embedding_dim != h.model_dim: + self.embed_proj = CastedLinear(h.embedding_dim, h.model_dim, bias=False) + self.head_proj = CastedLinear(h.model_dim, h.embedding_dim, bias=False) + else: + self.embed_proj = None + self.head_proj = None + + self.num_encoder_layers = h.num_layers // 2 + self.num_decoder_layers = h.num_layers - self.num_encoder_layers + + self.blocks = nn.ModuleList([ + Block(h.model_dim, h.num_heads, h.num_kv_heads, h.mlp_mult, h.rope_base, + h.qk_gain_init, h.train_seq_len, layer_idx=i, ln_scale=h.ln_scale, + parallel=(h.parallel_residual_start >= 0 and i >= h.parallel_residual_start)) + for i in range(h.num_layers) + ]) + + if h.rope_dims > 0: + hd = h.model_dim // h.num_heads + for block in self.blocks: + block.attn.rope_dims = h.rope_dims + block.attn.rotary = Rotary(hd, base=h.rope_base, train_seq_len=h.train_seq_len, + rope_dims=h.rope_dims) + self.final_norm = RMSNorm() - self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + self.lm_head = None if h.tie_embeddings else CastedLinear(h.embedding_dim, h.vocab_size, bias=False) if self.lm_head is not None: self.lm_head._zero_init = True + + if h.xsa_last_n > 0: + for i in range(max(0, h.num_layers - h.xsa_last_n), h.num_layers): + self.blocks[i].attn.use_xsa = True + + # Layer looping / recurrence + self.looping_active = False + if h.num_loops > 0: + loop_seg = list(range(h.loop_start, h.loop_end + 1)) + all_indices = list(range(h.loop_start)) + for _ in range(h.num_loops + 1): + all_indices.extend(loop_seg) + all_indices.extend(range(h.loop_end + 1, h.num_layers)) + ne = len(all_indices) // 2 + self.encoder_indices = all_indices[:ne] + self.decoder_indices = all_indices[ne:] + else: + self.encoder_indices = list(range(self.num_encoder_layers)) + self.decoder_indices = list(range(self.num_encoder_layers, h.num_layers)) + + # Skip connections (U-Net style) + self.num_skip_weights = min(len(self.encoder_indices), len(self.decoder_indices)) + self.skip_weights = nn.Parameter( + torch.ones(self.num_skip_weights, h.model_dim, dtype=torch.float32) + ) + self.skip_gates = nn.Parameter( + torch.zeros(self.num_skip_weights, h.model_dim, dtype=torch.float32) + ) if h.skip_gates_enabled else None + self._init_weights() - def _init_weights(self) -> None: + def _init_weights(self): if self.tie_embeddings: nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) - for module in self.modules(): - if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): - nn.init.zeros_(module.weight) - - def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + for name, module in self.named_modules(): + if isinstance(module, nn.Linear): + if getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + elif (module.weight.ndim == 2 and module.weight.shape[0] >= 64 + and module.weight.shape[1] >= 64): + nn.init.orthogonal_(module.weight, gain=1.0) + + def forward_logits(self, input_ids): x = self.tok_emb(input_ids) x = F.rms_norm(x, (x.size(-1),)) + if self.embed_proj is not None: + x = self.embed_proj(x) x0 = x - skips: list[Tensor] = [] - # First half stores skips; second half reuses them in reverse order. - for i in range(self.num_encoder_layers): + skips = [] + enc_iter = self.encoder_indices if self.looping_active else range(self.num_encoder_layers) + dec_iter = self.decoder_indices if self.looping_active else range( + self.num_encoder_layers, self.num_encoder_layers + self.num_decoder_layers + ) + + for i in enc_iter: x = self.blocks[i](x, x0) skips.append(x) - for i in range(self.num_decoder_layers): - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - x = self.blocks[self.num_encoder_layers + i](x, x0) - x = self.final_norm(x).reshape(-1, x.size(-1)) - targets = target_ids.reshape(-1) - if self.tie_embeddings: - logits_proj = F.linear(x, self.tok_emb.weight) - else: - if self.lm_head is None: - raise RuntimeError("lm_head is required when tie_embeddings=False") - logits_proj = self.lm_head(x) - logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - return F.cross_entropy(logits.float(), targets, reduction="mean") + for skip_idx, i in enumerate(dec_iter): + if skip_idx < self.num_skip_weights and skips: + ss = self.skip_weights[skip_idx].to(dtype=x.dtype)[None, None, :] * skips.pop() + if self.skip_gates is not None: + g = torch.sigmoid(self.skip_gates[skip_idx].to(dtype=x.dtype))[None, None, :] + x = torch.lerp(ss, x, g) + else: + x = x + ss + x = self.blocks[i](x, x0) + x = self.final_norm(x) + if self.head_proj is not None: + x = self.head_proj(x) + lp = F.linear(x, self.tok_emb.weight.to(x.dtype)) if self.tie_embeddings else self.lm_head(x) + return self.logit_softcap * torch.tanh(lp / self.logit_softcap) -# ----------------------------- -# TRAINING -# ----------------------------- + def forward(self, input_ids, target_ids): + logits = self.forward_logits(input_ids) + loss = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), target_ids.reshape(-1), reduction="mean" + ) + return loss -def main() -> None: - global zeropower_via_newtonschulz5 - code = Path(__file__).read_text(encoding="utf-8") - args = Hyperparameters() - zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) +# --------------------------------------------------------------------------- +# CAT (Compressor-Aware Training) +# --------------------------------------------------------------------------- +def cat_compression_loss(model, bits=6, temperature=0.1): + dev = next(model.parameters()).device + total = torch.tensor(0.0, device=dev) + count = 0 + cr = 2 ** (bits - 1) - 1 + for name, p in model.named_parameters(): + if p.ndim < 2 or p.numel() <= 65536: + continue + w = p.float() + row_std = w.std(dim=-1, keepdim=True).clamp_min(1e-10) + scaled = w / row_std + q = torch.round(scaled * (cr / 3.0)) + q = q.clamp(-cr, cr) + dist = (scaled * (cr / 3.0) - q).abs() + soft_round = torch.sigmoid((dist - 0.5) / temperature) + total = total + soft_round.mean() + count += 1 + return total / max(count, 1) + + +# --------------------------------------------------------------------------- +# 2:4 Sparsity (Hessian-guided: uses GPTQ Hessians when available) +# --------------------------------------------------------------------------- +def apply_sparsity_to_state_dict(sd, apply_to, hessians=None): + result = {} + sparsified = 0 + hessian_guided = 0 + for name, tensor in sd.items(): + if tensor.ndim != 2 or tensor.numel() <= 65536: + result[name] = tensor + continue + cat = classify_param(name) + if cat not in apply_to: + result[name] = tensor + continue + t = tensor.float() + rows, cols = t.shape + + # Hessian-guided importance: |w_j| * sqrt(H_jj) instead of just |w_j| + # This preserves weights that are small but important to the loss + if hessians is not None and name in hessians: + H = hessians[name].float() + diag_H = torch.diag(H).clamp_min(1e-8) + col_importance = torch.sqrt(diag_H) # shape: (cols,) + importance = t.abs() * col_importance.unsqueeze(0) # (rows, cols) + hessian_guided += 1 + else: + importance = t.abs() + + pad = (4 - cols % 4) % 4 + if pad: + t = F.pad(t, (0, pad)) + importance = F.pad(importance, (0, pad)) + t4 = t.reshape(rows, -1, 4) + imp4 = importance.reshape(rows, -1, 4) + _, idx = imp4.topk(2, dim=-1) + mask = torch.zeros_like(t4, dtype=torch.bool) + mask.scatter_(-1, idx, True) + t4 = t4 * mask + t_sparse = t4.reshape(rows, -1) + if pad: + t_sparse = t_sparse[:, :cols] + result[name] = t_sparse.to(tensor.dtype) + sparsified += 1 + log(f"sparsity: applied 2:4 to {sparsified} tensors ({hessian_guided} Hessian-guided)") + return result + + +# --------------------------------------------------------------------------- +# Param classification and optimizer +# --------------------------------------------------------------------------- +def classify_param(name): + if "tok_emb" in name or "lm_head" in name: + return "embed" + if ".mlp." in name: + return "mlp" + if ".attn." in name or (".proj." in name and ".mlp." not in name): + return "attn" + return "other" + + +def _zeropower_via_newtonschulz5(G, steps=10, eps=1e-7): + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X - # ----------------------------- - # DISTRIBUTED + CUDA SETUP - # ----------------------------- - distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ - rank = int(os.environ.get("RANK", "0")) - world_size = int(os.environ.get("WORLD_SIZE", "1")) - local_rank = int(os.environ.get("LOCAL_RANK", "0")) - if world_size <= 0: - raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") - if 8 % world_size != 0: - raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") - grad_accum_steps = 8 // world_size - grad_scale = 1.0 / grad_accum_steps - if not torch.cuda.is_available(): - raise RuntimeError("CUDA is required") - device = torch.device("cuda", local_rank) - torch.cuda.set_device(device) - if distributed: - dist.init_process_group(backend="nccl", device_id=device) - dist.barrier() - master_process = rank == 0 +if _USE_TORCH_COMPILE: + zeropower_via_newtonschulz5 = torch.compile(_zeropower_via_newtonschulz5) +else: + zeropower_via_newtonschulz5 = _zeropower_via_newtonschulz5 - # Fast math knobs - torch.backends.cuda.matmul.allow_tf32 = True - torch.backends.cudnn.allow_tf32 = True - from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp - enable_cudnn_sdp(False) - enable_flash_sdp(True) - enable_mem_efficient_sdp(False) - enable_math_sdp(False) +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr, momentum, backend_steps, nesterov=True, + weight_decay=0.0, row_normalize=False): + super().__init__(params, dict( + lr=lr, momentum=momentum, backend_steps=backend_steps, + nesterov=nesterov, weight_decay=weight_decay, row_normalize=row_normalize, + )) - logfile = None - if master_process: - os.makedirs("logs", exist_ok=True) - logfile = f"logs/{args.run_id}.txt" - print(logfile) + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + if group.get("row_normalize", False): + rn = g.float().norm(dim=-1, keepdim=True).clamp_min(1e-07) + g = g / rn.to(g.dtype) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr:curr + p.numel()] = g.reshape(-1) + curr += p.numel() + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + wd = group.get("weight_decay", 0.0) + curr = 0 + for p in params: + if wd > 0.0: + p.data.mul_(1.0 - lr * wd) + g = updates_flat[curr:curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + return loss - def log0(msg: str, console: bool = True) -> None: - if not master_process: - return - if console: - print(msg) - if logfile is not None: - with open(logfile, "a", encoding="utf-8") as f: - print(msg, file=f) - log0(code, console=False) - log0("=" * 100, console=False) - log0(f"Running Python {sys.version}", console=False) - log0(f"Running PyTorch {torch.__version__}", console=False) - log0( - subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, - console=False, - ) - log0("=" * 100, console=False) - - # ----------------------------- - # TOKENIZER + VALIDATION METRIC SETUP - # ----------------------------- - - random.seed(args.seed) - np.random.seed(args.seed) - torch.manual_seed(args.seed) - torch.cuda.manual_seed_all(args.seed) - - if not args.tokenizer_path.endswith(".model"): - raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") - sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) - if int(sp.vocab_size()) != args.vocab_size: - raise ValueError( - f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" +CONTROL_TENSOR_NAME_PATTERNS = tuple( + p for p in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes," + "q_gain,skip_weight,skip_weights,skip_gates" + ).split(",") if p +) + + +class Optimizers(): + def __init__(self, h, base_model): + bnp = list(base_model.blocks.named_parameters()) + matrix_params = [ + p for name, p in bnp + if p.ndim == 2 and not any(pat in name for pat in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p for name, p in bnp + if p.ndim < 2 or any(pat in name for pat in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + if base_model.skip_gates is not None and base_model.skip_gates.numel() > 0: + scalar_params.append(base_model.skip_gates) + + token_lr = h.tied_embed_lr if h.tie_embeddings else h.embed_lr + self.optimizer_tok = torch.optim.AdamW( + [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], + betas=(h.beta1, h.beta2), eps=h.adam_eps, weight_decay=h.embed_wd, fused=True, ) - dataset_dir = Path(args.data_path).resolve() - actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) - val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) - base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( - sp, args.vocab_size, device - ) - log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") - log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") - log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") - - # ----------------------------- - # MODEL + OPTIMIZER SETUP - # ----------------------------- - - base_model = GPT( - vocab_size=args.vocab_size, - num_layers=args.num_layers, - model_dim=args.model_dim, - num_heads=args.num_heads, - num_kv_heads=args.num_kv_heads, - mlp_mult=args.mlp_mult, - tie_embeddings=args.tie_embeddings, - tied_embed_init_std=args.tied_embed_init_std, - logit_softcap=args.logit_softcap, - rope_base=args.rope_base, - qk_gain_init=args.qk_gain_init, - ).to(device).bfloat16() - for module in base_model.modules(): - if isinstance(module, CastedLinear): - module.float() - restore_low_dim_params_to_fp32(base_model) - compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) - model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model - - # Optimizer split: - # - token embedding (Adam) uses EMBED_LR - # - untied lm_head (Adam) uses HEAD_LR - # - matrix params in transformer blocks use MATRIX_LR via Muon - # - vectors/scalars use SCALAR_LR via Adam - block_named_params = list(base_model.blocks.named_parameters()) - matrix_params = [ - p - for name, p in block_named_params - if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - scalar_params = [ - p - for name, p in block_named_params - if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - if base_model.skip_weights.numel() > 0: - scalar_params.append(base_model.skip_weights) - token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr - optimizer_tok = torch.optim.Adam( - [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizer_muon = Muon( - matrix_params, - lr=args.matrix_lr, - momentum=args.muon_momentum, - backend_steps=args.muon_backend_steps, - ) - for group in optimizer_muon.param_groups: - group["base_lr"] = args.matrix_lr - optimizer_scalar = torch.optim.Adam( - [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, - ) - optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] - if base_model.lm_head is not None: - optimizer_head = torch.optim.Adam( - [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], - betas=(args.beta1, args.beta2), - eps=args.adam_eps, - fused=True, + self.optimizer_muon = Muon( + matrix_params, lr=h.matrix_lr, momentum=h.muon_momentum, + backend_steps=h.muon_backend_steps, weight_decay=h.muon_wd, + row_normalize=h.muon_row_normalize, ) - optimizers.insert(1, optimizer_head) - - n_params = sum(p.numel() for p in base_model.parameters()) - log0(f"model_params:{n_params}") - log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") - log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") - log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") - log0( - f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " - f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " - f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" - ) - log0( - f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " - f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " - f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" - ) - log0(f"seed:{args.seed}") - - # ----------------------------- - # DATA LOADER & MODEL WARMUP - # ----------------------------- - - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - def zero_grad_all() -> None: - for opt in optimizers: + for group in self.optimizer_muon.param_groups: + group["base_lr"] = h.matrix_lr + self.optimizer_scalar = torch.optim.AdamW( + [{"params": scalar_params, "lr": h.scalar_lr, "base_lr": h.scalar_lr}], + betas=(h.beta1, h.beta2), eps=h.adam_eps, weight_decay=h.adam_wd, fused=True, + ) + self.optimizers = [self.optimizer_tok, self.optimizer_muon, self.optimizer_scalar] + if base_model.lm_head is not None: + self.optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": h.head_lr, "base_lr": h.head_lr}], + betas=(h.beta1, h.beta2), eps=h.adam_eps, fused=True, + ) + self.optimizers.insert(1, self.optimizer_head) + else: + self.optimizer_head = None + + def __iter__(self): + return iter(self.optimizers) + + def zero_grad_all(self): + for opt in self.optimizers: opt.zero_grad(set_to_none=True) - max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + def step(self): + for opt in self.optimizers: + opt.step() + self.zero_grad_all() + + +def restore_fp32_params(model): + for module in model.modules(): + if isinstance(module, CastedLinear): + module.float() + for name, param in model.named_parameters(): + if (param.ndim < 2 or any(pat in name for pat in CONTROL_TENSOR_NAME_PATTERNS)) \ + and param.dtype != torch.float32: + param.data = param.data.float() + + +# --------------------------------------------------------------------------- +# GPTQ quantization +# --------------------------------------------------------------------------- +def collect_hessians(model, train_loader, h, device, n_calibration_batches=64): + hessians = {} + hooks = [] + + def make_hook(name): + def hook_fn(module, inp, out): + x = inp[0].detach().float() + if x.ndim == 3: + x = x.reshape(-1, x.shape[-1]) + if name not in hessians: + hessians[name] = torch.zeros(x.shape[1], x.shape[1], dtype=torch.float32, device=device) + hessians[name].addmm_(x.T, x) + return hook_fn + + for name, module in model.named_modules(): + if isinstance(module, CastedLinear) and module.weight.numel() > 65536: + cat = classify_param(name + ".weight") + if cat in ("mlp", "attn"): + hooks.append(module.register_forward_hook(make_hook(name + ".weight"))) + + if model.tie_embeddings: + hm = model.head_proj if model.head_proj is not None else model.final_norm + + def make_output_hook(name): + def hook_fn(module, inp, out): + x = out.detach().float() + if x.ndim == 3: + x = x.reshape(-1, x.shape[-1]) + if name not in hessians: + hessians[name] = torch.zeros(x.shape[1], x.shape[1], dtype=torch.float32, device=device) + hessians[name].addmm_(x.T, x) + return hook_fn + + hooks.append(hm.register_forward_hook(make_output_hook("tok_emb.weight"))) + + model.eval() + with torch.no_grad(): + for _ in range(n_calibration_batches): + x, _ = train_loader.next_batch(h.train_batch_tokens, h.grad_accum_steps) + model.forward_logits(x) + for hook in hooks: + hook.remove() + for name in hessians: + hessians[name] = hessians[name].cpu() / n_calibration_batches + return hessians + + +def gptq_quantize_weight(w, H, clip_sigmas=3.0, clip_range=63, block_size=128): + W_orig = w.float().clone() + rows, cols = W_orig.shape + H = H.float().clone() + dead = torch.diag(H) == 0 + H[dead, dead] = 1 + damp = 0.01 * H.diag().mean() + H.diagonal().add_(damp) + perm = torch.argsort(H.diag(), descending=True) + invperm = torch.argsort(perm) + W_perm = W_orig[:, perm].clone() + W_perm[:, dead[perm]] = 0 + H = H[perm][:, perm] + + cholesky_ok = False + for extra_damp in [0.0, 0.1, 1.0, 10.0]: + try: + Ht = H.clone() + if extra_damp > 0: + Ht.diagonal().add_(extra_damp * Ht.diag().mean()) + Hinv = torch.cholesky_inverse(torch.linalg.cholesky(Ht)) + Hinv = torch.linalg.cholesky(Hinv, upper=True) + cholesky_ok = True + break + except torch.linalg.LinAlgError: + continue - def lr_mul(step: int, elapsed_ms: float) -> float: - if args.warmdown_iters <= 0: + if not cholesky_ok: + row_std = W_orig.std(dim=1) + s = (clip_sigmas * row_std / clip_range).clamp_min(1e-10).to(torch.float16) + return torch.clamp(torch.round(W_orig / s.float().unsqueeze(1)), -clip_range, clip_range).to(torch.int8), s + + row_std = W_orig.std(dim=1) + s = (clip_sigmas * row_std / clip_range).clamp_min(1e-10).to(torch.float16) + sf = s.float() + Q = torch.zeros(rows, cols, dtype=torch.int8) + W_work = W_perm.clone() + + for i1 in range(0, cols, block_size): + i2 = min(i1 + block_size, cols) + W_block = W_work[:, i1:i2].clone() + Hinv_block = Hinv[i1:i2, i1:i2] + Err = torch.zeros(rows, i2 - i1) + for j in range(i2 - i1): + w_col = W_block[:, j] + d = Hinv_block[j, j] + q_col = torch.clamp(torch.round(w_col / sf), -clip_range, clip_range) + Q[:, i1 + j] = q_col.to(torch.int8) + err = (w_col - q_col.float() * sf) / d + Err[:, j] = err + W_block[:, j:] -= err.unsqueeze(1) * Hinv_block[j, j:].unsqueeze(0) + if i2 < cols: + W_work[:, i2:] -= Err @ Hinv[i1:i2, i2:] + + return Q[:, invperm], s + + +def _simple_quantize_weight(w, clip_sigmas=3.0, clip_range=63): + orig_shape = w.shape + W = w.float().reshape(w.shape[0], -1) + row_std = W.std(dim=1) + s = (clip_sigmas * row_std / clip_range).clamp_min(1e-10).to(torch.float16) + Q = torch.clamp(torch.round(W / s.float().unsqueeze(1)), -clip_range, clip_range).to(torch.int8) + return Q.reshape(orig_shape), s + + +def gptq_mixed_quantize(state_dict, hessians, h): + result = {} + meta = {} + for name, tensor in state_dict.items(): + t = tensor.detach().cpu().contiguous() + if not t.is_floating_point() or t.numel() <= 65536: + result[name] = t.to(torch.float16) if t.is_floating_point() else t + meta[name] = "passthrough (float16)" + continue + cs = h.embed_clip_sigmas if "tok_emb" in name else h.matrix_clip_sigmas + bits = h.embed_bits if "tok_emb" in name else h.matrix_bits + clip_range = 2 ** (bits - 1) - 1 + if name in hessians: + q, s = gptq_quantize_weight(t, hessians[name], clip_sigmas=cs, clip_range=clip_range) + meta[name] = f"gptq (int{bits})" + else: + log(f" no Hessian for {name}, using simple quantization") + q, s = _simple_quantize_weight(t, clip_sigmas=cs, clip_range=clip_range) + meta[name] = f"simple (int{bits})" + result[name + ".q"] = q + result[name + ".scale"] = s + + cats = collections.defaultdict(set) + for name, cat in meta.items(): + short = re.sub(r'\.\d+$', '', re.sub(r'blocks\.\d+', 'blocks', name)) + cats[cat].add(short) + log("Quantized weights:") + for cat in sorted(cats): + log(f" {cat}: {', '.join(sorted(cats[cat]))}") + return result, meta + + +def dequantize_mixed(result, meta, template_sd): + out = {} + for name, orig in template_sd.items(): + info = meta.get(name) + if info is None: + continue + od = orig.dtype + if "passthrough" in info: + t = result[name] + out[name] = t.to(od) if t.dtype == torch.float16 and od in (torch.float32, torch.bfloat16) else t + continue + q, s = result[name + ".q"], result[name + ".scale"] + if s.ndim > 0: + out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(od) + else: + out[name] = (q.float() * float(s.item())).to(od) + return out + + +# --------------------------------------------------------------------------- +# Compression (byte-shuffle + brotli) +# --------------------------------------------------------------------------- +_BSHF_MAGIC = b"BSHF" + + +def _byte_shuffle(data, stride=2): + if stride <= 1 or len(data) < stride: + return data + src = np.frombuffer(data, dtype=np.uint8) + n = len(src) + out = np.empty(n, dtype=np.uint8) + off = 0 + for pos in range(stride): + chunk = src[pos::stride] + out[off:off + len(chunk)] = chunk + off += len(chunk) + return _BSHF_MAGIC + bytes([stride]) + out.tobytes() + + +def _byte_unshuffle(data): + if len(data) < 5 or data[:4] != _BSHF_MAGIC: + return data + stride = data[4] + if stride < 2: + return data[5:] + payload = np.frombuffer(data, dtype=np.uint8, offset=5) + n = len(payload) + out = np.empty(n, dtype=np.uint8) + off = 0 + for pos in range(stride): + cl = n // stride + (1 if pos < n % stride else 0) + out[pos::stride][:cl] = payload[off:off + cl] + off += cl + return out.tobytes() + + +def _compress(data, compressor): + data = _byte_shuffle(data) + if compressor == "lzma": + return lzma.compress(data, preset=6) + elif compressor == "brotli": + import brotli + return brotli.compress(data, quality=11) + raise ValueError(f"Unknown compressor: {compressor!r}") + + +def _decompress(data, compressor): + if compressor == "lzma": + raw = lzma.decompress(data) + elif compressor == "brotli": + import brotli + raw = brotli.decompress(data) + else: + raise ValueError(f"Unknown compressor: {compressor!r}") + return _byte_unshuffle(raw) + + +# --------------------------------------------------------------------------- +# Serialization +# --------------------------------------------------------------------------- +def serialize(h, base_model, code): + code_bytes = len(code.encode("utf-8")) + if h.is_main_process: + torch.save(base_model.state_dict(), h.model_path) + log(f"Serialized model: {os.path.getsize(h.model_path)} bytes") + log(f"Code size: {code_bytes} bytes") + + sd_cpu = {k: v.detach().cpu() for k, v in base_model.state_dict().items()} + device = torch.device("cuda", h.local_rank) + log("GPTQ:collecting Hessians...") + t0 = time.perf_counter() + calib_loader = ShuffledSequenceLoader(h, device) + hessians = collect_hessians(base_model, calib_loader, h, device, + n_calibration_batches=h.gptq_calibration_batches) + log(f"GPTQ:collected {len(hessians)} Hessians in {time.perf_counter() - t0:.1f}s") + if _SPARSITY_ENABLED: + log("Applying Hessian-guided 2:4 sparsity...") + sd_cpu = apply_sparsity_to_state_dict(sd_cpu, _SPARSITY_APPLY_TO, hessians=hessians) + + quant_result, quant_meta = gptq_mixed_quantize(sd_cpu, hessians, h) + quant_buf = io.BytesIO() + torch.save({"w": quant_result, "m": quant_meta}, quant_buf) + quant_blob = _compress(quant_buf.getvalue(), h.compressor) + qfb = len(quant_blob) + bt = qfb + code_bytes + + if h.is_main_process: + with open(h.quantized_model_path, "wb") as f: + f.write(quant_blob) + log(f"Serialized quantized+{h.compressor}: {qfb} bytes", console=True) + log(f"Total submission: {bt} bytes", console=True) + return bt, qfb + + +def deserialize(h, device): + eval_model = GPT(h).to(device).bfloat16() + restore_fp32_params(eval_model) + sd_cpu = {k: v.detach().cpu() for k, v in eval_model.state_dict().items()} + with open(h.quantized_model_path, "rb") as f: + qbd = f.read() + qs = torch.load(io.BytesIO(_decompress(qbd, h.compressor)), map_location="cpu") + deq = dequantize_mixed(qs["w"], qs["m"], sd_cpu) + eval_model.load_state_dict(deq, strict=True) + return eval_model + + +# --------------------------------------------------------------------------- +# Evaluation +# --------------------------------------------------------------------------- +def _loss_bpb(loss_sum, token_count, byte_count): + vl = (loss_sum / token_count).item() + return vl, vl / math.log(2.0) * (token_count.item() / byte_count.item()) + + +def eval_val(h, device, val_data, model): + seq_len = h.eval_seq_len + lbt = h.val_batch_tokens // (h.world_size * h.grad_accum_steps) + assert lbt >= seq_len, "VAL_BATCH_TOKENS too small" + lbs = lbt // seq_len + total_seqs = (val_data.val_tokens.numel() - 1) // seq_len + ss = (total_seqs * h.rank) // h.world_size + se = (total_seqs * (h.rank + 1)) // h.world_size + vls = torch.zeros((), device=device, dtype=torch.float64) + vtc = torch.zeros((), device=device, dtype=torch.float64) + vbc = torch.zeros((), device=device, dtype=torch.float64) + model.eval() + with torch.inference_mode(): + for bss in range(ss, se, lbs): + bse = min(bss + lbs, se) + rs = bss * seq_len + re_ = bse * seq_len + 1 + local = val_data.val_tokens[rs:re_].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + bl = model(x, y).detach() + btc = float(y.numel()) + vls += bl.to(torch.float64) * btc + vtc += btc + prev = x.reshape(-1) + tgt = y.reshape(-1) + tb = val_data.base_bytes_lut[tgt].to(dtype=torch.int16) + tb += (val_data.has_leading_space_lut[tgt] & ~val_data.is_boundary_token_lut[prev]).to(dtype=torch.int16) + vbc += tb.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(vls, op=dist.ReduceOp.SUM) + dist.all_reduce(vtc, op=dist.ReduceOp.SUM) + dist.all_reduce(vbc, op=dist.ReduceOp.SUM) + model.train() + return _loss_bpb(vls, vtc, vbc) + + +def eval_val_sliding(h, device, val_data, base_model, batch_seqs=32): + base_model.eval() + logits_fn = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) \ + if _USE_TORCH_COMPILE else base_model.forward_logits + seq_len = h.eval_seq_len + ctx = seq_len - h.eval_stride + tt = val_data.val_tokens.numel() - 1 + ws_list = [ws for ws in range(0, tt, h.eval_stride) if ws + ctx < tt] + tw = len(ws_list) + ms = (tw * h.rank) // h.world_size + me = (tw * (h.rank + 1)) // h.world_size + mw = ws_list[ms:me] + ls = torch.zeros((), device=device, dtype=torch.float64) + tc = torch.zeros((), device=device, dtype=torch.float64) + bc = torch.zeros((), device=device, dtype=torch.float64) + with torch.inference_mode(): + for bi in range(0, len(mw), batch_seqs): + bws = mw[bi:bi + batch_seqs] + bsz = len(bws) + xb = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + yb = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens = [] + for i, ws in enumerate(bws): + we = min(ws + seq_len, tt) + wlen = we - ws + wlens.append(wlen) + chunk = val_data.val_tokens[ws:we + 1].to(dtype=torch.int64, device=device) + xb[i, :wlen] = chunk[:-1] + yb[i, :wlen] = chunk[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = logits_fn(xb) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), yb.reshape(-1), reduction="none" + ).reshape(bsz, seq_len) + for i, ws in enumerate(bws): + wlen = wlens[i] + s = 0 if ws == 0 else ctx + ls += nll[i, s:wlen].to(torch.float64).sum() + tc += float(wlen - s) + tgt = yb[i, s:wlen] + prev = xb[i, s:wlen] + tb = val_data.base_bytes_lut[tgt].to(torch.float64) + tb += (val_data.has_leading_space_lut[tgt] & ~val_data.is_boundary_token_lut[prev]).to(torch.float64) + bc += tb.sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(ls, op=dist.ReduceOp.SUM) + dist.all_reduce(tc, op=dist.ReduceOp.SUM) + dist.all_reduce(bc, op=dist.ReduceOp.SUM) + base_model.train() + return _loss_bpb(ls, tc, bc) + + +# --------------------------------------------------------------------------- +# TTT (Test-Time Training) - score-first chunk-based SGD from top submissions +# --------------------------------------------------------------------------- +def eval_val_ttt(h, device, val_data, base_model, batch_seqs=32): + """Score-first TTT: score each chunk under no_grad, then train on non-final chunks via SGD.""" + rank = h.rank + world_size = h.world_size + seq_len = h.eval_seq_len + stride = h.eval_stride + total_tokens = val_data.val_tokens.numel() - 1 + ttt_chunk = h.ttt_chunk_tokens + context_size = seq_len - stride + + # Build sliding windows and assign to chunks + window_starts = [ws for ws in range(0, total_tokens, stride) if ws + context_size < total_tokens] + num_chunks = (total_tokens + ttt_chunk - 1) // ttt_chunk + chunk_windows = [[] for _ in range(num_chunks)] + for ws in window_starts: + wlen = min(ws + seq_len, total_tokens) - ws + s = 0 if ws == 0 else context_size + scored_start = ws + s + ci = min(scored_start // ttt_chunk, num_chunks - 1) + chunk_windows[ci].append(ws) + + log(f"ttt:start chunks={num_chunks} ttt_lr={h.ttt_lr} ttt_epochs={h.ttt_epochs}") + + compiled_logits = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) \ + if _USE_TORCH_COMPILE else base_model.forward_logits + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + + ttt_params = [p for p in base_model.parameters()] + for p in ttt_params: + p.requires_grad_(True) + optimizer = torch.optim.SGD(ttt_params, lr=h.ttt_lr, momentum=h.ttt_momentum) + + for ci in range(num_chunks): + windows = chunk_windows[ci] + if not windows: + continue + + chunk_start = ci * ttt_chunk + chunk_end = min((ci + 1) * ttt_chunk, total_tokens) + my_s = len(windows) * rank // world_size + my_e = len(windows) * (rank + 1) // world_size + my_windows = windows[my_s:my_e] + + # Phase 1: Score windows under no_grad (score-first approach) + base_model.eval() + with torch.no_grad(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens = [] + for i, ws in enumerate(batch_ws): + we = min(ws + seq_len, total_tokens) + wlen = we - ws + wlens.append(wlen) + chunk_tok = val_data.val_tokens[ws:we + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk_tok[:-1] + y_batch[i, :wlen] = chunk_tok[1:] + + with torch.autocast(device_type='cuda', dtype=torch.bfloat16): + logits = compiled_logits(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), reduction='none' + ).reshape(bsz, seq_len) + + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else context_size + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt = y_batch[i, s:wlen] + prev = x_batch[i, s:wlen] + tb = val_data.base_bytes_lut[tgt].to(torch.float64) + tb += (val_data.has_leading_space_lut[tgt] + & ~val_data.is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + + # Phase 2: Train on chunk tokens (skip last chunk) + is_last_chunk = ci == num_chunks - 1 + if not is_last_chunk and h.ttt_epochs > 0: + base_model.train() + chunk_seqs = (chunk_end - chunk_start) // seq_len + if chunk_seqs > 0: + cos_lr = h.ttt_lr * 0.5 * (1.0 + math.cos(math.pi * ci / max(num_chunks - 1, 1))) + for pg in optimizer.param_groups: + pg['lr'] = cos_lr + my_seq_s = chunk_seqs * rank // world_size + my_seq_e = chunk_seqs * (rank + 1) // world_size + my_chunk_seqs = my_seq_e - my_seq_s + for _ep in range(h.ttt_epochs): + for bs in range(0, my_chunk_seqs, batch_seqs): + be = min(bs + batch_seqs, my_chunk_seqs) + actual_bs = my_seq_s + bs + start_tok = chunk_start + actual_bs * seq_len + end_tok = chunk_start + (my_seq_s + be) * seq_len + 1 + if end_tok > val_data.val_tokens.numel(): + continue + local = val_data.val_tokens[start_tok:end_tok].to( + device=device, dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + optimizer.zero_grad(set_to_none=True) + with torch.autocast(device_type='cuda', dtype=torch.bfloat16): + loss = base_model(x, y) + loss.backward() + if world_size > 1: + for p in ttt_params: + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + torch.nn.utils.clip_grad_norm_(ttt_params, 1.0) + optimizer.step() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + + for p in base_model.parameters(): + p.requires_grad_(True) + base_model.eval() + return _loss_bpb(loss_sum, token_count, byte_count) + + +# --------------------------------------------------------------------------- +# Training +# --------------------------------------------------------------------------- +def timed_eval(label, fn, *args, **kwargs): + torch.cuda.synchronize() + t0 = time.perf_counter() + vl, vb = fn(*args, **kwargs) + torch.cuda.synchronize() + log(f"{label} val_loss:{vl:.8f} val_bpb:{vb:.8f} eval_time:{1000.0 * (time.perf_counter() - t0):.0f}ms", + console=True) + return vl, vb + + +def train_model(h, device, val_data): + base_model = GPT(h).to(device).bfloat16() + restore_fp32_params(base_model) + + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) \ + if _USE_TORCH_COMPILE else base_model + model = DDP(compiled_model, device_ids=[h.local_rank], broadcast_buffers=False) \ + if h.distributed else compiled_model + + af = [] + if _CAT_ENABLED: + af.append(f"CAT(weight={_CAT_WEIGHT},bits={_CAT_BITS},every={_CAT_EVERY})") + if _SPARSITY_ENABLED: + af.append(f"2:4_Sparsity(apply_to={_SPARSITY_APPLY_TO})") + if h.parallel_residual_start >= 0: + af.append(f"ParallelResid(from_layer={h.parallel_residual_start})") + if h.ttt_enabled: + af.append(f"TTT(lr={h.ttt_lr},epochs={h.ttt_epochs},chunk={h.ttt_chunk_tokens})") + log(f"novel_features: {', '.join(af)}" if af else "novel_features: none (baseline)", console=True) + log(f"test_mode: {_TEST_MODE}", console=True) + log(f"fa3_available: {_HAS_FA3}") + log(f"torch_compile: {_USE_TORCH_COMPILE}") + log(f"model_params:{sum(p.numel() for p in base_model.parameters())}", console=True) + + optimizers = Optimizers(h, base_model) + train_loader = ShuffledSequenceLoader(h, device) + max_wc_ms = 1000.0 * h.max_wallclock_seconds if h.max_wallclock_seconds > 0 else None + if max_wc_ms is not None: + max_wc_ms -= h.gptq_reserve_seconds * 1000.0 + log(f"gptq:reserving {h.gptq_reserve_seconds:.0f}s, effective={max_wc_ms:.0f}ms") + + def training_frac(step, elapsed_ms): + if max_wc_ms is None: + return step / max(h.iterations, 1) + return elapsed_ms / max(max_wc_ms, 1e-9) + + def lr_mul(frac): + if h.warmdown_frac <= 0: return 1.0 - if max_wallclock_ms is None: - warmdown_start = max(args.iterations - args.warmdown_iters, 0) - return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 - step_ms = elapsed_ms / max(step, 1) - warmdown_ms = args.warmdown_iters * step_ms - remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) - return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 - - # Warmup primes the compiled forward/backward/optimizer paths, then we restore the - # initial weights/optimizer state so measured training starts from the true init. - if args.warmup_steps > 0: - initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} - initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + if frac >= 1.0 - h.warmdown_frac: + return max((1.0 - frac) / h.warmdown_frac, h.min_lr) + return 1.0 + + def step_fn(step, lr_scale): + optimizers.zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(h.grad_accum_steps): + if h.distributed: + model.require_backward_grad_sync = micro_step == h.grad_accum_steps - 1 + x, y = train_loader.next_batch(h.train_batch_tokens, h.grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + if _CAT_ENABLED and step % _CAT_EVERY == 0: + loss = loss + _CAT_WEIGHT * cat_compression_loss(base_model, bits=_CAT_BITS) + train_loss += loss.detach() + (loss / h.grad_accum_steps).backward() + train_loss /= h.grad_accum_steps + + frac = min(step / h.muon_momentum_warmup_steps, 1.0) if h.muon_momentum_warmup_steps > 0 else 1.0 + mm = (1 - frac) * h.muon_momentum_warmup_start + frac * h.muon_momentum + for group in optimizers.optimizer_muon.param_groups: + group["momentum"] = mm + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * lr_scale + if h.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), h.grad_clip_norm) + optimizers.step() + return train_loss + + # Warmup phase + if h.warmup_steps > 0: + ims = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + ios = [copy.deepcopy(opt.state_dict()) for opt in optimizers] model.train() - for warmup_step in range(args.warmup_steps): - zero_grad_all() - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - warmup_loss = model(x, y) - (warmup_loss * grad_scale).backward() - for opt in optimizers: - opt.step() - zero_grad_all() - if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: - log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") - base_model.load_state_dict(initial_model_state, strict=True) - for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + for ws in range(h.warmup_steps): + step_fn(ws, 1.0) + if ws <= 5 or (ws + 1) % 10 == 0 or ws + 1 == h.warmup_steps: + log(f"warmup_step: {ws + 1}/{h.warmup_steps}") + if h.num_loops > 0: + base_model.looping_active = True + log(f"loop_warmup:enabled encoder:{base_model.encoder_indices} decoder:{base_model.decoder_indices}") + for ws in range(h.warmup_steps): + step_fn(ws, 1.0) + if ws <= 5 or (ws + 1) % 10 == 0 or ws + 1 == h.warmup_steps: + log(f"loop_warmup_step: {ws + 1}/{h.warmup_steps}") + base_model.looping_active = False + base_model.load_state_dict(ims, strict=True) + for opt, state in zip(optimizers, ios, strict=True): opt.load_state_dict(state) - zero_grad_all() - if distributed: + optimizers.zero_grad_all() + if h.distributed: model.require_backward_grad_sync = True - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - - # ----------------------------- - # MAIN TRAINING LOOP - # ----------------------------- + train_loader = ShuffledSequenceLoader(h, device) + # Main training loop + ema_state = {name: t.detach().float().clone() for name, t in base_model.state_dict().items()} + ema_decay = h.ema_decay training_time_ms = 0.0 - stop_after_step: int | None = None + stop_after_step = None torch.cuda.synchronize() t0 = time.perf_counter() - step = 0 + while True: - last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + last_step = step == h.iterations or (stop_after_step is not None and step >= stop_after_step) + should_validate = last_step or (h.val_loss_every > 0 and step % h.val_loss_every == 0) - should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) if should_validate: torch.cuda.synchronize() training_time_ms += 1000.0 * (time.perf_counter() - t0) - val_loss, val_bpb = eval_val( - args, - model, - rank, - world_size, - device, - grad_accum_steps, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, - ) - log0( - f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " - f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" - ) + vl, vb = eval_val(h, device, val_data, model) + log(f"{step}/{h.iterations} val_loss: {vl:.4f} val_bpb: {vb:.4f}", console=True) torch.cuda.synchronize() t0 = time.perf_counter() if last_step: - if stop_after_step is not None and step < args.iterations: - log0( - f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " - f"step:{step}/{args.iterations}" - ) + if stop_after_step is not None and step < h.iterations: + log(f"stopping_early: wallclock_cap train_time: {training_time_ms:.0f}ms step: {step}/{h.iterations}", + console=True) break elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - scale = lr_mul(step, elapsed_ms) - zero_grad_all() - train_loss = torch.zeros((), device=device) - for micro_step in range(grad_accum_steps): - if distributed: - model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 - x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - loss = model(x, y) - train_loss += loss.detach() - (loss * grad_scale).backward() - train_loss /= grad_accum_steps + frac = training_frac(step, elapsed_ms) + scale = lr_mul(frac) - frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 - muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum - for group in optimizer_muon.param_groups: - group["momentum"] = muon_momentum + if h.num_loops > 0 and not base_model.looping_active and frac >= h.enable_looping_at: + base_model.looping_active = True + log(f"layer_loop:enabled step:{step} frac:{frac:.3f} " + f"encoder:{base_model.encoder_indices} decoder:{base_model.decoder_indices}") - for opt in optimizers: - for group in opt.param_groups: - group["lr"] = group["base_lr"] * scale - - if args.grad_clip_norm > 0: - torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) - for opt in optimizers: - opt.step() - zero_grad_all() + train_loss = step_fn(step, scale) + with torch.no_grad(): + for name, t in base_model.state_dict().items(): + ema_state[name].mul_(ema_decay).add_(t.detach().float(), alpha=1.0 - ema_decay) step += 1 - approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - should_log_train = ( - args.train_log_every > 0 - and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) - ) - if should_log_train: - log0( - f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " - f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" - ) - - # Needed to sync whether we've reached the wallclock cap. - reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms - if distributed and max_wallclock_ms is not None: - reached_cap_tensor = torch.tensor(int(reached_cap), device=device) - dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) - reached_cap = bool(reached_cap_tensor.item()) + approx_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + if h.train_log_every > 0 and (step <= 5 or step % h.train_log_every == 0 + or stop_after_step is not None): + tps = step * h.train_batch_tokens / (approx_ms / 1000.0) + log(f"{step}/{h.iterations} train_loss: {train_loss.item():.4f} " + f"train_time: {approx_ms / 60000:.1f}m tok/s: {tps:.0f}") + + reached_cap = max_wc_ms is not None and approx_ms >= max_wc_ms + if h.distributed and max_wc_ms is not None: + rct = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(rct, op=dist.ReduceOp.MAX) + reached_cap = bool(rct.item()) if stop_after_step is None and reached_cap: stop_after_step = step - log0( - f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " - f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" - ) - - # ----------------------------- - # SERIALIZATION + ROUNDTRIP VALIDATION - # ----------------------------- - # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce - # the compressed int8+zlib artifact and validate the round-tripped weights. - - if master_process: - torch.save(base_model.state_dict(), "final_model.pt") - model_bytes = os.path.getsize("final_model.pt") - code_bytes = len(code.encode("utf-8")) - log0(f"Serialized model: {model_bytes} bytes") - log0(f"Code size: {code_bytes} bytes") - log0(f"Total submission size: {model_bytes + code_bytes} bytes") - - quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) - quant_buf = io.BytesIO() - torch.save(quant_obj, quant_buf) - quant_raw = quant_buf.getvalue() - quant_blob = zlib.compress(quant_raw, level=9) - quant_raw_bytes = len(quant_raw) - if master_process: - with open("final_model.int8.ptz", "wb") as f: - f.write(quant_blob) - quant_file_bytes = os.path.getsize("final_model.int8.ptz") - code_bytes = len(code.encode("utf-8")) - ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) - log0( - f"Serialized model int8+zlib: {quant_file_bytes} bytes " - f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" - ) - log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") + log(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB") + log("ema:applying EMA weights") + cs = base_model.state_dict() + avg = {name: t.to(dtype=cs[name].dtype) for name, t in ema_state.items()} + base_model.load_state_dict(avg, strict=True) + return base_model, compiled_model + + +# --------------------------------------------------------------------------- +# Train + Eval pipeline +# --------------------------------------------------------------------------- +def train_and_eval(h, device): + random.seed(h.seed) + np.random.seed(h.seed) + torch.manual_seed(h.seed) + torch.cuda.manual_seed_all(h.seed) + val_data = ValidationData(h, device) + log(f"train_shards: {len(list(Path(h.datasets_dir).resolve().glob('fineweb_train_*.bin')))}") + log(f"val_tokens: {val_data.val_tokens.numel() - 1}") + + base_model, compiled_model = train_model(h, device, val_data) + if _USE_TORCH_COMPILE: + torch._dynamo.reset() + timed_eval("pre-quantization post-ema", eval_val, h, device, val_data, compiled_model) + serialize(h, base_model, Path(__file__).read_text(encoding="utf-8")) + + if h.distributed: + dist.barrier() + eval_model = deserialize(h, device) + if h.num_loops > 0: + eval_model.looping_active = True + compiled_model = torch.compile(eval_model, dynamic=False, fullgraph=True) \ + if _USE_TORCH_COMPILE else eval_model + + timed_eval("quantized", eval_val, h, device, val_data, compiled_model) + + if h.sliding_window_enabled: + timed_eval("quantized_sliding_window", eval_val_sliding, h, device, val_data, eval_model) + # TTT evaluation (score-first SGD adaptation at eval time) + if h.ttt_enabled and h.sliding_window_enabled: + timed_eval("quantized_ttt", eval_val_ttt, h, device, val_data, eval_model) + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- +def main(): + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + assert torch.cuda.is_available(), "CUDA required" + assert world_size > 0 and 8 % world_size == 0 + + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) if distributed: + dist.init_process_group(backend="nccl", device_id=device) dist.barrier() - with open("final_model.int8.ptz", "rb") as f: - quant_blob_disk = f.read() - quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") - base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) - torch.cuda.synchronize() - t_qeval = time.perf_counter() - q_val_loss, q_val_bpb = eval_val( - args, - model, - rank, - world_size, - device, - grad_accum_steps, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, - ) - torch.cuda.synchronize() - log0( - f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " - f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" - ) - log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + torch.set_float32_matmul_precision("high") + + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + enable_cudnn_sdp(False) + enable_flash_sdp(True) + if _HAS_FA3: + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + else: + enable_mem_efficient_sdp(True) + enable_math_sdp(True) + + torch._dynamo.config.optimize_ddp = False + h = Hyperparameters() + set_logging_hparams(h) + + if h.is_main_process: + os.makedirs("logs", exist_ok=True) + log(100 * "=", console=False) + log("Hyperparameters:") + for k, v in sorted(vars(type(h)).items()): + if not k.startswith("_"): + log(f" {k}: {v}") + log("=" * 100, console=False) + log(f"Running Python {sys.version}", console=False) + log(f"Running PyTorch {torch.__version__}", console=False) + log(subprocess.run( + ["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False + ).stdout, console=False) + log("=" * 100, console=False) + + train_and_eval(h, device) if distributed: dist.destroy_process_group() diff --git a/train_gpt_mlx.py b/train_gpt_mlx.py deleted file mode 100644 index 7b9e935aa6..0000000000 --- a/train_gpt_mlx.py +++ /dev/null @@ -1,1104 +0,0 @@ -#!/usr/bin/env python3 -""" -The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. - -Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. -""" -from __future__ import annotations - -import glob -import json -import math -import os -import pickle -import sys -import time -import uuid -import zlib -from collections.abc import Callable -from pathlib import Path - -import numpy as np -import sentencepiece as spm - -import mlx.core as mx -import mlx.nn as nn -import mlx.optimizers as optim -from mlx.utils import tree_flatten, tree_unflatten - -# ============================================================================== -# SHARD FORMAT + COMPUTE DTYPE -# ============================================================================== - -COMPUTE_DTYPE = mx.bfloat16 - -# ============================================================================== -# HYPERPARAMETERS -# ============================================================================== -# Default Simple Baseline run: -# - 9 transformer blocks at width 512 -# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion -# - vocab size 1024, sequence length 1024, tied embeddings -# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap -class Hyperparameters: - # Data / tokenizer. - data_path: str = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") - tokenizer_path: str = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") - run_id: str = os.environ.get("RUN_ID", str(uuid.uuid4())) - seed: int = int(os.environ.get("SEED", 1337)) - - # Training loop. These defaults now mirror train_gpt.py on a single process. - iterations: int = int(os.environ.get("ITERATIONS", 20_000)) - val_loss_every: int = int(os.environ.get("VAL_LOSS_EVERY", 0)) - # Validation always uses the full fineweb_val split. - val_batch_size: int = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) - train_log_every: int = int(os.environ.get("TRAIN_LOG_EVERY", 200)) - train_batch_tokens: int = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) - grad_accum_steps: int = int(os.environ.get("GRAD_ACCUM_STEPS", 8)) - train_seq_len: int = int(os.environ.get("TRAIN_SEQ_LEN", os.environ.get("TRAIN_MAX_SEQ_LEN", 1024))) - # Chunk each logical MLX microbatch into smaller sub-batches to reduce peak - # memory pressure without changing the effective optimizer batch. - mlx_max_microbatch_tokens: int = int(os.environ.get("MLX_MAX_MICROBATCH_TOKENS", 8_192)) - # Force MLX to materialize the graph after every sub-batch, preventing lazy - # graph buildup across accumulation steps. Keeps peak memory low on 16GB machines. - # Disable on 32GB+ unified memory for better throughput (MLX_EAGER_EVAL=0). - mlx_eager_eval: bool = bool(int(os.environ.get("MLX_EAGER_EVAL", "1"))) - warmup_steps: int = int(os.environ.get("WARMUP_STEPS", 20)) - warmdown_iters: int = int(os.environ.get("WARMDOWN_ITERS", 1200)) - max_wallclock_seconds: float = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) - - # Model (defaults match the current baseline setup). - vocab_size: int = int(os.environ.get("VOCAB_SIZE", 1024)) - num_layers: int = int(os.environ.get("NUM_LAYERS", 9)) - model_dim: int = int(os.environ.get("MODEL_DIM", 512)) - num_heads: int = int(os.environ.get("NUM_HEADS", 8)) - num_kv_heads: int = int(os.environ.get("NUM_KV_HEADS", 4)) - mlp_mult: int = int(os.environ.get("MLP_MULT", 2)) - tie_embeddings: bool = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) - tied_embed_init_std: float = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) - logit_chunk_tokens: int = int(os.environ.get("LOGIT_CHUNK_TOKENS", 0)) - logit_softcap: float = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) - rope_base: float = float(os.environ.get("ROPE_BASE", 10000.0)) - qk_gain_init: float = float(os.environ.get("QK_GAIN_INIT", 1.5)) - - # Optimizer. We keep the same per-group defaults as train_gpt.py. - beta1: float = float(os.environ.get("BETA1", 0.9)) - beta2: float = float(os.environ.get("BETA2", 0.95)) - adam_eps: float = float(os.environ.get("ADAM_EPS", 1e-8)) - tied_embed_lr: float = float(os.environ.get("TIED_EMBED_LR", 0.05)) - matrix_lr: float = float(os.environ.get("MATRIX_LR", 0.04)) - scalar_lr: float = float(os.environ.get("SCALAR_LR", 0.04)) - muon_momentum: float = float(os.environ.get("MUON_MOMENTUM", 0.95)) - muon_backend_steps: int = int(os.environ.get("MUON_BACKEND_STEPS", 5)) - muon_momentum_warmup_start: float = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) - muon_momentum_warmup_steps: int = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) - grad_clip_norm: float = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) - - out_dir: str = os.environ.get("OUT_DIR", "logs") - - @property - def train_files(self) -> str: - return f"{self.data_path}/fineweb_train_*.bin" - - @property - def val_files(self) -> str: - return f"{self.data_path}/fineweb_val_*.bin" - - @property - def microbatch_tokens(self) -> int: - return self.train_batch_tokens // self.grad_accum_steps - - def lr_mul(self, step: int, elapsed_ms: float) -> float: - if self.warmdown_iters <= 0: - return 1.0 - if self.max_wallclock_seconds <= 0: - warmdown_start = max(self.iterations - self.warmdown_iters, 0) - return max((self.iterations - step) / max(self.warmdown_iters, 1), 0.0) if warmdown_start <= step < self.iterations else 1.0 - step_ms = elapsed_ms / max(step, 1) - warmdown_ms = self.warmdown_iters * step_ms - remaining_ms = max(1000.0 * self.max_wallclock_seconds - elapsed_ms, 0.0) - return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 - - -CONTROL_TENSOR_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "CONTROL_TENSOR_NAME_PATTERNS", - "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", - ).split(",") - if pattern -) -INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", - ",".join(CONTROL_TENSOR_NAME_PATTERNS), - ).split(",") - if pattern -) - - -def token_chunks(total_tokens: int, seq_len: int, max_chunk_tokens: int) -> list[int]: - usable_total = (total_tokens // seq_len) * seq_len - if usable_total <= 0: - raise ValueError(f"token budget too small for seq_len={seq_len}") - usable_chunk = max((max_chunk_tokens // seq_len) * seq_len, seq_len) - chunks: list[int] = [] - remaining = usable_total - while remaining > 0: - chunk = min(remaining, usable_chunk) - chunks.append(chunk) - remaining -= chunk - return chunks - - -def accumulate_flat_grads( - accum: dict[str, mx.array] | None, - grads_tree: dict, - scale: float, -) -> dict[str, mx.array]: - flat = dict(tree_flatten(grads_tree)) - if accum is None: - return {k: g * scale for k, g in flat.items()} - for k, g in flat.items(): - accum[k] = accum[k] + g * scale - return accum - - -# ============================================================================== -# MATH HELPERS -# ============================================================================== - -def rms_norm(x: mx.array, eps: float = 1e-6) -> mx.array: - return (x * mx.rsqrt(mx.mean(x * x, axis=-1, keepdims=True) + eps)).astype(x.dtype) - - -def zeropower_newtonschulz5(g: mx.array, steps: int, eps: float = 1e-7) -> mx.array: - # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. - # Muon uses this to normalize matrix-shaped gradients before applying them. - # Background on Muon: https://kellerjordan.github.io/posts/muon/ - a, b, c = 3.4445, -4.7750, 2.0315 - x = g.astype(mx.float32) - x = x / (mx.sqrt(mx.sum(x * x)) + eps) - transposed = x.shape[0] > x.shape[1] - if transposed: - x = x.T - for _ in range(steps): - a_mat = x @ x.T - b_mat = b * a_mat + c * (a_mat @ a_mat) - x = a * x + b_mat @ x - if transposed: - x = x.T - return x.astype(g.dtype) - - -def load_data_shard(path: Path) -> np.ndarray: - header_bytes = 256 * np.dtype(" None: - self.file_idx = (self.file_idx + 1) % len(self.files) - if self.file_idx == 0: - self.epoch += 1 - if self.log_fn is not None: - self.log_fn( - f"WARNING: starting epoch:{self.epoch} " - f"dataset:{self.dataset_name} train_shards:{len(self.files)}" - ) - self.tokens = load_data_shard(self.files[self.file_idx]) - self.pos = 0 - - def take(self, n: int) -> np.ndarray: - chunks: list[np.ndarray] = [] - left = n - while left > 0: - if self.pos >= self.tokens.size: - self.next_file() - k = min(left, int(self.tokens.size - self.pos)) - chunks.append(self.tokens[self.pos : self.pos + k]) - self.pos += k - left -= k - return chunks[0] if len(chunks) == 1 else np.concatenate(chunks, axis=0) - - -class TokenLoader: - def __init__( - self, - pattern: str, - log_fn: Callable[[str], None] | None = None, - dataset_name: str = "", - ): - self.stream = TokenStream(pattern, log_fn=log_fn, dataset_name=dataset_name) - - def next_batch(self, batch_tokens: int, seq_len: int) -> tuple[mx.array, mx.array]: - usable = (batch_tokens // seq_len) * seq_len - if usable <= 0: - raise ValueError(f"token budget too small for seq_len={seq_len}") - chunk = self.stream.take(usable + 1) - x = chunk[:-1].reshape(-1, seq_len) - y = chunk[1:].reshape(-1, seq_len) - return mx.array(x, dtype=mx.int32), mx.array(y, dtype=mx.int32) - - -# ============================================================================== -# MODEL BLOCKS -# ============================================================================== - -class CastedLinear(nn.Module): - def __init__(self, in_dim: int, out_dim: int): - super().__init__() - self.weight = nn.Linear(in_dim, out_dim, bias=False).weight.astype(mx.float32) - - def __call__(self, x: mx.array) -> mx.array: - return x @ self.weight.astype(x.dtype).T - - -class RMSNormNoWeight(nn.Module): - # MLX module wrapper around the functional RMSNorm helper so it composes nicely in blocks. - def __call__(self, x: mx.array) -> mx.array: - return rms_norm(x) - - -class CausalSelfAttention(nn.Module): - # - separate q/k/v projections - # - RMSNorm on q and k before attention - # - RoPE on q and k - # - causal masked SDPA - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - rope_base: float, - qk_gain_init: float, - ): - super().__init__() - if dim % num_heads != 0: - raise ValueError("model_dim must be divisible by num_heads") - if num_heads % num_kv_heads != 0: - raise ValueError("num_heads must be divisible by num_kv_heads") - self.num_heads = num_heads - self.num_kv_heads = num_kv_heads - self.head_dim = dim // num_heads - if self.head_dim % 2 != 0: - raise ValueError("head_dim must be even for RoPE") - kv_dim = self.num_kv_heads * self.head_dim - self.c_q = CastedLinear(dim, dim) - self.c_k = CastedLinear(dim, kv_dim) - self.c_v = CastedLinear(dim, kv_dim) - self.proj = CastedLinear(dim, dim) - self.q_gain = mx.ones((num_heads,), dtype=mx.float32) * qk_gain_init - self.rope = nn.RoPE(self.head_dim, traditional=False, base=rope_base) - self.scale = self.head_dim ** -0.5 - - def __call__(self, x: mx.array) -> mx.array: - bsz, seqlen, dim = x.shape - q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(0, 2, 1, 3) - k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(0, 2, 1, 3) - v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(0, 2, 1, 3) - - q = self.rope(rms_norm(q).astype(COMPUTE_DTYPE)) - k = self.rope(rms_norm(k).astype(COMPUTE_DTYPE)) - q = q * self.q_gain.astype(q.dtype)[None, :, None, None] - y = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale, mask="causal") - y = y.transpose(0, 2, 1, 3).reshape(bsz, seqlen, dim) - return self.proj(y) - - -class MLP(nn.Module): - # Baseline MLP uses relu^2 instead of GELU/SiLU. It is cheap and works well in this setup. - def __init__(self, dim: int, mlp_mult: int): - super().__init__() - hidden = dim * mlp_mult - self.fc = CastedLinear(dim, hidden) - self.proj = CastedLinear(hidden, dim) - - def __call__(self, x: mx.array) -> mx.array: - x = nn.relu(self.fc(x)) - return self.proj(x * x) - - -class Block(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: int, - rope_base: float, - qk_gain_init: float, - ): - super().__init__() - self.attn_norm = RMSNormNoWeight() - self.mlp_norm = RMSNormNoWeight() - self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) - self.mlp = MLP(dim, mlp_mult) - self.attn_scale = mx.ones((dim,), dtype=mx.float32) - self.mlp_scale = mx.ones((dim,), dtype=mx.float32) - self.resid_mix = mx.array(np.stack((np.ones((dim,), dtype=np.float32), np.zeros((dim,), dtype=np.float32)))) - - def __call__(self, x: mx.array, x0: mx.array) -> mx.array: - mix = self.resid_mix.astype(x.dtype) - x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 - attn_out = self.attn(self.attn_norm(x)) - x = x + self.attn_scale.astype(x.dtype)[None, None, :] * attn_out - x = x + self.mlp_scale.astype(x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) - return x - - -class GPT(nn.Module): - # - token embedding + RMSNorm - # - encoder half accumulates skip tensors - # - decoder half consumes reversed skips with learned skip_weights - # - tied embeddings for the LM head (the baseline default setup) - def __init__(self, vocab_size: int, num_layers: int, dim: int, num_heads: int, num_kv_heads: int, mlp_mult: int, - logit_chunk_tokens: int, logit_softcap: float, rope_base: float, tied_embed_init_std: float, - qk_gain_init: float): - super().__init__() - if logit_softcap <= 0.0: - raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") - self.logit_chunk_tokens = logit_chunk_tokens - self.logit_softcap = logit_softcap - - self.tok_emb = nn.Embedding(vocab_size, dim) - self.num_encoder_layers = num_layers // 2 - self.num_decoder_layers = num_layers - self.num_encoder_layers - self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) - self.skip_weights = mx.ones((self.num_skip_weights, dim), dtype=mx.float32) - self.blocks = [ - Block(dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init) - for i in range(num_layers) - ] - self.final_norm = RMSNormNoWeight() - - for b in self.blocks: - b.attn.proj.weight = mx.zeros_like(b.attn.proj.weight) - b.mlp.proj.weight = mx.zeros_like(b.mlp.proj.weight) - self.tok_emb.weight = ( - mx.random.normal(self.tok_emb.weight.shape, dtype=mx.float32) * tied_embed_init_std - ).astype(COMPUTE_DTYPE) - - def softcap(self, logits: mx.array) -> mx.array: - c = self.logit_softcap - return c * mx.tanh(logits / c) - - def __call__(self, input_ids: mx.array) -> mx.array: - x = rms_norm(self.tok_emb(input_ids).astype(COMPUTE_DTYPE)) - x0 = x - skips: list[mx.array] = [] - - for i in range(self.num_encoder_layers): - x = self.blocks[i](x, x0) - skips.append(x) - for i in range(self.num_decoder_layers): - # Odd layer counts have one more decoder block than encoder block. The baseline only - # applies a skip connection when one exists, then runs the remaining decoder block(s) - # without an added skip. - if skips: - x = x + self.skip_weights[i].astype(x.dtype)[None, None, :] * skips.pop() - x = self.blocks[self.num_encoder_layers + i](x, x0) - return self.final_norm(x) - - def loss(self, input_ids: mx.array, target_ids: mx.array) -> mx.array: - # Cross-entropy over flattened tokens. We keep optional logit chunking because it is a useful - # memory knob on Macs, but the common path is chunk_tokens=0 (single matmul + CE). - x = self(input_ids).reshape(-1, self.tok_emb.weight.shape[1]) - y = target_ids.reshape(-1) - if self.logit_chunk_tokens <= 0 or x.shape[0] <= self.logit_chunk_tokens: - logits_proj = x @ self.tok_emb.weight.astype(x.dtype).T - logits = self.softcap(logits_proj) - return nn.losses.cross_entropy(logits.astype(mx.float32), y, reduction="mean") - - loss_sum = mx.array(0.0, dtype=mx.float32) - n = int(x.shape[0]) - for s in range(0, n, self.logit_chunk_tokens): - e = min(s + self.logit_chunk_tokens, n) - logits_proj = x[s:e] @ self.tok_emb.weight.astype(x.dtype).T - logits = self.softcap(logits_proj) - loss_sum = loss_sum + nn.losses.cross_entropy(logits.astype(mx.float32), y[s:e], reduction="sum") - return loss_sum / float(n) - -# ============================================================================== -# OPTIMIZERS (MUON + ADAM SPLIT) -# ============================================================================== -class Muon: - # Muon applies SGD-momentum to matrix gradients, then orthogonalizes the result before the - # parameter update. - def __init__(self, keys: list[str], params: dict[str, mx.array], args: Hyperparameters): - self.keys = keys - self.args = args - self.buffers = {k: mx.zeros_like(params[k]) for k in keys} - - def step(self, params: dict[str, mx.array], grads: dict[str, mx.array], step: int, lr_mul: float) -> dict[str, mx.array]: - if self.args.muon_momentum_warmup_steps: - t = min(step / self.args.muon_momentum_warmup_steps, 1.0) - momentum = (1.0 - t) * self.args.muon_momentum_warmup_start + t * self.args.muon_momentum - else: - momentum = self.args.muon_momentum - lr = self.args.matrix_lr * lr_mul - out: dict[str, mx.array] = {} - for k in self.keys: - p = params[k] - g = grads[k] - buf = momentum * self.buffers[k] + g - self.buffers[k] = buf - g_eff = g + momentum * buf - g_ortho = zeropower_newtonschulz5(g_eff, self.args.muon_backend_steps) - scale = math.sqrt(max(1.0, float(p.shape[0]) / float(p.shape[1]))) - out[k] = p - lr * (g_ortho * scale).astype(p.dtype) - return out - - -class SplitOptimizers: - # - embeddings: Adam with the tied-embedding LR - # - block matrices (2D): Muon - # - block scalars + skip weights: Adam - # This preserves the high-level optimization behavior even though MLX internals differ. - def __init__(self, model: GPT, args: Hyperparameters): - self.args = args - params = dict(tree_flatten(model.parameters())) - self.embed_key = "tok_emb.weight" - self.matrix_keys = [ - k - for k, p in params.items() - if k.startswith("blocks.") and p.ndim == 2 and not any(pattern in k for pattern in CONTROL_TENSOR_NAME_PATTERNS) - ] - self.scalar_keys = [ - k - for k, p in params.items() - if k == "skip_weights" or (k.startswith("blocks.") and (p.ndim < 2 or any(pattern in k for pattern in CONTROL_TENSOR_NAME_PATTERNS))) - ] - - self.muon = Muon(self.matrix_keys, params, args) - self.adam_embed = optim.Adam( - learning_rate=args.tied_embed_lr, - betas=[args.beta1, args.beta2], - eps=args.adam_eps, - bias_correction=True, - ) - self.adam_scalar = optim.Adam( - learning_rate=args.scalar_lr, - betas=[args.beta1, args.beta2], - eps=args.adam_eps, - bias_correction=True, - ) - - def step(self, model: GPT, grads_tree: dict, step: int, lr_mul: float) -> None: - params = dict(tree_flatten(model.parameters())) - grads = dict(tree_flatten(grads_tree)) - updated = dict(params) - - updated.update(self.muon.step(params, grads, step=step, lr_mul=lr_mul)) - - self.adam_embed.learning_rate = self.args.tied_embed_lr * lr_mul - updated.update( - self.adam_embed.apply_gradients( - {self.embed_key: grads[self.embed_key]}, - {self.embed_key: params[self.embed_key]}, - ) - ) - - self.adam_scalar.learning_rate = self.args.scalar_lr * lr_mul - scalar_grads = {k: grads[k] for k in self.scalar_keys} - scalar_params = {k: params[k] for k in self.scalar_keys} - updated.update(self.adam_scalar.apply_gradients(scalar_grads, scalar_params)) - - model.update(tree_unflatten(list(updated.items()))) - -# ============================================================================== -# QUANTIZATION (INT8 + ZLIB) -# ============================================================================== -# - per-row int8 for 2D float tensors -# - per-tensor int8 for other float tensors -# - fp16 passthrough for small float tensors -# - exact passthrough for non-floats - -MX_DTYPE_FROM_NAME = { - "float32": mx.float32, - "float16": mx.float16, - "bfloat16": mx.bfloat16, -} - -INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 -INT8_KEEP_FLOAT_STORE_DTYPE = np.float16 -INT8_PER_ROW_SCALE_DTYPE = np.float16 -INT8_CLIP_PERCENTILE = 99.99984 -INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 - - -def _np_float32(arr: mx.array) -> np.ndarray: - return np.array(arr.astype(mx.float32), dtype=np.float32, copy=False) - - -def keep_float_array(name: str, arr: mx.array, passthrough_orig_dtypes: dict[str, str]) -> np.ndarray: - if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): - return np.ascontiguousarray(_np_float32(arr)) - if arr.dtype in {mx.float32, mx.bfloat16}: - passthrough_orig_dtypes[name] = str(arr.dtype).split(".")[-1] - return np.ascontiguousarray(np.array(arr.astype(mx.float16), dtype=INT8_KEEP_FLOAT_STORE_DTYPE, copy=False)) - return np.ascontiguousarray(np.array(arr, copy=True)) - - -def quantize_float_array(arr: mx.array) -> tuple[np.ndarray, np.ndarray]: - f32 = _np_float32(arr) - if f32.ndim == 2: - # Matrices get one scale per row, which usually tracks output-channel - # ranges much better than a single tensor-wide scale. - clip_abs = np.quantile(np.abs(f32), INT8_CLIP_Q, axis=1) if f32.size else np.empty((f32.shape[0],), dtype=np.float32) - clipped = np.clip(f32, -clip_abs[:, None], clip_abs[:, None]) - scale = np.maximum(clip_abs / 127.0, 1.0 / 127.0).astype(np.float32, copy=False) - q = np.clip(np.round(clipped / scale[:, None]), -127, 127).astype(np.int8, copy=False) - return np.ascontiguousarray(q), np.ascontiguousarray(scale.astype(INT8_PER_ROW_SCALE_DTYPE, copy=False)) - - # Vectors / scalars use a simpler per-tensor scale. - clip_abs = float(np.quantile(np.abs(f32).reshape(-1), INT8_CLIP_Q)) if f32.size else 0.0 - scale = np.array(clip_abs / 127.0 if clip_abs > 0.0 else 1.0, dtype=np.float32) - q = np.clip(np.round(np.clip(f32, -clip_abs, clip_abs) / scale), -127, 127).astype(np.int8, copy=False) - return np.ascontiguousarray(q), scale - - -def quantize_state_dict_int8(flat_state: dict[str, mx.array]) -> tuple[dict[str, object], dict[str, int]]: - quantized: dict[str, np.ndarray] = {} - scales: dict[str, np.ndarray] = {} - dtypes: dict[str, str] = {} - passthrough: dict[str, np.ndarray] = {} - passthrough_orig_dtypes: dict[str, str] = {} - qmeta: dict[str, dict[str, object]] = {} - stats = dict.fromkeys( - ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), - 0, - ) - for name, arr in flat_state.items(): - stats["param_count"] += int(arr.size) - stats["num_tensors"] += 1 - stats["baseline_tensor_bytes"] += int(arr.nbytes) - if not mx.issubdtype(arr.dtype, mx.floating): - stats["num_nonfloat_tensors"] += 1 - passthrough[name] = np.ascontiguousarray(np.array(arr)) - stats["int8_payload_bytes"] += int(passthrough[name].nbytes) - continue - - # Small float tensors are cheap enough to keep directly. We still downcast - # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. - if int(arr.size) <= INT8_KEEP_FLOAT_MAX_NUMEL: - kept = keep_float_array(name, arr, passthrough_orig_dtypes) - passthrough[name] = kept - stats["int8_payload_bytes"] += int(kept.nbytes) - continue - - stats["num_float_tensors"] += 1 - q, s = quantize_float_array(arr) - if s.ndim > 0: - qmeta[name] = {"scheme": "per_row", "axis": 0} - quantized[name] = q - scales[name] = s - dtypes[name] = str(arr.dtype).split(".")[-1] - stats["int8_payload_bytes"] += int(q.nbytes + s.nbytes) - obj: dict[str, object] = { - "__quant_format__": "int8_clean_per_row_v1", - "quantized": quantized, - "scales": scales, - "dtypes": dtypes, - "passthrough": passthrough, - } - if qmeta: - obj["qmeta"] = qmeta - if passthrough_orig_dtypes: - obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes - return obj, stats - - -def dequantize_state_dict_int8(quant_obj: dict[str, object]) -> dict[str, mx.array]: - out: dict[str, mx.array] = {} - qmeta = quant_obj.get("qmeta", {}) - passthrough_orig_dtypes = quant_obj.get("passthrough_orig_dtypes", {}) - for name, q in quant_obj["quantized"].items(): - q_np = np.asarray(q, dtype=np.int8) - dtype_name = quant_obj["dtypes"][name] - scale = np.asarray(quant_obj["scales"][name], dtype=np.float32) - if qmeta.get(name, {}).get("scheme") == "per_row" or scale.ndim > 0: - # Broadcast the saved row scale back across trailing dimensions. - out_arr = q_np.astype(np.float32) * scale.reshape((q_np.shape[0],) + (1,) * (q_np.ndim - 1)) - else: - out_arr = q_np.astype(np.float32) * float(scale) - out[name] = mx.array(out_arr, dtype=MX_DTYPE_FROM_NAME[dtype_name]) - for name, arr in quant_obj["passthrough"].items(): - # Restore small tensors, undoing the temporary fp16 storage cast if needed. - out_arr = np.array(arr, copy=True) - orig_dtype = passthrough_orig_dtypes.get(name) - if isinstance(orig_dtype, str): - out[name] = mx.array(out_arr, dtype=MX_DTYPE_FROM_NAME[orig_dtype]) - else: - out[name] = mx.array(out_arr) - return out - - -def build_sentencepiece_luts( - sp: spm.SentencePieceProcessor, vocab_size: int -) -> tuple[np.ndarray, np.ndarray, np.ndarray]: - sp_vocab_size = int(sp.vocab_size()) - table_size = max(sp_vocab_size, vocab_size) - base_bytes_lut = np.zeros((table_size,), dtype=np.int16) - has_leading_space_lut = np.zeros((table_size,), dtype=np.bool_) - is_boundary_token_lut = np.ones((table_size,), dtype=np.bool_) - for token_id in range(sp_vocab_size): - if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): - continue - is_boundary_token_lut[token_id] = False - if sp.is_byte(token_id): - base_bytes_lut[token_id] = 1 - continue - piece = sp.id_to_piece(token_id) - if piece.startswith("▁"): - has_leading_space_lut[token_id] = True - piece = piece[1:] - base_bytes_lut[token_id] = len(piece.encode("utf-8")) - return base_bytes_lut, has_leading_space_lut, is_boundary_token_lut - - -def validate_dataset_tokenizer_pair(data_path: str, tokenizer_path: str) -> tuple[str, int, int | None]: - # The shard directory and tokenizer are coupled: val_bpb is only meaningful if we - # decode bytes with the exact tokenizer that produced the shards. The manifest - # lets the training script fail fast on accidental dataset/tokenizer mismatches. - dataset_dir = Path(data_path).resolve() - actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) - if len(dataset_dir.parents) < 2: - return dataset_dir.name, actual_train_files, None - manifest_path = dataset_dir.parents[1] / "manifest.json" - if not manifest_path.is_file(): - return dataset_dir.name, actual_train_files, None - - manifest = json.loads(manifest_path.read_text(encoding="utf-8")) - dataset_entry = next((x for x in manifest.get("datasets", []) if x.get("name") == dataset_dir.name), None) - if dataset_entry is None: - return dataset_dir.name, actual_train_files, None - - tokenizer_name = dataset_entry.get("tokenizer_name") - tokenizer_entry = ( - next((x for x in manifest.get("tokenizers", []) if x.get("name") == tokenizer_name), None) - if tokenizer_name - else None - ) - expected_name = Path((tokenizer_entry or {}).get("model_path") or (tokenizer_entry or {}).get("path") or "").name - if expected_name and Path(tokenizer_path).name != expected_name: - raise ValueError(f"{dataset_dir.name} expects tokenizer {expected_name}, got {Path(tokenizer_path).name}") - expected_train_files = (dataset_entry.get("stats") or {}).get("files_train") - if expected_train_files is not None: - expected_train_files = int(expected_train_files) - if actual_train_files > expected_train_files: - raise ValueError( - f"{dataset_dir.name} has more train shards than expected: found {actual_train_files}, " - f"manifest says {expected_train_files}" - ) - return dataset_dir.name, actual_train_files, expected_train_files - - -def load_validation_tokens(pattern: str, seq_len: int) -> np.ndarray: - files = [Path(p) for p in sorted(glob.glob(pattern))] - if not files: - raise FileNotFoundError(f"No files found for pattern: {pattern}") - # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. - tokens = np.ascontiguousarray(np.concatenate([load_data_shard(file) for file in files], axis=0)) - usable = ((tokens.size - 1) // seq_len) * seq_len - if usable <= 0: - raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") - return tokens[: usable + 1] - - -def loss_and_grad_chunked( - args: Hyperparameters, - train_loader: TokenLoader, - compiled_loss_and_grad, -) -> tuple[mx.array, dict]: - chunk_sizes = token_chunks(args.microbatch_tokens, args.train_seq_len, args.mlx_max_microbatch_tokens) - total_tokens = float(sum(chunk_sizes)) - loss_value = mx.array(0.0, dtype=mx.float32) - grad_accum: dict[str, mx.array] | None = None - for chunk_tokens in chunk_sizes: - x, y = train_loader.next_batch(chunk_tokens, args.train_seq_len) - loss, grads = compiled_loss_and_grad(x, y) - scale = float(y.size) / total_tokens - loss_value = loss_value + loss.astype(mx.float32) * scale - grad_accum = accumulate_flat_grads(grad_accum, grads, scale) - if args.mlx_eager_eval: - mx.eval(loss_value, grad_accum) # materialize each chunk to cap peak memory - return loss_value, tree_unflatten(list(grad_accum.items())) - - -def eval_val( - args: Hyperparameters, - compiled_loss, - val_tokens: np.ndarray, - base_bytes_lut: np.ndarray, - has_leading_space_lut: np.ndarray, - is_boundary_token_lut: np.ndarray, - log_fn: Callable[[str], None] | None = None, -) -> tuple[float, float]: - # Validation computes two metrics: - # - val_loss: token cross-entropy (natural log) - # - val_bpb: tokenizer-agnostic compression metric used by the challenge - val_batch_tokens = args.val_batch_size // args.grad_accum_steps - if val_batch_tokens < args.train_seq_len: - raise ValueError( - "VAL_BATCH_SIZE must provide at least one sequence; " - f"got VAL_BATCH_SIZE={args.val_batch_size}, GRAD_ACCUM_STEPS={args.grad_accum_steps}, " - f"TRAIN_SEQ_LEN={args.train_seq_len}" - ) - val_batch_seqs = val_batch_tokens // args.train_seq_len - total_seqs = (val_tokens.size - 1) // args.train_seq_len - total_batches = max((total_seqs + val_batch_seqs - 1) // val_batch_seqs, 1) - total_loss_sum = 0.0 - total_tokens = 0.0 - total_bytes = 0.0 - for batch_idx, batch_seq_start in enumerate(range(0, total_seqs, val_batch_seqs), start=1): - batch_seq_end = min(batch_seq_start + val_batch_seqs, total_seqs) - raw_start = batch_seq_start * args.train_seq_len - raw_end = batch_seq_end * args.train_seq_len + 1 - chunk = val_tokens[raw_start:raw_end] - x_np = chunk[:-1].reshape(-1, args.train_seq_len) - y_np = chunk[1:].reshape(-1, args.train_seq_len) - x = mx.array(x_np, dtype=mx.int32) - y = mx.array(y_np, dtype=mx.int32) - chunk_token_count = float(y.size) - batch_loss = compiled_loss(x, y).astype(mx.float32) - mx.eval(batch_loss) - total_loss_sum += float(batch_loss.item()) * chunk_token_count - prev_ids = x_np.reshape(-1) - tgt_ids = y_np.reshape(-1) - bytes_np = base_bytes_lut[tgt_ids].astype(np.int16, copy=True) - bytes_np += ( - has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids] - ).astype(np.int16, copy=False) - total_tokens += chunk_token_count - total_bytes += float(bytes_np.astype(np.float64).sum()) - if log_fn is not None and total_batches > 1 and ( - batch_idx == 1 or batch_idx == total_batches or batch_idx % 25 == 0 - ): - log_fn(f"val_progress:{batch_idx}/{total_batches}") - val_loss = total_loss_sum / total_tokens - bits_per_token = val_loss / math.log(2.0) - val_bpb = bits_per_token * (total_tokens / total_bytes) - return val_loss, val_bpb - -# ----------------------------- -# TRAINING -# ----------------------------- - -def clip_grad_tree(grads_tree: dict, max_norm: float) -> dict: - if max_norm <= 0: - return grads_tree - flat = dict(tree_flatten(grads_tree)) - total_sq = 0.0 - for grad in flat.values(): - total_sq += float(np.sum(np.square(_np_float32(grad)), dtype=np.float64)) - if total_sq <= 0.0: - return grads_tree - total_norm = math.sqrt(total_sq) - if total_norm <= max_norm: - return grads_tree - scale = max_norm / (total_norm + 1e-12) - return tree_unflatten([(k, g * scale) for k, g in flat.items()]) - - -def main() -> None: - # ============================================================================== - # TOKENIZER + VALIDATION METRIC SETUP - # ============================================================================== - args = Hyperparameters() - out_dir = Path(args.out_dir) - out_dir.mkdir(parents=True, exist_ok=True) - logfile = out_dir / f"{args.run_id}.txt" - print(logfile) - - def log(msg: str, console: bool = True) -> None: - if console: - print(msg) - with logfile.open("a", encoding="utf-8") as f: - print(msg, file=f) - - code = Path(__file__).read_text(encoding="utf-8") - log(code, console=False) - log("=" * 100, console=False) - log(f"Running Python {sys.version}", console=False) - log(f"Running MLX {mx.__version__}", console=False) - log("=" * 100, console=False) - - if not args.tie_embeddings: - raise NotImplementedError("train_gpt_mlx.py only supports tied embeddings") - if not args.tokenizer_path.endswith(".model"): - raise ValueError(f"TOKENIZER_PATH must point to a SentencePiece .model file: {args.tokenizer_path}") - sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) - if int(sp.vocab_size()) != args.vocab_size: - raise ValueError( - f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" - ) - dataset_name, actual_train_files, expected_train_files = validate_dataset_tokenizer_pair( - args.data_path, - args.tokenizer_path, - ) - val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) - - base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( - sp, args.vocab_size - ) - - # ============================================================================== - # TRAINING SETUP - # ============================================================================== - mx.random.seed(args.seed) - - train_loader = TokenLoader(args.train_files, log_fn=log, dataset_name=dataset_name) - - # ============================================================================== - # MODEL + OPTIMIZER SETUP - # ============================================================================== - model = GPT( - vocab_size=args.vocab_size, - num_layers=args.num_layers, - dim=args.model_dim, - num_heads=args.num_heads, - num_kv_heads=args.num_kv_heads, - mlp_mult=args.mlp_mult, - logit_chunk_tokens=args.logit_chunk_tokens, - logit_softcap=args.logit_softcap, - rope_base=args.rope_base, - tied_embed_init_std=args.tied_embed_init_std, - qk_gain_init=args.qk_gain_init, - ) - opt = SplitOptimizers(model, args) - - # ============================================================================== - # COMPILED TRAIN / EVAL FUNCTIONS (MLX) - # ============================================================================== - # The crucial MLX detail is capture scope: this model contains non-trainable arrays too (for example - # inside RoPE modules), so compiling only against trainable parameters throws "uncaptured inputs". - # Compiling the model-bound functions and capturing the full model state fixes that while still - # returning gradients only for trainable parameters via nn.value_and_grad(...). - compiled_loss = mx.compile(lambda x, y: model.loss(x, y), inputs=model.state, outputs=model.state) - compiled_loss_and_grad = mx.compile( - nn.value_and_grad(model, lambda x, y: model.loss(x, y)), - inputs=model.state, - outputs=model.state, - ) - - # Print config once so logs are self-describing. - n_params = sum(int(np.prod(p.shape)) for _, p in tree_flatten(model.parameters())) - log(f"run_id:{args.run_id}") - log(f"mlx_version:{mx.__version__}") - log(f"train_loader:shards pattern={args.train_files}") - log(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.size - 1}") - if expected_train_files is None: - log(f"train_loader:dataset:{dataset_name} train_shards:{actual_train_files}") - elif actual_train_files < expected_train_files: - log( - f"WARNING: train_loader:subset dataset:{dataset_name} " - f"train_shards:{actual_train_files}/{expected_train_files} " - f"new epochs will arrive sooner than the full dataset" - ) - else: - log(f"train_loader:dataset:{dataset_name} train_shards:{actual_train_files}/{expected_train_files}") - log(f"tokenizer_path:{args.tokenizer_path}") - log( - f"model_params:{n_params} vocab_size:{args.vocab_size} layers:{args.num_layers} " - f"dim:{args.model_dim} heads:{args.num_heads} kv_heads:{args.num_kv_heads} " - f"seq_len:{args.train_seq_len} tie_embeddings:{args.tie_embeddings}" - ) - log( - f"iterations:{args.iterations} train_batch_tokens:{args.train_batch_tokens} grad_accum_steps:{args.grad_accum_steps} " - f"microbatch_tokens:{args.microbatch_tokens} microbatch_batch_size:{args.microbatch_tokens // args.train_seq_len} " - f"val_batch_size:{args.val_batch_size} " - f"warmup_steps:{args.warmup_steps} max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" - ) - log(f"mlx_max_microbatch_tokens:{args.mlx_max_microbatch_tokens}") - log( - f"optimizer:muon+adam muon_matrix_params:{len(opt.matrix_keys)} scalar_params:{len(opt.scalar_keys)} " - f"embed_lr:{args.tied_embed_lr} " - f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr} " - f"muon_momentum:{args.muon_momentum} muon_steps:{args.muon_backend_steps}" - ) - log(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") - log(f"compute_dtype:{COMPUTE_DTYPE} compile:True") - log( - f"dtypes tok_emb:{model.tok_emb.weight.dtype} " - f"linear_weight:{model.blocks[0].attn.c_q.weight.dtype} " - f"skip_weights:{model.skip_weights.dtype}" - ) - - # ============================================================================== - # TRAINING LOOP - # ============================================================================== - if args.warmup_steps > 0: - # Warmup should only prime MLX compile/allocation paths. Updating parameters here forces us - # to snapshot and restore model/optimizer state, which is expensive on unified-memory Macs. - # Instead we run the real train shapes, force the loss/grads to materialize, and then reset - # the loader so measured training still starts from the true init and token window. - for warmup_step in range(args.warmup_steps): - accum: dict[str, mx.array] | None = None - warmup_loss = mx.array(0.0, dtype=mx.float32) - grad_scale = 1.0 / args.grad_accum_steps - for _ in range(args.grad_accum_steps): - warmup_loss, grads = loss_and_grad_chunked(args, train_loader, compiled_loss_and_grad) - accum = accumulate_flat_grads(accum, grads, grad_scale) - mx.eval(warmup_loss, accum) - mx.synchronize() - if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: - log(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") - - # Prime the standalone eval graph once too. It is compiled separately from value_and_grad. - val_batch_tokens = args.val_batch_size // args.grad_accum_steps - if val_batch_tokens < args.train_seq_len: - raise ValueError( - "VAL_BATCH_SIZE must provide at least one sequence; " - f"got VAL_BATCH_SIZE={args.val_batch_size}, GRAD_ACCUM_STEPS={args.grad_accum_steps}, " - f"TRAIN_SEQ_LEN={args.train_seq_len}" - ) - warm_val_seqs = min(val_batch_tokens // args.train_seq_len, (val_tokens.size - 1) // args.train_seq_len) - warm_chunk = val_tokens[: warm_val_seqs * args.train_seq_len + 1] - x_val = mx.array(warm_chunk[:-1].reshape(-1, args.train_seq_len), dtype=mx.int32) - y_val = mx.array(warm_chunk[1:].reshape(-1, args.train_seq_len), dtype=mx.int32) - warm_val_loss = compiled_loss(x_val, y_val) - mx.eval(warm_val_loss) - mx.synchronize() - - train_loader = TokenLoader(args.train_files, log_fn=log, dataset_name=dataset_name) - - train_time_ms = 0.0 - max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None - stop_after_step: int | None = None - t0 = time.perf_counter() - step = 0 - while True: - last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) - if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): - train_time_ms += 1000.0 * (time.perf_counter() - t0) - # Validation always scans the same fixed full validation split. - val_loss, val_bpb = eval_val( - args, - compiled_loss, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, - log_fn=log, - ) - if step % 25 == 0 or last_step: - log( - f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " - f"train_time:{train_time_ms:.0f}ms step_avg:{train_time_ms / max(step, 1):.2f}ms" - ) - t0 = time.perf_counter() - if last_step: - if stop_after_step is not None and step < args.iterations: - log(f"stopping_early: wallclock_cap train_time:{train_time_ms:.0f}ms step:{step}/{args.iterations}") - break - - lr_mul = args.lr_mul(step, train_time_ms + 1000.0 * (time.perf_counter() - t0)) - step_t0 = time.perf_counter() - - accum: dict[str, mx.array] | None = None - train_loss = mx.array(0.0, dtype=mx.float32) - grad_scale = 1.0 / args.grad_accum_steps - for _ in range(args.grad_accum_steps): - loss, grads = loss_and_grad_chunked(args, train_loader, compiled_loss_and_grad) - accum = accumulate_flat_grads(accum, grads, grad_scale) - train_loss = train_loss + loss.astype(mx.float32) * grad_scale - if args.mlx_eager_eval: - mx.eval(train_loss, accum) # materialize each microbatch to cap peak memory - - grads = tree_unflatten(list(accum.items())) - grads = clip_grad_tree(grads, args.grad_clip_norm) - train_loss_value = float(train_loss.item()) - opt.step(model, grads, step=step, lr_mul=lr_mul) - mx.synchronize() - - step_ms = 1000.0 * (time.perf_counter() - step_t0) - approx_train_time_ms = train_time_ms + 1000.0 * (time.perf_counter() - t0) - tok_s = args.train_batch_tokens / (step_ms / 1000.0) - step += 1 - if args.train_log_every > 0 and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None): - log( - f"step:{step}/{args.iterations} train_loss:{train_loss_value:.4f} " - f"train_time:{approx_train_time_ms:.0f}ms step_avg:{approx_train_time_ms / step:.2f}ms tok_s:{tok_s:.0f}" - ) - if max_wallclock_ms is not None and stop_after_step is None and approx_train_time_ms >= max_wallclock_ms: - stop_after_step = step - - # ============================================================================== - # FINAL SERIALIZATION + QUANTIZED ROUNDTRIP EVAL - # ============================================================================== - # We always write a raw artifact and a quantized artifact, then validate the - # quantized roundtrip directly by loading the dequantized tensors back into the - # model and running one final validation pass. - out_path = out_dir / f"{args.run_id}_mlx_model.npz" - flat_state = {k: v for k, v in tree_flatten(model.state)} - mx.savez(str(out_path), **flat_state) - log(f"saved_model:{out_path} bytes:{out_path.stat().st_size}") - - quant_obj, quant_stats = quantize_state_dict_int8(flat_state) - quant_raw = pickle.dumps(quant_obj, protocol=pickle.HIGHEST_PROTOCOL) - quant_blob = zlib.compress(quant_raw, level=9) - quant_serialized_bytes = len(quant_raw) - quant_path = out_dir / f"{args.run_id}_mlx_model.int8.ptz" - with quant_path.open("wb") as f: - f.write(quant_blob) - quant_file_bytes = quant_path.stat().st_size - ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) - log( - f"serialized_model_int8_zlib:{quant_file_bytes} bytes " - f"(payload:{quant_stats['int8_payload_bytes']} raw_pickle:{quant_serialized_bytes} payload_ratio:{ratio:.2f}x)" - ) - - with quant_path.open("rb") as f: - quant_blob_disk = f.read() - quant_flat = dequantize_state_dict_int8(pickle.loads(zlib.decompress(quant_blob_disk))) - model.update(tree_unflatten(list(quant_flat.items()))) - q_t0 = time.perf_counter() - q_val_loss, q_val_bpb = eval_val( - args, - compiled_loss, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, - log_fn=log, - ) - q_eval_ms = 1000.0 * (time.perf_counter() - q_t0) - log(f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} eval_time:{q_eval_ms:.0f}ms") - log(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") - - -if __name__ == "__main__": - main() diff --git a/train_round1_baseline.log b/train_round1_baseline.log new file mode 100644 index 0000000000..0955b438bf --- /dev/null +++ b/train_round1_baseline.log @@ -0,0 +1,14 @@ +/workspace/parameter-golf/train_gpt_novel_test.py:1137: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + qs = torch.load(io.BytesIO(_decompress(qbd, h.compressor)), map_location="cpu") +novel_features: none (baseline PR#1394) +test_mode: medium +model_params:35943512 +0/2000 val_loss: 9.0055 val_bpb: 3.4863 +500/2000 val_loss: 4.4588 val_bpb: 1.7261 +1000/2000 val_loss: 4.2513 val_bpb: 1.6458 +1500/2000 val_loss: 4.1785 val_bpb: 1.6176 +2000/2000 val_loss: 4.2997 val_bpb: 1.6645 +pre-quantization post-ema val_loss:4.26408302 val_bpb:1.65075942 eval_time:93994ms +Serialized quantized+brotli: 15990993 bytes +Total submission: 16053780 bytes +quantized val_loss:4.27937457 val_bpb:1.65667926 eval_time:132661ms diff --git a/train_round1_cat.log b/train_round1_cat.log new file mode 100644 index 0000000000..269cfd8dc2 --- /dev/null +++ b/train_round1_cat.log @@ -0,0 +1,14 @@ +/workspace/parameter-golf/train_gpt_novel_test.py:1137: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + qs = torch.load(io.BytesIO(_decompress(qbd, h.compressor)), map_location="cpu") +novel_features: CAT(weight=0.001,bits=6,every=50) +test_mode: medium +model_params:35943512 +0/2000 val_loss: 9.0055 val_bpb: 3.4863 +500/2000 val_loss: 4.4583 val_bpb: 1.7260 +1000/2000 val_loss: 4.2454 val_bpb: 1.6435 +1500/2000 val_loss: 4.1537 val_bpb: 1.6080 +2000/2000 val_loss: 3.9925 val_bpb: 1.5456 +pre-quantization post-ema val_loss:3.79210831 val_bpb:1.46804330 eval_time:111492ms +Serialized quantized+brotli: 16006928 bytes +Total submission: 16069715 bytes +quantized val_loss:3.81251559 val_bpb:1.47594359 eval_time:138398ms diff --git a/train_round1_cat_sparsity_13L.log b/train_round1_cat_sparsity_13L.log new file mode 100644 index 0000000000..8a958f6759 --- /dev/null +++ b/train_round1_cat_sparsity_13L.log @@ -0,0 +1,14 @@ +/workspace/parameter-golf/train_gpt_novel_test.py:1137: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + qs = torch.load(io.BytesIO(_decompress(qbd, h.compressor)), map_location="cpu") +novel_features: CAT(weight=0.001,bits=6,every=50), 2:4_Sparsity(apply_to=['mlp']) +test_mode: medium +model_params:41715816 +0/2000 val_loss: 9.0080 val_bpb: 3.4873 +500/2000 val_loss: 4.4495 val_bpb: 1.7225 +1000/2000 val_loss: 4.2407 val_bpb: 1.6417 +1500/2000 val_loss: 4.1238 val_bpb: 1.5964 +2000/2000 val_loss: 3.9471 val_bpb: 1.5280 +pre-quantization post-ema val_loss:3.75875398 val_bpb:1.45513080 eval_time:108052ms +Serialized quantized+brotli: 16558100 bytes +Total submission: 16620887 bytes +quantized val_loss:3.86542404 val_bpb:1.49642611 eval_time:144798ms diff --git a/train_round1_kan.log b/train_round1_kan.log new file mode 100644 index 0000000000..8b2e5c565d --- /dev/null +++ b/train_round1_kan.log @@ -0,0 +1,14 @@ +/workspace/parameter-golf/train_gpt_novel_test.py:1137: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + qs = torch.load(io.BytesIO(_decompress(qbd, h.compressor)), map_location="cpu") +novel_features: KAN(grid=5,order=3) +test_mode: medium +model_params:128218200 +0/2000 val_loss: 9.0084 val_bpb: 3.4874 +500/2000 val_loss: 4.6328 val_bpb: 1.7935 +1000/2000 val_loss: 4.3789 val_bpb: 1.6952 +1500/2000 val_loss: 4.2677 val_bpb: 1.6522 +2000/2000 val_loss: 4.1121 val_bpb: 1.5919 +pre-quantization post-ema val_loss:3.93008766 val_bpb:1.52145940 eval_time:153343ms +Serialized quantized+brotli: 54949114 bytes +Total submission: 55011901 bytes +quantized val_loss:3.95771947 val_bpb:1.53215654 eval_time:200448ms diff --git a/train_round1_moe_4e.log b/train_round1_moe_4e.log new file mode 100644 index 0000000000..43be3174a5 --- /dev/null +++ b/train_round1_moe_4e.log @@ -0,0 +1,14 @@ +/workspace/parameter-golf/train_gpt_novel_test.py:1137: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + qs = torch.load(io.BytesIO(_decompress(qbd, h.compressor)), map_location="cpu") +novel_features: MoE(experts=4,top_k=2) +test_mode: medium +model_params:105172056 +0/2000 val_loss: 9.0255 val_bpb: 3.4940 +500/2000 val_loss: 4.4602 val_bpb: 1.7267 +1000/2000 val_loss: 4.2245 val_bpb: 1.6354 +1500/2000 val_loss: 4.0591 val_bpb: 1.5714 +2000/2000 val_loss: 3.8109 val_bpb: 1.4753 +pre-quantization post-ema val_loss:3.69149008 val_bpb:1.42909085 eval_time:216755ms +Serialized quantized+brotli: 45346783 bytes +Total submission: 45409570 bytes +quantized val_loss:3.71114800 val_bpb:1.43670104 eval_time:250264ms diff --git a/train_round1_sparsity_13L.log b/train_round1_sparsity_13L.log new file mode 100644 index 0000000000..c6393911ee --- /dev/null +++ b/train_round1_sparsity_13L.log @@ -0,0 +1,14 @@ +/workspace/parameter-golf/train_gpt_novel_test.py:1137: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + qs = torch.load(io.BytesIO(_decompress(qbd, h.compressor)), map_location="cpu") +novel_features: 2:4_Sparsity(apply_to=['mlp']) +test_mode: medium +model_params:41715816 +0/2000 val_loss: 9.0080 val_bpb: 3.4873 +500/2000 val_loss: 4.4478 val_bpb: 1.7219 +1000/2000 val_loss: 4.2452 val_bpb: 1.6434 +1500/2000 val_loss: 4.1385 val_bpb: 1.6021 +2000/2000 val_loss: 3.9734 val_bpb: 1.5382 +pre-quantization post-ema val_loss:3.77538406 val_bpb:1.46156882 eval_time:106539ms +Serialized quantized+brotli: 16561305 bytes +Total submission: 16624092 bytes +quantized val_loss:3.88504955 val_bpb:1.50402375 eval_time:158591ms diff --git a/train_round2_11L_cat.log b/train_round2_11L_cat.log new file mode 100644 index 0000000000..234210df55 --- /dev/null +++ b/train_round2_11L_cat.log @@ -0,0 +1,14 @@ +/workspace/parameter-golf/train_gpt_final.py:1097: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + qs = torch.load(io.BytesIO(_decompress(qbd, h.compressor)), map_location="cpu") +novel_features: CAT(weight=0.001,bits=6,every=50), ParallelResid(from_layer=7), TTT(lr=0.005,epochs=3,chunk=32768) +test_mode: medium +model_params:35944536 +0/2000 val_loss: 9.0055 val_bpb: 3.2820 +500/2000 val_loss: 4.4646 val_bpb: 1.6271 +1000/2000 val_loss: 4.2908 val_bpb: 1.5638 +1500/2000 val_loss: 4.2245 val_bpb: 1.5396 +2000/2000 val_loss: 4.0978 val_bpb: 1.4934 +pre-quantization post-ema val_loss:3.84097453 val_bpb:1.39983050 eval_time:133471ms +Serialized quantized+brotli: 16011068 bytes +Total submission: 16079171 bytes +quantized val_loss:3.86911694 val_bpb:1.41008691 eval_time:152635ms diff --git a/train_round2_12L_cat_sparse.log b/train_round2_12L_cat_sparse.log new file mode 100644 index 0000000000..d79afe605c --- /dev/null +++ b/train_round2_12L_cat_sparse.log @@ -0,0 +1,14 @@ +/workspace/parameter-golf/train_gpt_final.py:1097: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + qs = torch.load(io.BytesIO(_decompress(qbd, h.compressor)), map_location="cpu") +novel_features: CAT(weight=0.001,bits=6,every=50), 2:4_Sparsity(apply_to=['mlp']), ParallelResid(from_layer=7), TTT(lr=0.005,epochs=3,chunk=32768) +test_mode: medium +model_params:38831200 +0/2000 val_loss: 9.0064 val_bpb: 3.2824 +500/2000 val_loss: 4.4749 val_bpb: 1.6308 +1000/2000 val_loss: 4.2790 val_bpb: 1.5595 +1500/2000 val_loss: 4.2126 val_bpb: 1.5353 +2000/2000 val_loss: 4.1071 val_bpb: 1.4968 +pre-quantization post-ema val_loss:3.85001668 val_bpb:1.40312588 eval_time:145560ms +Serialized quantized+brotli: 15537409 bytes +Total submission: 15605512 bytes +quantized val_loss:3.95184360 val_bpb:1.44023636 eval_time:171029ms diff --git a/train_round2_12L_wideloop.log b/train_round2_12L_wideloop.log new file mode 100644 index 0000000000..ef1d48bcfb --- /dev/null +++ b/train_round2_12L_wideloop.log @@ -0,0 +1,14 @@ +/workspace/parameter-golf/train_gpt_final.py:1097: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + qs = torch.load(io.BytesIO(_decompress(qbd, h.compressor)), map_location="cpu") +novel_features: CAT(weight=0.001,bits=6,every=50), 2:4_Sparsity(apply_to=['mlp']), ParallelResid(from_layer=7), TTT(lr=0.005,epochs=3,chunk=32768) +test_mode: medium +model_params:38832224 +0/2000 val_loss: 9.0064 val_bpb: 3.2824 +500/2000 val_loss: 4.4710 val_bpb: 1.6294 +1000/2000 val_loss: 4.2763 val_bpb: 1.5585 +1500/2000 val_loss: 4.2072 val_bpb: 1.5333 +2000/2000 val_loss: 4.0756 val_bpb: 1.4853 +pre-quantization post-ema val_loss:3.82427339 val_bpb:1.39374382 eval_time:158860ms +Serialized quantized+brotli: 15528298 bytes +Total submission: 15596401 bytes +quantized val_loss:3.92083606 val_bpb:1.42893577 eval_time:181695ms diff --git a/train_round2_13L_cat_sparse.log b/train_round2_13L_cat_sparse.log new file mode 100644 index 0000000000..6a286f9d80 --- /dev/null +++ b/train_round2_13L_cat_sparse.log @@ -0,0 +1,14 @@ +/workspace/parameter-golf/train_gpt_final.py:1097: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + qs = torch.load(io.BytesIO(_decompress(qbd, h.compressor)), map_location="cpu") +novel_features: CAT(weight=0.001,bits=6,every=50), 2:4_Sparsity(apply_to=['mlp']), ParallelResid(from_layer=7), TTT(lr=0.005,epochs=3,chunk=32768) +test_mode: medium +model_params:41716840 +0/2000 val_loss: 9.0080 val_bpb: 3.2829 +500/2000 val_loss: 4.4611 val_bpb: 1.6258 +1000/2000 val_loss: 4.2525 val_bpb: 1.5498 +1500/2000 val_loss: 4.1883 val_bpb: 1.5264 +2000/2000 val_loss: 4.0599 val_bpb: 1.4796 +pre-quantization post-ema val_loss:3.80644901 val_bpb:1.38724778 eval_time:150094ms +Serialized quantized+brotli: 16619390 bytes +Total submission: 16687493 bytes +quantized val_loss:3.91747750 val_bpb:1.42771175 eval_time:172239ms diff --git a/train_round2_baseline_11L.log b/train_round2_baseline_11L.log new file mode 100644 index 0000000000..7909964562 --- /dev/null +++ b/train_round2_baseline_11L.log @@ -0,0 +1,14 @@ +/workspace/parameter-golf/train_gpt_final.py:1097: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + qs = torch.load(io.BytesIO(_decompress(qbd, h.compressor)), map_location="cpu") +novel_features: ParallelResid(from_layer=7), TTT(lr=0.005,epochs=3,chunk=32768) +test_mode: medium +model_params:35944536 +0/2000 val_loss: 9.0055 val_bpb: 3.2820 +500/2000 val_loss: 4.4704 val_bpb: 1.6292 +1000/2000 val_loss: 4.2896 val_bpb: 1.5633 +1500/2000 val_loss: 4.2321 val_bpb: 1.5424 +2000/2000 val_loss: 4.1005 val_bpb: 1.4944 +pre-quantization post-ema val_loss:3.84083219 val_bpb:1.39977862 eval_time:135065ms +Serialized quantized+brotli: 16009744 bytes +Total submission: 16077847 bytes +quantized val_loss:3.86772679 val_bpb:1.40958027 eval_time:153099ms diff --git a/train_round3_11L_cat_hsparse_ttt.log b/train_round3_11L_cat_hsparse_ttt.log new file mode 100644 index 0000000000..6ea7bfccf4 --- /dev/null +++ b/train_round3_11L_cat_hsparse_ttt.log @@ -0,0 +1,16 @@ +/workspace/parameter-golf/train_gpt_final.py:1112: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + qs = torch.load(io.BytesIO(_decompress(qbd, h.compressor)), map_location="cpu") +novel_features: CAT(weight=0.001,bits=6,every=50), 2:4_Sparsity(apply_to=['mlp']), ParallelResid(from_layer=7), TTT(lr=0.005,epochs=3,chunk=32768) +test_mode: medium +model_params:35944536 +0/2000 val_loss: 9.0055 val_bpb: 3.2820 +500/2000 val_loss: 4.4670 val_bpb: 1.6280 +1000/2000 val_loss: 4.2852 val_bpb: 1.5617 +1500/2000 val_loss: 4.2219 val_bpb: 1.5387 +2000/2000 val_loss: 4.1321 val_bpb: 1.5059 +pre-quantization post-ema val_loss:3.90745741 val_bpb:1.42405996 eval_time:116334ms +Serialized quantized+brotli: 14631625 bytes +Total submission: 14700464 bytes +quantized val_loss:3.99615162 val_bpb:1.45638428 eval_time:139090ms +quantized_sliding_window val_loss:3.95546414 val_bpb:1.44155586 eval_time:814558ms +quantized_ttt val_loss:3.85388371 val_bpb:1.40453521 eval_time:1666809ms diff --git a/train_round3_12L_cat_hsparse_ttt.log b/train_round3_12L_cat_hsparse_ttt.log new file mode 100644 index 0000000000..67c2bed4a3 --- /dev/null +++ b/train_round3_12L_cat_hsparse_ttt.log @@ -0,0 +1,16 @@ +/workspace/parameter-golf/train_gpt_final.py:1112: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + qs = torch.load(io.BytesIO(_decompress(qbd, h.compressor)), map_location="cpu") +novel_features: CAT(weight=0.001,bits=6,every=50), 2:4_Sparsity(apply_to=['mlp']), ParallelResid(from_layer=7), TTT(lr=0.005,epochs=3,chunk=32768) +test_mode: medium +model_params:38831200 +0/2000 val_loss: 9.0064 val_bpb: 3.2824 +500/2000 val_loss: 4.4619 val_bpb: 1.6261 +1000/2000 val_loss: 4.2688 val_bpb: 1.5558 +1500/2000 val_loss: 4.2009 val_bpb: 1.5310 +2000/2000 val_loss: 4.1248 val_bpb: 1.5033 +pre-quantization post-ema val_loss:3.93670028 val_bpb:1.43471743 eval_time:118724ms +Serialized quantized+brotli: 15689081 bytes +Total submission: 15757920 bytes +quantized val_loss:4.01513726 val_bpb:1.46330353 eval_time:142837ms +quantized_sliding_window val_loss:3.98232965 val_bpb:1.45134691 eval_time:846131ms +quantized_ttt val_loss:3.87431506 val_bpb:1.41198134 eval_time:1744594ms diff --git a/train_round3_baseline_ttt.log b/train_round3_baseline_ttt.log new file mode 100644 index 0000000000..9bc97ecef0 --- /dev/null +++ b/train_round3_baseline_ttt.log @@ -0,0 +1,16 @@ +/workspace/parameter-golf/train_gpt_final.py:1112: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + qs = torch.load(io.BytesIO(_decompress(qbd, h.compressor)), map_location="cpu") +novel_features: ParallelResid(from_layer=7), TTT(lr=0.005,epochs=3,chunk=32768) +test_mode: medium +model_params:35944536 +0/2000 val_loss: 9.0055 val_bpb: 3.2820 +500/2000 val_loss: 4.4659 val_bpb: 1.6276 +1000/2000 val_loss: 4.2846 val_bpb: 1.5615 +1500/2000 val_loss: 4.1987 val_bpb: 1.5302 +2000/2000 val_loss: 4.0455 val_bpb: 1.4744 +pre-quantization post-ema val_loss:3.80884927 val_bpb:1.38812255 eval_time:111323ms +Serialized quantized+brotli: 16007649 bytes +Total submission: 16076488 bytes +quantized val_loss:3.83338595 val_bpb:1.39706486 eval_time:133116ms +quantized_sliding_window val_loss:3.79110274 val_bpb:1.38165489 eval_time:814429ms +quantized_ttt val_loss:3.75805483 val_bpb:1.36961069 eval_time:1667511ms