diff --git a/CLAUDE.md b/CLAUDE.md index 5e4debadab..adaa667fb9 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -5,9 +5,11 @@ OpenAI competition to train the best LM that fits in 16MB. Baseline: train_gpt_mlx.py (stock), my version: train_gpt_mlx_kl.py ## Commands -- Smoke test: RUN_ID=test ITERATIONS=100 TRAIN_BATCH_TOKENS=8192 VAL_LOSS_EVERY=0 VAL_BATCH_SIZE=8192 WARMUP_STEPS=3 python3 train_gpt_mlx_kl.py -- Must activate venv first: source ~/pg_env/bin/activate (venv is at ~/pg_env, NOT .venv) +- Smoke test (baseline): `RUN_ID=test ITERATIONS=100 TRAIN_BATCH_TOKENS=8192 VAL_LOSS_EVERY=0 VAL_BATCH_SIZE=8192 WARMUP_STEPS=3 python3 train_gpt_mlx_kl.py` +- Smoke test (moonshot): `RUN_ID=test ITERATIONS=100 TRAIN_BATCH_TOKENS=8192 VAL_LOSS_EVERY=0 VAL_BATCH_SIZE=8192 WARMUP_STEPS=3 ENGRAM_LITE_ENABLED=1 COMPLEMENT_ALPHA=0.5 NGRAM_MIXER_ENABLED=1 python3 train_gpt_mlx_kl.py` +- Must activate venv first: `source ~/pg_env/bin/activate` (venv is at ~/pg_env, NOT .venv) - Data is in ./data/datasets/fineweb10B_sp1024/ +- Full moonshot run (8×H100): `ENGRAM_LITE_ENABLED=1 COMPLEMENT_ALPHA=0.5 NGRAM_MIXER_ENABLED=1 NGRAM_ALPHA=0.25 NGRAM_MAX_ORDER=4 python3 train_gpt_mlx_kl.py` ## Key metrics - train_loss: lower is better, compare at same step count diff --git a/pg_novel_ideas.md b/pg_novel_ideas.md new file mode 100644 index 0000000000..5c735e1f34 --- /dev/null +++ b/pg_novel_ideas.md @@ -0,0 +1,1007 @@ +# Parameter Golf: Novel Approaches to Sub-1.10 BPB + +*Analysis date: March 30, 2026* +*Current official SOTA: 1.1194 BPB (#549, @sanjeevmadhav)* +*Best pending pure neural: 1.1086 BPB (#1089, @mikeapedia — Turbo-Muon + EngramLite)* +*Best pending n-gram cache: 0.4027 BPB (#1094 — causal BackoffNgramMixer, compliance TBD)* + +--- + +## Competition Context Summary + +**Constraints:** 16MB artifact (code + compressed model), 10 min training on 8×H100, 10 min eval. +**Metric:** val_bpb (bits per byte) on FineWeb validation set. Lower = better. +**Baseline:** 1.2244 BPB (9L, 512d, int8+zlib). +**Current stack:** 11L, 512d, 3×MLP, int6+zstd-22, XSA, EMA, GPTQ-lite, BigramHash, SmearGate, Partial RoPE, sliding-window eval. + +**Key quantitative constraints from Issue #140 ablation data:** +- 1ms step overhead ≈ 0.006-0.007 BPB cost (at ~83ms/step baseline) +- Int6 quant gap: ~0.0036 BPB (GPTQ is near-optimal at int6) +- Int5 quant gap: ~0.007 BPB per matrix group +- Int4 quant gap: ~0.065 BPB (catastrophic — dead end) +- 3-seed std: ~0.0005-0.0015 BPB +- EMA > SWA by 0.003 BPB (3-seed verified) +- Sliding window (s64, w2048): ~0.034 BPB improvement +- N-gram cache with correct normalization: 1.51 BPB alone (WORSE than neural — #978) + +**What's been tried and failed (selected):** +- MoE (optimal sparsity=0 below 500M params) +- Depth recurrence >2 loops (quant error amplifies 900×) +- Knowledge distillation (11ms/step I/O overhead fatal in 600s) +- MTP (no improvement) +- INT4 quantization (catastrophic +0.065 BPB) +- TrigramHash without gating (+0.0049 BPB, hurts compression) +- MC Dropout ensembling (sub-networks lack diversity at 17M params) +- kNN-LM at eval (XSA already captures inter-position patterns) +- Advanced quant algorithms at int6 (Qronos, CDQuant: GPTQ already near-optimal) +- Procrustes rotation (91% MSE reduction but 380% larger artifact — MSE ≠ artifact size) +- Pruning 3% of weights (+728KB artifact — zeroes hurt zstd-22) + +--- + +## Analysis of All 8 Angles + +### ANGLE 1: COMPRESSION IS THE OBJECTIVE (BPB-Aware Loss) + +**Core insight:** Cross-entropy loss treats all tokens equally, but BPB weights by bytes-per-token. Tokens decoding to more UTF-8 bytes matter more for BPB. + +**Will it work?** Partially — but the gain is smaller than it appears. + +**Analysis:** +The gap between CE-optimized BPB and byte-weighted BPB is modest for SP1024 tokenization. With a 1024-token vocabulary, most tokens decode to 1-4 bytes, and the distribution of bytes-per-token is relatively concentrated (mean ~1.18 bytes/token for English-dominant FineWeb). The correction factor (token_count/byte_count) is already baked into the BPB formula, so optimizing CE already approximates BPB optimization. + +However, there IS a real effect: tokens that decode to many bytes (rare long tokens, Unicode sequences) receive proportionally less gradient signal under CE. Byte-weighting would reallocate gradient toward these tokens. + +**Estimated impact:** 0.001-0.003 BPB improvement. The correction is small because: +1. SP1024 has a tiny vocabulary — the variance of bytes-per-token is low +2. High-byte tokens are often rare/noisy (non-English text, HTML entities) and may not be learnable +3. The model already sees these tokens during training — they just get equal weight + +**Implementation difficulty:** Very low — multiply CE loss by a pre-computed bytes-per-token lookup. + +**Risk of failure:** Moderate. The effect may be below noise floor (0.0005 BPB 3-seed std). + +**Compatibility:** Full — one-line change to loss computation. Zero overhead. + +**Verdict: LOW PRIORITY.** The theoretical gap is real but the practical gain at SP1024 is likely below significance threshold. Worth a zero-cost experiment but don't build a strategy around it. + +```python +# PROOF-OF-CONCEPT: Byte-weighted cross-entropy loss +# Add to train_gpt_mlx_kl.py + +def compute_bytes_per_token_lut(tokenizer_path): + """Pre-compute UTF-8 byte count for each token in vocabulary.""" + import sentencepiece as spm + sp = spm.SentencePieceProcessor(model_file=tokenizer_path) + vocab_size = sp.get_piece_size() + bytes_lut = [] + for i in range(vocab_size): + piece = sp.id_to_piece(i) + # Decode to bytes, count UTF-8 length + try: + byte_count = len(piece.encode('utf-8')) + except: + byte_count = 1 + bytes_lut.append(max(1, byte_count)) + return bytes_lut + +def byte_weighted_ce_loss(logits, targets, bytes_lut_tensor): + """ + Cross-entropy loss weighted by bytes-per-token. + + logits: (B*T, V) float + targets: (B*T,) int + bytes_lut_tensor: (V,) float — bytes per token ID + + Instead of: loss = -log(p[target]) averaged over tokens + We compute: loss = -log(p[target]) * bytes[target] / mean(bytes[target]) + + This makes the loss proportional to BPB contribution. + """ + import mlx.core as mx + import mlx.nn as nn + + # Standard CE per token + ce_per_token = nn.losses.cross_entropy(logits, targets, reduction='none') # (B*T,) + + # Byte weights for each target token + byte_weights = bytes_lut_tensor[targets] # (B*T,) + + # Normalize so total weight equals number of tokens (preserves LR scale) + byte_weights = byte_weights / byte_weights.mean() + + # Weighted mean + loss = (ce_per_token * byte_weights).mean() + return loss +``` + +--- + +### ANGLE 2: NON-UNIFORM QUANTIZATION + +**Core insight:** Neural network weights are approximately Gaussian — uniform int6 wastes precision on the tails. + +**Will it work?** No, for two devastating reasons specific to this competition. + +**Analysis:** + +**Reason 1: MSE ≠ Artifact Size.** This is the single most important lesson from the competition (#1048, #316). Non-uniform quantization (k-means, log-scale, NF6) reduces reconstruction MSE. But the artifact is compressed with zstd-22, and non-uniform codebook indices have HIGHER entropy than uniform indices. Uniform int6 produces values that cluster around certain bit patterns, which zstd compresses efficiently. K-means indices are essentially random 6-bit values with near-uniform distribution — maximum entropy, minimum compressibility. + +Concrete example from #1048: Procrustes rotation reduced MSE by 91% but increased artifact size by 380% because the rotated weights had higher entropy. The same principle applies to non-uniform quant: better MSE, worse compression, net negative for the 16MB budget. + +**Reason 2: GPTQ is already near-optimal at int6.** #756 (@abaybektursun, SOTA holder) tested Qronos iterative Hessian (+0.0007 worse) and CDQuant coordinate descent (+0.0005 worse) — both more sophisticated than uniform GPTQ. At int6 with 64 levels, the quant gap is only 0.0036 BPB. There's simply not much room to improve. + +**Reason 3: Codebook storage.** A 64-entry codebook per row (or per tensor) costs bytes that offset any quality gain. For 17M params in ~11K rows, even 64×2 bytes per row = ~1.4MB of codebooks. + +**Estimated impact:** Net negative (larger artifact for marginal MSE improvement). + +**Implementation difficulty:** Moderate (k-means on weights, custom pack/unpack). + +**Risk of failure:** Very high — #1048 and #316 both demonstrate the MSE≠artifact principle. + +**Compatibility:** Poor — requires custom serialization, breaks zstd compression efficiency. + +**Verdict: DO NOT PURSUE.** The fundamental insight "MSE ≠ compressed artifact size" kills this entire angle. Every non-uniform scheme increases index entropy, which defeats zstd. The competition has empirically confirmed this multiple times. + +--- + +### ANGLE 3: ENTROPY-CODED WEIGHTS + +**Core insight:** Zstd treats all bytes equally. What if we designed weight distributions for maximum compressibility? + +**Will it work?** One sub-idea works marginally, the rest don't. + +**Analysis:** + +**Weight entropy regularization:** Promising in principle — penalize high-entropy weight distributions during training. But #609's ablation found lzma (which is closer to a custom entropy coder than zstd) achieves 99.7% of Shannon limit on the weight data. **Zstd-22 is already near-optimal.** The bottleneck isn't the entropy coder — it's the weight entropy itself. And the weight entropy is determined by the model's capacity needs, not by the compression algorithm. + +**Sparse + dense hybrid:** #1048 proved that 3% pruning INCREASES artifact by 728KB. Zeroing weights doesn't help zstd-22 because the zero values disrupt the statistical patterns zstd exploits. Structured pruning at 50% would be catastrophic for model quality AND artifact size. + +**ANS instead of zstd:** #1089 used Brotli+byte-shuffle instead of zstd on mixed int6/int7 — this IS the closest thing to a custom entropy coder that's been tried. The gain was real but small (enough to squeeze in slightly higher precision). ANS tuned to exact weight distribution could save 0.05-0.2MB over zstd-22, which translates to ~100-400K more parameters — worth ~0.001-0.003 BPB at the margin. + +**The one promising sub-idea: NuMuon (arXiv:2603.03597).** Nuclear-norm constraint on Muon updates → lower stable rank → better zstd compression. This pushes compressibility into the *optimizer itself*, which is fundamentally different from post-hoc compression. The weights naturally develop lower entropy during training, rather than being forced into compressible patterns afterward. + +**Estimated impact:** 0.001-0.003 BPB (ANS/Brotli tuning) or 0.002-0.006 BPB (NuMuon optimizer). + +**Implementation difficulty:** ANS is moderate. NuMuon is low (optimizer change). + +**Risk of failure:** Moderate for ANS (marginal gains). Low-moderate for NuMuon (backed by theory). + +**Compatibility:** Full — orthogonal to everything else. + +**Verdict: NuMuon is worth testing (Tier 2 idea from Issue #140). Custom entropy coding is marginal but free.** + +--- + +### ANGLE 4: HYPERNETWORK WEIGHT GENERATION + +**Core insight:** Store a tiny network that generates the weight matrices at load time. + +**Will it work?** Almost certainly not at competitive quality. + +**Analysis:** + +This is implicit neural representation (INR) applied to weight matrices. The problem: weight matrices in a trained LLM are NOT smooth or low-frequency. They contain high-frequency, semantically meaningful structure that resists compact representation. A 200K-param hypernetwork cannot generate 37M coherent weights — the information-theoretic compression ratio of ~185× would require the weights to have ~185× redundancy, which they don't. + +**The low-rank basis variant** is more realistic: hypernetwork generates a rank-K basis for each weight matrix, and coefficients are stored directly. But this is just low-rank factorization with extra steps, and it's been explored: +- #609 found Hadamard rotation saves -0.0002 BPB but costs +0.5MB (net negative) +- CPSVD (Column-Preserving SVD) is the principled version of this — untried but estimated at 0.003-0.008 BPB +- The fundamental issue: low-rank approximation of weight matrices loses too much information at the precision levels needed + +**The real version of this idea that works:** Weight tying + per-layer LoRA deltas (Relaxed Recursive Transformers, ICLR 2025). Share base weights across layers, add tiny per-layer LoRA adaptations. This gives you ~24 virtual layers from ~11 layers of parameters. #686 demonstrated shallow recurrence (2 layers repeated once, +2 virtual depth) at 1.1182 BPB — it works when limited to 2 loops. But >2 loops causes GPTQ error amplification (#579, #363). + +**Estimated impact:** Hypernetwork: net negative. Relaxed Recursive: 0.01-0.03 BPB (from deeper effective model). + +**Implementation difficulty:** Hypernetwork: high. Relaxed Recursive: moderate. + +**Risk of failure:** Hypernetwork: very high. Relaxed Recursive: moderate (2-loop limit). + +**Compatibility:** Relaxed Recursive is compatible with the existing stack. + +**Verdict: Hypernetwork is a dead end. Relaxed Recursive Transformers with LoRA deltas is the viable realization of this concept, and it's already on the Tier 2 list.** + +--- + +### ANGLE 5: CONTEXT MIXING (N-gram Ensemble) + +**Core insight:** Combine multiple simple predictors (bigrams, trigrams, byte-level) with learned mixing weights. + +**Will it work?** This is the MOST promising angle — with critical caveats. + +**Analysis:** + +The competition has extensively explored this direction, and the results are dramatic but complicated: + +**What happened with n-gram caches (Mar 25-27):** A wave of submissions used eval-time n-gram caches to achieve sub-1.0 BPB. Then #978 proved that with correct full-vocabulary normalization, standalone n-gram caches degrade to 1.51 BPB — worse than neural baseline. The previous sub-0.1 scores were normalization artifacts. 33+ PRs were closed. + +**But causal n-gram mixing IS viable:** Post-enforcement, the correctly-implemented BackoffNgramMixer (#803, #1094) still achieves 0.40-0.44 BPB. The key difference: these produce full normalized probability distributions over the entire vocabulary at each step, blended with the neural model's distribution using learned or entropy-adaptive alpha mixing. + +**TrigramHash as a training-time component:** #609 found TrigramHash (without gating) HURTS by +0.0049 BPB. But EngramLite (#1089) with gating + multi-head hashing + trigrams works — part of the new best 1.1086 BPB. **The gating is essential** — it suppresses noisy hash collisions that raw TrigramHash amplifies. + +**Byte-level predictor:** H-Net (#1044) attempted learned byte-level tokenization — 1.90 BPB, far behind. Byte-level processing is too slow for the 600s training budget at current architectures. + +**Skip-gram hash:** Untried in the competition. Issue #140 lists it as a Tier 1 idea with 0.005-0.015 BPB estimated gain. Uses non-contiguous positions (e.g., tokens[-1, -3, -5]) as context — captures patterns with intervening content. Zero additional memory per context, just hash different positions. Especially effective on FineWeb's structured web text. + +**The realistic path:** +1. Use EngramLite-style gated multi-head hashing (bigram + trigram) during training: ~0.003-0.008 BPB +2. At eval time, add a correctly-normalized BackoffNgramMixer with entropy-adaptive alpha: ~0.05-0.15 BPB +3. The combined system achieves the "complementary training" effect (#803): the neural model specializes on what n-grams can't predict + +**Estimated impact:** +- Training-time context mixing (EngramLite): 0.003-0.008 BPB (proven by #1089) +- Eval-time BackoffNgramMixer: 0.05-0.15 BPB additional (proven by #803, #1094) +- Skip-gram hash: 0.005-0.015 BPB (untried, moderate confidence) + +**Implementation difficulty:** EngramLite: moderate. BackoffNgramMixer: moderate-high (must get normalization right). Skip-gram: low. + +**Risk of failure:** Low for EngramLite (proven). Moderate for BackoffNgramMixer (normalization is tricky). Low for skip-gram (simple extension of proven concept). + +**Compatibility:** Excellent — EngramLite replaces BigramHash. BackoffNgramMixer is eval-only. + +**Verdict: HIGH PRIORITY. The combination of EngramLite training + BackoffNgramMixer eval + Complementary Training is the most promising path to sub-1.10 BPB.** + +```python +# PROOF-OF-CONCEPT: Gated Multi-Head Hash Embedding (EngramLite-inspired) +# Replaces BigramHash in train_gpt_mlx_kl.py + +import mlx.core as mx +import mlx.nn as nn + +class EngramLiteEmbedding(nn.Module): + """ + Multi-head hashed n-gram embeddings with learned gating. + + Key improvements over BigramHash: + 1. Multiple hash heads (K=4) per n-gram order — reduces collision rate + 2. Trigram support — captures 3-token patterns + 3. Learned gate — sigmoid suppresses noisy lookups + + Architecture: + - For each n-gram order (2, 3): + - K hash functions map context to table indices + - Table lookup produces K embeddings + - Mean pool across heads + - Sigmoid gate (from context) scales the output + - Sum across orders → output + + Parameter budget (adjusted for 16MB constraint): + bigram_table: 2048 × 128 × 2 heads = 524K params + trigram_table: 2048 × 128 × 2 heads = 524K params + projection: 128 × 1024 = 131K params + gate: 128 × 2 + 2 = 258 params + Total: ~1.2M params ≈ 0.9MB in int6 — fits easily + """ + def __init__(self, hash_size: int = 2048, embed_dim: int = 128, + output_dim: int = 1024, n_heads: int = 2, + orders: tuple = (2, 3)): + super().__init__() + self.hash_size = hash_size + self.embed_dim = embed_dim + self.output_dim = output_dim + self.n_heads = n_heads + self.orders = orders + + # Hash primes for multi-head hashing (different prime per head) + self.primes = [31337, 59999, 73721, 97531][:n_heads] + + # Separate embedding table per order (small embed_dim, projected later) + self.tables = {} + for order in orders: + table = nn.Embedding(hash_size, embed_dim) + # Small init — these are additive corrections + table.weight = table.weight * 0.01 + self.tables[f'order_{order}'] = table + + # Project from embed_dim to output_dim (vocab_size or model_dim) + self.proj = nn.Linear(embed_dim, output_dim, bias=False) + + # Learned gate: context → sigmoid scalar per position + # Input: concatenated n-gram context embeddings + self.gate_proj = nn.Linear(embed_dim, len(orders), bias=True) + # Initialize gate bias to -2.0 → sigmoid(-2) ≈ 0.12 → starts mostly suppressed + # Model learns to trust hash lookups as training progresses + self.gate_proj.bias = mx.full((len(orders),), -2.0) + + def _hash_ngram(self, tokens, order, head_idx): + """Hash n-gram context to table index.""" + B, T = tokens.shape + prime = self.primes[head_idx] + + if order == 2: + # Bigram: hash(t-1, t) + t_prev = tokens[:, :-1] # (B, T-1) + t_curr = tokens[:, 1:] # (B, T-1) + idx = mx.remainder(t_prev * prime + t_curr, self.hash_size) + valid_start = 1 + elif order == 3: + # Trigram: hash(t-2, t-1, t) + t_prev2 = tokens[:, :-2] # (B, T-2) + t_prev1 = tokens[:, 1:-1] # (B, T-2) + t_curr = tokens[:, 2:] # (B, T-2) + idx = mx.remainder( + t_prev2 * (prime * prime) + t_prev1 * prime + t_curr, + self.hash_size + ) + valid_start = 2 + else: + raise ValueError(f"Order {order} not supported") + + return idx, valid_start + + def __call__(self, tokens): + """ + tokens: (B, T) int32 + Returns: (B, T, output_dim) — additive logit bias + """ + B, T = tokens.shape + output = mx.zeros((B, T, self.embed_dim)) + + for oi, order in enumerate(self.orders): + table = self.tables[f'order_{order}'] + + # Multi-head: average K hash lookups + head_embeds = [] + for hi in range(self.n_heads): + idx, valid_start = self._hash_ngram(tokens, order, hi) + emb = table(idx) # (B, T-order+1, embed_dim) + head_embeds.append(emb) + + # Mean pool across heads — reduces collision noise + ngram_emb = sum(head_embeds) / self.n_heads # (B, T-valid_start, embed_dim) + + # Pad to full sequence length + pad = mx.zeros((B, valid_start, self.embed_dim)) + ngram_emb = mx.concatenate([pad, ngram_emb], axis=1) # (B, T, embed_dim) + + output = output + ngram_emb + + # Project to output dimension and apply gate + gated = mx.sigmoid(self.gate_proj(output)) # (B, T, n_orders) + # Average gate across orders for simplicity + gate_scalar = gated.mean(axis=-1, keepdims=True) # (B, T, 1) + + return self.proj(output) * gate_scalar # (B, T, output_dim) + + +# PROOF-OF-CONCEPT: Skip-Gram Hash Embedding +class SkipGramHashEmbedding(nn.Module): + """ + Hash embedding using non-contiguous token positions. + + Captures patterns like: + - token[-1] × token[-3] (skip one) + - token[-1] × token[-5] (skip three) + + Effective for structured text (HTML tags, code indentation, + sentence templates) where intervening content varies. + """ + def __init__(self, hash_size: int = 4096, dim: int = 1024, + skip_patterns: list = None): + super().__init__() + self.hash_size = hash_size + self.dim = dim + # Each pattern is a tuple of negative offsets, e.g., (-1, -3) + self.skip_patterns = skip_patterns or [(-1, -3), (-1, -5), (-2, -4)] + + self.tables = {} + for i, pattern in enumerate(self.skip_patterns): + table = nn.Embedding(hash_size, dim) + table.weight = table.weight * 0.01 + self.tables[f'skip_{i}'] = table + + def __call__(self, tokens): + B, T = tokens.shape + output = mx.zeros((B, T, self.dim)) + + for i, pattern in enumerate(self.skip_patterns): + table = self.tables[f'skip_{i}'] + min_offset = min(pattern) # Most negative + valid_start = abs(min_offset) + + # Gather tokens at skip positions + # pattern = (-1, -3) means: token at t-1 and token at t-3 + hash_val = mx.zeros((B, T - valid_start), dtype=mx.int32) + prime = 31337 + for offset in pattern: + start = valid_start + offset + end = T + offset + tok_slice = tokens[:, start:end] + hash_val = hash_val * prime + tok_slice + + idx = mx.remainder(mx.abs(hash_val), self.hash_size) + emb = table(idx) + + pad = mx.zeros((B, valid_start, self.dim)) + emb = mx.concatenate([pad, emb], axis=1) + output = output + emb + + return output +``` + +--- + +### ANGLE 6: DEPTH RECURRENCE WITH PROGRESSIVE ADAPTATION + +**Core insight:** 4 unique blocks × 3 loops = 12 effective layers from 4 layers of parameters. + +**Will it work?** Only with ≤2 loops, and the gain is modest. + +**Analysis:** + +The competition has extensively tested this: +- #344: 2× slower, hurts BPB +- #363: **Quantization error amplifies ~900× over 3 cycles** — this is the killer +- #579: 6×2 loops gives 1.1478 (1-seed), but GPTQ compounds multiplicatively +- #686: **Shallow recurrence works** — layers 4+5 repeated once (11→13 virtual), 1.1182 BPB (3-seed) + +**The critical insight from #363:** When you reuse the same quantized weights K times, the error in each weight gets applied K times. At int6, each weight has ~0.016 expected quantization error (range/64). Over 3 cycles, this compounds: effective error ≈ 3 × 0.016 = 0.048, which is approaching int4-level degradation. At 2 cycles, error ≈ 2 × 0.016 = 0.032 — still viable. + +**FiLM conditioning** (scale/shift per loop iteration) helps because it differentiates loop passes, but it can't overcome the quantization amplification problem for >2 loops. + +**The viable version:** #686's approach — repeat 2 middle layers once, getting 13 virtual layers from 11 layers of parameters. Combined with per-pass learnable scalars (~2K params). Recovers ~70% of independent 12L quality at minimal step cost. + +**Budget math with aggressive recurrence:** 4 unique blocks in int6 ≈ 4MB. But we need >4 blocks for competitive quality — the 4-block config is far too small. And even at 4 blocks × 3 loops, the 900× quant error amplification makes it nonviable. + +**Estimated impact:** +2 virtual layers via shallow recurrence: ~0.003-0.008 BPB (proven by #686). + +**Implementation difficulty:** Low for shallow recurrence. High for full FiLM + multi-loop. + +**Risk of failure:** Low for ≤2 loops. Very high for ≥3 loops. + +**Compatibility:** Good — drop-in modification to layer stack. + +**Verdict: SHALLOW RECURRENCE (2 loops on 2 layers) IS PROVEN AND VIABLE. Full depth recurrence (3+ loops) is dead due to int6 error amplification.** + +--- + +### ANGLE 7: MULTI-MODEL ENSEMBLE IN 16MB + +**Core insight:** Two complementary models (small transformer + massive n-gram hash) might beat one larger transformer. + +**Will it work?** YES — this is essentially what the top n-gram cache submissions do. + +**Analysis:** + +This is not a novel idea — it's the dominant strategy on the (compliance-questioned) pending leaderboard. The top submissions (#803 at 0.4416, #1094 at 0.4027) are exactly this: a neural base model + an n-gram cache ensemble, with learned mixing. + +**The specific budget breakdown works:** +- Neural model (11L, 512d, int6+zstd): ~15MB +- BackoffNgramMixer (orders 2-10, ~4M hash buckets): 0 MB artifact (built at eval time from already-scored tokens) +- Total: ~15MB ✓ + +The n-gram component costs ZERO artifact space because it's built incrementally during evaluation from tokens already scored. This is the key insight — the 16MB budget goes entirely to the neural model, and the n-gram mixer is free. + +**Why it works despite #978:** #978 showed standalone normalized n-grams achieve only 1.51 BPB. But that's standalone. When MIXED with a neural model, the n-gram component handles high-confidence local patterns (common bigrams, frequent phrases) while the neural model handles everything else. The mixing is complementary — each handles what the other can't. + +**Complementary Training (#803):** The neural model is trained with loss weights that down-weight tokens easily predicted by bigram statistics. This forces the model to specialize on what n-grams can't handle, maximizing complementarity. This is the critical innovation that separates 0.44 from 0.55 BPB. + +**Estimated impact:** 0.05-0.20 BPB over pure neural, depending on n-gram order and mixing quality. + +**Implementation difficulty:** Moderate-high (correct normalization is crucial — #978's lesson). + +**Risk of failure:** Low for concept (proven), moderate for compliance (organizer scrutiny ongoing). + +**Compatibility:** Eval-time only — compatible with any training stack. + +**Verdict: HIGH PRIORITY. This is proven and the highest-impact single technique available. The risk is compliance, not effectiveness.** + +```python +# PROOF-OF-CONCEPT: Complementary Training + BackoffNgramMixer +# +# Two-part system: +# Part 1: Modified training loss (in train_gpt_mlx_kl.py) +# Part 2: Eval-time BackoffNgramMixer + +# ============================================================ +# PART 1: Complementary Training Loss +# ============================================================ + +import mlx.core as mx +import mlx.nn as nn +import numpy as np +from collections import defaultdict + +def build_bigram_stats(train_data_path, vocab_size=1024): + """ + Pre-compute bigram transition probabilities from training data. + Used to identify tokens that are 'easy' for n-gram models. + + Returns: bigram_probs[prev_token, next_token] = P(next|prev) + """ + # Count bigram frequencies + counts = np.zeros((vocab_size, vocab_size), dtype=np.float32) + # Read training shards + import glob, struct + for shard_path in sorted(glob.glob(f"{train_data_path}/fineweb_train_*.bin")): + with open(shard_path, 'rb') as f: + header = struct.unpack('<256i', f.read(1024)) + n_tokens = header[2] + tokens = np.frombuffer(f.read(n_tokens * 2), dtype=np.uint16) + for i in range(len(tokens) - 1): + counts[tokens[i], tokens[i+1]] += 1 + + # Normalize to probabilities (with Laplace smoothing) + row_sums = counts.sum(axis=1, keepdims=True) + vocab_size + bigram_probs = (counts + 1) / row_sums + return bigram_probs + + +def complementary_loss(logits, targets, prev_tokens, bigram_probs_mx, + complement_alpha=0.5): + """ + Weighted CE loss that down-weights tokens easily predicted by bigrams. + + For each token: + p_bigram = bigram_probs[prev_token, target_token] + weight = 1 - complement_alpha * p_bigram + + Tokens with high bigram probability get lower training weight, + forcing the neural model to specialize on what n-grams can't predict. + + logits: (B*T, V) + targets: (B*T,) + prev_tokens: (B*T,) — the token at position t-1 + bigram_probs_mx: (V, V) — pre-computed bigram transition probs + complement_alpha: float — strength of complementary weighting (0=standard CE) + """ + # Standard CE per token + ce_per_token = nn.losses.cross_entropy(logits, targets, reduction='none') + + # Look up bigram probability of each target given its predecessor + # p_bigram[i] = bigram_probs[prev_tokens[i], targets[i]] + p_bigram = bigram_probs_mx[prev_tokens, targets] # (B*T,) + + # Complementary weights: tokens easily predicted by bigrams get low weight + weights = 1.0 - complement_alpha * p_bigram + weights = mx.clip(weights, 0.1, 1.0) # Floor at 0.1 to avoid zero gradients + + # Normalize weights to preserve effective learning rate + weights = weights / weights.mean() + + loss = (ce_per_token * weights).mean() + return loss + + +# ============================================================ +# PART 2: Eval-Time BackoffNgramMixer (Causal, Full-Vocab Normalized) +# ============================================================ + +class BackoffNgramMixer: + """ + Causal n-gram language model with Kneser-Ney-style backoff. + Built incrementally from already-scored tokens (backward-looking). + + Key properties for compliance: + 1. Full-vocabulary normalized: produces valid probability distribution + 2. Causal: only uses tokens at positions < current position + 3. Single-pass: score-first, then update cache + 4. No artifact cost: built from scratch during eval + + Algorithm: + - Maintain count tables for orders 1 through max_order + - For each position t: + 1. Query all orders using context tokens[t-order+1:t] + 2. Backoff from highest to lowest order with interpolation + 3. Produce P_ngram(token | context) for all vocab tokens + 4. Mix with neural model: P = (1-alpha)*P_neural + alpha*P_ngram + 5. Score position t using P + 6. Update count tables with token[t] + """ + def __init__(self, vocab_size=1024, max_order=7, + hash_buckets=2_000_000, alpha_mode='entropy_adaptive'): + self.vocab_size = vocab_size + self.max_order = max_order + self.hash_buckets = hash_buckets + self.alpha_mode = alpha_mode + + # Count tables: hash(context) -> array of vocab counts + # Using defaultdict for prototype; production uses fixed hash tables + self.counts = [defaultdict(lambda: np.zeros(vocab_size, dtype=np.float32)) + for _ in range(max_order + 1)] + self.total_counts = [defaultdict(float) for _ in range(max_order + 1)] + + def _hash_context(self, context_tokens): + """Hash a sequence of tokens to a bucket index.""" + h = 0 + for t in context_tokens: + h = (h * 31337 + int(t)) % self.hash_buckets + return h + + def _get_ngram_probs(self, context_tokens): + """ + Compute interpolated n-gram probability distribution. + Uses simple linear interpolation backoff. + """ + vocab_size = self.vocab_size + + # Start with uniform (order 0) + probs = np.ones(vocab_size, dtype=np.float64) / vocab_size + + # Interpolate from low to high order + for order in range(1, self.max_order + 1): + if len(context_tokens) < order: + break + + ctx = context_tokens[-order:] + ctx_hash = self._hash_context(ctx) + + counts = self.counts[order][ctx_hash] + total = self.total_counts[order][ctx_hash] + + if total > 0: + # Interpolation weight increases with total count (confidence) + lambda_order = total / (total + 5.0) # Simple discount + order_probs = (counts + 1e-10) / (total + 1e-10 * vocab_size) + + # Ensure normalization + order_probs = order_probs / order_probs.sum() + + probs = (1 - lambda_order) * probs + lambda_order * order_probs + + # Final normalization (paranoia) + probs = probs / probs.sum() + return probs + + def _compute_alpha(self, neural_logits): + """ + Entropy-adaptive mixing weight. + When neural model is uncertain (high entropy), trust n-grams more. + When neural model is confident (low entropy), trust it more. + """ + if self.alpha_mode == 'fixed': + return 0.3 + + # Compute neural model entropy + probs = np.exp(neural_logits - neural_logits.max()) + probs = probs / probs.sum() + entropy = -np.sum(probs * np.log2(probs + 1e-10)) + max_entropy = np.log2(self.vocab_size) # ~10 bits for 1024 vocab + + # Map entropy to alpha: high entropy → high alpha + normalized_entropy = entropy / max_entropy + alpha = 0.15 + 0.45 * normalized_entropy # Range: 0.15-0.60 + + return alpha + + def score_and_update(self, position, context_tokens, token_at_pos, + neural_log_probs): + """ + Score position and update cache. Must be called sequentially. + + Returns: log probability of token_at_pos under mixed distribution + """ + # 1. Get n-gram distribution (causal: uses only tokens before position) + ngram_probs = self._get_ngram_probs(context_tokens) + + # 2. Get mixing weight + alpha = self._compute_alpha(neural_log_probs) + + # 3. Mix distributions + neural_probs = np.exp(neural_log_probs) + neural_probs = neural_probs / neural_probs.sum() # Re-normalize + + mixed_probs = (1 - alpha) * neural_probs + alpha * ngram_probs + mixed_probs = mixed_probs / mixed_probs.sum() # Ensure normalization + + # 4. Score + log_prob = np.log(mixed_probs[token_at_pos] + 1e-30) + + # 5. Update cache (AFTER scoring — backward-looking) + for order in range(1, self.max_order + 1): + if len(context_tokens) >= order: + ctx = context_tokens[-order:] + ctx_hash = self._hash_context(ctx) + self.counts[order][ctx_hash][token_at_pos] += 1 + self.total_counts[order][ctx_hash] += 1 + + return log_prob +``` + +--- + +### ANGLE 8: INFORMATION-THEORETIC LOWER BOUND + +**Analysis:** + +The theoretical analysis is sound: +- Shannon entropy of clean English: ~1.0-1.3 bits/byte +- Best neural compressors on clean English: ~0.8-0.9 bits/byte (exploiting long-range structure) +- FineWeb is web text: noisier, multilingual, more diverse → practical floor ~1.0-1.1 bits/byte +- Current SOTA: 1.1194 (official), 1.1086 (pending) +- LoRA TTT: reached 1.0865 (#628, GEPA+legal TTT on 4×A100) + +**The 16MB constraint is the binding limit, not the theoretical floor.** A 16MB artifact encodes ~128M bits of information. The model has ~37M parameters in int6 ≈ 222M bits. After zstd compression, ~124M bits. The validation set is ~60M tokens × ~1.18 bytes/token ≈ ~70M bytes. To achieve 1.0 BPB, we need 70M bits of prediction accuracy from 124M bits of model. That's a ~1.8:1 ratio — tight but feasible. + +**Sub-1.10 BPB is achievable within 16MB** — the GEPA+TTT result (1.0865) proves this, though it used 4×A100 for 20K steps (more compute). The question is whether 600s on 8×H100 (~7K steps) provides enough training to reach the same quality. + +**Estimated gap breakdown (1.1194 → 1.10):** +- Better quantization/compression (entropy coding, NuMuon): ~0.003 +- Better architecture (shallow recurrence, EngramLite): ~0.005-0.008 +- Better training (Complementary Training, Mousse/Turbo-Muon): ~0.003-0.005 +- Eval-time BackoffNgramMixer: ~0.05-0.10 +- **Total estimated: ~0.06-0.12 BPB improvement → 1.00-1.06 BPB** + +Sub-1.10 is clearly achievable. Sub-1.05 is plausible. Sub-1.00 is at the edge of feasibility. + +--- + +## Ranked List: Most Promising Ideas + +### Tier 1 — Highest Impact, Proven Feasible + +| Rank | Idea | Source Angle | Est. BPB Gain | Risk | Difficulty | +|------|------|-------------|---------------|------|------------| +| 1 | **BackoffNgramMixer at eval time** | 5, 7 | 0.05-0.15 | Low (proven) / Moderate (compliance) | Moderate-High | +| 2 | **Complementary Training** | 5, 7 | 0.01-0.03 (over standard training) | Low (proven by #803) | Low | +| 3 | **EngramLite (gated multi-head hash)** | 5 | 0.003-0.008 (over BigramHash) | Low (proven by #1089) | Moderate | + +### Tier 2 — Moderate Impact, Good Feasibility + +| Rank | Idea | Source Angle | Est. BPB Gain | Risk | Difficulty | +|------|------|-------------|---------------|------|------------| +| 4 | **Shallow recurrence (+2 virtual layers)** | 6 | 0.003-0.008 | Low (proven by #686) | Low | +| 5 | **Skip-gram hash embedding** | 5 | 0.005-0.015 | Moderate (untried) | Low | +| 6 | **NuMuon optimizer** | 3 | 0.002-0.006 | Moderate | Low | +| 7 | **Mousse optimizer** | (Issue #140) | 0.003-0.008 | Low-Moderate | Low | +| 8 | **PPMII-style escape estimation** | 5, 7 | 0.01-0.03 (over basic backoff) | Moderate | Medium | + +### Tier 3 — Small/Speculative Impact + +| Rank | Idea | Source Angle | Est. BPB Gain | Risk | Difficulty | +|------|------|-------------|---------------|------|------------| +| 9 | **Byte-weighted CE loss** | 1 | 0.001-0.003 | High (below noise) | Very Low | +| 10 | **Custom entropy coding (ANS/Brotli)** | 3 | 0.001-0.003 | Moderate | Moderate | +| 11 | **Logistic-domain mixing** | 5 | 0.002-0.005 | Low | Very Low | + +### Dead Ideas (Don't Pursue) + +| Idea | Source Angle | Why Dead | +|------|-------------|----------| +| Non-uniform quantization (K-means, NF6) | 2 | MSE ≠ artifact size; higher index entropy defeats zstd | +| Hypernetwork weight generation | 4 | Information-theoretic impossibility at 185× compression | +| Deep recurrence (3+ loops) | 6 | Int6 error amplifies 900× over 3 cycles | +| Weight sparsification | 3 | Zeroing weights INCREASES artifact size (#1048) | +| Byte-level model | 5 | Far too slow for 600s training budget | +| Standalone n-gram (no neural) | 5 | 1.51 BPB with correct normalization — worse than neural | + +--- + +## Top 3: Proof-of-Concept Code + +### POC 1: Complementary Training + BackoffNgramMixer + +*(Full code stubs provided in Angle 5 and Angle 7 analysis above)* + +**Smoke test plan (M1, 100 steps):** +1. Pre-compute bigram statistics from first training shard +2. Modify `train_gpt_mlx_kl.py` loss to use `complementary_loss` +3. Train 100 steps, compare train_loss vs baseline +4. Expected: slightly higher train_loss (we're down-weighting easy tokens) but model learns harder patterns + +**Integration with existing stack:** +- Replace BigramHash with EngramLite in model init +- Add `--complement-alpha 0.5` flag +- Pre-compute bigram stats during data loading (one-time cost) +- At eval time, wrap sliding-window eval with BackoffNgramMixer + +### POC 2: EngramLite Gated Multi-Head Hash + +*(Full code stub provided in Angle 5 analysis above)* + +**Smoke test plan (M1, 100 steps):** +1. Replace `BigramHashEmbedding` with `EngramLiteEmbedding` in model +2. Config: hash_size=2048, embed_dim=128, output_dim=1024, n_heads=2, orders=(2,3) +3. Train 100 steps, compare train_loss vs BigramHash baseline +4. Expected: comparable or slightly better loss (gating suppresses noise) + +**Key implementation notes:** +- The gating mechanism is essential — without it, trigrams hurt (#609) +- Multi-head averaging reduces hash collision noise +- Parameter budget: 2 tables × 2048 × 128 × 2 heads + projection (128×1024) ≈ 1.2M params + - In int6+zstd: ~0.9MB — fits within 16MB budget easily + +### POC 3: Skip-Gram Hash + Shallow Recurrence Combo + +*(Skip-gram code stub provided in Angle 5. Shallow recurrence below.)* + +```python +# PROOF-OF-CONCEPT: Shallow Recurrence with Per-Pass Scalars +# Modification to GPT model in train_gpt_mlx_kl.py +# +# Key idea: Repeat layers 4 and 5 once each (11 → 13 virtual layers) +# with per-pass learnable scalar multipliers. + +class ShallowRecurrentGPT: + """ + Modification to existing GPT class. + + Original: layers 0,1,2,3,4,5,6,7,8,9,10 (11 layers) + Modified: layers 0,1,2,3,4,5,4',5',6,7,8,9,10 (13 virtual, 11 unique) + + Layers 4' and 5' reuse weights from layers 4 and 5 but with + per-pass learnable scalars that differentiate the passes. + + Cost: ~2K extra parameters (scale + shift per layer per pass) + Benefit: ~70% of independent 12L quality gain + + CRITICAL: Only 2 loops (1 repeat). 3+ loops cause 900× quant error + amplification at int6. + """ + + def __init__(self, config): + # ... (standard init) ... + + # Per-pass scalars for recurrent layers (pseudocode — actual impl + # would use nn.Module parameters for gradient tracking) + self.recur_layers = [4, 5] # Which layers to repeat + self.n_passes = 2 # Original + 1 repeat + + # Learnable scale per pass (FiLM-lite) + # 2 recurrent layers × 1 repeat = 2 learnable scalars + # In real implementation: self.pass_scales = {key: nn.Parameter(mx.array(0.9))} + self.pass_scales = {} + for layer_idx in self.recur_layers: + for pass_idx in range(self.n_passes): + key = f'layer{layer_idx}_pass{pass_idx}' + if pass_idx == 0: + self.pass_scales[key] = 1.0 # Fixed (original pass) + else: + self.pass_scales[key] = 0.9 # Learnable (repeated pass) + + def forward(self, x): + """ + Forward pass with shallow recurrence. + + Instead of: 0 → 1 → 2 → 3 → 4 → 5 → 6 → 7 → 8 → 9 → 10 + We do: 0 → 1 → 2 → 3 → 4 → 5 → 4' → 5' → 6 → 7 → 8 → 9 → 10 + + Where 4' means layer 4 weights with pass_scales['layer4_pass1'] + """ + # Encoder layers (0 through num_encoder-1) + for i in range(self.num_encoder): + x = self.layers[i](x) + + x0 = x # Skip connection source + + # Decoder layers with recurrence + virtual_layer_order = [] + for i in range(self.num_encoder, self.num_layers): + virtual_layer_order.append((i, 0)) # (layer_idx, pass_idx) + if i in self.recur_layers: + virtual_layer_order.append((i, 1)) # Repeat + + for layer_idx, pass_idx in virtual_layer_order: + scale = self.pass_scales.get( + f'layer{layer_idx}_pass{pass_idx}', 1.0 + ) + + # Standard block forward with scaling + block_out = self.layers[layer_idx](x, x0) + + if pass_idx > 0: + # For repeated passes, apply dampened residual + x = x + scale * (block_out - x) + else: + x = block_out + + return x +``` + +**Smoke test plan (M1, 100 steps):** +1. Modify GPT forward to add shallow recurrence on layers 4,5 +2. Add SkipGramHashEmbedding alongside existing BigramHash +3. Train 100 steps, compare train_loss +4. Expected: slightly slower per step (~5-10% from extra 2 layer passes) but better loss per step + +--- + +## Implementation Details for Each Idea + +| Idea | Difficulty | BPB Est. | Risk | Compatible? | Dependencies | +|------|-----------|----------|------|-------------|-------------| +| BackoffNgramMixer | Moderate-High | 0.05-0.15 | Low/Moderate | Yes (eval-only) | numpy | +| Complementary Training | Low | 0.01-0.03 | Low | Yes | Pre-computed bigram stats | +| EngramLite | Moderate | 0.003-0.008 | Low | Yes (replaces BigramHash) | None | +| Shallow Recurrence | Low | 0.003-0.008 | Low | Yes (model arch change) | None | +| Skip-gram Hash | Low | 0.005-0.015 | Moderate | Yes (additive) | None | +| NuMuon | Low | 0.002-0.006 | Moderate | Yes (optimizer swap) | None | +| Byte-weighted CE | Very Low | 0.001-0.003 | High | Yes (loss change) | tokenizer stats | +| Custom entropy coding | Moderate | 0.001-0.003 | Moderate | Yes (post-training) | ANS library | + +--- + +## THE MOONSHOT + +### Complementary Training + EngramLite + BackoffNgramMixer: The Integrated Stack + +**Status: IMPLEMENTED** in `train_gpt_mlx_kl.py` (April 2026). + +**Env vars for full moonshot run (8×H100):** +``` +ENGRAM_LITE_ENABLED=1 COMPLEMENT_ALPHA=0.5 NGRAM_MIXER_ENABLED=1 NGRAM_ALPHA=0.25 NGRAM_MAX_ORDER=4 +``` + +**Smoke test (M1, 100 steps):** +``` +RUN_ID=moonshot_test ITERATIONS=100 TRAIN_BATCH_TOKENS=8192 VAL_LOSS_EVERY=0 VAL_BATCH_SIZE=8192 \ +WARMUP_STEPS=3 ENGRAM_LITE_ENABLED=1 COMPLEMENT_ALPHA=0.5 NGRAM_MIXER_ENABLED=1 EVAL_MODE=standard \ +python3 train_gpt_mlx_kl.py +``` + +**Why this is the single best bet nobody has fully combined:** + +The top competition results reveal three independent discoveries that, when properly integrated, form a system greater than the sum of its parts: + +1. **EngramLite (#1089):** Gated multi-head hashing makes n-gram features trainable end-to-end, fixing the TrigramHash failure (#609). This is the TRAINING-TIME component — the neural model learns to use n-gram context efficiently. + +2. **Complementary Training (#803):** Down-weighting tokens predictable by n-grams forces the neural model to specialize. This creates maximum COMPLEMENTARITY between neural and n-gram components. Without this, the neural model wastes capacity re-learning patterns the n-gram cache will handle at eval time. + +3. **BackoffNgramMixer (#1094):** A correctly-normalized eval-time n-gram cache with entropy-adaptive mixing. This is the EVAL-TIME component that adds 0.05-0.15 BPB improvement at zero artifact cost. + +**Why nobody has combined all three:** +- EngramLite is new (#1089, March 29) +- Complementary Training is new (#803, March 25) +- BackoffNgramMixer compliance was only clarified March 27 +- The three ideas emerged from different teams in different weeks + +**The integrated system:** + +``` +TRAINING: + 1. Pre-compute bigram/trigram stats from training data + 2. Train with EngramLite (gated bigram+trigram hash, replaces BigramHash) + 3. Use Complementary Training loss (down-weight easy n-gram tokens) + 4. Standard stack: 11L, 512d, 3×MLP, XSA, EMA, GPTQ-lite, etc. + 5. Result: neural model specialized for what n-grams can't predict + +EVAL: + 1. Load quantized model (standard sliding-window) + 2. For each token: + a. Score with neural model → neural_log_probs + b. Score with BackoffNgramMixer (orders 2-7) → ngram_probs + c. Entropy-adaptive alpha: high neural uncertainty → trust n-grams more + d. Mix: P = (1-alpha) * P_neural + alpha * P_ngram + e. Record log(P[true_token]) + f. Update n-gram cache (backward-looking) + 3. Result: complementary predictions from specialized components +``` + +**Expected BPB:** +- Baseline pure neural SOTA: 1.1086 (#1089) +- EngramLite + Complementary Training: ~1.10-1.11 (saving capacity for hard tokens) +- + BackoffNgramMixer at eval: ~0.95-1.05 +- + Skip-gram hash + shallow recurrence: ~0.92-1.00 + +**Why this could leapfrog the field:** +1. Nobody has done Complementary Training + EngramLite together (complementarity is maximized) +2. The BackoffNgramMixer is free (zero artifact cost) and additive +3. The neural model is BETTER at the tokens that matter because it doesn't waste capacity on n-gram-predictable tokens +4. This is principled compression theory: multiple experts with complementary specializations, mixed with learned weights + +**Risk factors:** +- Compliance: BackoffNgramMixer must produce correctly normalized full-vocabulary distributions +- The eval-time n-gram cache quality depends on validation set repetitiveness +- Complementary Training requires careful alpha tuning (too aggressive → model loses basic competence) +- EngramLite's parameter budget must be managed (hash tables compete with model capacity) + +**Estimated floor:** Even without the BackoffNgramMixer (which has compliance risk), EngramLite + Complementary Training + the full existing stack should reach ~1.10 BPB on pure neural — matching the current pending best with a principled path forward. + +--- + +## Honest Assessment: What Won't Work + +1. **Non-uniform quantization** — MSE ≠ artifact size. Killed by competition data (#1048, #316). +2. **Hypernetworks** — Information-theoretically impossible at required compression ratios. +3. **Deep recurrence (3+ loops)** — Int6 error amplification is a fundamental constraint. +4. **Byte-level models** — Too slow for 600s training. H-Net proved this at 1.90 BPB. +5. **Standalone n-gram replacement for neural** — 1.51 BPB with correct normalization (#978). +6. **Byte-weighted CE as a major lever** — The effect is real but ~0.001-0.003 BPB, below noise. +7. **Knowledge distillation** — 11ms/step I/O overhead fatal at 600s (#1029). +8. **Weight sparsification** — Increases artifact size, doesn't decrease it (#1048). + +The only genuine insights in this analysis are: +1. **The Complementary Training + EngramLite + BackoffNgramMixer integrated stack** (Moonshot) +2. **Skip-gram hashing** as a genuinely untried extension of the proven hash embedding approach +3. **The "MSE ≠ artifact size" principle** that eliminates entire categories of ideas + +Everything else is either already known to the competition community or below the significance threshold. diff --git a/train_gpt_mlx_kl.py b/train_gpt_mlx_kl.py index 91735eaad5..b362f7fa75 100644 --- a/train_gpt_mlx_kl.py +++ b/train_gpt_mlx_kl.py @@ -1,21 +1,14 @@ #!/usr/bin/env python3 +"""KaiLean's Parameter Golf script — GPT training with int6 QAT, EMA, BigramHash, +EngramLite, SmearGate, XSA, complementary training, BackoffNgramMixer, and LoRA TTT.""" from __future__ import annotations - -import glob -import json -import math -import os -import pickle -import sys -import time -import uuid +import glob, json, math, os, pickle, sys, time, uuid, copy +import zstandard +from collections import defaultdict from collections.abc import Callable from pathlib import Path - import numpy as np import sentencepiece as spm -import zstandard as zstd - import mlx.core as mx import mlx.nn as nn import mlx.optimizers as optim @@ -23,7 +16,6 @@ COMPUTE_DTYPE = mx.bfloat16 - class Hyperparameters: data_path: str = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") tokenizer_path: str = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") @@ -33,14 +25,14 @@ class Hyperparameters: iterations: int = int(os.environ.get("ITERATIONS", 20_000)) val_loss_every: int = int(os.environ.get("VAL_LOSS_EVERY", 0)) val_batch_size: int = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) - train_log_every: int = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + train_log_every: int = int(os.environ.get("TRAIN_LOG_EVERY", 50)) train_batch_tokens: int = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) grad_accum_steps: int = int(os.environ.get("GRAD_ACCUM_STEPS", 8)) - train_seq_len: int = int(os.environ.get("TRAIN_SEQ_LEN", os.environ.get("TRAIN_MAX_SEQ_LEN", 1024))) + train_seq_len: int = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) mlx_max_microbatch_tokens: int = int(os.environ.get("MLX_MAX_MICROBATCH_TOKENS", 8_192)) mlx_eager_eval: bool = bool(int(os.environ.get("MLX_EAGER_EVAL", "1"))) warmup_steps: int = int(os.environ.get("WARMUP_STEPS", 20)) - warmdown_iters: int = int(os.environ.get("WARMDOWN_ITERS", 3_500)) + warmdown_iters: int = int(os.environ.get("WARMDOWN_ITERS", 3500)) max_wallclock_seconds: float = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) vocab_size: int = int(os.environ.get("VOCAB_SIZE", 1024)) @@ -49,46 +41,65 @@ class Hyperparameters: num_heads: int = int(os.environ.get("NUM_HEADS", 8)) num_kv_heads: int = int(os.environ.get("NUM_KV_HEADS", 4)) mlp_mult: int = int(os.environ.get("MLP_MULT", 3)) - tie_embeddings: bool = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + tie_embeddings: bool = True tied_embed_init_std: float = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) logit_chunk_tokens: int = int(os.environ.get("LOGIT_CHUNK_TOKENS", 0)) logit_softcap: float = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) rope_base: float = float(os.environ.get("ROPE_BASE", 10000.0)) qk_gain_init: float = float(os.environ.get("QK_GAIN_INIT", 1.5)) - + eval_seq_len: int = int(os.environ.get("EVAL_SEQ_LEN", 2048)) + eval_stride: int = int(os.environ.get("EVAL_STRIDE", 64)) + eval_batch_seqs: int = int(os.environ.get("EVAL_BATCH_SEQS", 32)) + bigram_hash_size: int = int(os.environ.get("BIGRAM_HASH_SIZE", 16384)) + qat_start_frac: float = float(os.environ.get("QAT_START_FRAC", 0.15)) + ema_decay: float = float(os.environ.get("EMA_DECAY", 0.995)) + ema_start_frac: float = float(os.environ.get("EMA_START_FRAC", 0.5)) + use_ortho_init: bool = bool(int(os.environ.get("USE_ORTHO_INIT", "1"))) + use_swa: bool = bool(int(os.environ.get("USE_SWA", "0"))) + swa_decay: float = float(os.environ.get("SWA_DECAY", "0.4")) + smear_enabled: bool = bool(int(os.environ.get("USE_SMEARGATE", os.environ.get("SMEAR_ENABLED", "1")))) + rope_dims: int = int(os.environ.get("ROPE_DIMS", 16)) + ln_scale_enabled: bool = bool(int(os.environ.get( + "LN_SCALE_ENABLED", os.environ.get("USE_LN_SCALE", "1")))) + xsa_last_n: int = int(os.environ.get("XSA_LAST_N", 4)) + engram_lite_enabled: bool = bool(int(os.environ.get("ENGRAM_LITE_ENABLED", "0"))) + engram_hash_size: int = int(os.environ.get("ENGRAM_HASH_SIZE", "2048")) + engram_embed_dim: int = int(os.environ.get("ENGRAM_EMBED_DIM", "128")) + engram_n_heads: int = int(os.environ.get("ENGRAM_N_HEADS", "2")) + skipgram_hash_size: int = int(os.environ.get("SKIPGRAM_HASH_SIZE", "0")) + complement_alpha: float = float(os.environ.get("COMPLEMENT_ALPHA", "0.0")) + ngram_mixer_enabled: bool = bool(int(os.environ.get("NGRAM_MIXER_ENABLED", "0"))) + ngram_alpha: float = float(os.environ.get("NGRAM_ALPHA", "0.25")) + ngram_max_order: int = int(os.environ.get("NGRAM_MAX_ORDER", "4")) + eval_mode: str = os.environ.get("EVAL_MODE", "sliding") + ttt_enabled: bool = bool(int(os.environ.get("TTT_ENABLED", "0"))) + ttt_rank: int = int(os.environ.get("TTT_RANK", 4)) + ttt_lr: float = float(os.environ.get("TTT_LR", 0.001)) + ttt_steps: int = int(os.environ.get("TTT_STEPS", 2)) beta1: float = float(os.environ.get("BETA1", 0.9)) beta2: float = float(os.environ.get("BETA2", 0.95)) adam_eps: float = float(os.environ.get("ADAM_EPS", 1e-8)) - tied_embed_lr: float = float(os.environ.get("TIED_EMBED_LR", 0.03)) - matrix_lr: float = float(os.environ.get("MATRIX_LR", 0.02)) - scalar_lr: float = float(os.environ.get("SCALAR_LR", 0.02)) - muon_momentum: float = float(os.environ.get("MUON_MOMENTUM", 0.99)) + tied_embed_lr: float = float(os.environ.get("TIED_EMBED_LR", 0.05)) + matrix_lr: float = float(os.environ.get("MATRIX_LR", 0.04)) + scalar_lr: float = float(os.environ.get("SCALAR_LR", 0.04)) + muon_momentum: float = float(os.environ.get("MUON_MOMENTUM", 0.95)) muon_backend_steps: int = int(os.environ.get("MUON_BACKEND_STEPS", 5)) - muon_momentum_warmup_start: float = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) - muon_momentum_warmup_steps: int = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1_500)) - grad_clip_norm: float = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) - muon_wd: float = float(os.environ.get("MUON_WD", 0.04)) - adam_wd: float = float(os.environ.get("ADAM_WD", 0.04)) - - bigram_hash_size: int = int(os.environ.get("BIGRAM_HASH_SIZE", 16_384)) - use_ortho_init: bool = bool(int(os.environ.get("USE_ORTHO_INIT", "1"))) - ema_decay: float = float(os.environ.get("EMA_DECAY", 0.995)) - ema_start_frac: float = float(os.environ.get("EMA_START_FRAC", 0.5)) + muon_momentum_warmup_start: float = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps: int = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + muon_weight_decay: float = float(os.environ.get("MUON_WEIGHT_DECAY", 0.04)) + adam_weight_decay: float = float(os.environ.get("ADAM_WEIGHT_DECAY", 0.04)) late_qat_threshold: float = float(os.environ.get("LATE_QAT_THRESHOLD", 0.15)) - learned_scales: bool = bool(int(os.environ.get("LEARNED_SCALES", "1"))) - eval_seq_len: int = int(os.environ.get("EVAL_SEQ_LEN", 2048)) - eval_stride: int = int(os.environ.get("EVAL_STRIDE", 64)) + use_gptq_lite: bool = bool(int(os.environ.get("USE_GPTQ_LITE", "1"))) + grad_clip_norm: float = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) out_dir: str = os.environ.get("OUT_DIR", "logs") @property def train_files(self) -> str: return f"{self.data_path}/fineweb_train_*.bin" - @property def val_files(self) -> str: return f"{self.data_path}/fineweb_val_*.bin" - @property def microbatch_tokens(self) -> int: return self.train_batch_tokens // self.grad_accum_steps @@ -106,68 +117,15 @@ def lr_mul(self, step: int, elapsed_ms: float) -> float: remaining_ms = max(1000.0 * self.max_wallclock_seconds - elapsed_ms, 0.0) return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 - -CONTROL_TENSOR_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "CONTROL_TENSOR_NAME_PATTERNS", - "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,log_qscale", - ).split(",") - if pattern +CONTROL_TENSOR_NAME_PATTERNS = ( + "attn_scale", "attn_scales", "mlp_scale", "mlp_scales", + "resid_mix", "resid_mixes", "q_gain", "skip_weight", "skip_weights", + "smear", ) -INT6_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "INT6_KEEP_FLOAT_FP32_NAME_PATTERNS", - ",".join(CONTROL_TENSOR_NAME_PATTERNS), - ).split(",") - if pattern -) -INT6_KEEP_FLOAT_MAX_NUMEL = 65_536 -INT6_KEEP_FLOAT_STORE_DTYPE = np.float16 -INT6_PER_ROW_SCALE_DTYPE = np.float16 -MX_DTYPE_FROM_NAME = { - "float32": mx.float32, - "float16": mx.float16, - "bfloat16": mx.bfloat16, -} - - -def token_chunks(total_tokens: int, seq_len: int, max_chunk_tokens: int) -> list[int]: - usable_total = (total_tokens // seq_len) * seq_len - if usable_total <= 0: - raise ValueError(f"token budget too small for seq_len={seq_len}") - usable_chunk = max((max_chunk_tokens // seq_len) * seq_len, seq_len) - chunks: list[int] = [] - remaining = usable_total - while remaining > 0: - chunk = min(remaining, usable_chunk) - chunks.append(chunk) - remaining -= chunk - return chunks - - -def accumulate_flat_grads( - accum: dict[str, mx.array] | None, - grads_tree: dict, - scale: float, -) -> dict[str, mx.array]: - flat = dict(tree_flatten(grads_tree)) - if accum is None: - return {k: g * scale for k, g in flat.items()} - for k, g in flat.items(): - accum[k] = accum[k] + g * scale - return accum - - -def _np_float32(arr: mx.array) -> np.ndarray: - return np.array(arr.astype(mx.float32), dtype=np.float32, copy=False) - def rms_norm(x: mx.array, eps: float = 1e-6) -> mx.array: return (x * mx.rsqrt(mx.mean(x * x, axis=-1, keepdims=True) + eps)).astype(x.dtype) - def zeropower_newtonschulz5(g: mx.array, steps: int, eps: float = 1e-7) -> mx.array: a, b, c = 3.4445, -4.7750, 2.0315 x = g.astype(mx.float32) @@ -183,36 +141,6 @@ def zeropower_newtonschulz5(g: mx.array, steps: int, eps: float = 1e-7) -> mx.ar x = x.T return x.astype(g.dtype) - -def fake_quant_int6(w: mx.array) -> mx.array: - scale = mx.maximum(mx.max(mx.abs(w), axis=-1, keepdims=True) / 31.0, mx.array(1e-8, dtype=mx.float32)) - w_q = mx.clip(mx.round(w / scale), -32, 31) * scale - return w + mx.stop_gradient(w_q - w) - - -def pack_int6_np(arr_int8: np.ndarray) -> np.ndarray: - flat = (arr_int8.reshape(-1).astype(np.int16) + 32).astype(np.uint8) - pad_len = (4 - len(flat) % 4) % 4 - if pad_len: - flat = np.concatenate([flat, np.zeros(pad_len, dtype=np.uint8)]) - a, b, c, d = flat[0::4], flat[1::4], flat[2::4], flat[3::4] - b0 = (a << 2) | (b >> 4) - b1 = ((b & 0x0F) << 4) | (c >> 2) - b2 = ((c & 0x03) << 6) | d - return np.ascontiguousarray(np.stack([b0, b1, b2], axis=-1).reshape(-1)) - - -def unpack_int6_np(packed: np.ndarray, numel: int) -> np.ndarray: - packed = packed.reshape(-1, 3) - b0, b1, b2 = packed[:, 0], packed[:, 1], packed[:, 2] - a = b0 >> 2 - b = ((b0 & 0x03) << 4) | (b1 >> 4) - c = ((b1 & 0x0F) << 2) | (b2 >> 6) - d = b2 & 0x3F - flat = np.stack([a, b, c, d], axis=-1).reshape(-1)[:numel] - return (flat.astype(np.int16) - 32).astype(np.int8) - - def load_data_shard(path: Path) -> np.ndarray: header_bytes = 256 * np.dtype(" np.ndarray: if path.stat().st_size != header_bytes + num_tokens * token_bytes: raise ValueError(f"Shard size mismatch for {path}") tokens = np.fromfile(path, dtype=" None: self.file_idx = (self.file_idx + 1) % len(self.files) if self.file_idx == 0: self.epoch += 1 - if self.log_fn is not None: - self.log_fn( - f"WARNING: starting epoch:{self.epoch} " - f"dataset:{self.dataset_name} train_shards:{len(self.files)}" - ) + if self.log_fn: + self.log_fn(f"WARNING: starting epoch:{self.epoch} dataset:{self.dataset_name}") self.tokens = load_data_shard(self.files[self.file_idx]) self.pos = 0 def take(self, n: int) -> np.ndarray: - chunks: list[np.ndarray] = [] + chunks = [] left = n while left > 0: if self.pos >= self.tokens.size: @@ -269,14 +186,8 @@ def take(self, n: int) -> np.ndarray: left -= k return chunks[0] if len(chunks) == 1 else np.concatenate(chunks, axis=0) - class TokenLoader: - def __init__( - self, - pattern: str, - log_fn: Callable[[str], None] | None = None, - dataset_name: str = "", - ): + def __init__(self, pattern: str, log_fn=None, dataset_name: str = ""): self.stream = TokenStream(pattern, log_fn=log_fn, dataset_name=dataset_name) def next_batch(self, batch_tokens: int, seq_len: int) -> tuple[mx.array, mx.array]: @@ -288,94 +199,225 @@ def next_batch(self, batch_tokens: int, seq_len: int) -> tuple[mx.array, mx.arra y = chunk[1:].reshape(-1, seq_len) return mx.array(x, dtype=mx.int32), mx.array(y, dtype=mx.int32) - -class CastedLinear(nn.Module): - def __init__(self, in_dim: int, out_dim: int): +class BigramHashEmbedding(nn.Module): + def __init__(self, hash_size: int, dim: int): super().__init__() - base = nn.Linear(in_dim, out_dim, bias=False) - self.weight = base.weight.astype(mx.float32) - init_scale = mx.maximum( - mx.max(mx.abs(self.weight), axis=-1, keepdims=True) / 31.0, - mx.array(1e-5, dtype=mx.float32), - ) - self.log_qscale = mx.log(init_scale) - self.use_qat = False - self.use_learned_scales = False + self.hash_size = hash_size + self.table = nn.Embedding(hash_size, dim) + self.table.weight = self.table.weight * 0.02 - def __call__(self, x: mx.array) -> mx.array: - if self.use_qat: - if self.use_learned_scales: - s = mx.clip(mx.exp(self.log_qscale), 1e-5, 1.0) - w_q = mx.clip(mx.round(self.weight / s), -32, 31) * s - w = self.weight + mx.stop_gradient(w_q - self.weight) - else: - w = fake_quant_int6(self.weight) + def __call__(self, tokens: mx.array) -> mx.array: + """tokens: (B, T) int32 → bigram embeddings: (B, T, dim)""" + t_prev = tokens[:, :-1] + t_curr = tokens[:, 1:] + idx = mx.remainder(t_prev * 31337 + t_curr, self.hash_size) + bigram_emb = self.table(idx) + pad = mx.zeros((tokens.shape[0], 1, bigram_emb.shape[-1]), dtype=bigram_emb.dtype) + return mx.concatenate([pad, bigram_emb], axis=1) + +class EngramLiteEmbedding(nn.Module): + def __init__(self, hash_size: int = 2048, embed_dim: int = 128, + output_dim: int = 1024, n_heads: int = 2, + orders: tuple = (2, 3)): + super().__init__() + self.hash_size = hash_size + self.embed_dim = embed_dim + self.output_dim = output_dim + self.n_heads = n_heads + self.orders = list(orders) + _all_primes = [31337, 59999, 73721, 97531] + if n_heads > len(_all_primes): + raise ValueError(f"EngramLiteEmbedding: n_heads={n_heads} exceeds max supported ({len(_all_primes)})") + self._primes = _all_primes[:n_heads] + self.tables = { + f"order_{o}": nn.Embedding(hash_size, embed_dim) + for o in orders + } + for tbl in self.tables.values(): + tbl.weight = tbl.weight * 0.01 + self.proj = nn.Linear(embed_dim, output_dim, bias=False) + self.proj.weight = self.proj.weight * 0.01 + self.gate_proj = nn.Linear(embed_dim, len(orders), bias=True) + self.gate_proj.bias = mx.full((len(orders),), -2.0) + self.gate_proj.weight = self.gate_proj.weight * 0.01 + + def _hash_ngram(self, tokens: mx.array, order: int, head_idx: int): + """Hash n-gram context for given order and hash head.""" + prime = self._primes[head_idx] + if order == 2: + t_prev = tokens[:, :-1] + t_curr = tokens[:, 1:] + idx = mx.remainder(t_prev * prime + t_curr, self.hash_size) + valid_start = 1 + elif order == 3: + t_prev2 = tokens[:, :-2] + t_prev1 = tokens[:, 1:-1] + t_curr = tokens[:, 2:] + idx = mx.remainder( + t_prev2 * (prime * prime) + t_prev1 * prime + t_curr, + self.hash_size + ) + valid_start = 2 else: - w = self.weight - return x @ w.astype(x.dtype).T + raise ValueError(f"n-gram order {order} not supported") + return idx, valid_start - -class BigramHashEmbedding(nn.Module): - def __init__(self, hash_size: int, vocab_size: int): + def __call__(self, tokens: mx.array) -> mx.array: + """tokens: (B, T) → (B, T, output_dim) additive logit bias""" + B, T = tokens.shape + combined = mx.zeros((B, T, self.embed_dim), dtype=mx.float32) + for order in self.orders: + tbl = self.tables[f"order_{order}"] + head_sum = None + for hi in range(self.n_heads): + idx, valid_start = self._hash_ngram(tokens, order, hi) + emb = tbl(idx).astype(mx.float32) + pad = mx.zeros((B, valid_start, self.embed_dim), dtype=mx.float32) + emb = mx.concatenate([pad, emb], axis=1) + head_sum = emb if head_sum is None else head_sum + emb + combined = combined + head_sum / self.n_heads + gate = mx.sigmoid(self.gate_proj(combined)) + gate_scalar = gate.mean(axis=-1, keepdims=True) + return self.proj(combined) * gate_scalar + +class SkipGramHashEmbedding(nn.Module): + def __init__(self, hash_size: int = 4096, dim: int = 1024, + skip_patterns: list = None): super().__init__() self.hash_size = hash_size - self.table = nn.Embedding(hash_size, vocab_size) - self.table.weight = (mx.random.normal(self.table.weight.shape, dtype=mx.float32) * 0.02).astype(mx.float32) + self.dim = dim + self.skip_patterns = skip_patterns if skip_patterns is not None else [[-1, -3], [-1, -5], [-2, -4]] + self.tables = { + f"skip_{i}": nn.Embedding(hash_size, dim) + for i in range(len(self.skip_patterns)) + } + for tbl in self.tables.values(): + tbl.weight = tbl.weight * 0.01 def __call__(self, tokens: mx.array) -> mx.array: - prev = tokens[:, :-1] - curr = tokens[:, 1:] - idx = mx.remainder(prev * 31337 + curr, self.hash_size) - bias = self.table(idx) - pad = mx.zeros((tokens.shape[0], 1, bias.shape[-1]), dtype=bias.dtype) - return mx.concatenate([pad, bias], axis=1) + """tokens: (B, T) → (B, T, dim) additive logit bias""" + B, T = tokens.shape + output = mx.zeros((B, T, self.dim), dtype=mx.float32) + for i, pattern in enumerate(self.skip_patterns): + tbl = self.tables[f"skip_{i}"] + min_offset = min(pattern) + valid_start = abs(min_offset) + if valid_start >= T: + continue + hash_val = mx.zeros((B, T - valid_start), dtype=mx.int32) + prime = 31337 + for offset in pattern: + start = valid_start + offset + end = T + offset + hash_val = hash_val * prime + tokens[:, start:end] + idx = mx.remainder(mx.abs(hash_val), self.hash_size) + emb = tbl(idx).astype(mx.float32) + pad = mx.zeros((B, valid_start, self.dim), dtype=mx.float32) + emb = mx.concatenate([pad, emb], axis=1) + output = output + emb + return output + +class SmearGate(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.gate = mx.full((dim,), 3.0, dtype=mx.float32) + + def __call__(self, x: mx.array) -> mx.array: + g = mx.sigmoid(self.gate).astype(x.dtype) + x_prev = mx.concatenate([mx.zeros_like(x[:, :1]), x[:, :-1]], axis=1) + return g * x + (1.0 - g) * x_prev + +def fake_quant_int6(w: mx.array) -> mx.array: + """Simulate int6 quantization during training (STE for gradients).""" + scale = mx.max(mx.abs(w), keepdims=True) / 31.0 + 1e-8 + w_q = mx.clip(mx.round(w / scale), -32, 31) * scale + return w + mx.stop_gradient(w_q - w) + +class EMABuffer: + def __init__(self, model, decay: float = 0.995): + self.decay = decay + self.shadow = {} + for k, v in tree_flatten(model.parameters()): + key = ".".join(str(p) for p in k) if isinstance(k, (list, tuple)) else k + self.shadow[key] = mx.array(v) + def update(self, model): + d = self.decay + for k, v in tree_flatten(model.parameters()): + key = ".".join(str(p) for p in k) if isinstance(k, (list, tuple)) else k + if key in self.shadow: + self.shadow[key] = d * self.shadow[key] + (1.0 - d) * v + mx.eval(list(self.shadow.values())) + + def apply(self, model): + """Replace model weights with EMA weights.""" + model.update(tree_unflatten(list(self.shadow.items()))) + + def state_dict(self): + return dict(self.shadow) + +class CastedLinear(nn.Module): + def __init__(self, in_dim: int, out_dim: int): + super().__init__() + self.weight = nn.Linear(in_dim, out_dim, bias=False).weight.astype(mx.float32) + + def __call__(self, x: mx.array, use_qat: bool = False) -> mx.array: + w = self.weight + if use_qat: + w = fake_quant_int6(w) + return x @ w.astype(x.dtype).T class RMSNormNoWeight(nn.Module): def __call__(self, x: mx.array) -> mx.array: return rms_norm(x) - class CausalSelfAttention(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - rope_base: float, - qk_gain_init: float, - ): + def __init__(self, dim: int, num_heads: int, num_kv_heads: int, + rope_base: float, qk_gain_init: float, + rope_dims: int = 0, use_xsa: bool = False): super().__init__() - if dim % num_heads != 0: - raise ValueError("model_dim must be divisible by num_heads") - if num_heads % num_kv_heads != 0: - raise ValueError("num_heads must be divisible by num_kv_heads") self.num_heads = num_heads self.num_kv_heads = num_kv_heads self.head_dim = dim // num_heads - if self.head_dim % 2 != 0: - raise ValueError("head_dim must be even for RoPE") - kv_dim = self.num_kv_heads * self.head_dim + kv_dim = num_kv_heads * self.head_dim self.c_q = CastedLinear(dim, dim) self.c_k = CastedLinear(dim, kv_dim) self.c_v = CastedLinear(dim, kv_dim) self.proj = CastedLinear(dim, dim) self.q_gain = mx.ones((num_heads,), dtype=mx.float32) * qk_gain_init - self.rope = nn.RoPE(self.head_dim, traditional=False, base=rope_base) + if rope_dims > 0: + self.rope = nn.RoPE(rope_dims, traditional=False, base=rope_base) + else: + self.rope = None self.scale = self.head_dim ** -0.5 - - def __call__(self, x: mx.array) -> mx.array: + self.use_xsa = use_xsa + + def _xsa(self, y: mx.array, v: mx.array) -> mx.array: + """XSA: subtract self-value component (PR #198).""" + B, H, T, D = y.shape + Hkv = v.shape[1] + group = H // Hkv + y_g = y.reshape(B, Hkv, group, T, D) + v_norm = v / (mx.sqrt((v * v).sum(-1, keepdims=True)) + 1e-6) + vn = v_norm[:, :, None, :, :] + proj = (y_g * vn).sum(-1, keepdims=True) * vn + return (y_g - proj).reshape(B, H, T, D) + + def __call__(self, x: mx.array, use_qat: bool = False) -> mx.array: bsz, seqlen, dim = x.shape - q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(0, 2, 1, 3) - k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(0, 2, 1, 3) - v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(0, 2, 1, 3) - q = self.rope(rms_norm(q).astype(COMPUTE_DTYPE)) - k = self.rope(rms_norm(k).astype(COMPUTE_DTYPE)) + q = self.c_q(x, use_qat).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(0, 2, 1, 3) + k = self.c_k(x, use_qat).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(0, 2, 1, 3) + v = self.c_v(x, use_qat).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(0, 2, 1, 3) + q = rms_norm(q).astype(COMPUTE_DTYPE) + k = rms_norm(k).astype(COMPUTE_DTYPE) + if self.rope is not None: + q = self.rope(q) + k = self.rope(k) q = q * self.q_gain.astype(q.dtype)[None, :, None, None] y = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale, mask="causal") - y = y.transpose(0, 2, 1, 3).reshape(bsz, seqlen, dim) - return self.proj(y) - + if self.use_xsa: + y = self._xsa(y, v) + return self.proj(y.transpose(0, 2, 1, 3).reshape(bsz, seqlen, dim), use_qat) class MLP(nn.Module): def __init__(self, dim: int, mlp_mult: int): @@ -384,92 +426,92 @@ def __init__(self, dim: int, mlp_mult: int): self.fc = CastedLinear(dim, hidden) self.proj = CastedLinear(hidden, dim) - def __call__(self, x: mx.array) -> mx.array: - x = nn.relu(self.fc(x)) - return self.proj(x * x) - + def __call__(self, x: mx.array, use_qat: bool = False) -> mx.array: + x = nn.relu(self.fc(x, use_qat)) + return self.proj(x * x, use_qat) class Block(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: int, - rope_base: float, - qk_gain_init: float, - ): + def __init__(self, dim: int, num_heads: int, num_kv_heads: int, + mlp_mult: int, rope_base: float, qk_gain_init: float, + rope_dims: int = 0, use_xsa: bool = False, + layer_idx: int = 0, use_ln_scale: bool = True): super().__init__() self.attn_norm = RMSNormNoWeight() self.mlp_norm = RMSNormNoWeight() - self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init, + rope_dims=rope_dims, use_xsa=use_xsa) self.mlp = MLP(dim, mlp_mult) self.attn_scale = mx.ones((dim,), dtype=mx.float32) self.mlp_scale = mx.ones((dim,), dtype=mx.float32) - self.resid_mix = mx.array(np.stack((np.ones((dim,), dtype=np.float32), np.zeros((dim,), dtype=np.float32)))) + self.resid_mix = mx.array(np.stack(( + np.ones((dim,), dtype=np.float32), + np.zeros((dim,), dtype=np.float32), + ))) + self.ln_scale_factor = float(1.0 / math.sqrt(layer_idx + 1)) if use_ln_scale else 1.0 - def __call__(self, x: mx.array, x0: mx.array) -> mx.array: + def __call__(self, x: mx.array, x0: mx.array, use_qat: bool = False) -> mx.array: mix = self.resid_mix.astype(x.dtype) x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 - attn_out = self.attn(self.attn_norm(x)) + s = self.ln_scale_factor + attn_out = self.attn(self.attn_norm(x).astype(COMPUTE_DTYPE) * s, use_qat) x = x + self.attn_scale.astype(x.dtype)[None, None, :] * attn_out - x = x + self.mlp_scale.astype(x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + x = x + self.mlp_scale.astype(x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x).astype(COMPUTE_DTYPE) * s, use_qat) return x - -def ortho_init_matrix(weight: mx.array) -> mx.array: - w = np.array(weight.astype(mx.float32), dtype=np.float64, copy=True) - u, _, vt = np.linalg.svd(w, full_matrices=False) - return mx.array((u @ vt).astype(np.float32, copy=False) * 0.5, dtype=mx.float32) - - class GPT(nn.Module): - def __init__( - self, - vocab_size: int, - num_layers: int, - dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: int, - logit_chunk_tokens: int, - logit_softcap: float, - rope_base: float, - tied_embed_init_std: float, - qk_gain_init: float, - bigram_hash_size: int, - use_ortho_init: bool, - ): + def __init__(self, vocab_size, num_layers, dim, num_heads, num_kv_heads, + mlp_mult, logit_chunk_tokens, logit_softcap, rope_base, + tied_embed_init_std, qk_gain_init, bigram_hash_size, + use_ortho_init, rope_dims: int = 0, xsa_last_n: int = 0, + use_ln_scale: bool = True, smear_enabled: bool = True, + engram_lite_enabled: bool = False, engram_hash_size: int = 2048, + engram_embed_dim: int = 128, engram_n_heads: int = 2, + skipgram_hash_size: int = 0): super().__init__() - if logit_softcap <= 0.0: - raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") self.logit_chunk_tokens = logit_chunk_tokens self.logit_softcap = logit_softcap - self.vocab_size = vocab_size - self.bigram_hash = BigramHashEmbedding(bigram_hash_size, vocab_size) if bigram_hash_size > 0 else None - + self.use_qat = False self.tok_emb = nn.Embedding(vocab_size, dim) self.num_encoder_layers = num_layers // 2 self.num_decoder_layers = num_layers - self.num_encoder_layers self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) self.skip_weights = mx.ones((self.num_skip_weights, dim), dtype=mx.float32) + self.smear = SmearGate(dim) if smear_enabled else None + xsa_decoder_start = max(0, self.num_decoder_layers - xsa_last_n) if xsa_last_n > 0 else self.num_decoder_layers self.blocks = [ - Block(dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init) - for _ in range(num_layers) + Block(dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init, + rope_dims=rope_dims, + use_xsa=(i >= self.num_encoder_layers + xsa_decoder_start), + layer_idx=i, use_ln_scale=use_ln_scale) + for i in range(num_layers) ] self.final_norm = RMSNormNoWeight() - + if engram_lite_enabled: + self.engram_lite = EngramLiteEmbedding( + hash_size=engram_hash_size, embed_dim=engram_embed_dim, + output_dim=vocab_size, n_heads=engram_n_heads, orders=(2, 3)) + self.bigram_hash = None + else: + self.engram_lite = None + self.bigram_hash = BigramHashEmbedding(bigram_hash_size, vocab_size) if bigram_hash_size > 0 else None + self.skipgram_hash = SkipGramHashEmbedding(hash_size=skipgram_hash_size, dim=vocab_size) if skipgram_hash_size > 0 else None + for b in self.blocks: + b.attn.proj.weight = mx.zeros_like(b.attn.proj.weight) + b.mlp.proj.weight = mx.zeros_like(b.mlp.proj.weight) if use_ortho_init: - for block in self.blocks: - block.attn.c_q.weight = ortho_init_matrix(block.attn.c_q.weight) - block.attn.c_k.weight = ortho_init_matrix(block.attn.c_k.weight) - block.attn.c_v.weight = ortho_init_matrix(block.attn.c_v.weight) - block.mlp.fc.weight = ortho_init_matrix(block.mlp.fc.weight) - for block in self.blocks: - block.attn.proj.weight = mx.zeros_like(block.attn.proj.weight) - block.mlp.proj.weight = mx.zeros_like(block.mlp.proj.weight) + for b in self.blocks: + for linear in [b.attn.c_q, b.attn.c_k, b.attn.c_v, b.mlp.fc]: + w = linear.weight + m, n = w.shape + flat = mx.random.normal((m, n)).astype(mx.float32) + u, s, vt = mx.linalg.svd(flat, stream=mx.cpu) + if m >= n: + linear.weight = (u[:, :n] * 0.5).astype(w.dtype) + else: + linear.weight = (vt[:m, :] * 0.5).astype(w.dtype) self.tok_emb.weight = ( - mx.random.normal(self.tok_emb.weight.shape, dtype=mx.float32) * tied_embed_init_std + mx.random.normal(self.tok_emb.weight.shape, dtype=mx.float32) + * tied_embed_init_std ).astype(COMPUTE_DTYPE) def softcap(self, logits: mx.array) -> mx.array: @@ -478,234 +520,261 @@ def softcap(self, logits: mx.array) -> mx.array: def __call__(self, input_ids: mx.array) -> mx.array: x = rms_norm(self.tok_emb(input_ids).astype(COMPUTE_DTYPE)) + if self.smear is not None: + x = self.smear(x) x0 = x - skips: list[mx.array] = [] + skips = [] + qat = self.use_qat for i in range(self.num_encoder_layers): - x = self.blocks[i](x, x0) + x = self.blocks[i](x, x0, qat) skips.append(x) for i in range(self.num_decoder_layers): if skips: x = x + self.skip_weights[i].astype(x.dtype)[None, None, :] * skips.pop() - x = self.blocks[self.num_encoder_layers + i](x, x0) + x = self.blocks[self.num_encoder_layers + i](x, x0, qat) return self.final_norm(x) - def forward_logits(self, input_ids: mx.array) -> mx.array: - hidden = self(input_ids) - logits = hidden @ self.tok_emb.weight.astype(hidden.dtype).T - logits = self.softcap(logits) - if self.bigram_hash is not None: - logits = logits + self.bigram_hash(input_ids).astype(logits.dtype) + def _add_logit_biases(self, logits: mx.array, input_ids: mx.array) -> mx.array: + """Add all enabled logit bias modules (BigramHash/EngramLite/SkipGram).""" + vocab = self.tok_emb.weight.shape[0] + if self.engram_lite is not None: + bias = self.engram_lite(input_ids).reshape(-1, vocab) + logits = logits + bias.astype(logits.dtype) + elif self.bigram_hash is not None: + bias = self.bigram_hash(input_ids).reshape(-1, vocab) + logits = logits + bias.astype(logits.dtype) + if self.skipgram_hash is not None: + bias = self.skipgram_hash(input_ids).reshape(-1, vocab) + logits = logits + bias.astype(logits.dtype) return logits def loss(self, input_ids: mx.array, target_ids: mx.array) -> mx.array: x = self(input_ids).reshape(-1, self.tok_emb.weight.shape[1]) y = target_ids.reshape(-1) - bigram_bias = None - if self.bigram_hash is not None: - bigram_bias = self.bigram_hash(input_ids).reshape(-1, self.vocab_size) - if self.logit_chunk_tokens <= 0 or x.shape[0] <= self.logit_chunk_tokens: - logits = self.softcap(x @ self.tok_emb.weight.astype(x.dtype).T) - if bigram_bias is not None: - logits = logits + bigram_bias.astype(logits.dtype) - return nn.losses.cross_entropy(logits.astype(mx.float32), y, reduction="mean") - loss_sum = mx.array(0.0, dtype=mx.float32) - n = int(x.shape[0]) - for s in range(0, n, self.logit_chunk_tokens): - e = min(s + self.logit_chunk_tokens, n) - logits = self.softcap(x[s:e] @ self.tok_emb.weight.astype(x.dtype).T) - if bigram_bias is not None: - logits = logits + bigram_bias[s:e].astype(logits.dtype) - loss_sum = loss_sum + nn.losses.cross_entropy(logits.astype(mx.float32), y[s:e], reduction="sum") - return loss_sum / float(n) - - -class EMABuffer: - def __init__(self, flat_params: dict[str, mx.array], decay: float): - self.decay = decay - self.dtypes = {k: v.dtype for k, v in flat_params.items()} - self.shadow = {k: _np_float32(v) for k, v in flat_params.items()} - - def update(self, flat_params: dict[str, mx.array]) -> None: - d = self.decay - for k, v in flat_params.items(): - self.shadow[k] = d * self.shadow[k] + (1.0 - d) * _np_float32(v) - - def as_mlx(self) -> dict[str, mx.array]: - return {k: mx.array(v, dtype=self.dtypes[k]) for k, v in self.shadow.items()} - - -def all_casted_linears(model: GPT) -> list[CastedLinear]: - layers: list[CastedLinear] = [] - for block in model.blocks: - layers.extend( - [ - block.attn.c_q, - block.attn.c_k, - block.attn.c_v, - block.attn.proj, - block.mlp.fc, - block.mlp.proj, - ] - ) - return layers - - -def capture_qat_flags(model: GPT) -> list[tuple[bool, bool]]: - return [(layer.use_qat, layer.use_learned_scales) for layer in all_casted_linears(model)] - - -def restore_qat_flags(model: GPT, flags: list[tuple[bool, bool]]) -> None: - for layer, (use_qat, use_learned_scales) in zip(all_casted_linears(model), flags, strict=True): - layer.use_qat = use_qat - layer.use_learned_scales = use_learned_scales - - -def set_qat_mode(model: GPT, enabled: bool, use_learned_scales: bool) -> None: - for layer in all_casted_linears(model): - layer.use_qat = enabled - layer.use_learned_scales = enabled and use_learned_scales - - -def flatten_params(model: GPT) -> dict[str, mx.array]: - return {k: v for k, v in tree_flatten(model.parameters())} - - -def update_model_params(model: GPT, flat_params: dict[str, mx.array]) -> None: - model.update(tree_unflatten(list(flat_params.items()))) - + logits = x @ self.tok_emb.weight.astype(x.dtype).T + logits = self.softcap(logits) + logits = self._add_logit_biases(logits, input_ids) + return nn.losses.cross_entropy(logits.astype(mx.float32), y, reduction="mean") -def compile_model_fns(model: GPT): - compiled_loss = mx.compile(lambda x, y: model.loss(x, y), inputs=model.state, outputs=model.state) - compiled_loss_and_grad = mx.compile( - nn.value_and_grad(model, lambda x, y: model.loss(x, y)), - inputs=model.state, - outputs=model.state, - ) - return compiled_loss, compiled_loss_and_grad + def complementary_loss(self, input_ids: mx.array, target_ids: mx.array, + bigram_probs: mx.array, alpha: float) -> mx.array: + """CE loss that down-weights tokens easily predicted by bigrams.""" + x = self(input_ids).reshape(-1, self.tok_emb.weight.shape[1]) + y = target_ids.reshape(-1) + logits = x @ self.tok_emb.weight.astype(x.dtype).T + logits = self.softcap(logits) + logits = self._add_logit_biases(logits, input_ids) + ce_per_token = nn.losses.cross_entropy(logits.astype(mx.float32), y, reduction="none") + prev_tokens = input_ids.reshape(-1) + p_bigram = bigram_probs[prev_tokens, y] + weights = 1.0 - alpha * p_bigram.astype(mx.float32) + weights = mx.clip(weights, 0.1, 1.0) + weights = weights / (weights.mean() + 1e-8) + return (ce_per_token * weights).mean() + + def token_losses(self, input_ids: mx.array, target_ids: mx.array) -> mx.array: + """Return (B, T) per-token NLL — used for sliding-window eval.""" + B, T = input_ids.shape + x = self(input_ids).reshape(-1, self.tok_emb.weight.shape[1]) + y = target_ids.reshape(-1) + logits = x @ self.tok_emb.weight.astype(x.dtype).T + logits = self.softcap(logits) + logits = self._add_logit_biases(logits, input_ids) + nll = nn.losses.cross_entropy(logits.astype(mx.float32), y, reduction="none") + return nll.reshape(B, T) + def token_logits(self, input_ids: mx.array) -> mx.array: + """Return (B, T, V) raw logits — used by BackoffNgramMixer for mixing.""" + B, T = input_ids.shape + x = self(input_ids).reshape(-1, self.tok_emb.weight.shape[1]) + logits = x @ self.tok_emb.weight.astype(x.dtype).T + logits = self.softcap(logits) + logits = self._add_logit_biases(logits, input_ids) + return logits.reshape(B, T, -1) class Muon: - def __init__(self, keys: list[str], params: dict[str, mx.array], args: Hyperparameters): + def __init__(self, keys, params, args): self.keys = keys self.args = args self.buffers = {k: mx.zeros_like(params[k]) for k in keys} - def step(self, params: dict[str, mx.array], grads: dict[str, mx.array], step: int, lr_mul: float) -> dict[str, mx.array]: + def step(self, params, grads, step, lr_mul): if self.args.muon_momentum_warmup_steps: t = min(step / self.args.muon_momentum_warmup_steps, 1.0) momentum = (1.0 - t) * self.args.muon_momentum_warmup_start + t * self.args.muon_momentum else: momentum = self.args.muon_momentum lr = self.args.matrix_lr * lr_mul - out: dict[str, mx.array] = {} + wd = self.args.muon_weight_decay + out = {} for k in self.keys: - p = params[k] - g = grads[k] + p, g = params[k], grads[k] buf = momentum * self.buffers[k] + g self.buffers[k] = buf g_eff = g + momentum * buf g_ortho = zeropower_newtonschulz5(g_eff, self.args.muon_backend_steps) scale = math.sqrt(max(1.0, float(p.shape[0]) / float(p.shape[1]))) - out[k] = p - lr * (g_ortho * scale).astype(p.dtype) - out[k] = out[k] - lr * self.args.muon_wd * p + g_with_wd = g_ortho * scale + wd * p + out[k] = p - lr * g_with_wd.astype(p.dtype) return out - -MUON_WEIGHT_SUFFIXES = ( - ".attn.c_q.weight", - ".attn.c_k.weight", - ".attn.c_v.weight", - ".attn.proj.weight", - ".mlp.fc.weight", - ".mlp.proj.weight", -) - - class SplitOptimizers: - def __init__(self, model: GPT, args: Hyperparameters): + def __init__(self, model, args): self.args = args - params = flatten_params(model) + params = dict(tree_flatten(model.parameters())) self.embed_key = "tok_emb.weight" + _module_prefixes = ( + "blocks.", "bigram_hash.", "engram_lite.", "skipgram_hash.", + ) self.matrix_keys = [ - k - for k in params - if k == "bigram_hash.table.weight" or any(k.endswith(suffix) for suffix in MUON_WEIGHT_SUFFIXES) + k for k, p in params.items() + if any(k.startswith(pfx) for pfx in _module_prefixes) + and p.ndim == 2 + and not any(pat in k for pat in CONTROL_TENSOR_NAME_PATTERNS) + ] + self.scalar_keys = [ + k for k, p in params.items() + if k == "skip_weights" or ( + any(k.startswith(pfx) for pfx in _module_prefixes) + and (p.ndim < 2 or any(pat in k for pat in CONTROL_TENSOR_NAME_PATTERNS)) + ) ] - self.scalar_keys = [k for k in params if k not in self.matrix_keys and k != self.embed_key] self.muon = Muon(self.matrix_keys, params, args) - self.adam_embed = optim.Adam( - learning_rate=args.tied_embed_lr, - betas=[args.beta1, args.beta2], - eps=args.adam_eps, - bias_correction=True, - ) - self.adam_scalar = optim.Adam( - learning_rate=args.scalar_lr, - betas=[args.beta1, args.beta2], - eps=args.adam_eps, - bias_correction=True, - ) + self.adam_embed = optim.Adam(learning_rate=args.tied_embed_lr, + betas=[args.beta1, args.beta2], eps=args.adam_eps, bias_correction=True) + self.adam_scalar = optim.Adam(learning_rate=args.scalar_lr, + betas=[args.beta1, args.beta2], eps=args.adam_eps, bias_correction=True) - def _apply_adam_weight_decay( - self, - updated: dict[str, mx.array], - params: dict[str, mx.array], - keys: list[str], - lr: float, - ) -> None: - if self.args.adam_wd <= 0.0: - return - for k in keys: - updated[k] = updated[k] - lr * self.args.adam_wd * params[k] - - def step(self, model: GPT, grads_tree: dict, step: int, lr_mul: float) -> None: - params = flatten_params(model) + def step(self, model, grads_tree, step, lr_mul): + params = dict(tree_flatten(model.parameters())) grads = dict(tree_flatten(grads_tree)) updated = dict(params) - updated.update(self.muon.step(params, grads, step=step, lr_mul=lr_mul)) - - embed_lr = self.args.tied_embed_lr * lr_mul - self.adam_embed.learning_rate = embed_lr + self.adam_embed.learning_rate = self.args.tied_embed_lr * lr_mul embed_updated = self.adam_embed.apply_gradients( {self.embed_key: grads[self.embed_key]}, {self.embed_key: params[self.embed_key]}, ) - self._apply_adam_weight_decay(embed_updated, params, [self.embed_key], embed_lr) + if self.embed_key in embed_updated: + embed_updated[self.embed_key] = embed_updated[self.embed_key] * (1.0 - self.args.adam_weight_decay * lr_mul) updated.update(embed_updated) - - scalar_lr = self.args.scalar_lr * lr_mul - self.adam_scalar.learning_rate = scalar_lr - scalar_grads = {k: grads[k] for k in self.scalar_keys} - scalar_params = {k: params[k] for k in self.scalar_keys} - scalar_updated = self.adam_scalar.apply_gradients(scalar_grads, scalar_params) - self._apply_adam_weight_decay(scalar_updated, params, self.scalar_keys, scalar_lr) - updated.update(scalar_updated) - + self.adam_scalar.learning_rate = self.args.scalar_lr * lr_mul + scalar_g = {k: grads[k] for k in self.scalar_keys if k in grads} + scalar_p = {k: params[k] for k in self.scalar_keys if k in params} + if scalar_g: + scalar_updated = self.adam_scalar.apply_gradients(scalar_g, scalar_p) + for k in scalar_updated: + scalar_updated[k] = scalar_updated[k] * (1.0 - self.args.adam_weight_decay * lr_mul) + updated.update(scalar_updated) model.update(tree_unflatten(list(updated.items()))) +INT6_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT6_KEEP_FLOAT_STORE_DTYPE = np.float16 +INT6_PER_ROW_SCALE_DTYPE = np.float16 +INT6_CLIP_Q = 99.99984 / 100.0 +MX_DTYPE_FROM_NAME = {"float32": mx.float32, "float16": mx.float16, "bfloat16": mx.bfloat16} +INT6_KEEP_FLOAT_FP32_NAME_PATTERNS = CONTROL_TENSOR_NAME_PATTERNS + +def _np_float32(arr): + return np.array(arr.astype(mx.float32), dtype=np.float32, copy=False) -def keep_float_array(name: str, arr: mx.array, passthrough_orig_dtypes: dict[str, str]) -> np.ndarray: - if any(pattern in name for pattern in INT6_KEEP_FLOAT_FP32_NAME_PATTERNS): +def keep_float_array(name, arr, passthrough_orig_dtypes): + if any(p in name for p in INT6_KEEP_FLOAT_FP32_NAME_PATTERNS): return np.ascontiguousarray(_np_float32(arr)) if arr.dtype in {mx.float32, mx.bfloat16}: passthrough_orig_dtypes[name] = str(arr.dtype).split(".")[-1] return np.ascontiguousarray(np.array(arr.astype(mx.float16), dtype=INT6_KEEP_FLOAT_STORE_DTYPE, copy=False)) return np.ascontiguousarray(np.array(arr, copy=True)) - -def quantize_state_dict_int6(flat_state: dict[str, mx.array]) -> tuple[dict[str, object], dict[str, int]]: - quantized: dict[str, np.ndarray] = {} - scales: dict[str, np.ndarray] = {} - dtypes: dict[str, str] = {} - shapes: dict[str, tuple[int, ...]] = {} - numels: dict[str, int] = {} - passthrough: dict[str, np.ndarray] = {} - passthrough_orig_dtypes: dict[str, str] = {} - stats = dict.fromkeys( - ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int6_payload_bytes"), - 0, - ) +def pack_int6(q_int8: np.ndarray) -> tuple[np.ndarray, int]: + """Pack int8 values in [-32,31] into 6-bit packed bytes (4 values per 3 bytes). + Bias to unsigned [0,63] then interleave bits across 3-byte groups.""" + orig_len = int(q_int8.size) + u = (q_int8.ravel().astype(np.int16) + 32).astype(np.uint8) # [0,63] + pad = (-len(u)) % 4 + if pad: + u = np.concatenate([u, np.zeros(pad, dtype=np.uint8)]) + u = u.reshape(-1, 4).astype(np.uint16) + out = np.empty((len(u), 3), dtype=np.uint8) + out[:, 0] = (u[:, 0] | (u[:, 1] << 6)).astype(np.uint8) + out[:, 1] = ((u[:, 1] >> 2) | (u[:, 2] << 4)).astype(np.uint8) + out[:, 2] = ((u[:, 2] >> 4) | (u[:, 3] << 2)).astype(np.uint8) + return out.ravel(), orig_len + +def unpack_int6(packed: np.ndarray, orig_len: int) -> np.ndarray: + """Reverse pack_int6: uint8 packed bytes → int8 values in [-32,31].""" + n_groups = (orig_len + 3) // 4 + p = packed.ravel()[:n_groups * 3].reshape(-1, 3).astype(np.uint16) + u = np.empty((n_groups, 4), dtype=np.uint16) + u[:, 0] = p[:, 0] & 0x3F + u[:, 1] = ((p[:, 0] >> 6) | (p[:, 1] << 2)) & 0x3F + u[:, 2] = ((p[:, 1] >> 4) | (p[:, 2] << 4)) & 0x3F + u[:, 3] = (p[:, 2] >> 2) & 0x3F + return (u.ravel()[:orig_len].astype(np.int16) - 32).astype(np.int8) + +_GPTQ_PERCENTILES = np.array([99.0, 99.5, 99.9, 99.99, 99.999]) + +def quantize_float_array_gptq_lite(arr): + """GPTQ-lite: search 5 percentiles per row to minimize MSE in int6 quantization.""" + f32 = _np_float32(arr) + if f32.ndim == 2: + n_rows = f32.shape[0] + clip_abs = np.zeros(n_rows, dtype=np.float32) + chosen_pct_idx = np.zeros(n_rows, dtype=np.int8) # index into _GPTQ_PERCENTILES + for i in range(n_rows): + row = f32[i] + best_mse, best_clip, best_idx = float('inf'), 1.0, len(_GPTQ_PERCENTILES) - 1 + abs_row = np.abs(row) + for j, pct in enumerate(_GPTQ_PERCENTILES): + c = float(np.quantile(abs_row, pct / 100.0)) if row.size else 1.0 + if c < 1e-6: + c = 1.0 + scale = c / 31.0 + q = np.clip(np.round(np.clip(row, -c, c) / scale), -32, 31).astype(np.float32) + mse = np.mean((row - q * scale) ** 2) + if mse < best_mse: + best_mse, best_clip, best_idx = mse, c, j + clip_abs[i] = best_clip + chosen_pct_idx[i] = best_idx + pct_counts = {float(_GPTQ_PERCENTILES[j]): int(np.sum(chosen_pct_idx == j)) + for j in range(len(_GPTQ_PERCENTILES))} + clipped = np.clip(f32, -clip_abs[:, None], clip_abs[:, None]) + scale = np.maximum(clip_abs / 31.0, 1.0 / 31.0).astype(np.float32) + q = np.clip(np.round(clipped / scale[:, None]), -32, 31).astype(np.int8) + packed, orig_len = pack_int6(q) + gptq_stats = {"n_rows": n_rows, "pct_counts": pct_counts} + return packed, np.ascontiguousarray(scale.astype(INT6_PER_ROW_SCALE_DTYPE)), f32.shape, orig_len, gptq_stats + clip_abs_s = float(np.quantile(np.abs(f32).reshape(-1), INT6_CLIP_Q)) if f32.size else 0.0 + scale = np.array(clip_abs_s / 31.0 if clip_abs_s > 0 else 1.0, dtype=np.float32) + q = np.clip(np.round(np.clip(f32, -clip_abs_s, clip_abs_s) / scale), -32, 31).astype(np.int8) + packed, orig_len = pack_int6(q) + return packed, scale, f32.shape, orig_len, None + +def quantize_float_array(arr): + """Quantize to int6 (range [-32,31]) with per-row float16 scales, packed 4-per-3-bytes.""" + f32 = _np_float32(arr) + if f32.ndim == 2: + clip_abs = np.quantile(np.abs(f32), INT6_CLIP_Q, axis=1) if f32.size else np.empty((f32.shape[0],), dtype=np.float32) + clipped = np.clip(f32, -clip_abs[:, None], clip_abs[:, None]) + scale = np.maximum(clip_abs / 31.0, 1.0 / 31.0).astype(np.float32) + q = np.clip(np.round(clipped / scale[:, None]), -32, 31).astype(np.int8) + packed, orig_len = pack_int6(q) + return packed, np.ascontiguousarray(scale.astype(INT6_PER_ROW_SCALE_DTYPE)), f32.shape, orig_len + clip_abs = float(np.quantile(np.abs(f32).reshape(-1), INT6_CLIP_Q)) if f32.size else 0.0 + scale = np.array(clip_abs / 31.0 if clip_abs > 0 else 1.0, dtype=np.float32) + q = np.clip(np.round(np.clip(f32, -clip_abs, clip_abs) / scale), -32, 31).astype(np.int8) + packed, orig_len = pack_int6(q) + return packed, scale, f32.shape, orig_len + +def quantize_state_dict_int6(flat_state, args=None): + """Quantize state dict to int6 with optional GPTQ-lite clip search.""" + use_gptq = args.use_gptq_lite if args else False + quant_fn = quantize_float_array_gptq_lite if use_gptq else quantize_float_array + quantized, scales, shapes, dtypes, passthrough = {}, {}, {}, {}, {} + passthrough_orig_dtypes, qmeta = {}, {} + stats = dict.fromkeys(("param_count","num_tensors","num_float_tensors","num_nonfloat_tensors","baseline_tensor_bytes","int6_payload_bytes"), 0) + gptq_total_rows = 0 + gptq_pct_counts: dict[float, int] = {} for name, arr in flat_state.items(): stats["param_count"] += int(arr.size) stats["num_tensors"] += 1 @@ -715,136 +784,222 @@ def quantize_state_dict_int6(flat_state: dict[str, mx.array]) -> tuple[dict[str, passthrough[name] = np.ascontiguousarray(np.array(arr)) stats["int6_payload_bytes"] += int(passthrough[name].nbytes) continue - if arr.ndim == 2 and int(arr.size) > INT6_KEEP_FLOAT_MAX_NUMEL: - f32 = _np_float32(arr) - log_qscale_name = f"{name[:-7]}.log_qscale" if name.endswith(".weight") else "" - if log_qscale_name and log_qscale_name in flat_state: - scale = np.clip(np.exp(_np_float32(flat_state[log_qscale_name])), 1e-5, 1.0) - else: - scale = np.maximum(np.max(np.abs(f32), axis=-1, keepdims=True) / 31.0, 1e-5) - q = np.clip(np.round(f32 / scale), -32, 31).astype(np.int8, copy=False) - packed = pack_int6_np(q) - quantized[name] = packed - scales[name] = np.ascontiguousarray(scale.astype(INT6_PER_ROW_SCALE_DTYPE, copy=False)) - dtypes[name] = str(arr.dtype).split(".")[-1] - shapes[name] = tuple(int(x) for x in f32.shape) - numels[name] = int(f32.size) - stats["num_float_tensors"] += 1 - stats["int6_payload_bytes"] += int(packed.nbytes + scales[name].nbytes) + if int(arr.size) <= INT6_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_array(name, arr, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int6_payload_bytes"] += int(kept.nbytes) continue - kept = keep_float_array(name, arr, passthrough_orig_dtypes) - passthrough[name] = kept - stats["int6_payload_bytes"] += int(kept.nbytes) - obj: dict[str, object] = { - "__quant_format__": "int6_zstd_per_row_v1", - "quantized": quantized, - "scales": scales, - "dtypes": dtypes, - "shapes": shapes, - "numels": numels, - "passthrough": passthrough, - } - if passthrough_orig_dtypes: - obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + stats["num_float_tensors"] += 1 + result = quant_fn(arr) + if use_gptq: + packed, s, orig_shape, orig_len, gptq_row_stats = result + if gptq_row_stats is not None: + gptq_total_rows += gptq_row_stats["n_rows"] + for pct, cnt in gptq_row_stats["pct_counts"].items(): + gptq_pct_counts[pct] = gptq_pct_counts.get(pct, 0) + cnt + else: + packed, s, orig_shape, orig_len = result + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = packed + scales[name] = s + shapes[name] = orig_shape + dtypes[name] = str(arr.dtype).split(".")[-1] + stats["int6_payload_bytes"] += int(packed.nbytes + s.nbytes) + if use_gptq and gptq_total_rows > 0: + stats["gptq_total_rows"] = gptq_total_rows + stats["gptq_pct_counts"] = gptq_pct_counts + obj = {"__quant_format__": "int6_packed_per_row_v1", "quantized": quantized, + "scales": scales, "shapes": shapes, "dtypes": dtypes, "passthrough": passthrough} + if qmeta: obj["qmeta"] = qmeta + if passthrough_orig_dtypes: obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes return obj, stats - -def dequantize_state_dict_int6(quant_obj: dict[str, object]) -> dict[str, mx.array]: - out: dict[str, mx.array] = {} - passthrough_orig_dtypes = quant_obj.get("passthrough_orig_dtypes", {}) +def dequantize_state_dict_int6(quant_obj): + out = {} + qmeta = quant_obj.get("qmeta", {}) + pt_dtypes = quant_obj.get("passthrough_orig_dtypes", {}) + shapes = quant_obj.get("shapes", {}) for name, packed in quant_obj["quantized"].items(): - dtype_name = quant_obj["dtypes"][name] - shape = tuple(int(x) for x in quant_obj["shapes"][name]) - numel = int(quant_obj["numels"][name]) + orig_shape = shapes[name] + orig_len = int(np.prod(orig_shape)) + q_np = unpack_int6(np.asarray(packed, dtype=np.uint8), orig_len).reshape(orig_shape) scale = np.asarray(quant_obj["scales"][name], dtype=np.float32) - q_np = unpack_int6_np(np.asarray(packed, dtype=np.uint8), numel).reshape(shape) - out_arr = q_np.astype(np.float32) * scale.reshape((shape[0],) + (1,) * (len(shape) - 1)) + dtype_name = quant_obj["dtypes"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or scale.ndim > 0: + out_arr = q_np.astype(np.float32) * scale.reshape((q_np.shape[0],) + (1,) * (q_np.ndim - 1)) + else: + out_arr = q_np.astype(np.float32) * float(scale) out[name] = mx.array(out_arr, dtype=MX_DTYPE_FROM_NAME[dtype_name]) for name, arr in quant_obj["passthrough"].items(): out_arr = np.array(arr, copy=True) - orig_dtype = passthrough_orig_dtypes.get(name) - if isinstance(orig_dtype, str): - out[name] = mx.array(out_arr, dtype=MX_DTYPE_FROM_NAME[orig_dtype]) - else: - out[name] = mx.array(out_arr) + orig = pt_dtypes.get(name) + out[name] = mx.array(out_arr, dtype=MX_DTYPE_FROM_NAME[orig]) if isinstance(orig, str) else mx.array(out_arr) return out - -def build_sentencepiece_luts( - sp: spm.SentencePieceProcessor, vocab_size: int -) -> tuple[np.ndarray, np.ndarray, np.ndarray]: +def build_bigram_stats(data_path: str, vocab_size: int = 1024) -> np.ndarray: + """Pre-compute P(next_token | prev_token) from all training shards with Laplace smoothing.""" + counts = np.zeros((vocab_size, vocab_size), dtype=np.float64) + shard_paths = sorted(glob.glob(f"{data_path}/fineweb_train_*.bin")) + for shard_path in shard_paths: + tokens = load_data_shard(Path(shard_path)) + prev = tokens[:-1].astype(np.int32) + curr = tokens[1:].astype(np.int32) + mask = (prev < vocab_size) & (curr < vocab_size) + np.add.at(counts, (prev[mask], curr[mask]), 1.0) + counts += 1.0 + row_sums = counts.sum(axis=1, keepdims=True) + return (counts / row_sums).astype(np.float32) + +class BackoffNgramMixer: + """Causal n-gram LM with linear-interpolation backoff for eval-time mixing.""" + + def __init__(self, vocab_size: int = 1024, max_order: int = 4, + hash_buckets: int = 2_000_000, # ~2M buckets ≈ 16 collisions at 32M tokens + alpha_mode: str = "entropy_adaptive", + fixed_alpha: float = 0.25): + self.vocab_size = vocab_size + self.max_order = max_order + self.hash_buckets = hash_buckets + self.alpha_mode = alpha_mode + self.fixed_alpha = fixed_alpha + self._reset() + + def _reset(self): + """Clear all count tables — call before each new eval pass.""" + self._counts = [ + defaultdict(lambda: np.zeros(self.vocab_size, dtype=np.float32)) + for _ in range(self.max_order + 1) + ] + self._total = [defaultdict(float) for _ in range(self.max_order + 1)] + + def _hash_ctx(self, context_tokens) -> int: + h = 0 + for t in context_tokens: + h = (h * 31337 + int(t)) % self.hash_buckets + return h + + def _ngram_probs(self, context_tokens) -> np.ndarray: + """Interpolated n-gram distribution P(· | context). Sums to 1.""" + V = self.vocab_size + probs = np.ones(V, dtype=np.float64) / V # uniform prior (order 0) + for order in range(1, self.max_order + 1): + if len(context_tokens) < order: + break + ctx_hash = self._hash_ctx(context_tokens[-order:]) + total = self._total[order][ctx_hash] + if total <= 0.0: + continue + lam = total / (total + 5.0) + c = self._counts[order][ctx_hash].astype(np.float64) + order_probs = (c + 1e-10) / (total + 1e-10 * V) + order_probs /= order_probs.sum() + probs = (1.0 - lam) * probs + lam * order_probs + s = probs.sum() + if s > 0: + probs /= s + return probs.astype(np.float32) + + def _mixing_alpha(self, neural_logits: np.ndarray) -> float: + """Entropy-adaptive mixing weight α ∈ [0.15, 0.60].""" + if self.alpha_mode == "fixed": + return self.fixed_alpha + logits = neural_logits.astype(np.float64) + logits -= logits.max() + probs = np.exp(logits) + probs /= probs.sum() + entropy = float(-np.sum(probs * np.log2(probs + 1e-10))) + max_entropy = math.log2(self.vocab_size) + normalized = entropy / max_entropy + return 0.15 + 0.45 * normalized # α ∈ [0.15, 0.60]: min trust even when confident; max when fully uncertain + + def score_and_update(self, context_tokens, target_token: int, + neural_logits: np.ndarray) -> float: + """Score target_token under mixed neural+ngram distribution and update cache.""" + ngram_probs = self._ngram_probs(context_tokens) + alpha = self._mixing_alpha(neural_logits) + nl = neural_logits.astype(np.float64) + nl -= nl.max() + neural_probs = np.exp(nl) + neural_probs /= neural_probs.sum() + mixed = (1.0 - alpha) * neural_probs + alpha * ngram_probs.astype(np.float64) + s = mixed.sum() + if s > 0: + mixed /= s + log_prob = float(np.log(mixed[target_token] + 1e-40)) + tok = int(target_token) + for order in range(1, self.max_order + 1): + if len(context_tokens) >= order: + ctx_hash = self._hash_ctx(context_tokens[-order:]) + self._counts[order][ctx_hash][tok] += 1.0 + self._total[order][ctx_hash] += 1.0 + return log_prob + +def build_sentencepiece_luts(sp, vocab_size): sp_vocab_size = int(sp.vocab_size()) table_size = max(sp_vocab_size, vocab_size) base_bytes_lut = np.zeros((table_size,), dtype=np.int16) has_leading_space_lut = np.zeros((table_size,), dtype=np.bool_) is_boundary_token_lut = np.ones((table_size,), dtype=np.bool_) - for token_id in range(sp_vocab_size): - if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + for tid in range(sp_vocab_size): + if sp.is_control(tid) or sp.is_unknown(tid) or sp.is_unused(tid): continue - is_boundary_token_lut[token_id] = False - if sp.is_byte(token_id): - base_bytes_lut[token_id] = 1 + is_boundary_token_lut[tid] = False + if sp.is_byte(tid): + base_bytes_lut[tid] = 1 continue - piece = sp.id_to_piece(token_id) - if piece.startswith("▁"): - has_leading_space_lut[token_id] = True + piece = sp.id_to_piece(tid) + if piece.startswith("\u2581"): + has_leading_space_lut[tid] = True piece = piece[1:] - base_bytes_lut[token_id] = len(piece.encode("utf-8")) + base_bytes_lut[tid] = len(piece.encode("utf-8")) return base_bytes_lut, has_leading_space_lut, is_boundary_token_lut - -def validate_dataset_tokenizer_pair(data_path: str, tokenizer_path: str) -> tuple[str, int, int | None]: +def validate_dataset_tokenizer_pair(data_path, tokenizer_path): dataset_dir = Path(data_path).resolve() - actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) - if len(dataset_dir.parents) < 2: - return dataset_dir.name, actual_train_files, None - manifest_path = dataset_dir.parents[1] / "manifest.json" - if not manifest_path.is_file(): - return dataset_dir.name, actual_train_files, None - manifest = json.loads(manifest_path.read_text(encoding="utf-8")) - dataset_entry = next((x for x in manifest.get("datasets", []) if x.get("name") == dataset_dir.name), None) - if dataset_entry is None: - return dataset_dir.name, actual_train_files, None - tokenizer_name = dataset_entry.get("tokenizer_name") - tokenizer_entry = ( - next((x for x in manifest.get("tokenizers", []) if x.get("name") == tokenizer_name), None) - if tokenizer_name - else None - ) - expected_name = Path((tokenizer_entry or {}).get("model_path") or (tokenizer_entry or {}).get("path") or "").name - if expected_name and Path(tokenizer_path).name != expected_name: - raise ValueError(f"{dataset_dir.name} expects tokenizer {expected_name}, got {Path(tokenizer_path).name}") - expected_train_files = (dataset_entry.get("stats") or {}).get("files_train") - if expected_train_files is not None: - expected_train_files = int(expected_train_files) - if actual_train_files > expected_train_files: - raise ValueError( - f"{dataset_dir.name} has more train shards than expected: found {actual_train_files}, " - f"manifest says {expected_train_files}" - ) - return dataset_dir.name, actual_train_files, expected_train_files - - -def load_validation_tokens(pattern: str, seq_len: int) -> np.ndarray: + actual = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + manifest_path = dataset_dir.parents[1] / "manifest.json" if len(dataset_dir.parents) >= 2 else None + if manifest_path and manifest_path.is_file(): + manifest = json.loads(manifest_path.read_text()) + entry = next((x for x in manifest.get("datasets", []) if x.get("name") == dataset_dir.name), None) + if entry: + expected = (entry.get("stats") or {}).get("files_train") + if expected is not None and actual > int(expected): + raise ValueError(f"Too many train shards: {actual} > {expected}") + return dataset_dir.name, actual, int(expected) if expected else None + return dataset_dir.name, actual, None + +def load_validation_tokens(pattern, seq_len): files = [Path(p) for p in sorted(glob.glob(pattern))] if not files: - raise FileNotFoundError(f"No files found for pattern: {pattern}") - tokens = np.ascontiguousarray(np.concatenate([load_data_shard(file) for file in files], axis=0)) + raise FileNotFoundError(f"No files: {pattern}") + tokens = np.concatenate([load_data_shard(f) for f in files], axis=0) usable = ((tokens.size - 1) // seq_len) * seq_len - if usable <= 0: - raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") - return tokens[: usable + 1] + return tokens[:usable + 1] +def token_chunks(total_tokens, seq_len, max_chunk_tokens): + usable_total = (total_tokens // seq_len) * seq_len + usable_chunk = max((max_chunk_tokens // seq_len) * seq_len, seq_len) + chunks, remaining = [], usable_total + while remaining > 0: + chunk = min(remaining, usable_chunk) + chunks.append(chunk) + remaining -= chunk + return chunks + +def accumulate_flat_grads(accum, grads_tree, scale): + flat = dict(tree_flatten(grads_tree)) + if accum is None: + return {k: g * scale for k, g in flat.items()} + for k, g in flat.items(): + accum[k] = accum[k] + g * scale + return accum -def loss_and_grad_chunked( - args: Hyperparameters, - train_loader: TokenLoader, - compiled_loss_and_grad, -) -> tuple[mx.array, dict]: +def loss_and_grad_chunked(args, train_loader, compiled_loss_and_grad): chunk_sizes = token_chunks(args.microbatch_tokens, args.train_seq_len, args.mlx_max_microbatch_tokens) total_tokens = float(sum(chunk_sizes)) loss_value = mx.array(0.0, dtype=mx.float32) - grad_accum: dict[str, mx.array] | None = None + grad_accum = None for chunk_tokens in chunk_sizes: x, y = train_loader.next_batch(chunk_tokens, args.train_seq_len) loss, grads = compiled_loss_and_grad(x, y) @@ -855,430 +1010,484 @@ def loss_and_grad_chunked( mx.eval(loss_value, grad_accum) return loss_value, tree_unflatten(list(grad_accum.items())) - -def eval_val( - args: Hyperparameters, - compiled_loss_fn, - model: GPT, - val_tokens: np.ndarray, - base_bytes_lut: np.ndarray, - has_leading_space_lut: np.ndarray, - is_boundary_token_lut: np.ndarray, - log_fn: Callable[[str], None] | None = None, -) -> tuple[float, float]: - del compiled_loss_fn +def eval_val(args, compiled_loss, val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, log_fn=None): val_batch_tokens = args.val_batch_size // args.grad_accum_steps - if val_batch_tokens < args.train_seq_len: - raise ValueError( - "VAL_BATCH_SIZE must provide at least one sequence; " - f"got VAL_BATCH_SIZE={args.val_batch_size}, GRAD_ACCUM_STEPS={args.grad_accum_steps}, " - f"TRAIN_SEQ_LEN={args.train_seq_len}" - ) val_batch_seqs = val_batch_tokens // args.train_seq_len total_seqs = (val_tokens.size - 1) // args.train_seq_len total_batches = max((total_seqs + val_batch_seqs - 1) // val_batch_seqs, 1) - total_loss_sum = 0.0 - total_tokens = 0.0 - total_bytes = 0.0 - for batch_idx, batch_seq_start in enumerate(range(0, total_seqs, val_batch_seqs), start=1): - batch_seq_end = min(batch_seq_start + val_batch_seqs, total_seqs) - raw_start = batch_seq_start * args.train_seq_len - raw_end = batch_seq_end * args.train_seq_len + 1 - chunk = val_tokens[raw_start:raw_end] + total_loss_sum, total_tokens_f, total_bytes = 0.0, 0.0, 0.0 + for batch_idx, start in enumerate(range(0, total_seqs, val_batch_seqs), 1): + end = min(start + val_batch_seqs, total_seqs) + raw_s, raw_e = start * args.train_seq_len, end * args.train_seq_len + 1 + chunk = val_tokens[raw_s:raw_e] x_np = chunk[:-1].reshape(-1, args.train_seq_len) y_np = chunk[1:].reshape(-1, args.train_seq_len) x = mx.array(x_np, dtype=mx.int32) y = mx.array(y_np, dtype=mx.int32) - batch_loss = model.loss(x, y).astype(mx.float32) - mx.eval(batch_loss) - chunk_token_count = float(y.size) - total_loss_sum += float(batch_loss.item()) * chunk_token_count - prev_ids = x_np.reshape(-1) - tgt_ids = y_np.reshape(-1) + ct = float(y.size) + bl = compiled_loss(x, y).astype(mx.float32) + mx.eval(bl) + total_loss_sum += float(bl.item()) * ct + prev_ids, tgt_ids = x_np.reshape(-1), y_np.reshape(-1) bytes_np = base_bytes_lut[tgt_ids].astype(np.int16, copy=True) - bytes_np += ( - has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids] - ).astype(np.int16, copy=False) - total_tokens += chunk_token_count + bytes_np += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).astype(np.int16) + total_tokens_f += ct total_bytes += float(bytes_np.astype(np.float64).sum()) - if log_fn is not None and total_batches > 1 and ( - batch_idx == 1 or batch_idx == total_batches or batch_idx % 25 == 0 - ): + if log_fn and total_batches > 1 and (batch_idx == 1 or batch_idx == total_batches or batch_idx % 50 == 0): log_fn(f"val_progress:{batch_idx}/{total_batches}") - val_loss = total_loss_sum / total_tokens - val_bpb = (val_loss / math.log(2.0)) * (total_tokens / total_bytes) + val_loss = total_loss_sum / total_tokens_f + val_bpb = (val_loss / math.log(2.0)) * (total_tokens_f / total_bytes) return val_loss, val_bpb - -def eval_val_sliding( - args: Hyperparameters, - compiled_loss_fn, - model: GPT, - val_tokens: np.ndarray, - base_bytes_lut: np.ndarray, - has_leading_space_lut: np.ndarray, - is_boundary_token_lut: np.ndarray, - log_fn: Callable[[str], None] | None = None, -) -> tuple[float, float]: - del compiled_loss_fn - seq_len = args.eval_seq_len - stride = args.eval_stride +def eval_val_sliding(args, model, val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, log_fn=None): + """Sliding-window eval: each token scored with up to eval_seq_len context.""" + seq_len = args.eval_seq_len + stride = args.eval_stride + batch_seqs = args.eval_batch_seqs total_tokens = val_tokens.size - 1 - if total_tokens < seq_len: - raise ValueError(f"Validation split is too short for EVAL_SEQ_LEN={seq_len}") - starts = list(range(0, total_tokens - seq_len + 1, stride)) - final_start = total_tokens - seq_len - if starts[-1] != final_start: - starts.append(final_start) - total_loss_sum = 0.0 - total_scored_tokens = 0.0 - total_bytes = 0.0 - scored_until = 0 - for window_idx, start in enumerate(starts): - end = start + seq_len - chunk = val_tokens[start : end + 1] - x = mx.array(chunk[:-1].reshape(1, seq_len), dtype=mx.int32) - y = mx.array(chunk[1:].reshape(1, seq_len), dtype=mx.int32) - logits = model.forward_logits(x) - logits_flat = logits.reshape(-1, logits.shape[-1]).astype(mx.float32) - targets_flat = y.reshape(-1) - per_token_loss = nn.losses.cross_entropy(logits_flat, targets_flat, reduction="none") - mx.eval(per_token_loss) - per_token_loss_np = np.array(per_token_loss) - score_start = 0 if window_idx == 0 else max(scored_until - start, 0) - scored_losses = per_token_loss_np[score_start:] - scored_targets = np.array(y.reshape(-1))[score_start:] - scored_prevs = np.array(x.reshape(-1))[score_start:] - bytes_np = base_bytes_lut[scored_targets].astype(np.int16) - bytes_np += (has_leading_space_lut[scored_targets] & ~is_boundary_token_lut[scored_prevs]).astype(np.int16) - total_loss_sum += float(np.sum(scored_losses)) - total_scored_tokens += float(len(scored_losses)) - total_bytes += float(np.sum(bytes_np.astype(np.float64))) - scored_until = end - if log_fn is not None and window_idx % 200 == 0: - log_fn(f"sliding_progress:{start}/{total_tokens}") - val_loss = total_loss_sum / total_scored_tokens - val_bpb = (val_loss / math.log(2.0)) * (total_scored_tokens / total_bytes) + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= 1] + total_windows = len(window_starts) + loss_sum = 0.0 + token_count = 0.0 + byte_count = 0.0 + model.use_qat = False + for bi in range(0, total_windows, batch_seqs): + batch_ws = window_starts[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_np = np.zeros((bsz, seq_len), dtype=np.int32) + y_np = np.zeros((bsz, seq_len), dtype=np.int32) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + x_np[i, :wlen] = val_tokens[ws:end] + y_np[i, :wlen] = val_tokens[ws + 1:end + 1] + x = mx.array(x_np) + y = mx.array(y_np) + nll = model.token_losses(x, y) # (B, T) + mx.eval(nll) + nll_np = np.array(nll) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + loss_sum += float(nll_np[i, s:wlen].sum()) + token_count += float(wlen - s) + tgt = y_np[i, s:wlen] + prev = x_np[i, s:wlen] + tb = base_bytes_lut[tgt].astype(np.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).astype(np.float64) + byte_count += float(tb.sum()) + if log_fn and (bi // batch_seqs) % 50 == 0: + done = min(bi + batch_seqs, total_windows) + pct = done / total_windows * 100 + rbpb = 0.0 + if token_count > 0: + rbpb = (loss_sum / token_count) / math.log(2.0) * (token_count / byte_count) + log_fn(f"sliding_eval [{pct:5.1f}%] {done}/{total_windows} windows running_bpb={rbpb:.6f}") + val_loss = loss_sum / token_count + val_bpb = (val_loss / math.log(2.0)) * (token_count / byte_count) return val_loss, val_bpb +def eval_val_sliding_ngram(args, model, val_tokens, + base_bytes_lut, has_leading_space_lut, + is_boundary_token_lut, log_fn=None): + """Sliding-window eval with BackoffNgramMixer post-processing.""" + seq_len = args.eval_seq_len + stride = args.eval_stride + batch_seqs = args.eval_batch_seqs + max_order = args.ngram_max_order + total_tokens = val_tokens.size - 1 + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= 1] + total_windows = len(window_starts) + loss_sum = 0.0 + token_count = 0.0 + byte_count = 0.0 + model.use_qat = False + mixer = BackoffNgramMixer( + vocab_size=args.vocab_size, + max_order=max_order, + alpha_mode="entropy_adaptive", + fixed_alpha=args.ngram_alpha, + ) + for bi in range(0, total_windows, batch_seqs): + batch_ws = window_starts[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_np = np.zeros((bsz, seq_len), dtype=np.int32) + y_np = np.zeros((bsz, seq_len), dtype=np.int32) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + x_np[i, :wlen] = val_tokens[ws:end] + y_np[i, :wlen] = val_tokens[ws + 1:end + 1] + x = mx.array(x_np) + logits_all = model.token_logits(x) + mx.eval(logits_all) + logits_np = np.array(logits_all) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + for j in range(s, wlen): + global_pos = ws + j # global index of the input token + target = int(y_np[i, j]) + neural_logits = logits_np[i, j] + ctx_start = max(0, global_pos + 1 - max_order) + context = val_tokens[ctx_start:global_pos + 1].tolist() + log_prob = mixer.score_and_update(context, target, neural_logits) + loss_sum += -log_prob + token_count += 1.0 + tgt_arr = y_np[i, j:j + 1] + prev_arr = x_np[i, j:j + 1] + tb = float(base_bytes_lut[tgt_arr[0]]) + tb += float(has_leading_space_lut[tgt_arr[0]] and not is_boundary_token_lut[prev_arr[0]]) + byte_count += tb + if log_fn and (bi // batch_seqs) % 50 == 0: + done = min(bi + batch_seqs, total_windows) + pct = done / total_windows * 100 + rbpb = 0.0 + if token_count > 0 and byte_count > 0: + rbpb = (loss_sum / token_count) / math.log(2.0) * (token_count / byte_count) + log_fn(f"ngram_sliding_eval [{pct:5.1f}%] {done}/{total_windows} windows running_bpb={rbpb:.6f}") + val_loss = loss_sum / max(token_count, 1.0) + val_bpb = (val_loss / math.log(2.0)) * (token_count / max(byte_count, 1.0)) + return val_loss, val_bpb -def clip_grad_tree(grads_tree: dict, max_norm: float) -> dict: +def clip_grad_tree(grads_tree, max_norm): + """Clip gradient tree by global norm.""" if max_norm <= 0: return grads_tree flat = dict(tree_flatten(grads_tree)) - total_sq = 0.0 - for grad in flat.values(): - total_sq += float(np.sum(np.square(_np_float32(grad)), dtype=np.float64)) - if total_sq <= 0.0: + total_sq = sum(float(np.sum(np.square(_np_float32(g)), dtype=np.float64)) for g in flat.values()) + if total_sq <= 0 or math.sqrt(total_sq) <= max_norm: return grads_tree - total_norm = math.sqrt(total_sq) - if total_norm <= max_norm: - return grads_tree - scale = max_norm / (total_norm + 1e-12) + scale = max_norm / (math.sqrt(total_sq) + 1e-12) return tree_unflatten([(k, g * scale) for k, g in flat.items()]) +def eval_val_sliding_ttt(args, model, val_tokens, + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + log_fn=None): + """Sliding-window eval with per-document LoRA TTT on Q and V projections.""" + seq_len = args.eval_seq_len + stride = args.eval_stride + rank = args.ttt_rank + ttt_lr = args.ttt_lr + ttt_steps = args.ttt_steps + total_tokens = val_tokens.size - 1 + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= 1] + loss_sum, token_count, byte_count = 0.0, 0.0, 0.0 + model.use_qat = False + qv_keys = [(li, proj) + for li, blk in enumerate(model.blocks) + for proj in ("attn.c_q", "attn.c_v")] + + def _get_w(li, proj_name): + blk = model.blocks[li] + return blk.attn.c_q.weight if proj_name == "attn.c_q" else blk.attn.c_v.weight + + def _set_w(li, proj_name, w): + blk = model.blocks[li] + if proj_name == "attn.c_q": + blk.attn.c_q.weight = w + else: + blk.attn.c_v.weight = w + for wi, ws in enumerate(window_starts): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + x_np = np.zeros((1, seq_len), dtype=np.int32) + y_np = np.zeros((1, seq_len), dtype=np.int32) + x_np[0, :wlen] = val_tokens[ws:end] + y_np[0, :wlen] = val_tokens[ws + 1:end + 1] + x = mx.array(x_np) + y = mx.array(y_np) + saved, lora_A, lora_B = {}, {}, {} + for li, proj in qv_keys: + w = _get_w(li, proj) + saved[(li, proj)] = mx.array(w) + out_d, in_d = w.shape + lora_A[(li, proj)] = mx.random.normal((rank, in_d)).astype(mx.float32) * 0.01 + lora_B[(li, proj)] = mx.zeros((out_d, rank), dtype=mx.float32) + s = 0 if ws == 0 else max(wlen - stride, 0) + if s > 0: + ctx_x = x_np[:, :s] + ctx_y = y_np[:, :s] + for _ in range(ttt_steps): + for li, proj in qv_keys: + w_base = saved[(li, proj)] + delta = lora_B[(li, proj)] @ lora_A[(li, proj)] + _set_w(li, proj, w_base + delta.astype(w_base.dtype)) + + def lora_loss(): + cx = mx.array(ctx_x) + cy = mx.array(ctx_y) + return model.loss(cx, cy) + loss_val = lora_loss() + mx.eval(loss_val) + for li, proj in qv_keys: + w_base = saved[(li, proj)] + _set_w(li, proj, w_base) # restore for clean grad + for li, proj in qv_keys: + w_base = saved[(li, proj)] + delta = lora_B[(li, proj)] @ lora_A[(li, proj)] + _set_w(li, proj, w_base + delta.astype(w_base.dtype)) + nll = model.token_losses(x, y) + mx.eval(nll) + nll_np = np.array(nll) + loss_sum += float(nll_np[0, s:wlen].sum()) + token_count += float(wlen - s) + tgt = y_np[0, s:wlen] + prev = x_np[0, s:wlen] + tb = base_bytes_lut[tgt].astype(np.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).astype(np.float64) + byte_count += float(tb.sum()) + for li, proj in qv_keys: + _set_w(li, proj, saved[(li, proj)]) + if log_fn and wi % 500 == 0: + pct = wi / len(window_starts) * 100 + rbpb = (loss_sum / max(token_count, 1)) / math.log(2.0) * (token_count / max(byte_count, 1)) + log_fn(f"ttt_eval [{pct:.1f}%] {wi}/{len(window_starts)} bpb={rbpb:.4f}") + val_loss = loss_sum / token_count + val_bpb = (val_loss / math.log(2.0)) * (token_count / byte_count) + return val_loss, val_bpb -def estimate_total_steps(args: Hyperparameters, step: int, elapsed_ms: float) -> int: - if args.max_wallclock_seconds <= 0 or step <= 0 or elapsed_ms <= 0.0: - return args.iterations - step_ms = elapsed_ms / step - wallclock_steps = max(1, int((1000.0 * args.max_wallclock_seconds) / max(step_ms, 1e-9))) - return min(args.iterations, wallclock_steps) - - -def should_start_qat(args: Hyperparameters, step: int, lr_mul: float) -> bool: - if args.max_wallclock_seconds <= 0: - return (step / max(args.iterations, 1)) >= args.late_qat_threshold - return lr_mul < args.late_qat_threshold - - -def with_eval_params( - model: GPT, - ema: EMABuffer | None, - eval_fn: Callable[[], tuple[float, float]], -) -> tuple[float, float]: - saved_params = flatten_params(model) - saved_qat_flags = capture_qat_flags(model) - try: - if ema is not None: - update_model_params(model, ema.as_mlx()) - set_qat_mode(model, False, False) - return eval_fn() - finally: - restore_qat_flags(model, saved_qat_flags) - update_model_params(model, saved_params) - - -def main() -> None: +def main(): args = Hyperparameters() out_dir = Path(args.out_dir) out_dir.mkdir(parents=True, exist_ok=True) logfile = out_dir / f"{args.run_id}.txt" - print(logfile) - def log(msg: str, console: bool = True) -> None: - if console: - print(msg) - with logfile.open("a", encoding="utf-8") as f: - print(msg, file=f) - - code = Path(__file__).read_text(encoding="utf-8") + def log(msg, console=True): + if console: print(msg) + with logfile.open("a") as f: print(msg, file=f) + code = Path(__file__).read_text() log(code, console=False) - log("=" * 100, console=False) - log(f"Running Python {sys.version}", console=False) log(f"Running MLX {mx.__version__}", console=False) - log("=" * 100, console=False) - - if not args.tie_embeddings: - raise NotImplementedError("train_gpt_mlx_kl.py only supports tied embeddings") - if not args.tokenizer_path.endswith(".model"): - raise ValueError(f"TOKENIZER_PATH must point to a SentencePiece .model file: {args.tokenizer_path}") sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) if int(sp.vocab_size()) != args.vocab_size: - raise ValueError( - f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" - ) - dataset_name, actual_train_files, expected_train_files = validate_dataset_tokenizer_pair( - args.data_path, - args.tokenizer_path, - ) + raise ValueError(f"VOCAB_SIZE mismatch: {args.vocab_size} vs {int(sp.vocab_size())}") + dataset_name, actual_files, expected_files = validate_dataset_tokenizer_pair(args.data_path, args.tokenizer_path) val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts(sp, args.vocab_size) - mx.random.seed(args.seed) train_loader = TokenLoader(args.train_files, log_fn=log, dataset_name=dataset_name) model = GPT( - vocab_size=args.vocab_size, - num_layers=args.num_layers, - dim=args.model_dim, - num_heads=args.num_heads, - num_kv_heads=args.num_kv_heads, - mlp_mult=args.mlp_mult, - logit_chunk_tokens=args.logit_chunk_tokens, - logit_softcap=args.logit_softcap, - rope_base=args.rope_base, - tied_embed_init_std=args.tied_embed_init_std, - qk_gain_init=args.qk_gain_init, - bigram_hash_size=args.bigram_hash_size, + vocab_size=args.vocab_size, num_layers=args.num_layers, dim=args.model_dim, + num_heads=args.num_heads, num_kv_heads=args.num_kv_heads, mlp_mult=args.mlp_mult, + logit_chunk_tokens=args.logit_chunk_tokens, logit_softcap=args.logit_softcap, + rope_base=args.rope_base, tied_embed_init_std=args.tied_embed_init_std, + qk_gain_init=args.qk_gain_init, bigram_hash_size=args.bigram_hash_size, use_ortho_init=args.use_ortho_init, + rope_dims=args.rope_dims, xsa_last_n=args.xsa_last_n, + use_ln_scale=args.ln_scale_enabled, smear_enabled=args.smear_enabled, + engram_lite_enabled=args.engram_lite_enabled, + engram_hash_size=args.engram_hash_size, + engram_embed_dim=args.engram_embed_dim, + engram_n_heads=args.engram_n_heads, + skipgram_hash_size=args.skipgram_hash_size, ) opt = SplitOptimizers(model, args) - compiled_loss, compiled_loss_and_grad = compile_model_fns(model) - - n_params = sum(int(np.prod(p.shape)) for _, p in tree_flatten(model.parameters())) - log(f"run_id:{args.run_id}") - log(f"mlx_version:{mx.__version__}") - log(f"train_loader:shards pattern={args.train_files}") - log(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.size - 1}") - if expected_train_files is None: - log(f"train_loader:dataset:{dataset_name} train_shards:{actual_train_files}") - elif actual_train_files < expected_train_files: - log( - f"WARNING: train_loader:subset dataset:{dataset_name} " - f"train_shards:{actual_train_files}/{expected_train_files} " - f"new epochs will arrive sooner than the full dataset" - ) + ema = None + swa = None + bigram_probs_mx = None + if args.complement_alpha > 0.0: + log("complement_training: building bigram stats from training shards...") + _bp_np = build_bigram_stats(args.data_path, args.vocab_size) + bigram_probs_mx = mx.array(_bp_np, dtype=mx.float32) + mx.eval(bigram_probs_mx) + log(f"complement_training: bigram stats ready (alpha={args.complement_alpha})") + del _bp_np + if bigram_probs_mx is not None: + _alpha = args.complement_alpha + _bp = bigram_probs_mx + def _loss_fn(x, y): + return model.complementary_loss(x, y, _bp, _alpha) else: - log(f"train_loader:dataset:{dataset_name} train_shards:{actual_train_files}/{expected_train_files}") - log(f"tokenizer_path:{args.tokenizer_path}") - log( - f"model_params:{n_params} vocab_size:{args.vocab_size} layers:{args.num_layers} " - f"dim:{args.model_dim} heads:{args.num_heads} kv_heads:{args.num_kv_heads} " - f"seq_len:{args.train_seq_len} tie_embeddings:{args.tie_embeddings}" - ) - log( - f"iterations:{args.iterations} train_batch_tokens:{args.train_batch_tokens} grad_accum_steps:{args.grad_accum_steps} " - f"microbatch_tokens:{args.microbatch_tokens} microbatch_batch_size:{args.microbatch_tokens // args.train_seq_len} " - f"val_batch_size:{args.val_batch_size} warmup_steps:{args.warmup_steps} " - f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" - ) - log(f"mlx_max_microbatch_tokens:{args.mlx_max_microbatch_tokens}") - log( - f"optimizer:muon+adam muon_matrix_params:{len(opt.matrix_keys)} scalar_params:{len(opt.scalar_keys)} " - f"embed_lr:{args.tied_embed_lr} matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr} " - f"muon_momentum:{args.muon_momentum} muon_steps:{args.muon_backend_steps}" - ) - log( - f"features:bigram_hash:{args.bigram_hash_size} ortho_init:{int(args.use_ortho_init)} " - f"ema_decay:{args.ema_decay} ema_start_frac:{args.ema_start_frac} " - f"late_qat_threshold:{args.late_qat_threshold} learned_scales:{int(args.learned_scales)}" - ) - log(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") - log(f"compute_dtype:{COMPUTE_DTYPE} compile:True") - log( - f"dtypes tok_emb:{model.tok_emb.weight.dtype} " - f"linear_weight:{model.blocks[0].attn.c_q.weight.dtype} " - f"skip_weights:{model.skip_weights.dtype}" + def _loss_fn(x, y): + return model.loss(x, y) + compiled_loss = mx.compile(_loss_fn, inputs=model.state, outputs=model.state) + compiled_loss_and_grad = mx.compile( + nn.value_and_grad(model, _loss_fn), + inputs=model.state, outputs=model.state, ) - - do_final_eval = args.val_loss_every > 0 or args.max_wallclock_seconds > 0 - + n_params = sum(int(np.prod(p.shape)) for _, p in tree_flatten(model.parameters())) + log(f"run_id:{args.run_id}") + xsa_layers = [i for i, b in enumerate(model.blocks) if b.attn.use_xsa] + log(f"model_params:{n_params} layers:{args.num_layers} dim:{args.model_dim} " + f"mlp_mult:{args.mlp_mult} bigram_hash:{args.bigram_hash_size} " + f"ortho_init:{args.use_ortho_init} ema_decay:{args.ema_decay}") + log(f"innovations: smear={args.smear_enabled} rope_dims={args.rope_dims} " + f"ln_scale={args.ln_scale_enabled} xsa_last_n={args.xsa_last_n} xsa_layers={xsa_layers} " + f"gptq_lite={args.use_gptq_lite} ttt={args.ttt_enabled} eval_mode={args.eval_mode} " + f"use_swa={args.use_swa} swa_decay={args.swa_decay}") + log(f"moonshot: engram_lite={args.engram_lite_enabled} skipgram_hash={args.skipgram_hash_size} " + f"complement_alpha={args.complement_alpha} " + f"ngram_mixer={args.ngram_mixer_enabled} ngram_alpha={args.ngram_alpha} " + f"ngram_max_order={args.ngram_max_order}") + log(f"iterations:{args.iterations} train_batch_tokens:{args.train_batch_tokens} " + f"grad_accum:{args.grad_accum_steps} seq_len:{args.train_seq_len}") + log(f"optimizer: muon_keys:{len(opt.matrix_keys)} scalar_keys:{len(opt.scalar_keys)}") + log(f"val_tokens:{val_tokens.size - 1} train_shards:{actual_files}") if args.warmup_steps > 0: - for warmup_step in range(args.warmup_steps): - warmup_loss, grads = loss_and_grad_chunked(args, train_loader, compiled_loss_and_grad) - mx.eval(warmup_loss, grads) - mx.synchronize() - if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: - log(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") - if do_final_eval: - val_batch_tokens = args.val_batch_size // args.grad_accum_steps - if val_batch_tokens < args.train_seq_len: - raise ValueError( - "VAL_BATCH_SIZE must provide at least one sequence; " - f"got VAL_BATCH_SIZE={args.val_batch_size}, GRAD_ACCUM_STEPS={args.grad_accum_steps}, " - f"TRAIN_SEQ_LEN={args.train_seq_len}" - ) - warm_val_seqs = min(val_batch_tokens // args.train_seq_len, (val_tokens.size - 1) // args.train_seq_len) - warm_chunk = val_tokens[: warm_val_seqs * args.train_seq_len + 1] - x_val = mx.array(warm_chunk[:-1].reshape(-1, args.train_seq_len), dtype=mx.int32) - y_val = mx.array(warm_chunk[1:].reshape(-1, args.train_seq_len), dtype=mx.int32) - warm_val_loss = compiled_loss(x_val, y_val) - mx.eval(warm_val_loss) + for ws in range(args.warmup_steps): + wl, wg = loss_and_grad_chunked(args, train_loader, compiled_loss_and_grad) + mx.eval(wl) mx.synchronize() + if ws + 1 == args.warmup_steps: + log(f"warmup_done:{args.warmup_steps} steps") + vbs = args.val_batch_size // args.grad_accum_steps + vs = min(vbs // args.train_seq_len, (val_tokens.size - 1) // args.train_seq_len) + wc = val_tokens[:vs * args.train_seq_len + 1] + xv = mx.array(wc[:-1].reshape(-1, args.train_seq_len), dtype=mx.int32) + yv = mx.array(wc[1:].reshape(-1, args.train_seq_len), dtype=mx.int32) + mx.eval(compiled_loss(xv, yv)) + mx.synchronize() train_loader = TokenLoader(args.train_files, log_fn=log, dataset_name=dataset_name) - train_time_ms = 0.0 - max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None - stop_after_step: int | None = None - ema: EMABuffer | None = None - qat_active = False + max_wc_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + stop_after = None t0 = time.perf_counter() step = 0 + _prev_use_qat = False # track QAT state to detect transition and recompile while True: - last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) - if args.val_loss_every > 0 and (step % args.val_loss_every == 0 or last_step): + last_step = step == args.iterations or (stop_after is not None and step >= stop_after) + if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): train_time_ms += 1000.0 * (time.perf_counter() - t0) - val_loss, val_bpb = with_eval_params( - model, - ema, - lambda: eval_val( - args, - compiled_loss, - model, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, - log_fn=log, - ), - ) - if step % 25 == 0 or last_step: - log( - f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " - f"train_time:{train_time_ms:.0f}ms" - ) + _avg = swa if swa is not None else ema + if _avg is not None: + saved_state = {k: mx.array(v) for k, v in tree_flatten(model.parameters())} + _avg.apply(model) + compiled_loss = mx.compile(_loss_fn, inputs=model.state, outputs=model.state) + model.use_qat = False # No QAT during eval + val_loss, val_bpb = eval_val(args, compiled_loss, val_tokens, + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, log_fn=log) + log(f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{train_time_ms:.0f}ms step_avg:{train_time_ms/max(step,1):.2f}ms") + if _avg is not None: + model.update(tree_unflatten(list(saved_state.items()))) + compiled_loss = mx.compile(_loss_fn, inputs=model.state, outputs=model.state) + compiled_loss_and_grad = mx.compile( + nn.value_and_grad(model, _loss_fn), + inputs=model.state, outputs=model.state) t0 = time.perf_counter() if last_step: - if stop_after_step is not None and step < args.iterations: - log(f"stopping_early: wallclock_cap train_time:{train_time_ms:.0f}ms step:{step}/{args.iterations}") + if stop_after is not None and step < args.iterations: + log(f"stopping_early: wallclock train_time:{train_time_ms:.0f}ms step:{step}") break - lr_mul = args.lr_mul(step, train_time_ms + 1000.0 * (time.perf_counter() - t0)) + _new_use_qat = lr_mul < args.late_qat_threshold + if _new_use_qat != _prev_use_qat: + model.use_qat = _new_use_qat + _prev_use_qat = _new_use_qat + if _new_use_qat: + log(f"qat_started:step={step} lr_mul={lr_mul:.4f} — recompiling graph") + compiled_loss = mx.compile( + _loss_fn, inputs=model.state, outputs=model.state) + compiled_loss_and_grad = mx.compile( + nn.value_and_grad(model, _loss_fn), + inputs=model.state, outputs=model.state) + est_total = args.iterations + if max_wc_ms and step > 0: + est_total = min(args.iterations, int(max_wc_ms / (train_time_ms / step + 0.001))) + if ema is None and step >= int(est_total * args.ema_start_frac): + ema = EMABuffer(model, decay=args.ema_decay) + log(f"ema_started:step={step}") + if args.use_swa and swa is None and step >= int(est_total * 0.6): + swa = EMABuffer(model, decay=args.swa_decay) + log(f"swa_started:step={step} decay={args.swa_decay}") step_t0 = time.perf_counter() - accum: dict[str, mx.array] | None = None + accum = None train_loss = mx.array(0.0, dtype=mx.float32) - grad_scale = 1.0 / args.grad_accum_steps + gs = 1.0 / args.grad_accum_steps for _ in range(args.grad_accum_steps): loss, grads = loss_and_grad_chunked(args, train_loader, compiled_loss_and_grad) - accum = accumulate_flat_grads(accum, grads, grad_scale) - train_loss = train_loss + loss.astype(mx.float32) * grad_scale + accum = accumulate_flat_grads(accum, grads, gs) + train_loss = train_loss + loss.astype(mx.float32) * gs if args.mlx_eager_eval: mx.eval(train_loss, accum) grads = tree_unflatten(list(accum.items())) grads = clip_grad_tree(grads, args.grad_clip_norm) - train_loss_value = float(train_loss.item()) + tl = float(train_loss.item()) opt.step(model, grads, step=step, lr_mul=lr_mul) - step += 1 - - approx_train_time_ms = train_time_ms + 1000.0 * (time.perf_counter() - t0) - estimated_total_steps = estimate_total_steps(args, step, approx_train_time_ms) - ema_start_step = max(1, int(math.ceil(args.ema_start_frac * estimated_total_steps))) - if ema is None and step >= ema_start_step: - ema = EMABuffer(flatten_params(model), decay=args.ema_decay) - log(f"ema_started step:{step} est_total_steps:{estimated_total_steps}") - elif ema is not None: - ema.update(flatten_params(model)) - - if not qat_active and should_start_qat(args, step, lr_mul): - qat_active = True - set_qat_mode(model, True, args.learned_scales) - compiled_loss, compiled_loss_and_grad = compile_model_fns(model) - log(f"qat_started step:{step} lr_mul:{lr_mul:.4f}") - mx.synchronize() - step_ms = 1000.0 * (time.perf_counter() - step_t0) - tok_s = args.train_batch_tokens / max(step_ms / 1000.0, 1e-9) - tags = [] - if qat_active: - tags.append("[QAT]") if ema is not None: - tags.append("[EMA]") - if args.train_log_every > 0 and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None): - suffix = f" {' '.join(tags)}" if tags else "" - log(f"step:{step}/{args.iterations} train_loss:{train_loss_value:.4f} step_ms:{step_ms:.0f} tok_s:{tok_s:.0f}{suffix}") - if max_wallclock_ms is not None and stop_after_step is None and approx_train_time_ms >= max_wallclock_ms: - stop_after_step = step - - if ema is not None: - update_model_params(model, ema.as_mlx()) - log("ema_applied_for_final_save") - - if do_final_eval: - final_val_loss, final_val_bpb = with_eval_params( - model, - None, - lambda: eval_val_sliding( - args, - compiled_loss, - model, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, - log_fn=log, - ), - ) - log(f"final_prequant_sliding val_loss:{final_val_loss:.4f} val_bpb:{final_val_bpb:.4f}") - log(f"final_prequant_sliding_exact val_loss:{final_val_loss:.8f} val_bpb:{final_val_bpb:.8f}") - - out_path = out_dir / f"{args.run_id}_mlx_model.npz" + ema.update(model) + if swa is not None: + swa.update(model) + step_ms = 1000.0 * (time.perf_counter() - step_t0) + approx_ms = train_time_ms + 1000.0 * (time.perf_counter() - t0) + step += 1 + if args.train_log_every > 0 and (step <= 5 or step % args.train_log_every == 0): + tok_s = args.train_batch_tokens / (step_ms / 1000.0) + qat_tag = " [QAT]" if model.use_qat else "" + ema_tag = " [EMA]" if ema is not None else "" + swa_tag = " [SWA]" if swa is not None else "" + log(f"step:{step}/{args.iterations} train_loss:{tl:.4f} " + f"step_ms:{step_ms:.0f} tok_s:{tok_s:.0f}{qat_tag}{ema_tag}{swa_tag}") + if max_wc_ms and stop_after is None and approx_ms >= max_wc_ms: + stop_after = step + if swa is not None: + swa.apply(model) + log("swa_applied_for_save") + elif ema is not None: + ema.apply(model) + log("ema_applied_for_save") + model.use_qat = False flat_state = {k: v for k, v in tree_flatten(model.state)} + out_path = out_dir / f"{args.run_id}_mlx_model.npz" mx.savez(str(out_path), **flat_state) log(f"saved_model:{out_path} bytes:{out_path.stat().st_size}") - - quant_obj, quant_stats = quantize_state_dict_int6(flat_state) + quant_obj, quant_stats = quantize_state_dict_int6(flat_state, args) quant_raw = pickle.dumps(quant_obj, protocol=pickle.HIGHEST_PROTOCOL) - quant_blob = zstd.ZstdCompressor(level=22).compress(quant_raw) + quant_blob = zstandard.ZstdCompressor(level=22).compress(quant_raw) quant_path = out_dir / f"{args.run_id}_mlx_model.int6.ptz" with quant_path.open("wb") as f: f.write(quant_blob) - quant_file_bytes = quant_path.stat().st_size - ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int6_payload_bytes"], 1) - log( - f"serialized_model_int6_zstd:{quant_file_bytes} bytes " - f"(payload:{quant_stats['int6_payload_bytes']} raw_pickle:{len(quant_raw)} payload_ratio:{ratio:.2f}x)" - ) - - quant_flat = dequantize_state_dict_int6(pickle.loads(zstd.ZstdDecompressor().decompress(quant_blob))) + log(f"serialized_int6_zstd:{quant_path.stat().st_size} bytes " + f"(payload:{quant_stats['int6_payload_bytes']} ratio:{quant_stats['baseline_tensor_bytes']/max(quant_stats['int6_payload_bytes'],1):.2f}x)") + with quant_path.open("rb") as f: + quant_blob_disk = f.read() + quant_flat = dequantize_state_dict_int6(pickle.loads(zstandard.ZstdDecompressor().decompress(quant_blob_disk))) model.update(tree_unflatten(list(quant_flat.items()))) - if do_final_eval: - q_val_loss, q_val_bpb = with_eval_params( - model, - None, - lambda: eval_val_sliding( - args, - compiled_loss, - model, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, - log_fn=log, - ), - ) - log(f"final_int6_zstd_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f}") + compiled_loss = mx.compile(_loss_fn, inputs=model.state, outputs=model.state) + eval_mode = args.eval_mode.lower().strip() + if eval_mode not in ("standard", "sliding", "both"): + raise ValueError(f"EVAL_MODE must be standard/sliding/both, got: {eval_mode!r}") + if eval_mode in ("standard", "both"): + qt0 = time.perf_counter() + log("final_eval_mode:standard") + s_val_loss, s_val_bpb = eval_val(args, compiled_loss, val_tokens, + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, log_fn=log) + sms = 1000.0 * (time.perf_counter() - qt0) + log(f"final_int6_zstd_roundtrip_standard val_loss:{s_val_loss:.4f} val_bpb:{s_val_bpb:.4f} eval_time:{sms:.0f}ms") + log(f"final_int6_zstd_roundtrip_standard_exact val_loss:{s_val_loss:.8f} val_bpb:{s_val_bpb:.8f}") + q_val_loss, q_val_bpb = s_val_loss, s_val_bpb # used as fallback for the summary lines below + if eval_mode in ("sliding", "both"): + qt0 = time.perf_counter() + if args.ngram_mixer_enabled: + log(f"final_eval_mode:sliding_ngram_mixer eval_seq_len:{args.eval_seq_len} " + f"stride:{args.eval_stride} ngram_alpha:{args.ngram_alpha} " + f"ngram_max_order:{args.ngram_max_order}") + q_val_loss, q_val_bpb = eval_val_sliding_ngram(args, model, val_tokens, + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, log_fn=log) + elif args.ttt_enabled: + log(f"final_eval_mode:ttt_sliding rank:{args.ttt_rank} lr:{args.ttt_lr} steps:{args.ttt_steps} stride:{args.eval_stride}") + q_val_loss, q_val_bpb = eval_val_sliding_ttt(args, model, val_tokens, + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, log_fn=log) + else: + log(f"final_eval_mode:sliding_window eval_seq_len:{args.eval_seq_len} stride:{args.eval_stride} batch_seqs:{args.eval_batch_seqs}") + q_val_loss, q_val_bpb = eval_val_sliding(args, model, val_tokens, + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, log_fn=log) + qms = 1000.0 * (time.perf_counter() - qt0) + log(f"final_int6_zstd_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} eval_time:{qms:.0f}ms") + log(f"final_int6_zstd_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + elif eval_mode == "standard": + qms = sms + log(f"final_int6_zstd_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} eval_time:{qms:.0f}ms") log(f"final_int6_zstd_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") - if __name__ == "__main__": main()