diff --git a/train_gpt_mlx_kl.py b/train_gpt_mlx_kl.py index 9cf931ef01..4bb15182a3 100644 --- a/train_gpt_mlx_kl.py +++ b/train_gpt_mlx_kl.py @@ -12,6 +12,7 @@ from __future__ import annotations import glob, json, math, os, pickle, sys, time, uuid, copy import zstandard +from collections import defaultdict from collections.abc import Callable from pathlib import Path import numpy as np @@ -81,6 +82,23 @@ class Hyperparameters: "LN_SCALE_ENABLED", os.environ.get("USE_LN_SCALE", "1")))) xsa_last_n: int = int(os.environ.get("XSA_LAST_N", 4)) # 4: XSA on last N decoder layers + # EngramLite — gated multi-head bigram+trigram hash (replaces BigramHash when enabled) + engram_lite_enabled: bool = bool(int(os.environ.get("ENGRAM_LITE_ENABLED", "0"))) + engram_hash_size: int = int(os.environ.get("ENGRAM_HASH_SIZE", "2048")) + engram_embed_dim: int = int(os.environ.get("ENGRAM_EMBED_DIM", "128")) + engram_n_heads: int = int(os.environ.get("ENGRAM_N_HEADS", "2")) + + # SkipGram — non-adjacent token hash logit bias (disabled by default) + skipgram_hash_size: int = int(os.environ.get("SKIPGRAM_HASH_SIZE", "0")) + + # Complementary Training — down-weights tokens easily predicted by bigrams + complement_alpha: float = float(os.environ.get("COMPLEMENT_ALPHA", "0.0")) + + # BackoffNgramMixer — eval-time causal n-gram mixing (zero artifact cost) + ngram_mixer_enabled: bool = bool(int(os.environ.get("NGRAM_MIXER_ENABLED", "0"))) + ngram_alpha: float = float(os.environ.get("NGRAM_ALPHA", "0.25")) + ngram_max_order: int = int(os.environ.get("NGRAM_MAX_ORDER", "4")) + # Eval mode — controls final roundtrip evaluation strategy # "standard" = chunked eval only (fast, ~36 min on M1) # "sliding" = sliding-window eval (accurate, ~3× slower) [default] @@ -247,6 +265,137 @@ def __call__(self, tokens: mx.array) -> mx.array: pad = mx.zeros((tokens.shape[0], 1, bigram_emb.shape[-1]), dtype=bigram_emb.dtype) return mx.concatenate([pad, bigram_emb], axis=1) # (B, T, dim) +# ============================================================================ +# EngramLiteEmbedding — gated multi-head bigram+trigram hash logit features. +# Replaces BigramHash when ENGRAM_LITE_ENABLED=1. +# Key improvement: gating suppresses noisy hash collisions (raw trigrams hurt). +# ============================================================================ +class EngramLiteEmbedding(nn.Module): + def __init__(self, hash_size: int = 2048, embed_dim: int = 128, + output_dim: int = 1024, n_heads: int = 2, + orders: tuple = (2, 3)): + super().__init__() + self.hash_size = hash_size + self.embed_dim = embed_dim + self.output_dim = output_dim + self.n_heads = n_heads + self.orders = list(orders) + # Different prime per hash head to reduce collision rate (at most 4 heads supported) + _all_primes = [31337, 59999, 73721, 97531] + if n_heads > len(_all_primes): + raise ValueError(f"EngramLiteEmbedding: n_heads={n_heads} exceeds max supported ({len(_all_primes)})") + self._primes = _all_primes[:n_heads] + + # Separate embedding table per n-gram order (small dim, projected later) + self.tables = { + f"order_{o}": nn.Embedding(hash_size, embed_dim) + for o in orders + } + for tbl in self.tables.values(): + tbl.weight = tbl.weight * 0.01 + + # Project from embed_dim → output_dim (vocab_size) + self.proj = nn.Linear(embed_dim, output_dim, bias=False) + self.proj.weight = self.proj.weight * 0.01 + + # Learned gate per n-gram order — starts mostly suppressed (sigmoid(-2) ≈ 0.12) + self.gate_proj = nn.Linear(embed_dim, len(orders), bias=True) + self.gate_proj.bias = mx.full((len(orders),), -2.0) + self.gate_proj.weight = self.gate_proj.weight * 0.01 + + def _hash_ngram(self, tokens: mx.array, order: int, head_idx: int): + """Hash n-gram context for given order and hash head.""" + prime = self._primes[head_idx] + if order == 2: + t_prev = tokens[:, :-1] + t_curr = tokens[:, 1:] + idx = mx.remainder(t_prev * prime + t_curr, self.hash_size) + valid_start = 1 + elif order == 3: + t_prev2 = tokens[:, :-2] + t_prev1 = tokens[:, 1:-1] + t_curr = tokens[:, 2:] + idx = mx.remainder( + t_prev2 * (prime * prime) + t_prev1 * prime + t_curr, + self.hash_size + ) + valid_start = 2 + else: + raise ValueError(f"n-gram order {order} not supported") + return idx, valid_start + + def __call__(self, tokens: mx.array) -> mx.array: + """tokens: (B, T) → (B, T, output_dim) additive logit bias""" + B, T = tokens.shape + combined = mx.zeros((B, T, self.embed_dim), dtype=mx.float32) + + for order in self.orders: + tbl = self.tables[f"order_{order}"] + # Multi-head average to reduce collision noise + head_sum = None + for hi in range(self.n_heads): + idx, valid_start = self._hash_ngram(tokens, order, hi) + emb = tbl(idx).astype(mx.float32) # (B, T-valid_start, embed_dim) + pad = mx.zeros((B, valid_start, self.embed_dim), dtype=mx.float32) + emb = mx.concatenate([pad, emb], axis=1) # (B, T, embed_dim) + head_sum = emb if head_sum is None else head_sum + emb + combined = combined + head_sum / self.n_heads + + # Sigmoid gate: suppress noisy lookups, let model learn when to trust them. + # Gate is averaged across all n-gram orders into a single scalar per position — + # empirically simpler and stable; the combined embedding already encodes order info. + gate = mx.sigmoid(self.gate_proj(combined)) # (B, T, n_orders) + gate_scalar = gate.mean(axis=-1, keepdims=True) # (B, T, 1) + return self.proj(combined) * gate_scalar # (B, T, output_dim) + + +# ============================================================================ +# SkipGramHashEmbedding — non-adjacent token hash logit bias. +# Captures structured patterns (e.g. token[-1] × token[-3]). +# Enabled when SKIPGRAM_HASH_SIZE > 0. +# ============================================================================ +class SkipGramHashEmbedding(nn.Module): + def __init__(self, hash_size: int = 4096, dim: int = 1024, + skip_patterns: list = None): + super().__init__() + self.hash_size = hash_size + self.dim = dim + # Each pattern is a list of negative offsets, e.g. [-1, -3] + self.skip_patterns = skip_patterns if skip_patterns is not None else [[-1, -3], [-1, -5], [-2, -4]] + self.tables = { + f"skip_{i}": nn.Embedding(hash_size, dim) + for i in range(len(self.skip_patterns)) + } + for tbl in self.tables.values(): + tbl.weight = tbl.weight * 0.01 + + def __call__(self, tokens: mx.array) -> mx.array: + """tokens: (B, T) → (B, T, dim) additive logit bias""" + B, T = tokens.shape + output = mx.zeros((B, T, self.dim), dtype=mx.float32) + + for i, pattern in enumerate(self.skip_patterns): + tbl = self.tables[f"skip_{i}"] + min_offset = min(pattern) # most negative offset + valid_start = abs(min_offset) + if valid_start >= T: + continue + # Accumulate hash over all offsets + hash_val = mx.zeros((B, T - valid_start), dtype=mx.int32) + prime = 31337 + for offset in pattern: + start = valid_start + offset + end = T + offset + hash_val = hash_val * prime + tokens[:, start:end] + idx = mx.remainder(mx.abs(hash_val), self.hash_size) + emb = tbl(idx).astype(mx.float32) + pad = mx.zeros((B, valid_start, self.dim), dtype=mx.float32) + emb = mx.concatenate([pad, emb], axis=1) + output = output + emb + + return output + + # ============================================================================ # INNOVATION: SmearGate — blend each token embedding with previous token's # Technique: @unnir (PR #102/#135). Gate initialized to 3.0 → sigmoid≈0.95 pass-through. @@ -414,7 +563,10 @@ def __init__(self, vocab_size, num_layers, dim, num_heads, num_kv_heads, mlp_mult, logit_chunk_tokens, logit_softcap, rope_base, tied_embed_init_std, qk_gain_init, bigram_hash_size, use_ortho_init, rope_dims: int = 0, xsa_last_n: int = 0, - use_ln_scale: bool = True, smear_enabled: bool = True): + use_ln_scale: bool = True, smear_enabled: bool = True, + engram_lite_enabled: bool = False, engram_hash_size: int = 2048, + engram_embed_dim: int = 128, engram_n_heads: int = 2, + skipgram_hash_size: int = 0): super().__init__() self.logit_chunk_tokens = logit_chunk_tokens self.logit_softcap = logit_softcap @@ -440,8 +592,18 @@ def __init__(self, vocab_size, num_layers, dim, num_heads, num_kv_heads, ] self.final_norm = RMSNormNoWeight() - # INNOVATION: BigramHash on logits (None when bigram_hash_size=0) - self.bigram_hash = BigramHashEmbedding(bigram_hash_size, vocab_size) if bigram_hash_size > 0 else None + # Logit bias modules: EngramLite replaces BigramHash when enabled + if engram_lite_enabled: + self.engram_lite = EngramLiteEmbedding( + hash_size=engram_hash_size, embed_dim=engram_embed_dim, + output_dim=vocab_size, n_heads=engram_n_heads, orders=(2, 3)) + self.bigram_hash = None + else: + self.engram_lite = None + self.bigram_hash = BigramHashEmbedding(bigram_hash_size, vocab_size) if bigram_hash_size > 0 else None + + # SkipGram logit bias (additive, independent of BigramHash/EngramLite) + self.skipgram_hash = SkipGramHashEmbedding(hash_size=skipgram_hash_size, dim=vocab_size) if skipgram_hash_size > 0 else None # Zero-init output projections for b in self.blocks: @@ -489,20 +651,57 @@ def __call__(self, input_ids: mx.array) -> mx.array: x = self.blocks[self.num_encoder_layers + i](x, x0, qat) return self.final_norm(x) + def _add_logit_biases(self, logits: mx.array, input_ids: mx.array) -> mx.array: + """Add all enabled logit bias modules (BigramHash/EngramLite/SkipGram).""" + vocab = self.tok_emb.weight.shape[0] + if self.engram_lite is not None: + bias = self.engram_lite(input_ids).reshape(-1, vocab) + logits = logits + bias.astype(logits.dtype) + elif self.bigram_hash is not None: + bias = self.bigram_hash(input_ids).reshape(-1, vocab) + logits = logits + bias.astype(logits.dtype) + if self.skipgram_hash is not None: + bias = self.skipgram_hash(input_ids).reshape(-1, vocab) + logits = logits + bias.astype(logits.dtype) + return logits + def loss(self, input_ids: mx.array, target_ids: mx.array) -> mx.array: x = self(input_ids).reshape(-1, self.tok_emb.weight.shape[1]) y = target_ids.reshape(-1) + logits = x @ self.tok_emb.weight.astype(x.dtype).T + logits = self.softcap(logits) + logits = self._add_logit_biases(logits, input_ids) + return nn.losses.cross_entropy(logits.astype(mx.float32), y, reduction="mean") + + def complementary_loss(self, input_ids: mx.array, target_ids: mx.array, + bigram_probs: mx.array, alpha: float) -> mx.array: + """Cross-entropy loss that down-weights tokens easily predicted by bigrams. + For token at position t with predecessor prev: + weight[t] = 1 - alpha * P_bigram(target[t] | prev[t]) + Weights are clipped to [0.1, 1.0] and normalized so the effective + learning rate is preserved. + + Args: + bigram_probs: (V, V) float32 pre-computed P(next|prev) matrix. + alpha: strength of complementary weighting (0 = standard CE). + """ + x = self(input_ids).reshape(-1, self.tok_emb.weight.shape[1]) + y = target_ids.reshape(-1) logits = x @ self.tok_emb.weight.astype(x.dtype).T logits = self.softcap(logits) + logits = self._add_logit_biases(logits, input_ids) - # Add BigramHash logit bias (skipped when bigram_hash is None) - if self.bigram_hash is not None: - bigram_bias = self.bigram_hash(input_ids) # (B, T, vocab) - bigram_bias = bigram_bias.reshape(-1, bigram_bias.shape[-1]) - logits = logits + bigram_bias.astype(logits.dtype) + ce_per_token = nn.losses.cross_entropy(logits.astype(mx.float32), y, reduction="none") - return nn.losses.cross_entropy(logits.astype(mx.float32), y, reduction="mean") + # prev_tokens: the input token at each position (= predecessor of target) + prev_tokens = input_ids.reshape(-1) + p_bigram = bigram_probs[prev_tokens, y] # (B*T,) + weights = 1.0 - alpha * p_bigram.astype(mx.float32) + weights = mx.clip(weights, 0.1, 1.0) + weights = weights / (weights.mean() + 1e-8) # normalize: E[weight]=1 preserves effective LR + + return (ce_per_token * weights).mean() def token_losses(self, input_ids: mx.array, target_ids: mx.array) -> mx.array: """Return (B, T) per-token NLL — used for sliding-window eval.""" @@ -511,12 +710,19 @@ def token_losses(self, input_ids: mx.array, target_ids: mx.array) -> mx.array: y = target_ids.reshape(-1) logits = x @ self.tok_emb.weight.astype(x.dtype).T logits = self.softcap(logits) - if self.bigram_hash is not None: - bigram_bias = self.bigram_hash(input_ids).reshape(-1, self.tok_emb.weight.shape[0]) - logits = logits + bigram_bias.astype(logits.dtype) + logits = self._add_logit_biases(logits, input_ids) nll = nn.losses.cross_entropy(logits.astype(mx.float32), y, reduction="none") return nll.reshape(B, T) + def token_logits(self, input_ids: mx.array) -> mx.array: + """Return (B, T, V) raw logits — used by BackoffNgramMixer for mixing.""" + B, T = input_ids.shape + x = self(input_ids).reshape(-1, self.tok_emb.weight.shape[1]) + logits = x @ self.tok_emb.weight.astype(x.dtype).T + logits = self.softcap(logits) + logits = self._add_logit_biases(logits, input_ids) + return logits.reshape(B, T, -1) + # ============================================================================ # OPTIMIZERS (same structure as baseline) # ============================================================================ @@ -551,16 +757,19 @@ def __init__(self, model, args): self.args = args params = dict(tree_flatten(model.parameters())) self.embed_key = "tok_emb.weight" + _module_prefixes = ( + "blocks.", "bigram_hash.", "engram_lite.", "skipgram_hash.", + ) self.matrix_keys = [ k for k, p in params.items() - if (k.startswith("blocks.") or k.startswith("bigram_hash.")) + if any(k.startswith(pfx) for pfx in _module_prefixes) and p.ndim == 2 and not any(pat in k for pat in CONTROL_TENSOR_NAME_PATTERNS) ] self.scalar_keys = [ k for k, p in params.items() if k == "skip_weights" or ( - (k.startswith("blocks.") or k.startswith("bigram_hash.")) + any(k.startswith(pfx) for pfx in _module_prefixes) and (p.ndim < 2 or any(pat in k for pat in CONTROL_TENSOR_NAME_PATTERNS)) ) ] @@ -776,6 +985,164 @@ def dequantize_state_dict_int6(quant_obj): out[name] = mx.array(out_arr, dtype=MX_DTYPE_FROM_NAME[orig]) if isinstance(orig, str) else mx.array(out_arr) return out +# ============================================================================ +# COMPLEMENTARY TRAINING HELPER +# ============================================================================ +def build_bigram_stats(data_path: str, vocab_size: int = 1024) -> np.ndarray: + """Pre-compute P(next_token | prev_token) from all training shards. + + Returns a (vocab_size, vocab_size) float32 array where entry [i, j] is + the smoothed probability of token j following token i. Uses Laplace + smoothing so every entry is > 0. + + The result is used for complementary training (down-weighting tokens that + bigram statistics can already predict well) and is NOT stored in the + artifact — it is recomputed from training data at the start of each run. + """ + counts = np.zeros((vocab_size, vocab_size), dtype=np.float64) + # Shard pattern is the same as args.train_files (fineweb_train_*.bin) + shard_paths = sorted(glob.glob(f"{data_path}/fineweb_train_*.bin")) + for shard_path in shard_paths: + tokens = load_data_shard(Path(shard_path)) + prev = tokens[:-1].astype(np.int32) + curr = tokens[1:].astype(np.int32) + mask = (prev < vocab_size) & (curr < vocab_size) + np.add.at(counts, (prev[mask], curr[mask]), 1.0) + # Laplace smoothing: add 1 to every cell, normalize per row + counts += 1.0 + row_sums = counts.sum(axis=1, keepdims=True) + return (counts / row_sums).astype(np.float32) + + +# ============================================================================ +# BACKOFF N-GRAM MIXER +# Causal, zero-artifact-cost eval-time n-gram language model. +# +# Competition compliance notes: +# - Produces a full normalized probability distribution over the vocabulary +# at each step (sums to 1 by construction). +# - Strictly causal: only tokens at positions < current position are used. +# - No artifact cost: the cache is built from scratch during each evaluation +# pass using the tokens already scored — it is never saved to disk. +# ============================================================================ +class BackoffNgramMixer: + """Causal n-gram LM with linear-interpolation backoff. + + For each evaluation position t the mixer: + 1. Queries count tables built from tokens at positions 0 .. t-1. + 2. Produces P_ngram(· | context) via linear interpolation from order 1 + up to max_order — a valid probability distribution over all vocab_size + tokens. + 3. Mixes with the neural model's distribution: + P_mix = (1 - α) · P_neural + α · P_ngram + 4. Scores the true next token under P_mix. + 5. Updates the count tables with the token just scored (so it can be used + as context for future positions). + """ + + def __init__(self, vocab_size: int = 1024, max_order: int = 4, + hash_buckets: int = 2_000_000, # ~2M buckets ≈ 16 collisions at 32M tokens + alpha_mode: str = "entropy_adaptive", + fixed_alpha: float = 0.25): + self.vocab_size = vocab_size + self.max_order = max_order + self.hash_buckets = hash_buckets + self.alpha_mode = alpha_mode + self.fixed_alpha = fixed_alpha + self._reset() + + def _reset(self): + """Clear all count tables — call before each new eval pass.""" + # counts[order][ctx_hash] -> float32 array of shape (vocab_size,) + self._counts = [ + defaultdict(lambda: np.zeros(self.vocab_size, dtype=np.float32)) + for _ in range(self.max_order + 1) + ] + self._total = [defaultdict(float) for _ in range(self.max_order + 1)] + + def _hash_ctx(self, context_tokens) -> int: + h = 0 + for t in context_tokens: + h = (h * 31337 + int(t)) % self.hash_buckets + return h + + def _ngram_probs(self, context_tokens) -> np.ndarray: + """Interpolated n-gram distribution P(· | context). Sums to 1.""" + V = self.vocab_size + probs = np.ones(V, dtype=np.float64) / V # uniform prior (order 0) + for order in range(1, self.max_order + 1): + if len(context_tokens) < order: + break + ctx_hash = self._hash_ctx(context_tokens[-order:]) + total = self._total[order][ctx_hash] + if total <= 0.0: + continue + # Confidence-weighted interpolation: λ→1 when total>>5 (counts well established) + lam = total / (total + 5.0) # 5.0: discount factor; reaches λ=0.5 at 5 observations + c = self._counts[order][ctx_hash].astype(np.float64) + order_probs = (c + 1e-10) / (total + 1e-10 * V) + order_probs /= order_probs.sum() + probs = (1.0 - lam) * probs + lam * order_probs + # Final normalization (guard against floating-point drift) + s = probs.sum() + if s > 0: + probs /= s + return probs.astype(np.float32) + + def _mixing_alpha(self, neural_logits: np.ndarray) -> float: + """Entropy-adaptive mixing weight α ∈ [0.15, 0.60].""" + if self.alpha_mode == "fixed": + return self.fixed_alpha + # High neural entropy → trust n-grams more + logits = neural_logits.astype(np.float64) + logits -= logits.max() + probs = np.exp(logits) + probs /= probs.sum() + entropy = float(-np.sum(probs * np.log2(probs + 1e-10))) + max_entropy = math.log2(self.vocab_size) + normalized = entropy / max_entropy + return 0.15 + 0.45 * normalized # α ∈ [0.15, 0.60]: min trust even when confident; max when fully uncertain + + def score_and_update(self, context_tokens, target_token: int, + neural_logits: np.ndarray) -> float: + """Score target_token and update the cache. Must be called in order. + + Args: + context_tokens: sequence of token IDs before the current position. + target_token: the true next token to score. + neural_logits: (V,) float32/64 raw logits from the neural model + at the current position. + + Returns: + log_prob: natural-log probability of target_token under P_mix. + """ + ngram_probs = self._ngram_probs(context_tokens) + alpha = self._mixing_alpha(neural_logits) + + # Normalize neural distribution + nl = neural_logits.astype(np.float64) + nl -= nl.max() + neural_probs = np.exp(nl) + neural_probs /= neural_probs.sum() + + mixed = (1.0 - alpha) * neural_probs + alpha * ngram_probs.astype(np.float64) + s = mixed.sum() + if s > 0: + mixed /= s + + log_prob = float(np.log(mixed[target_token] + 1e-40)) + + # Update cache AFTER scoring (causal: this token becomes future context) + tok = int(target_token) + for order in range(1, self.max_order + 1): + if len(context_tokens) >= order: + ctx_hash = self._hash_ctx(context_tokens[-order:]) + self._counts[order][ctx_hash][tok] += 1.0 + self._total[order][ctx_hash] += 1.0 + + return log_prob + + # ============================================================================ # VALIDATION HELPERS # ============================================================================ @@ -950,7 +1317,101 @@ def eval_val_sliding(args, model, val_tokens, base_bytes_lut, has_leading_space_ return val_loss, val_bpb -def clip_grad_tree(grads_tree, max_norm): +def eval_val_sliding_ngram(args, model, val_tokens, + base_bytes_lut, has_leading_space_lut, + is_boundary_token_lut, log_fn=None): + """Sliding-window eval with BackoffNgramMixer post-processing. + + Identical windowing strategy to eval_val_sliding, but each "new" token + in each window is scored under a mixture of the neural distribution and + a causal n-gram distribution built incrementally from all previously + scored tokens. + + Competition compliance: + - The n-gram cache only sees tokens at positions strictly before the + current position (causal). + - The mixed distribution sums to 1 at every position. + - The n-gram cache is built from scratch during the eval pass and is + never saved to disk (zero artifact cost). + """ + seq_len = args.eval_seq_len + stride = args.eval_stride + batch_seqs = args.eval_batch_seqs + max_order = args.ngram_max_order + total_tokens = val_tokens.size - 1 + + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= 1] + total_windows = len(window_starts) + + loss_sum = 0.0 + token_count = 0.0 + byte_count = 0.0 + + model.use_qat = False + mixer = BackoffNgramMixer( + vocab_size=args.vocab_size, + max_order=max_order, + alpha_mode="entropy_adaptive", + fixed_alpha=args.ngram_alpha, + ) + + for bi in range(0, total_windows, batch_seqs): + batch_ws = window_starts[bi:bi + batch_seqs] + bsz = len(batch_ws) + + x_np = np.zeros((bsz, seq_len), dtype=np.int32) + y_np = np.zeros((bsz, seq_len), dtype=np.int32) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + x_np[i, :wlen] = val_tokens[ws:end] + y_np[i, :wlen] = val_tokens[ws + 1:end + 1] + + # Get full (B, T, V) logits for the batch + x = mx.array(x_np) + logits_all = model.token_logits(x) # (B, T, V) + mx.eval(logits_all) + logits_np = np.array(logits_all) # (B, T, V) float32 + + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + + for j in range(s, wlen): + global_pos = ws + j # global index of the input token + target = int(y_np[i, j]) # = val_tokens[global_pos + 1] + neural_logits = logits_np[i, j] # (V,) + + # N-gram context: all tokens before global_pos + 1 + ctx_start = max(0, global_pos + 1 - max_order) + context = val_tokens[ctx_start:global_pos + 1].tolist() + + log_prob = mixer.score_and_update(context, target, neural_logits) + loss_sum += -log_prob + token_count += 1.0 + tgt_arr = y_np[i, j:j + 1] + prev_arr = x_np[i, j:j + 1] + tb = float(base_bytes_lut[tgt_arr[0]]) + tb += float(has_leading_space_lut[tgt_arr[0]] and not is_boundary_token_lut[prev_arr[0]]) + byte_count += tb + + if log_fn and (bi // batch_seqs) % 50 == 0: + done = min(bi + batch_seqs, total_windows) + pct = done / total_windows * 100 + rbpb = 0.0 + if token_count > 0 and byte_count > 0: + rbpb = (loss_sum / token_count) / math.log(2.0) * (token_count / byte_count) + log_fn(f"ngram_sliding_eval [{pct:5.1f}%] {done}/{total_windows} windows running_bpb={rbpb:.6f}") + + val_loss = loss_sum / max(token_count, 1.0) + val_bpb = (val_loss / math.log(2.0)) * (token_count / max(byte_count, 1.0)) + return val_loss, val_bpb + + + if max_norm <= 0: return grads_tree flat = dict(tree_flatten(grads_tree)) @@ -1112,6 +1573,11 @@ def log(msg, console=True): use_ortho_init=args.use_ortho_init, rope_dims=args.rope_dims, xsa_last_n=args.xsa_last_n, use_ln_scale=args.ln_scale_enabled, smear_enabled=args.smear_enabled, + engram_lite_enabled=args.engram_lite_enabled, + engram_hash_size=args.engram_hash_size, + engram_embed_dim=args.engram_embed_dim, + engram_n_heads=args.engram_n_heads, + skipgram_hash_size=args.skipgram_hash_size, ) opt = SplitOptimizers(model, args) @@ -1120,9 +1586,29 @@ def log(msg, console=True): # SWA buffer — starts at 60% of training when USE_SWA=1 swa = None - compiled_loss = mx.compile(lambda x, y: model.loss(x, y), inputs=model.state, outputs=model.state) + # Complementary training: precompute bigram stats once if needed + bigram_probs_mx = None + if args.complement_alpha > 0.0: + log("complement_training: building bigram stats from training shards...") + _bp_np = build_bigram_stats(args.data_path, args.vocab_size) + bigram_probs_mx = mx.array(_bp_np, dtype=mx.float32) + mx.eval(bigram_probs_mx) + log(f"complement_training: bigram stats ready (alpha={args.complement_alpha})") + del _bp_np + + # Wire up loss functions — use complementary loss when alpha > 0 + if bigram_probs_mx is not None: + _alpha = args.complement_alpha + _bp = bigram_probs_mx + def _loss_fn(x, y): + return model.complementary_loss(x, y, _bp, _alpha) + else: + def _loss_fn(x, y): + return model.loss(x, y) + + compiled_loss = mx.compile(_loss_fn, inputs=model.state, outputs=model.state) compiled_loss_and_grad = mx.compile( - nn.value_and_grad(model, lambda x, y: model.loss(x, y)), + nn.value_and_grad(model, _loss_fn), inputs=model.state, outputs=model.state, ) @@ -1136,6 +1622,10 @@ def log(msg, console=True): f"ln_scale={args.ln_scale_enabled} xsa_last_n={args.xsa_last_n} xsa_layers={xsa_layers} " f"gptq_lite={args.use_gptq_lite} ttt={args.ttt_enabled} eval_mode={args.eval_mode} " f"use_swa={args.use_swa} swa_decay={args.swa_decay}") + log(f"moonshot: engram_lite={args.engram_lite_enabled} skipgram_hash={args.skipgram_hash_size} " + f"complement_alpha={args.complement_alpha} " + f"ngram_mixer={args.ngram_mixer_enabled} ngram_alpha={args.ngram_alpha} " + f"ngram_max_order={args.ngram_max_order}") log(f"iterations:{args.iterations} train_batch_tokens:{args.train_batch_tokens} " f"grad_accum:{args.grad_accum_steps} seq_len:{args.train_seq_len}") log(f"optimizer: muon_keys:{len(opt.matrix_keys)} scalar_keys:{len(opt.scalar_keys)}") @@ -1177,7 +1667,7 @@ def log(msg, console=True): if _avg is not None: saved_state = {k: mx.array(v) for k, v in tree_flatten(model.parameters())} _avg.apply(model) - compiled_loss = mx.compile(lambda x, y: model.loss(x, y), inputs=model.state, outputs=model.state) + compiled_loss = mx.compile(_loss_fn, inputs=model.state, outputs=model.state) model.use_qat = False # No QAT during eval val_loss, val_bpb = eval_val(args, compiled_loss, val_tokens, @@ -1187,9 +1677,9 @@ def log(msg, console=True): if _avg is not None: model.update(tree_unflatten(list(saved_state.items()))) - compiled_loss = mx.compile(lambda x, y: model.loss(x, y), inputs=model.state, outputs=model.state) + compiled_loss = mx.compile(_loss_fn, inputs=model.state, outputs=model.state) compiled_loss_and_grad = mx.compile( - nn.value_and_grad(model, lambda x, y: model.loss(x, y)), + nn.value_and_grad(model, _loss_fn), inputs=model.state, outputs=model.state) t0 = time.perf_counter() @@ -1210,10 +1700,9 @@ def log(msg, console=True): if _new_use_qat: log(f"qat_started:step={step} lr_mul={lr_mul:.4f} — recompiling graph") compiled_loss = mx.compile( - lambda x, y: model.loss(x, y), - inputs=model.state, outputs=model.state) + _loss_fn, inputs=model.state, outputs=model.state) compiled_loss_and_grad = mx.compile( - nn.value_and_grad(model, lambda x, y: model.loss(x, y)), + nn.value_and_grad(model, _loss_fn), inputs=model.state, outputs=model.state) # Initialize EMA after ema_start_frac; initialize SWA at 60% of iterations @@ -1296,7 +1785,7 @@ def log(msg, console=True): quant_blob_disk = f.read() quant_flat = dequantize_state_dict_int6(pickle.loads(zstandard.ZstdDecompressor().decompress(quant_blob_disk))) model.update(tree_unflatten(list(quant_flat.items()))) - compiled_loss = mx.compile(lambda x, y: model.loss(x, y), inputs=model.state, outputs=model.state) + compiled_loss = mx.compile(_loss_fn, inputs=model.state, outputs=model.state) eval_mode = args.eval_mode.lower().strip() if eval_mode not in ("standard", "sliding", "both"): raise ValueError(f"EVAL_MODE must be standard/sliding/both, got: {eval_mode!r}") @@ -1315,7 +1804,13 @@ def log(msg, console=True): # sliding path (run for "sliding" or "both") if eval_mode in ("sliding", "both"): qt0 = time.perf_counter() - if args.ttt_enabled: + if args.ngram_mixer_enabled: + log(f"final_eval_mode:sliding_ngram_mixer eval_seq_len:{args.eval_seq_len} " + f"stride:{args.eval_stride} ngram_alpha:{args.ngram_alpha} " + f"ngram_max_order:{args.ngram_max_order}") + q_val_loss, q_val_bpb = eval_val_sliding_ngram(args, model, val_tokens, + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, log_fn=log) + elif args.ttt_enabled: log(f"final_eval_mode:ttt_sliding rank:{args.ttt_rank} lr:{args.ttt_lr} steps:{args.ttt_steps} stride:{args.eval_stride}") q_val_loss, q_val_bpb = eval_val_sliding_ttt(args, model, val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, log_fn=log)