1.1145 BPB: Parallel Muon + INT5 GPTQ + Legal TTT#1171
1.1145 BPB: Parallel Muon + INT5 GPTQ + Legal TTT#1171EthanYangTW wants to merge 1 commit intoopenai:mainfrom
Conversation
Key innovations over previous submission (1.1195, PR openai#529): 1. **Parallel Muon Optimizer** — Parameter banking with async reduce-scatter/ all-gather overlapping Newton-Schulz orthogonalization. 3-phase training loop: (1) launch async RS for banks, (2) all-reduce + Adam step for replicated params (overlaps with RS), (3) wait RS, NS5, async AG. Eliminates DDP wrapper entirely. From PR openai#1120 (Rascal/Cambrian). 2. **INT5 Quantization (clip_range=15)** — 31 unique integer levels instead of 63 (INT6). Combined with GPTQ Hessian-aware error compensation, achieves ~0.476 bytes/param compression ratio vs ~0.64 for INT6. Enables fitting a larger model (MHA 8/8, MLP 3.5x, BigramHash 6144, ~32M unique params) under the 16MB artifact limit. 3. **Coprime Stride Data Loader** — Deterministic permutation-free sampling using coprime strides over memory-mapped shards. Each shard is traversed via stride coprime to block count, guaranteeing full coverage without storing permutation arrays. Adaptive shard selection with power-law weighting (alpha decays 0.9→0.5 over training). 4. **Wallclock-Adaptive LR Schedule** — LR warmdown triggers based on elapsed wallclock time rather than step count. Automatically adapts to varying step times across hardware, ensuring consistent convergence regardless of system performance. 5. **MHA 8/8 + MLP 3.5x + BigramHash 6144** — Larger architecture than previous submissions (was GQA 8/4, MLP 3.0, BigramHash 2048). Full multi-head attention, wider MLP, richer bigram hash embeddings. Only possible due to INT5 compression. Architecture: 11L, dim=512, MHA 8/8, MLP 3.5x (1792), LeakyReLU²(0.5), XSA all 11 layers, partial RoPE 16/64, LN scale 1/√(L+1), SmearGate, OrthoInit, BigramHash 6144, Shared VE128 (layers 9,10), U-Net skip connections, EMA 0.997, Tight SWA (every 50), Late QAT (threshold 0.15), Muon lr=0.025 WD=0.04 (momentum warmup 0.92→0.99 over 1500 steps) Training: 94ms/step → ~6333 steps in 600s wallclock on 8×H100 SXM Quantization: INT5 GPTQ (clip_range=15, block_size=64, 256-sample calibration) + 2% magnitude pruning + zstd-22 compression Eval: Sliding window (stride=64) + Legal score-first AdamW TTT (5 epochs, lr=0.0001, last 2 blocks + norms + head unfrozen, 262144-token chunks) 3-seed results: Seed 1337: 1.1144 BPB (16.12 MB artifact) Seed 42: 1.1141 BPB (15.12 MB artifact) Seed 7: 1.1150 BPB (15.26 MB artifact) Mean: 1.1145 BPB (std 0.0005)
There was a problem hiding this comment.
Copilot wasn't able to review any files in this pull request.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Community Review — 1.1145 BPB: Parallel Muon + INT5 GPTQ + Legal TTTCompliance: LOOKS CLEAN — legal score-first-per-chunk TTT (PR #1413 pattern) Analysis Head SHA: 38702da Files changed: train_gpt.py only --- ### N-gram / BigramHash family bug check — CLEAN
|
Summary
3-seed mean: 1.1145 BPB (std 0.0005)
All runs: 600s training + ~335s eval (sliding window stride=64 + 5-epoch TTT) on 8×H100 SXM.
Key Techniques
1. INT5 GPTQ Quantization (clip_range=15)
31 unique integer levels instead of the standard 63 (INT6). Combined with full GPTQ (Hessian-aware error compensation, column reordering, 256-sample self-generated calibration), achieves ~0.476 bytes/param — 26% smaller than INT6. This unlocks fitting a larger model under the 16MB artifact limit.
2. XSA on All 11 Layers
Cross-sequence attention applied to every layer, not just the last 4. Against conventional wisdom, but consistently better in our ablations.
3. Legal Score-First Chunked TTT
Validation data split into 262144-token chunks. For each chunk: score first (sliding window, inference mode), then adapt with AdamW (lr=0.0001, 5 epochs, last 2 blocks + norms + head unfrozen). Cosine LR decay across chunks. Every token scored BEFORE any gradient update touches it.
4. Coprime Stride Data Loader
Deterministic permutation-free sampling using strides coprime to shard block counts. Guarantees full data coverage without storing permutation arrays. Adaptive shard selection with decaying power-law weighting.
5. Wallclock-Adaptive LR Schedule
LR warmdown triggers based on elapsed wall time rather than step count, automatically adapting to hardware variation.
6. Parallel Muon Optimizer
Parameter banking with async reduce-scatter/all-gather overlapping Newton-Schulz orthogonalization (adapted from PR #1120). Three-phase training loop eliminates DDP wrapper.
Architecture
Training: Muon (lr=0.025, WD=0.04, NS5) + AdamW. 94ms/step, ~6333 steps in 600s.
Compliance