diff --git a/records/track_non_record_16mb/2026-03-19_BitNet158/README.md b/records/track_non_record_16mb/2026-03-19_BitNet158/README.md new file mode 100644 index 0000000000..117a79e96e --- /dev/null +++ b/records/track_non_record_16mb/2026-03-19_BitNet158/README.md @@ -0,0 +1,126 @@ +# BitNet b1.58: Ternary Weights Beat Full-Precision via Chinchilla Scaling + +**val_bpb: 1.2029** (post-quantization ternary roundtrip) | **15.11 MB** | 8×H100 SXM, 10 minutes + +## Abstract + +This submission presents a BitNet b1.58 model that achieves **1.2029 val_bpb** using ternary {-1, 0, 1} weights — beating the naive baseline (1.2244) and the 4-hour unlimited-compute baseline (1.2074) in just 10 minutes of training. The key insight is that under a fixed artifact size constraint, Chinchilla scaling laws favor *more parameters at lower precision* over fewer parameters at higher precision. The 64.5M-parameter ternary model fits in 15.1MB while the baseline's 17.1M model requires 15.9MB at INT8. No eval tricks, no sliding window — just scaling laws and ternary weights. + +## 1. Motivation: Chinchilla Under a Size Constraint + +The Chinchilla scaling law states that optimal performance requires balancing model size (N) with training tokens (T), roughly T ≈ 20N. The baseline's 17.1M-parameter model sees ~7.2B tokens in 10 minutes — a T/N ratio of ~424×, massively over-trained. The model improvement rate decelerates significantly because the model approaches its capacity limit long before the wallclock runs out. + +This creates a clear opportunity: **if you can fit more parameters into 16MB, you get a model that uses the full training budget more efficiently.** But models trained in fp16 and stored at INT8 are limited to ~17-20M parameters at 16MB. INT6 pushes to ~22M. INT4 fits ~30M but the quantization gap destroys the gains. + +## 2. Hypothesis: BitNet as Extreme Compression + +The [BitNet b1.58 paper](https://arxiv.org/abs/2402.17764) demonstrates that ternary-weight transformers can match full-precision models at sufficient scale. I hypothesized that: + +1. **More parameters > higher precision** under a fixed size budget. At 1.6 bits/param (base-3 packed), I fit **64.5M parameters** in 15.1MB — 3.8× more than the baseline. + +2. **BitNet models saturate later** because they have more parameters to saturate. The baseline exhausts its 17M parameters within ~4B tokens. The 64.5M model continues learning through the full ~2.5B token budget. + +3. **Zero quantization gap.** Unlike fp16 models that suffer 0.005-0.05 BPB degradation from post-training quantization, this model trains with ternary quantization active in every forward pass via Straight-Through Estimation (STE). The weights are already {-1, 0, 1} during training — what you train is what you ship. + +## 3. Method + +### 3.1 Architecture + +| Component | Configuration | +|-----------|--------------| +| Layers | 12 | +| Model dim | 768 | +| Attention heads | 12 (6 KV heads, GQA) | +| MLP expansion | 3× (hidden dim 2304) | +| Sequence length | 2048 | +| Vocabulary | 1024 (SentencePiece BPE) | +| Embeddings | fp16, tied input/output | +| Total parameters | 64.5M ternary + ~0.8M fp16 (embedding + scalars). Artifact also stores ~1M fp16 group scales. | + +All linear layers in attention (Q, K, V, O) and MLP (up, down) use **BitLinear**: ternary weight quantization with per-group (g=64) mean-absolute scaling, RMSNorm on inputs, and STE gradients. The architecture otherwise matches the baseline: U-Net skip connections, RoPE (base=200,000), logit softcap (30.0). + +### 3.2 Training with fp16 Scale Simulation + +A critical detail: during training, the per-group scales are computed as `scale = w.abs().mean().half().float()` — the `.half().float()` simulates fp16 precision. This ensures the model adapts to the exact scale values that will be stored in the artifact, eliminating the quantization roundtrip gap. + +### 3.3 Ternary Packing + +Weights are packed using **base-3 encoding**: 5 trits per byte (3⁵ = 243 < 256), achieving 1.6 bits per weight — lossless and near the theoretical minimum of log₂(3) ≈ 1.585 bits. Per-group scales are stored in fp16. The artifact is compressed with LZMA. + +### 3.4 Roundtrip Evaluation + +During roundtrip evaluation, the packed ternary values are unpacked and multiplied by their fp16 scales to reconstruct the weight matrices. The model runs inference with these reconstructed weights directly — no re-quantization occurs. This ensures the eval weights are identical to what the model saw during training. + +### 3.5 Training Configuration + +| Parameter | Value | +|-----------|-------| +| Hardware | 8×H100 SXM | +| Wallclock | 600 seconds | +| Optimizer | Muon (matrix) + Adam (scalars/embedding) | +| Matrix/Scalar LR | 0.04 | +| Tied embedding LR | 0.03 | +| Muon momentum | 0.99 (warmup from 0.92 over 1500 steps) | +| LR schedule | Linear warmup (50 steps) + wallclock-aware linear warmdown (last 1200 steps) | +| Batch size | 524,288 tokens/step | +| Total steps | 4,713 | +| Total tokens | ~2.5B | + +No hyperparameter sweep was performed — LR was set to 0.04 from the intuition that STE noise benefits from higher learning rates. Muon momentum (0.99), RoPE base (200,000), and sequence length (2048) were adjusted for the larger model. The warmdown schedule is identical to the baseline. + +## 4. Results + +### 4.1 Key Metrics + +| Metric | Value | +|--------|-------| +| val_bpb (pre-quant) | 1.2013 | +| **val_bpb (post-roundtrip)** | **1.2029** | +| Quantization gap | 0.002 | +| Artifact size | 15,111,456 bytes (15.11 MB) | +| Training steps | 4,713 | +| Wallclock | 600s | +| Step avg | 127.3ms | + +### 4.2 Comparison + +| Model | Params | Bits/param | Artifact | val_bpb | Quant gap | Training | +|-------|--------|-----------|----------|---------|-----------|----------| +| Current SOTA (INT6+SW) | ~20M | 6 | ~15.4MB | 1.1748 | ~0.01 | 10 min | +| Naive Baseline (INT8) | 17.1M | 8 | 15.9MB | 1.2244 | 0.007 | 10 min | +| 4-Hour Baseline (INT8) | 17.1M | 8 | 15.9MB | 1.2074 | 0.033 | 4 hours | +| **BitNet b1.58 (ours)** | **64.5M** | **1.6** | **15.1MB** | **1.2029** | **0.002** | **10 min** | + +This 10-minute ternary model outperforms the 4-hour full-precision baseline. The training efficiency advantage comes from Chinchilla-optimal scaling: 64.5M parameters trained on 2.5B tokens (T/N ≈ 39) vs 17.1M parameters trained on ~173B tokens (T/N ≈ 10,100) in the 4-hour run. + +### 4.3 Scaling Dynamics + +![Scaling Laws](scaling_laws.png) + +The plot shows validation BPB vs training progress for the fp16 baseline (blue) and the BitNet model (orange). The baseline converges quickly but decelerates significantly around 60% of training. The BitNet model starts slower — ternary weights have less capacity per parameter — but continues improving throughout, crossing the baseline at ~80% progress. The linear warmdown in the final ~25% provides a large improvement (~0.055 BPB), as the high learning rate (0.04) leaves significant room for the model to settle into a sharper minimum. + +## 5. Key Findings + +1. **Chinchilla scaling holds for ternary models.** Under a fixed artifact size, fitting 3.8× more parameters at 1.6 bits/param outperforms fewer parameters at 8 bits/param, even though each ternary parameter carries less information. + +2. **Near-zero quantization gap is achievable.** By simulating fp16 scale precision during training (`.half().float()`), the model adapts to the exact values stored in the artifact. The roundtrip gap is 0.002 BPB — effectively zero. + +3. **BitNet models plateau later.** The 64.5M model continues improving through the full 10-minute budget, while the 17.1M baseline decelerates early. This validates the Chinchilla argument: the baseline is over-trained, not under-sized. + +4. **Minimal hyperparameter search.** LR, momentum, RoPE base, and sequence length were set from first principles and prior work — no sweep was performed. This suggests the approach is robust and likely improvable with tuning. + +## 6. Limitations & Future Work + +- **No eval tricks.** Adding sliding window evaluation or longer eval sequence lengths would likely improve the score by ~0.03 BPB, as demonstrated by other submissions. +- **No LR tuning.** A sweep over learning rates could improve convergence. +- **Larger models.** With better packing or mixed-precision (ternary MLP + INT4 attention), even more parameters could fit in 16MB. +- **Single run.** Only one seed was evaluated. Multiple seeds would provide variance estimates. + +## Included Files + +- `train_gpt.py` — standalone training script +- `run_8xh100.sh` — exact command used for the submission run +- `train.log` — full training log from the submission run +- `submission.json` — leaderboard metadata +- `scaling_laws.png` — validation BPB comparison vs baseline +- `README.md` — this file diff --git a/records/track_non_record_16mb/2026-03-19_BitNet158/run_8xh100.sh b/records/track_non_record_16mb/2026-03-19_BitNet158/run_8xh100.sh new file mode 100755 index 0000000000..1b5f3cd906 --- /dev/null +++ b/records/track_non_record_16mb/2026-03-19_BitNet158/run_8xh100.sh @@ -0,0 +1,29 @@ +#!/bin/bash +set -e +# 12x768 BitNet b1.58 MLP3x on 8xH100 — official submission run +# Run from repo root: bash records/track_non_record_16mb/2026-03-19_BitNet158/run_8xh100.sh +SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" +RUN_ID=bitnet_12x768_mlp3x_8xh100 \ +DATA_PATH=./data/datasets/fineweb10B_sp1024/ \ +TOKENIZER_PATH=./data/tokenizers/fineweb_1024_bpe.model \ +ITERATIONS=20000 \ +NUM_LAYERS=12 \ +MODEL_DIM=768 \ +NUM_HEADS=12 \ +NUM_KV_HEADS=6 \ +MLP_MULT=3 \ +TRAIN_BATCH_TOKENS=524288 \ +TRAIN_SEQ_LEN=2048 \ +VAL_LOSS_EVERY=500 \ +VAL_BATCH_SIZE=524288 \ +MAX_WALLCLOCK_SECONDS=600 \ +TRAIN_LOG_EVERY=50 \ +LR_WARMUP_STEPS=50 \ +MATRIX_LR=0.04 \ +SCALAR_LR=0.04 \ +TIED_EMBED_LR=0.03 \ +MUON_MOMENTUM=0.99 \ +MUON_MOMENTUM_WARMUP_START=0.92 \ +MUON_MOMENTUM_WARMUP_STEPS=1500 \ +ROPE_BASE=200000 \ +torchrun --standalone --nproc_per_node=8 "$SCRIPT_DIR/train_gpt.py" diff --git a/records/track_non_record_16mb/2026-03-19_BitNet158/scaling_laws.png b/records/track_non_record_16mb/2026-03-19_BitNet158/scaling_laws.png new file mode 100644 index 0000000000..22f4ba9ac1 Binary files /dev/null and b/records/track_non_record_16mb/2026-03-19_BitNet158/scaling_laws.png differ diff --git a/records/track_non_record_16mb/2026-03-19_BitNet158/submission.json b/records/track_non_record_16mb/2026-03-19_BitNet158/submission.json new file mode 100644 index 0000000000..23cd59e550 --- /dev/null +++ b/records/track_non_record_16mb/2026-03-19_BitNet158/submission.json @@ -0,0 +1,18 @@ +{ + "author": "Etai Zilberman", + "github_id": "ksang123", + "name": "BitNet b1.58: 12x768 MLP3x Ternary + Chinchilla Scaling", + "blurb": "64.5M ternary-weight transformer (12 layers, 768 dim, MLP 3x) packed at 1.6 bits/param via base-3 encoding. Trains with STE ternary quantization and fp16 scale simulation for near-zero roundtrip gap (0.002 BPB). Beats 4-hour baseline in 10 minutes via Chinchilla-optimal scaling: 3.8x more params in same artifact size. No hyperparameter sweep, no eval tricks.", + "track": "non_record_16mb", + "date": "2026-03-19", + "val_loss": 2.03099029, + "val_bpb": 1.20286685, + "pre_quant_val_loss": 2.0284, + "pre_quant_val_bpb": 1.2013, + "step_stop": 4713, + "wallclock_seconds": 600.041, + "eval_time_seconds": 64.36, + "bytes_total": 15111456, + "bytes_model": 15057568, + "bytes_code": 53888 +} diff --git a/records/track_non_record_16mb/2026-03-19_BitNet158/train.log b/records/track_non_record_16mb/2026-03-19_BitNet158/train.log new file mode 100644 index 0000000000..a18d2b55f6 --- /dev/null +++ b/records/track_non_record_16mb/2026-03-19_BitNet158/train.log @@ -0,0 +1,161 @@ +W0319 22:04:08.662000 141314 torch/distributed/run.py:803] +W0319 22:04:08.662000 141314 torch/distributed/run.py:803] ***************************************** +W0319 22:04:08.662000 141314 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0319 22:04:08.662000 141314 torch/distributed/run.py:803] ***************************************** +logs/bitnet_12x768_mlp3x_8xh100.txt +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:64529040 +world_size:8 grad_accum_steps:1 +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:12 num_kv_heads:6 +tie_embeddings:True embed_lr:0.03 head_lr:0.0 matrix_lr:0.04 scalar_lr:0.04 +train_batch_tokens:524288 train_seq_len:2048 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 +seed:1337 +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:0/20000 val_loss:6.9656 val_bpb:4.1254 train_time:0ms step_avg:0.04ms +step:1/20000 train_loss:6.9655 train_time:93ms step_avg:92.90ms +step:2/20000 train_loss:6.9701 train_time:210ms step_avg:105.24ms +step:3/20000 train_loss:6.8558 train_time:343ms step_avg:114.34ms +step:4/20000 train_loss:6.6123 train_time:479ms step_avg:119.86ms +step:5/20000 train_loss:6.2247 train_time:613ms step_avg:122.64ms +step:6/20000 train_loss:5.9914 train_time:746ms step_avg:124.36ms +step:7/20000 train_loss:5.7134 train_time:878ms step_avg:125.46ms +step:8/20000 train_loss:5.7020 train_time:1013ms step_avg:126.63ms +step:9/20000 train_loss:5.6503 train_time:1147ms step_avg:127.49ms +step:10/20000 train_loss:5.5641 train_time:1281ms step_avg:128.11ms +step:50/20000 train_loss:4.4387 train_time:6286ms step_avg:125.72ms +step:100/20000 train_loss:3.4780 train_time:12561ms step_avg:125.61ms +step:150/20000 train_loss:3.0365 train_time:18837ms step_avg:125.58ms +step:200/20000 train_loss:2.8215 train_time:25224ms step_avg:126.12ms +step:250/20000 train_loss:2.7101 train_time:31496ms step_avg:125.98ms +step:300/20000 train_loss:2.4252 train_time:37785ms step_avg:125.95ms +step:350/20000 train_loss:2.6244 train_time:44076ms step_avg:125.93ms +step:400/20000 train_loss:2.3059 train_time:50465ms step_avg:126.16ms +step:450/20000 train_loss:2.4578 train_time:56758ms step_avg:126.13ms +step:500/20000 train_loss:2.4510 train_time:63049ms step_avg:126.10ms +step:500/20000 val_loss:2.4438 val_bpb:1.4474 train_time:63117ms step_avg:126.23ms +step:550/20000 train_loss:2.3515 train_time:69347ms step_avg:126.08ms +step:600/20000 train_loss:2.4954 train_time:75750ms step_avg:126.25ms +step:650/20000 train_loss:2.3402 train_time:82039ms step_avg:126.21ms +step:700/20000 train_loss:2.3934 train_time:88335ms step_avg:126.19ms +step:750/20000 train_loss:2.2277 train_time:94637ms step_avg:126.18ms +step:800/20000 train_loss:2.2456 train_time:101039ms step_avg:126.30ms +step:850/20000 train_loss:2.6662 train_time:107345ms step_avg:126.29ms +step:900/20000 train_loss:2.2992 train_time:113651ms step_avg:126.28ms +step:950/20000 train_loss:2.3473 train_time:119957ms step_avg:126.27ms +step:1000/20000 train_loss:2.3376 train_time:126369ms step_avg:126.37ms +step:1000/20000 val_loss:2.2907 val_bpb:1.3567 train_time:126436ms step_avg:126.44ms +step:1050/20000 train_loss:2.4498 train_time:132668ms step_avg:126.35ms +step:1100/20000 train_loss:2.2113 train_time:138964ms step_avg:126.33ms +step:1150/20000 train_loss:2.2090 train_time:145354ms step_avg:126.39ms +step:1200/20000 train_loss:2.3619 train_time:152505ms step_avg:127.09ms +step:1250/20000 train_loss:2.1745 train_time:159407ms step_avg:127.53ms +step:1300/20000 train_loss:2.3343 train_time:165705ms step_avg:127.47ms +step:1350/20000 train_loss:2.2280 train_time:172100ms step_avg:127.48ms +step:1400/20000 train_loss:2.3885 train_time:178405ms step_avg:127.43ms +step:1450/20000 train_loss:2.2097 train_time:184730ms step_avg:127.40ms +step:1500/20000 train_loss:2.2067 train_time:191712ms step_avg:127.81ms +step:1500/20000 val_loss:2.2411 val_bpb:1.3273 train_time:191780ms step_avg:127.85ms +step:1550/20000 train_loss:2.1273 train_time:198116ms step_avg:127.82ms +step:1600/20000 train_loss:2.0603 train_time:204416ms step_avg:127.76ms +step:1650/20000 train_loss:2.2134 train_time:210722ms step_avg:127.71ms +step:1700/20000 train_loss:2.1385 train_time:217036ms step_avg:127.67ms +step:1750/20000 train_loss:2.2145 train_time:223455ms step_avg:127.69ms +step:1800/20000 train_loss:2.1776 train_time:229770ms step_avg:127.65ms +step:1850/20000 train_loss:2.2778 train_time:236081ms step_avg:127.61ms +step:1900/20000 train_loss:2.1639 train_time:242399ms step_avg:127.58ms +step:1950/20000 train_loss:2.1782 train_time:250831ms step_avg:128.63ms +step:2000/20000 train_loss:2.2134 train_time:257124ms step_avg:128.56ms +step:2000/20000 val_loss:2.1982 val_bpb:1.3019 train_time:257189ms step_avg:128.59ms +step:2050/20000 train_loss:2.2133 train_time:263426ms step_avg:128.50ms +step:2100/20000 train_loss:2.2265 train_time:269836ms step_avg:128.49ms +step:2150/20000 train_loss:2.1444 train_time:276144ms step_avg:128.44ms +step:2200/20000 train_loss:2.0327 train_time:282444ms step_avg:128.38ms +step:2250/20000 train_loss:2.1252 train_time:288751ms step_avg:128.33ms +step:2300/20000 train_loss:2.3228 train_time:295158ms step_avg:128.33ms +step:2350/20000 train_loss:2.1475 train_time:301462ms step_avg:128.28ms +step:2400/20000 train_loss:2.1570 train_time:307766ms step_avg:128.24ms +step:2450/20000 train_loss:2.1578 train_time:314064ms step_avg:128.19ms +step:2500/20000 train_loss:2.0877 train_time:320471ms step_avg:128.19ms +step:2500/20000 val_loss:2.1611 val_bpb:1.2799 train_time:320537ms step_avg:128.21ms +step:2550/20000 train_loss:2.0834 train_time:326780ms step_avg:128.15ms +step:2600/20000 train_loss:2.3750 train_time:333092ms step_avg:128.11ms +step:2650/20000 train_loss:2.1802 train_time:339391ms step_avg:128.07ms +step:2700/20000 train_loss:2.1005 train_time:345780ms step_avg:128.07ms +step:2750/20000 train_loss:2.3110 train_time:352081ms step_avg:128.03ms +step:2800/20000 train_loss:2.1872 train_time:358376ms step_avg:127.99ms +step:2850/20000 train_loss:2.1358 train_time:364674ms step_avg:127.96ms +step:2900/20000 train_loss:2.1310 train_time:371067ms step_avg:127.95ms +step:2950/20000 train_loss:2.1755 train_time:377363ms step_avg:127.92ms +step:3000/20000 train_loss:2.1738 train_time:383652ms step_avg:127.88ms +step:3000/20000 val_loss:2.1377 val_bpb:1.2661 train_time:383724ms step_avg:127.91ms +step:3050/20000 train_loss:2.1035 train_time:389945ms step_avg:127.85ms +step:3100/20000 train_loss:2.1408 train_time:396328ms step_avg:127.85ms +step:3150/20000 train_loss:2.1166 train_time:402618ms step_avg:127.82ms +step:3200/20000 train_loss:2.1341 train_time:408903ms step_avg:127.78ms +step:3250/20000 train_loss:2.0360 train_time:415292ms step_avg:127.78ms +step:3300/20000 train_loss:2.1804 train_time:421585ms step_avg:127.75ms +step:3350/20000 train_loss:2.0332 train_time:427875ms step_avg:127.72ms +step:3400/20000 train_loss:2.1038 train_time:434162ms step_avg:127.69ms +step:3450/20000 train_loss:2.0322 train_time:440542ms step_avg:127.69ms +step:3500/20000 train_loss:2.1927 train_time:446828ms step_avg:127.67ms +step:3500/20000 val_loss:2.1216 val_bpb:1.2566 train_time:446898ms step_avg:127.69ms +step:3550/20000 train_loss:2.3345 train_time:453126ms step_avg:127.64ms +step:3600/20000 train_loss:2.0442 train_time:459413ms step_avg:127.61ms +step:3650/20000 train_loss:2.1435 train_time:465789ms step_avg:127.61ms +step:3700/20000 train_loss:2.0649 train_time:472082ms step_avg:127.59ms +step:3750/20000 train_loss:2.0749 train_time:478365ms step_avg:127.56ms +step:3800/20000 train_loss:2.1327 train_time:484653ms step_avg:127.54ms +step:3850/20000 train_loss:2.0929 train_time:491047ms step_avg:127.54ms +step:3900/20000 train_loss:1.9060 train_time:497347ms step_avg:127.52ms +step:3950/20000 train_loss:2.0444 train_time:503642ms step_avg:127.50ms +step:4000/20000 train_loss:2.0870 train_time:509934ms step_avg:127.48ms +step:4000/20000 val_loss:2.0785 val_bpb:1.2310 train_time:510000ms step_avg:127.50ms +step:4050/20000 train_loss:2.0055 train_time:516329ms step_avg:127.49ms +step:4100/20000 train_loss:2.0866 train_time:522613ms step_avg:127.47ms +step:4150/20000 train_loss:2.2170 train_time:528899ms step_avg:127.45ms +step:4200/20000 train_loss:2.0628 train_time:535289ms step_avg:127.45ms +step:4250/20000 train_loss:2.0106 train_time:541576ms step_avg:127.43ms +step:4300/20000 train_loss:1.9043 train_time:547876ms step_avg:127.41ms +step:4350/20000 train_loss:2.0912 train_time:554167ms step_avg:127.39ms +step:4400/20000 train_loss:1.9780 train_time:560543ms step_avg:127.40ms +step:4450/20000 train_loss:1.9450 train_time:566831ms step_avg:127.38ms +step:4500/20000 train_loss:2.1330 train_time:573122ms step_avg:127.36ms +step:4500/20000 val_loss:2.0306 val_bpb:1.2026 train_time:573186ms step_avg:127.37ms +step:4550/20000 train_loss:1.9162 train_time:579403ms step_avg:127.34ms +step:4600/20000 train_loss:1.8466 train_time:585779ms step_avg:127.34ms +step:4650/20000 train_loss:1.9470 train_time:592062ms step_avg:127.33ms +step:4700/20000 train_loss:2.1416 train_time:598350ms step_avg:127.31ms +step:4713/20000 val_loss:2.0284 val_bpb:1.2013 train_time:600041ms step_avg:127.32ms +stopping_early: wallclock_cap train_time:600041ms step:4713/20000 +peak memory allocated: 22355 MiB reserved: 22816 MiB +Code size: 53888 bytes +Ternary artifact: 15057568 bytes (lzma) = 15.06MB + lzma: 15057568 bytes = 15.06MB + zlib: 17573053 bytes = 17.57MB + zstd: 17004297 bytes = 17.00MB + code: 53888 bytes +Total submission size: 15111456 bytes = 15.11MB +final_ternary_roundtrip val_loss:2.0310 val_bpb:1.2029 eval_time:64360ms +final_ternary_roundtrip_exact val_loss:2.03099029 val_bpb:1.20286685 \ No newline at end of file diff --git a/records/track_non_record_16mb/2026-03-19_BitNet158/train_gpt.py b/records/track_non_record_16mb/2026-03-19_BitNet158/train_gpt.py new file mode 100644 index 0000000000..a01e3bbb7d --- /dev/null +++ b/records/track_non_record_16mb/2026-03-19_BitNet158/train_gpt.py @@ -0,0 +1,1253 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: `train_gpt.py` and `train_gpt_mlx.py` must never be longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +import lzma +try: + import zstandard as zstd + HAS_ZSTD = True +except ImportError: + HAS_ZSTD = False +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 9)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 5)) + model_dim = int(os.environ.get("MODEL_DIM", 640)) + num_heads = int(os.environ.get("NUM_HEADS", 10)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +# Ternary quantization for BitNet artifact +BITNET_GROUP_SIZE = 64 + +def pack_ternary(q: Tensor) -> tuple[bytes, list]: + """Pack ternary {-1,0,1} as base-3: 5 trits per byte (1.6 bits/trit). Lossless.""" + flat = (q.reshape(-1).to(torch.int8) + 1).numpy() # map {-1,0,1} -> {0,1,2} + n = len(flat) + pad = (5 - n % 5) % 5 + if pad: + flat = np.concatenate([flat, np.zeros(pad, dtype=np.int8)]) + groups = flat.reshape(-1, 5) + # Encode 5 trits as: t0 + 3*t1 + 9*t2 + 27*t3 + 81*t4 + packed = (groups[:, 0].astype(np.uint8) + + groups[:, 1].astype(np.uint8) * 3 + + groups[:, 2].astype(np.uint8) * 9 + + groups[:, 3].astype(np.uint8) * 27 + + groups[:, 4].astype(np.uint8) * 81) + return packed.tobytes(), [n] + + +def unpack_ternary(data: bytes, n: int) -> Tensor: + """Unpack base-3 encoded ternary back to {-1,0,1}.""" + packed = np.frombuffer(data, dtype=np.uint8) + trits = np.zeros((len(packed), 5), dtype=np.int8) + vals = packed.astype(np.int16) + for i in range(5): + trits[:, i] = vals % 3 + vals //= 3 + flat = trits.reshape(-1)[:n] + return torch.from_numpy(flat.astype(np.int8) - 1) # map {0,1,2} -> {-1,0,1} + + +def quantize_state_dict_ternary(state_dict: dict[str, Tensor], model: nn.Module = None, group_size: int = BITNET_GROUP_SIZE): + """Ternary for large matrices (packed base-3), fp16 for embedding + scalars. + If model is provided, uses cached quantization from the last forward pass.""" + # Build cache from model's BitLinear layers + cache = {} + if model is not None: + for name, mod in model.named_modules(): + if isinstance(mod, BitLinear) and hasattr(mod, '_cached_q'): + cache[name + '.weight'] = (mod._cached_q, mod._cached_scale, mod._cached_shape) + + quantized: dict[str, object] = {} + stats = {"ternary_bytes": 0, "fp16_bytes": 0} + for name, tensor in state_dict.items(): + t = tensor.detach().cpu().float().contiguous() + if name in cache: + q, scale, shape = cache[name] + q = q.cpu() + scale = scale.cpu() + packed_bytes, pack_meta = pack_ternary(q) + quantized[name] = {"type": "ternary", "packed": packed_bytes, + "scale": scale, "shape": list(shape), + "padded_cols": shape[1], "group_size": group_size, + "n_trits": pack_meta[0]} + stats["ternary_bytes"] += len(packed_bytes) + scale.numel() * 2 + elif t.ndim == 2 and t.numel() > 65_536 and "tok_emb" not in name: + pad = (group_size - t.shape[1] % group_size) % group_size + t_padded = F.pad(t, (0, pad)) if pad > 0 else t + t_grouped = t_padded.reshape(-1, group_size) + scale = t_grouped.abs().mean(-1, keepdim=True).clamp(min=1e-8).half().float() + q = (t_grouped / scale).round().clamp(-1, 1).to(torch.int8) + packed_bytes, pack_meta = pack_ternary(q) + quantized[name] = {"type": "ternary", "packed": packed_bytes, + "scale": scale.to(torch.float16).squeeze(-1), + "shape": list(t.shape), "padded_cols": t_padded.shape[1], + "group_size": group_size, "n_trits": pack_meta[0]} + stats["ternary_bytes"] += len(packed_bytes) + scale.numel() * 2 + else: + quantized[name] = {"type": "fp16", "data": t.to(torch.float16)} + stats["fp16_bytes"] += t.numel() * 2 + return quantized, stats + + +def dequantize_state_dict_ternary(quantized: dict[str, object], target_dtype=torch.bfloat16) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + for name, entry in quantized.items(): + if entry["type"] == "ternary": + q = unpack_ternary(entry["packed"], entry["n_trits"]) + q = q.float().reshape(-1, entry["group_size"]) + scale = entry["scale"].float().unsqueeze(-1) + t = (q * scale).reshape(-1, entry["padded_cols"]) + shape = entry["shape"] + out[name] = t[:shape[0], :shape[1]].to(target_dtype).contiguous() + else: + out[name] = entry["data"].to(target_dtype).contiguous() + return out + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + def forward(self, x: Tensor) -> Tensor: + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), bias) + + +class BitLinear(nn.Linear): + """BitNet b1.58: ternary weights {-1, 0, 1} with STE, per-group absmax scaling.""" + def __init__(self, in_features: int, out_features: int, bias: bool = False, group_size: int = 64): + super().__init__(in_features, out_features, bias=bias) + self.group_size = group_size + self._skip_quantize = False + + def _quantize_weights(self, w: Tensor) -> Tensor: + # Per-group absmax ternary quantization with STE + shape = w.shape + g = self.group_size + w_flat = w.reshape(-1, g) + scale = w_flat.abs().mean(-1, keepdim=True).clamp(min=1e-8).half().float() + q = (w_flat / scale).round().clamp(-1, 1) + # Cache for serialization + self._cached_q = q.detach().to(torch.int8) + self._cached_scale = scale.detach().squeeze(-1).half() + self._cached_shape = shape + result = q * scale + return (result.reshape(shape) - w).detach() + w # STE + + def forward(self, x: Tensor) -> Tensor: + # RMSNorm on input activations (BitNet b1.58 style) + x = F.rms_norm(x, (x.size(-1),)) + if self._skip_quantize: + w = self.weight + else: + w = self._quantize_weights(self.weight) + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w.to(x.dtype), bias) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = BitLinear(dim, dim, bias=False) + self.c_k = BitLinear(dim, kv_dim, bias=False) + self.c_v = BitLinear(dim, kv_dim, bias=False) + self.proj = BitLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # relu^2 MLP from the original modded-nanogpt setup + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + hidden = mlp_mult * dim + self.fc = BitLinear(dim, hidden, bias=False) + self.proj = BitLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + ) + for i in range(num_layers) + ] + ) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips: list[Tensor] = [] + + # First half stores skips; second half reuses them in reverse order. + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + optimizer_tok = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + lr_warmup_steps = int(os.environ.get("LR_WARMUP_STEPS", "50")) + + def lr_mul(step: int, elapsed_ms: float) -> float: + # Linear warmup + if step < lr_warmup_steps: + return step / max(lr_warmup_steps, 1) + # Wallclock-aware linear warmdown (same as baseline) + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is not None and max_wallclock_ms > 0: + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # Save raw model and code size + if master_process: + code_bytes = len(code.encode("utf-8")) + log0(f"Code size: {code_bytes} bytes") + + # Ternary roundtrip + # Run one forward pass to populate caches + base_model.eval() + with torch.no_grad(): + x_dummy, y_dummy = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + base_model(x_dummy, y_dummy) + + tern_obj, tern_stats = quantize_state_dict_ternary(base_model.state_dict(), model=base_model) + tern_buf = io.BytesIO() + torch.save(tern_obj, tern_buf) + tern_raw = tern_buf.getvalue() + tern_lzma = lzma.compress(tern_raw, preset=9) + tern_zlib = zlib.compress(tern_raw, 9) + candidates = [("lzma", tern_lzma), ("zlib", tern_zlib)] + if HAS_ZSTD: + cctx = zstd.ZstdCompressor(level=22) + candidates.append(("zstd", cctx.compress(tern_raw))) + compress_method, tern_blob = min(candidates, key=lambda x: len(x[1])) + if master_process: + with open("final_model.ternary.ptz", "wb") as f: + f.write(tern_blob) + tern_file_bytes = os.path.getsize("final_model.ternary.ptz") + code_bytes = len(code.encode("utf-8")) + log0(f"Ternary artifact: {tern_file_bytes} bytes ({compress_method}) = {tern_file_bytes/1e6:.2f}MB") + for method, blob in candidates: + log0(f" {method}: {len(blob)} bytes = {len(blob)/1e6:.2f}MB") + log0(f" code: {code_bytes} bytes") + log0(f"Total submission size: {tern_file_bytes + code_bytes} bytes = {(tern_file_bytes + code_bytes)/1e6:.2f}MB") + + base_model.load_state_dict(dequantize_state_dict_ternary(tern_obj), strict=True) + # Disable re-quantization for roundtrip eval — loaded weights are already dequantized ternary + for mod in base_model.modules(): + if isinstance(mod, BitLinear): + mod._skip_quantize = True + torch.cuda.synchronize() + t_terneval = time.perf_counter() + tern_val_loss, tern_val_bpb = eval_val( + args, model, rank, world_size, device, grad_accum_steps, val_tokens, + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0(f"final_ternary_roundtrip val_loss:{tern_val_loss:.4f} val_bpb:{tern_val_bpb:.4f} eval_time:{1000.0 * (time.perf_counter() - t_terneval):.0f}ms") + log0(f"final_ternary_roundtrip_exact val_loss:{tern_val_loss:.8f} val_bpb:{tern_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main()