diff --git a/records/track_10min_16mb/2026-04-03_SplitLR_NgramAgreement_FullGPTQ/README.md b/records/track_10min_16mb/2026-04-03_SplitLR_NgramAgreement_FullGPTQ/README.md new file mode 100644 index 0000000000..64c6fcc5d6 --- /dev/null +++ b/records/track_10min_16mb/2026-04-03_SplitLR_NgramAgreement_FullGPTQ/README.md @@ -0,0 +1,126 @@ +# Record: Split-LR + N-gram Agreement + Full Hessian GPTQ + Brotli + +**val_bpb: 1.1078** (3-seed mean, std 0.0009) | **1.8752 nats** | **~15.86 MB** | 8xH100 SXM, 600s train + 449s eval + +Built on [PR #1179](https://github.com/openai/parameter-golf/pull/1179) by @dexhunter (training) and [PR #1145](https://github.com/openai/parameter-golf/pull/1145) by @AnirudhRahul (n-gram agreement evaluation). + +## Results (8xH100 80GB SXM, PyTorch 2.9.1+cu128) + +| Seed | Steps | ms/step | Sliding BPB | **N-gram BPB** | Artifact | +|------|-------|---------|-------------|----------------|----------| +| 1337 | ~6780 | 88.0 | 1.1110 | **1.1083** | 15,853,466 | +| 42 | ~6780 | 88.0 | 1.1095 | **1.1068** | 15,857,705 | +| 2025 | ~6780 | 88.0 | 1.1112 | **1.1085** | 15,846,914 | +| **Mean** | | | **1.1106** | **1.1078** | | + +SOTA (PR #1019, 3-seed mean): **1.8822 nats**. This run: **1.8752 nats**. Delta: **-0.00697 nats**. Clears the 0.005-nat threshold. + +### Timing Budget + +| Phase | Time | +|-------|------| +| Training (wallclock cap) | ~591s | +| GPTQ calibration (reserved) | ~7s | +| Post-EMA eval | ~2s | +| Int6 roundtrip eval | ~7s | +| Sliding window eval (stride=64) | ~78s | +| **N-gram agreement eval** | **~449s** | +| **Total eval** | **~536s** | + +## What's New vs PR #1019 + +### Training improvements (from PR #1179) +1. **Split-LR** — different learning rates for early (0.025) vs late (0.030) layers +2. **BigramHash(2816x160)** — wider projection (160 vs 112), fewer buckets +3. **Sigmoid-gated U-Net** — learnable gates on encoder-decoder skip connections +4. **Soft-round QAT** — temperature-controlled rounding (alpha 1->16) replacing STE +5. **Brotli-11 + byte-shuffle** — saves ~400KB vs LZMA +6. **Coprime-stride data loader** — better data shuffling and coverage + +### Evaluation improvement (from PR #1145) +7. **Online n-gram agreement** — 3 causal experts (token n-gram, within-word, word-start) with agreement boosting. Adjusts LLM probabilities via properly normalized exponential tilting. Contributes **-0.0028 BPB**. + +## N-gram Agreement: How It Works + +Three online n-gram experts predict the next token using only already-scored (past) tokens: +- **Token n-gram** (16-gram context, hash table): predicts based on raw token patterns +- **Within-word continuation**: predicts next subword within the current word +- **Word-start hints**: predicts first token of next word based on previous word context + +For each position, the expert with highest expected gain is selected. When 2+ experts agree on the same token, their boost is increased. The LLM's probability is adjusted via exponential tilting: + +``` +p_adjusted = (scale * p_true) / (1 - p_hint + scale * p_hint) +``` + +This produces a properly normalized distribution (sums to 1.0). The adjustment is: +- **Causal**: each expert predicts BEFORE updating its state with the target token +- **Score-first**: runs under `torch.inference_mode()`, no model parameters modified +- **Properly normalized**: exponential tilting with correct partition function + +## Legality + +- Standard F.cross_entropy for training +- N-gram agreement: causal, score-first, properly normalized (exponential tilting) +- No training on validation data +- No SLOT, no multi-epoch TTT +- GPTQ calibration within training budget +- Artifact < 16,000,000 bytes (all seeds) +- Training <= 600s, eval <= 600s (all seeds) + +## Architecture + +| Component | Setting | +|-----------|---------| +| Layers | 11 (512d, 8 GQA heads, 4 KV heads) | +| MLP | 3x (1536) with LeakyReLU(0.5)^2 | +| Attention | XSA on all 11 layers | +| BigramHash | 2816 x dim=160 | +| Split-LR | early=0.025, late=0.030, bank_split=5 | +| Skip connections | Sigmoid-gated U-Net | +| QAT | Soft-round (alpha ramp 1->16) | +| RoPE | Partial (16/64 dims) | +| LN Scale | 1/sqrt(layer+1) | +| VE128 | Layers 9-10 | +| SmearGate | Position-mixing gate | +| Weight avg | EMA(0.997) + SWA(every 50) | +| Quantization | Full Hessian GPTQ int6 | +| Compression | Brotli quality=11 + byte-shuffle | +| Optimizer | Parallel Muon + Parameter Banking | +| Eval | Online n-gram agreement (token 16-gram + within-word + word-start) | + +## Run Command + +```bash +# Training (3 seeds) +pip install brotli +for SEED in 1337 42 2025; do + BIGRAM_DIM=160 SEED=$SEED \ + torchrun --standalone --nproc_per_node=8 train_gpt.py 2>&1 | tee train_seed${SEED}.log + cp final_model.int6.ptz checkpoints/final_model_seed${SEED}.int6.ptz +done + +# N-gram agreement evaluation (per seed) +gcc -O3 -march=native -shared -fPIC -o libonline_ngram_state.so online_ngram_state.c +for SEED in 1337 42 2025; do + BIGRAM_DIM=160 CHECKPOINT=checkpoints/final_model_seed${SEED}.int6.ptz \ + torchrun --standalone --nproc_per_node=8 eval_ngram_on_checkpoint.py +done +``` + +## Credits + +- **Training scaffold**: [PR #1179](https://github.com/openai/parameter-golf/pull/1179) by @dexhunter (built on PR #1019 by @abaybektursun) +- **N-gram agreement eval**: [PR #1145](https://github.com/openai/parameter-golf/pull/1145) by @AnirudhRahul +- **Full Hessian GPTQ**: [PR #535](https://github.com/openai/parameter-golf/pull/535) by @raahilshah +- **XSA-all**: [PR #478](https://github.com/openai/parameter-golf/pull/478) by @gowtham0992 + +## Included Files + +- `train_gpt.py` — training + quantization + sliding window eval +- `online_best_agree_eval.py` — n-gram agreement evaluation +- `online_ngram_state.c` — native n-gram hash table (compiled at eval time) +- `eval_ngram_on_checkpoint.py` — helper to run n-gram eval on saved checkpoints +- `train_seed{1337,42,2025}.log` — training logs +- `submission_ngram_seed{1337,42,2025}.log` — n-gram eval logs +- `submission.json` — leaderboard metadata diff --git a/records/track_10min_16mb/2026-04-03_SplitLR_NgramAgreement_FullGPTQ/eval_ngram_on_checkpoint.py b/records/track_10min_16mb/2026-04-03_SplitLR_NgramAgreement_FullGPTQ/eval_ngram_on_checkpoint.py new file mode 100644 index 0000000000..1b57fec328 --- /dev/null +++ b/records/track_10min_16mb/2026-04-03_SplitLR_NgramAgreement_FullGPTQ/eval_ngram_on_checkpoint.py @@ -0,0 +1,135 @@ +#!/usr/bin/env python3 +"""Evaluate n-gram agreement on a saved int6 checkpoint.""" +from __future__ import annotations +import io +import os +import sys +import time + +import brotli +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist + +# Add current dir to path for imports +sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) +from train_gpt import ( + GPT, + CastedLinear, + Hyperparameters, + _byte_unshuffle, + _unbank_state_dict, + _rebank_state_dict, + build_sentencepiece_luts, + dequantize_mixed_int6, + load_validation_tokens, + restore_low_dim_params_to_fp32, +) +from online_best_agree_eval import eval_val_sliding_online_best_agree + + +def main(): + args = Hyperparameters() + args.bigram_dim = int(os.environ.get("BIGRAM_DIM", "160")) + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master = rank == 0 + + def log0(msg, console=True): + if master and console: + print(msg, flush=True) + + # Load tokenizer + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + + # Load int6 checkpoint + ptz_path = os.environ.get("CHECKPOINT", "final_model.int6.ptz") + log0(f"Loading checkpoint: {ptz_path}") + with open(ptz_path, "rb") as f: + quant_blob = f.read() + quant_state = torch.load( + io.BytesIO(_byte_unshuffle(brotli.decompress(quant_blob))), + map_location="cpu", + ) + + # Build model + eval_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + bigram_vocab_size=args.bigram_vocab_size, + bigram_dim=args.bigram_dim, + xsa_last_n=args.xsa_last_n, + rope_dims=args.rope_dims, + ln_scale=args.ln_scale, + ve_enabled=args.ve_enabled, + ve_dim=args.ve_dim, + ve_layers=args.ve_layers, + neg_slope=args.negative_slope, + ).to(device).bfloat16() + eval_model.qo_bank.data = eval_model.qo_bank.data.float() + eval_model.kv_bank.data = eval_model.kv_bank.data.float() + eval_model.mlp_up_bank.data = eval_model.mlp_up_bank.data.float() + eval_model.mlp_down_bank.data = eval_model.mlp_down_bank.data.float() + for m in eval_model.modules(): + if isinstance(m, CastedLinear): + m.float() + restore_low_dim_params_to_fp32(eval_model) + + # Dequantize and load weights + template_sd = {k: v.detach().cpu() for k, v in eval_model.state_dict().items()} + unbanked_template = _unbank_state_dict(template_sd, args.num_layers) + deq_unbanked = dequantize_mixed_int6(quant_state["w"], quant_state["m"], unbanked_template) + deq_state = _rebank_state_dict(deq_unbanked, args.num_layers, template_sd) + eval_model.load_state_dict(deq_state, strict=True) + eval_model.eval() + + log0(f"Model loaded, running n-gram agreement eval...") + t0 = time.perf_counter() + _, best_bpb, timings = eval_val_sliding_online_best_agree( + args=args, + base_model=eval_model, + rank=rank, + world_size=world_size, + device=device, + val_tokens=val_tokens, + base_bytes_lut=base_bytes_lut, + has_leading_space_lut=has_leading_space_lut, + is_boundary_token_lut=is_boundary_token_lut, + stride=args.eval_stride, + batch_seqs=32, + eval_seq_len=args.train_seq_len, + log0=log0, + ) + elapsed = time.perf_counter() - t0 + log0(f"n-gram agreement BPB: {best_bpb:.8f} (elapsed: {elapsed:.1f}s)") + log0(f"LLM-only BPB: {timings['llm_bpb']:.8f}") + log0(f"Gain: {timings['gain_bpb']:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() diff --git a/records/track_10min_16mb/2026-04-03_SplitLR_NgramAgreement_FullGPTQ/online_best_agree_eval.py b/records/track_10min_16mb/2026-04-03_SplitLR_NgramAgreement_FullGPTQ/online_best_agree_eval.py new file mode 100644 index 0000000000..dd94dbb344 --- /dev/null +++ b/records/track_10min_16mb/2026-04-03_SplitLR_NgramAgreement_FullGPTQ/online_best_agree_eval.py @@ -0,0 +1,671 @@ +#!/usr/bin/env python3 +from __future__ import annotations + +import ctypes +import math +import os +import subprocess +import time +from collections import deque +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F + + +SCRIPT_DIR = Path(__file__).resolve().parent +ONLINE_NGRAM_SRC = SCRIPT_DIR / "online_ngram_state.c" +ONLINE_NGRAM_LIB = SCRIPT_DIR / "libonline_ngram_state.so" + +WHITESPACE_BYTE_IDS = {9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 36} +EDGE_PUNCT = ".,:;!?()[]{}<>\"'`" + + +def normalize_word(text: str, mode: str) -> str: + text = text.strip() + if mode == "lower": + return text.lower() + if mode == "identity": + return text + if mode == "strip_punct_lower": + return text.strip(EDGE_PUNCT).lower() + raise ValueError(f"Unknown word normalization mode: {mode}") + + +def apply_boost( + llm_true_probs: np.ndarray, + llm_hint_probs: np.ndarray, + hit_mask: np.ndarray, + gate_mask: np.ndarray, + boost: float | np.ndarray, +) -> np.ndarray: + boosted = llm_true_probs.astype(np.float64, copy=True) + if not gate_mask.any(): + return boosted + + if np.isscalar(boost): + scale = math.exp(float(boost)) + hit_gate = gate_mask & hit_mask + miss_gate = gate_mask & ~hit_mask + boosted[hit_gate] = (scale * llm_true_probs[hit_gate]) / ( + 1.0 - llm_true_probs[hit_gate] + scale * llm_true_probs[hit_gate] + ) + boosted[miss_gate] = llm_true_probs[miss_gate] / ( + 1.0 - llm_hint_probs[miss_gate] + scale * llm_hint_probs[miss_gate] + ) + return boosted + + boost_arr = boost.astype(np.float64, copy=False) + scale = np.ones(llm_true_probs.shape, dtype=np.float64) + scale[gate_mask] = np.exp(boost_arr[gate_mask]) + denom = 1.0 - llm_hint_probs + scale * llm_hint_probs + boosted = llm_true_probs / denom + hit_gate = gate_mask & hit_mask + boosted[hit_gate] *= scale[hit_gate] + return boosted + + +def expected_gain(top_prob: np.ndarray, llm_hint_prob: np.ndarray, boost: float) -> np.ndarray: + q = np.clip(llm_hint_prob.astype(np.float64, copy=False), 1e-12, 1.0 - 1e-12) + p = np.clip(top_prob.astype(np.float64, copy=False), 0.0, 1.0) + log_norm = np.log1p(q * (math.exp(boost) - 1.0)) + return (p * boost - log_norm).astype(np.float32) + + +def compute_best_agreement_chunk( + *, + llm_chunk: np.ndarray, + true_targets: np.ndarray, + token_top_prob: np.ndarray, + token_top_token: np.ndarray, + token_hint_probs: np.ndarray, + within_top_prob: np.ndarray, + within_top_token: np.ndarray, + within_valid: np.ndarray, + within_hint_probs: np.ndarray, + word_top_prob: np.ndarray, + word_top_token: np.ndarray, + word_hint_probs: np.ndarray, + token_threshold: float, + token_boost: float, + within_tau: float, + within_boost: float, + word_tau: float, + word_boost: float, + agree_add_boost: float, +) -> np.ndarray: + token_hit = token_top_token == true_targets + token_gate = token_top_prob >= token_threshold + token_exp_gain = expected_gain(token_top_prob, token_hint_probs, token_boost) + + within_hit = within_top_token == true_targets + within_gate = within_valid & (within_top_prob >= within_tau) + within_exp_gain = expected_gain(within_top_prob, within_hint_probs, within_boost) + + word_hit = word_top_token == true_targets + word_gate = word_top_prob >= word_tau + word_exp_gain = expected_gain(word_top_prob, word_hint_probs, word_boost) + + within_pick = within_gate & (~token_gate | (within_exp_gain > token_exp_gain)) + token_pick_tw = token_gate & ~within_pick + tw_gate = token_pick_tw | within_pick + + word_pick = word_gate & ((~tw_gate) | (token_pick_tw & (word_exp_gain > token_exp_gain)) | (within_pick & (word_exp_gain > within_exp_gain))) + token_pick = token_pick_tw & ~word_pick + within_pick_final = within_pick & ~word_pick + chosen_gate = token_pick | within_pick_final | word_pick + + chosen_hint_probs = np.zeros(llm_chunk.shape, dtype=np.float64) + chosen_hint_probs[token_pick] = token_hint_probs[token_pick] + chosen_hint_probs[within_pick_final] = within_hint_probs[within_pick_final] + chosen_hint_probs[word_pick] = word_hint_probs[word_pick] + + chosen_hit = np.zeros(llm_chunk.shape, dtype=np.bool_) + chosen_hit[token_pick] = token_hit[token_pick] + chosen_hit[within_pick_final] = within_hit[within_pick_final] + chosen_hit[word_pick] = word_hit[word_pick] + + chosen_boost = np.zeros(llm_chunk.shape, dtype=np.float64) + chosen_boost[token_pick] = token_boost + chosen_boost[within_pick_final] = within_boost + chosen_boost[word_pick] = word_boost + + selected_token = np.zeros(llm_chunk.shape, dtype=np.uint16) + selected_token[token_pick] = token_top_token[token_pick] + selected_token[within_pick_final] = within_top_token[within_pick_final] + selected_token[word_pick] = word_top_token[word_pick] + + agree_count = np.zeros(llm_chunk.shape, dtype=np.uint8) + agree_count += (token_gate & (token_top_token == selected_token)).astype(np.uint8) + agree_count += (within_gate & (within_top_token == selected_token)).astype(np.uint8) + agree_count += (word_gate & (word_top_token == selected_token)).astype(np.uint8) + agree_any = chosen_gate & (agree_count >= 2) + + agree_boost = chosen_boost.copy() + agree_boost[agree_any] += agree_add_boost + return apply_boost(llm_chunk, chosen_hint_probs, chosen_hit, chosen_gate, agree_boost) + + +def dist_max_float(value: float, device: torch.device, world_size: int) -> float: + if world_size <= 1: + return float(value) + tensor = torch.tensor([value], dtype=torch.float64, device=device) + dist.all_reduce(tensor, op=dist.ReduceOp.MAX) + return float(tensor.item()) + + +def suggest_table_bits(expected_entries: int, load_factor: float) -> int: + expected_entries = max(int(expected_entries), 1) + size = 1 + while size * load_factor < expected_entries: + size <<= 1 + return max(size.bit_length() - 1, 10) + + +def loss_to_bpb(total_loss: float, total_bytes: float) -> float: + return total_loss / (total_bytes * math.log(2.0)) + + +def loss_to_nats_per_byte(total_loss: float, total_bytes: float) -> float: + return total_loss / total_bytes + + +def build_chunk_windows(total_targets: int, seq_len: int, stride: int, chunk_tokens: int) -> list[list[int]]: + window_starts = [ + ws + for ws in range(0, total_targets, stride) + if min(ws + seq_len, total_targets) - ws >= stride or ws == 0 + ] + full_num_chunks = (total_targets + chunk_tokens - 1) // chunk_tokens + chunk_windows: list[list[int]] = [[] for _ in range(full_num_chunks)] + for ws in window_starts: + end = min(ws + seq_len, total_targets) + wlen = end - ws + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_start = ws + s + ci = min(scored_start // chunk_tokens, full_num_chunks - 1) + chunk_windows[ci].append(ws) + return chunk_windows + + +def ensure_online_ngram_lib(log0) -> ctypes.CDLL: + needs_build = (not ONLINE_NGRAM_LIB.exists()) or ( + ONLINE_NGRAM_SRC.stat().st_mtime_ns > ONLINE_NGRAM_LIB.stat().st_mtime_ns + ) + if needs_build: + log0(f"building_native_ngram_helper src={ONLINE_NGRAM_SRC.name}") + subprocess.run( + [ + "gcc", + "-O3", + "-march=native", + "-shared", + "-fPIC", + "-o", + str(ONLINE_NGRAM_LIB), + str(ONLINE_NGRAM_SRC), + ], + check=True, + ) + lib = ctypes.CDLL(str(ONLINE_NGRAM_LIB)) + lib.online_ngram_state_create.restype = ctypes.c_void_p + lib.online_ngram_state_create.argtypes = [ctypes.c_int, ctypes.c_int, ctypes.c_int] + lib.online_ngram_state_destroy.restype = None + lib.online_ngram_state_destroy.argtypes = [ctypes.c_void_p] + lib.online_ngram_state_seed_prefix_token.restype = None + lib.online_ngram_state_seed_prefix_token.argtypes = [ctypes.c_void_p, ctypes.c_uint16] + lib.online_ngram_state_process_chunk.restype = ctypes.c_int + lib.online_ngram_state_process_chunk.argtypes = [ + ctypes.c_void_p, + ctypes.POINTER(ctypes.c_uint16), + ctypes.c_int64, + ctypes.POINTER(ctypes.c_uint8), + ctypes.POINTER(ctypes.c_uint8), + ctypes.POINTER(ctypes.c_uint16), + ctypes.POINTER(ctypes.c_float), + ctypes.POINTER(ctypes.c_uint16), + ctypes.POINTER(ctypes.c_float), + ctypes.POINTER(ctypes.c_uint8), + ] + return lib + + +class OnlineNgramState: + def __init__( + self, + *, + lib: ctypes.CDLL, + token_ctx_len: int, + token_table_bits: int, + within_table_bits: int, + starts_new_word_lut: np.ndarray, + boundary_lut: np.ndarray, + seed_prefix_token: int, + ) -> None: + self.lib = lib + self.state = lib.online_ngram_state_create(token_ctx_len, token_table_bits, within_table_bits) + if not self.state: + raise RuntimeError( + "Failed to allocate native online ngram state. " + f"token_table_bits={token_table_bits} within_table_bits={within_table_bits}" + ) + self.starts_new_word_lut = np.ascontiguousarray(starts_new_word_lut.astype(np.uint8, copy=False)) + self.boundary_lut = np.ascontiguousarray(boundary_lut.astype(np.uint8, copy=False)) + self.lib.online_ngram_state_seed_prefix_token(self.state, ctypes.c_uint16(int(seed_prefix_token))) + + def close(self) -> None: + if self.state: + self.lib.online_ngram_state_destroy(self.state) + self.state = None + + def __del__(self) -> None: + self.close() + + def process_chunk( + self, + chunk_tokens: np.ndarray, + ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + chunk_tokens = np.ascontiguousarray(chunk_tokens.astype(np.uint16, copy=False)) + n = int(chunk_tokens.size) + token_top_token = np.zeros(n, dtype=np.uint16) + token_top_prob = np.zeros(n, dtype=np.float32) + within_top_token = np.zeros(n, dtype=np.uint16) + within_top_prob = np.zeros(n, dtype=np.float32) + within_valid = np.zeros(n, dtype=np.uint8) + rc = self.lib.online_ngram_state_process_chunk( + self.state, + chunk_tokens.ctypes.data_as(ctypes.POINTER(ctypes.c_uint16)), + ctypes.c_int64(n), + self.starts_new_word_lut.ctypes.data_as(ctypes.POINTER(ctypes.c_uint8)), + self.boundary_lut.ctypes.data_as(ctypes.POINTER(ctypes.c_uint8)), + token_top_token.ctypes.data_as(ctypes.POINTER(ctypes.c_uint16)), + token_top_prob.ctypes.data_as(ctypes.POINTER(ctypes.c_float)), + within_top_token.ctypes.data_as(ctypes.POINTER(ctypes.c_uint16)), + within_top_prob.ctypes.data_as(ctypes.POINTER(ctypes.c_float)), + within_valid.ctypes.data_as(ctypes.POINTER(ctypes.c_uint8)), + ) + if rc != 0: + raise RuntimeError(f"Native online ngram chunk processing failed rc={rc}") + return token_top_token, token_top_prob, within_top_token, within_top_prob, within_valid.astype(bool) + + +class WordStartState: + def __init__( + self, + *, + sp: spm.SentencePieceProcessor, + order: int, + normalize_mode: str, + ) -> None: + self.sp = sp + self.ctx_w = max(order - 1, 0) + self.normalize_mode = normalize_mode + self.prev_word_ids: deque[int] = deque(maxlen=self.ctx_w) + self.current_word_tokens: list[int] = [] + self.word_to_id: dict[str, int] = {} + self.next_word_id = 1 + self.ctx_total: dict[tuple[int, ...], int] = {} + self.pair_count: dict[tuple[tuple[int, ...], int], int] = {} + self.ctx_best_token: dict[tuple[int, ...], int] = {} + self.ctx_best_count: dict[tuple[int, ...], int] = {} + + def _flush_current_word(self) -> None: + if not self.current_word_tokens: + return + text = normalize_word( + self.sp.decode(self.current_word_tokens), + self.normalize_mode, + ) + if text: + word_id = self.word_to_id.get(text) + if word_id is None: + word_id = self.next_word_id + self.word_to_id[text] = word_id + self.next_word_id += 1 + if self.ctx_w > 0: + self.prev_word_ids.append(word_id) + self.current_word_tokens = [] + + def process_chunk( + self, + chunk_tokens: np.ndarray, + *, + starts_new_word_lut: np.ndarray, + boundary_lut: np.ndarray, + ) -> tuple[np.ndarray, np.ndarray]: + chunk_tokens = np.ascontiguousarray(chunk_tokens.astype(np.uint16, copy=False)) + top_token = np.zeros(chunk_tokens.size, dtype=np.uint16) + top_prob = np.zeros(chunk_tokens.size, dtype=np.float32) + for i, tok_u16 in enumerate(chunk_tokens): + tok = int(tok_u16) + is_boundary = bool(boundary_lut[tok]) + is_word_start = bool(starts_new_word_lut[tok]) or not self.current_word_tokens + if is_boundary: + self._flush_current_word() + continue + if bool(starts_new_word_lut[tok]): + self._flush_current_word() + + ctx_key: tuple[int, ...] | None = None + if is_word_start and len(self.prev_word_ids) >= self.ctx_w: + ctx_key = tuple(self.prev_word_ids) if self.ctx_w > 0 else () + total = self.ctx_total.get(ctx_key, 0) + if total > 0: + top_token[i] = np.uint16(self.ctx_best_token[ctx_key]) + top_prob[i] = np.float32(self.ctx_best_count[ctx_key] / total) + + if is_word_start: + if ctx_key is not None: + pair_key = (ctx_key, tok) + pair = self.pair_count.get(pair_key, 0) + 1 + self.pair_count[pair_key] = pair + total = self.ctx_total.get(ctx_key, 0) + 1 + self.ctx_total[ctx_key] = total + best_count = self.ctx_best_count.get(ctx_key, 0) + if pair > best_count: + self.ctx_best_count[ctx_key] = pair + self.ctx_best_token[ctx_key] = tok + self.current_word_tokens = [tok] + else: + self.current_word_tokens.append(tok) + return top_token, top_prob + + +def build_piece_luts( + *, + tokenizer_path: str, + vocab_size: int, +) -> tuple[spm.SentencePieceProcessor, np.ndarray, np.ndarray]: + sp = spm.SentencePieceProcessor(model_file=tokenizer_path) + pieces = [sp.id_to_piece(i) for i in range(sp.vocab_size())] + starts_new_word_lut = np.zeros(vocab_size, dtype=np.uint8) + for i, piece in enumerate(pieces): + starts_new_word_lut[i] = 1 if piece.startswith("▁") else 0 + boundary_lut = np.zeros(vocab_size, dtype=np.uint8) + bos_id = sp.bos_id() + if bos_id >= 0 and bos_id < vocab_size: + boundary_lut[bos_id] = 1 + for tok in range(min(sp.vocab_size(), vocab_size)): + if sp.is_byte(tok) and tok in WHITESPACE_BYTE_IDS: + boundary_lut[tok] = 1 + return sp, starts_new_word_lut, boundary_lut + + +def compile_logits_fn(model: torch.nn.Module, *, seq_len: int, device: torch.device, log0): + if os.environ.get("EVAL_COMPILE", "0") != "1": + log0("eval-pass-online: using eager logits path") + return model.forward_logits + log0("eval-pass-online: compiling logits path") + compiled = torch.compile(model.forward_logits, dynamic=False, fullgraph=True) + dummy = torch.zeros(1, seq_len, dtype=torch.int64, device=device) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + if hasattr(torch.compiler, "cudagraph_mark_step_begin"): + torch.compiler.cudagraph_mark_step_begin() + _ = compiled(dummy) + del dummy + log0("eval-pass-online: compile warmup done") + return compiled + + +def partition_windows(windows: list[int], rank: int, world_size: int) -> list[int]: + start = (len(windows) * rank) // world_size + end = (len(windows) * (rank + 1)) // world_size + return windows[start:end] + + +def eval_val_sliding_online_best_agree( + *, + args, + base_model: torch.nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: torch.Tensor, + base_bytes_lut: torch.Tensor, + has_leading_space_lut: torch.Tensor, + is_boundary_token_lut: torch.Tensor, + stride: int, + batch_seqs: int = 32, + eval_seq_len: int | None = None, + log0=print, +) -> tuple[float, float, dict[str, float]]: + startup_t0 = time.perf_counter() + seq_len = eval_seq_len or args.train_seq_len + chunk_tokens = int(os.environ.get("CHUNK_TOKENS", "131072")) + token_order = int(os.environ.get("TOKEN_ORDER", "16")) + token_threshold = float(os.environ.get("TOKEN_THRESHOLD", "0.800")) + token_boost = float(os.environ.get("TOKEN_BOOST", "2.625")) + within_tau = float(os.environ.get("WITHIN_TAU", "0.450")) + within_boost = float(os.environ.get("WITHIN_BOOST", "0.750")) + word_order = int(os.environ.get("WORD_ORDER", "4")) + word_normalize = os.environ.get("WORD_NORMALIZE", "strip_punct_lower") + word_tau = float(os.environ.get("WORD_TAU", "0.650")) + word_boost = float(os.environ.get("WORD_BOOST", "0.750")) + agree_add_boost = float(os.environ.get("AGREE_ADD_BOOST", "0.500")) + + total_targets = val_tokens.numel() - 1 + tokens_np = val_tokens.cpu().numpy().astype(np.uint16, copy=False) + sp, starts_new_word_lut, boundary_lut = build_piece_luts( + tokenizer_path=args.tokenizer_path, + vocab_size=args.vocab_size, + ) + token_table_bits = int( + os.environ.get( + "TOKEN_TABLE_BITS", + str(suggest_table_bits(total_targets, load_factor=0.55)), + ) + ) + within_table_bits = int( + os.environ.get( + "WITHIN_TABLE_BITS", + str(suggest_table_bits(max(total_targets // 2, 1), load_factor=0.60)), + ) + ) + online_lib = ensure_online_ngram_lib(log0) + ngram_state = OnlineNgramState( + lib=online_lib, + token_ctx_len=max(token_order - 1, 0), + token_table_bits=token_table_bits, + within_table_bits=within_table_bits, + starts_new_word_lut=starts_new_word_lut, + boundary_lut=boundary_lut, + seed_prefix_token=int(tokens_np[0]), + ) + word_state = WordStartState( + sp=sp, + order=word_order, + normalize_mode=word_normalize, + ) + + compiled_logits = compile_logits_fn(base_model, seq_len=seq_len, device=device, log0=log0 if rank == 0 else (lambda *_: None)) + startup_s = time.perf_counter() - startup_t0 + startup_max_s = dist_max_float(startup_s, device, world_size) + if rank == 0: + log0( + f"online_best_agree:start total_targets={total_targets} seq_len={seq_len} stride={stride} " + f"chunk_tokens={chunk_tokens} batch_seqs={batch_seqs} token_order={token_order} " + f"word_order={word_order} startup_max={startup_max_s:.2f}s" + ) + + chunk_windows = build_chunk_windows(total_targets, seq_len, stride, chunk_tokens) + + llm_loss_sum = 0.0 + best_agree_loss_sum = 0.0 + byte_sum = 0.0 + token_count = 0.0 + state_time_s = 0.0 + input_time_s = 0.0 + forward_time_s = 0.0 + blend_time_s = 0.0 + loop_t0 = time.perf_counter() + + try: + with torch.inference_mode(): + for ci, windows in enumerate(chunk_windows): + if not windows: + continue + chunk_t0 = ci * chunk_tokens + chunk_t1 = min((ci + 1) * chunk_tokens, total_targets) + chunk_target_tokens = np.ascontiguousarray(tokens_np[chunk_t0 + 1 : chunk_t1 + 1], dtype=np.uint16) + + t_state0 = time.perf_counter() + token_top_token, token_top_prob, within_top_token, within_top_prob, within_valid = ngram_state.process_chunk( + chunk_target_tokens + ) + word_top_token, word_top_prob = word_state.process_chunk( + chunk_target_tokens, + starts_new_word_lut=starts_new_word_lut, + boundary_lut=boundary_lut, + ) + state_time_s += time.perf_counter() - t_state0 + + my_windows = partition_windows(windows, rank, world_size) + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi : bi + batch_seqs] + if not batch_ws: + continue + t_input0 = time.perf_counter() + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + token_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + within_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + word_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + score_starts: list[int] = [] + + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_targets) + wlen = end - ws + wlens.append(wlen) + local = val_tokens[ws : end + 1].to(device=device, dtype=torch.int64) + x_batch[i, :wlen] = local[:-1] + y_batch[i, :wlen] = local[1:] + s = 0 if ws == 0 else max(wlen - stride, 0) + score_starts.append(s) + if wlen - s <= 0: + continue + c0 = ws + s - chunk_t0 + c1 = ws + wlen - chunk_t0 + token_batch[i, s:wlen] = torch.from_numpy( + np.asarray(token_top_token[c0:c1], dtype=np.int64) + ).to(device=device, dtype=torch.int64) + within_batch[i, s:wlen] = torch.from_numpy( + np.asarray(within_top_token[c0:c1], dtype=np.int64) + ).to(device=device, dtype=torch.int64) + word_batch[i, s:wlen] = torch.from_numpy( + np.asarray(word_top_token[c0:c1], dtype=np.int64) + ).to(device=device, dtype=torch.int64) + input_time_s += time.perf_counter() - t_input0 + + t_forward0 = time.perf_counter() + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + if hasattr(torch.compiler, "cudagraph_mark_step_begin"): + torch.compiler.cudagraph_mark_step_begin() + logits = compiled_logits(x_batch) + logits_f = logits.float() + log_norm = torch.logsumexp(logits_f, dim=-1) + true_logits = logits_f.gather(-1, y_batch.unsqueeze(-1)).squeeze(-1) + token_logits = logits_f.gather(-1, token_batch.unsqueeze(-1)).squeeze(-1) + within_logits = logits_f.gather(-1, within_batch.unsqueeze(-1)).squeeze(-1) + word_logits = logits_f.gather(-1, word_batch.unsqueeze(-1)).squeeze(-1) + true_probs = (true_logits - log_norm).exp() + token_hint = (token_logits - log_norm).exp() + within_hint = (within_logits - log_norm).exp() + word_hint = (word_logits - log_norm).exp() + forward_time_s += time.perf_counter() - t_forward0 + + t_blend0 = time.perf_counter() + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = score_starts[i] + if wlen - s <= 0: + continue + c0 = ws + s - chunk_t0 + c1 = ws + wlen - chunk_t0 + llm_chunk = true_probs[i, s:wlen].detach().cpu().numpy().astype(np.float64, copy=False) + token_hint_chunk = token_hint[i, s:wlen].detach().cpu().numpy().astype(np.float32, copy=False) + within_hint_chunk = within_hint[i, s:wlen].detach().cpu().numpy().astype(np.float32, copy=False) + word_hint_chunk = word_hint[i, s:wlen].detach().cpu().numpy().astype(np.float32, copy=False) + best_agree_chunk = compute_best_agreement_chunk( + llm_chunk=llm_chunk, + true_targets=chunk_target_tokens[c0:c1], + token_top_prob=np.asarray(token_top_prob[c0:c1], dtype=np.float32), + token_top_token=np.asarray(token_top_token[c0:c1], dtype=np.uint16), + token_hint_probs=token_hint_chunk, + within_top_prob=np.asarray(within_top_prob[c0:c1], dtype=np.float32), + within_top_token=np.asarray(within_top_token[c0:c1], dtype=np.uint16), + within_valid=np.asarray(within_valid[c0:c1], dtype=np.bool_), + within_hint_probs=within_hint_chunk, + word_top_prob=np.asarray(word_top_prob[c0:c1], dtype=np.float32), + word_top_token=np.asarray(word_top_token[c0:c1], dtype=np.uint16), + word_hint_probs=word_hint_chunk, + token_threshold=token_threshold, + token_boost=token_boost, + within_tau=within_tau, + within_boost=within_boost, + word_tau=word_tau, + word_boost=word_boost, + agree_add_boost=agree_add_boost, + ) + llm_loss_sum += float((-np.log(np.clip(llm_chunk, 1e-12, 1.0))).sum()) + best_agree_loss_sum += float((-np.log(np.clip(best_agree_chunk, 1e-12, 1.0))).sum()) + token_count += float(c1 - c0) + tgt = y_batch[i, s:wlen] + prev = x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_sum += float(tb.sum().item()) + blend_time_s += time.perf_counter() - t_blend0 + finally: + ngram_state.close() + + llm_loss_t = torch.tensor([llm_loss_sum], dtype=torch.float64, device=device) + best_loss_t = torch.tensor([best_agree_loss_sum], dtype=torch.float64, device=device) + byte_sum_t = torch.tensor([byte_sum], dtype=torch.float64, device=device) + token_count_t = torch.tensor([token_count], dtype=torch.float64, device=device) + if world_size > 1: + dist.all_reduce(llm_loss_t, op=dist.ReduceOp.SUM) + dist.all_reduce(best_loss_t, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_sum_t, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count_t, op=dist.ReduceOp.SUM) + + state_max_s = dist_max_float(state_time_s, device, world_size) + input_max_s = dist_max_float(input_time_s, device, world_size) + forward_max_s = dist_max_float(forward_time_s, device, world_size) + blend_max_s = dist_max_float(blend_time_s, device, world_size) + loop_total_max_s = dist_max_float(time.perf_counter() - loop_t0, device, world_size) + + llm_total_loss = float(llm_loss_t.item()) + best_total_loss = float(best_loss_t.item()) + total_bytes = float(byte_sum_t.item()) + total_token_count = float(token_count_t.item()) + llm_bpb = loss_to_bpb(llm_total_loss, total_bytes) + best_agree_bpb = loss_to_bpb(best_total_loss, total_bytes) + + timings = { + "llm_bpb": llm_bpb, + "best_agree_bpb": best_agree_bpb, + "gain_bpb": llm_bpb - best_agree_bpb, + "startup_max_s": startup_max_s, + "loop_total_max_s": loop_total_max_s, + "state_max_s": state_max_s, + "input_max_s": input_max_s, + "forward_max_s": forward_max_s, + "blend_max_s": blend_max_s, + "llm_nats_per_byte": loss_to_nats_per_byte(llm_total_loss, total_bytes), + "best_agree_nats_per_byte": loss_to_nats_per_byte(best_total_loss, total_bytes), + "gain_nats_per_byte": loss_to_nats_per_byte(llm_total_loss, total_bytes) + - loss_to_nats_per_byte(best_total_loss, total_bytes), + } + if rank == 0: + log0( + f"online_best_agree:done llm_bpb={llm_bpb:.8f} best_agree_bpb={best_agree_bpb:.8f} " + f"gain_bpb={llm_bpb - best_agree_bpb:.8f} startup_max={startup_max_s:.2f}s " + f"loop_total_max={loop_total_max_s:.2f}s state_max={state_max_s:.2f}s " + f"input_max={input_max_s:.2f}s forward_max={forward_max_s:.2f}s blend_max={blend_max_s:.2f}s" + ) + return best_total_loss / max(total_token_count, 1.0), best_agree_bpb, timings diff --git a/records/track_10min_16mb/2026-04-03_SplitLR_NgramAgreement_FullGPTQ/online_ngram_state.c b/records/track_10min_16mb/2026-04-03_SplitLR_NgramAgreement_FullGPTQ/online_ngram_state.c new file mode 100644 index 0000000000..f8472a6f05 --- /dev/null +++ b/records/track_10min_16mb/2026-04-03_SplitLR_NgramAgreement_FullGPTQ/online_ngram_state.c @@ -0,0 +1,433 @@ +#include +#include +#include + +#define COEFF_COUNT 32 + +static const uint64_t ROLLING_COEFFS[COEFF_COUNT] = { + 36313ULL, 27191ULL, 51647ULL, 81929ULL, 131071ULL, 196613ULL, + 262147ULL, 393241ULL, 524309ULL, 655373ULL, 786433ULL, 917521ULL, + 1048583ULL, 1179653ULL, 1310729ULL, 1441801ULL, 1572869ULL, 1703941ULL, + 1835017ULL, 1966087ULL, 2097169ULL, 2228243ULL, 2359319ULL, 2490389ULL, + 2621471ULL, 2752549ULL, 2883617ULL, 3014687ULL, 3145757ULL, 3276833ULL, + 3407903ULL, 3538973ULL, +}; + +static const uint64_t PAIR_MIX = 1000003ULL; +static const uint64_t PREFIX_BASE = 1099511628211ULL; +static const uint64_t LEN_MIX = 0x9E3779B185EBCA87ULL; +static const uint64_t TABLE_MIX = 0x9e3779b97f4a7c15ULL; + +typedef struct { + uint64_t key; + uint32_t total; + uint32_t top_count; + uint16_t top_tok; + uint16_t _pad; +} CtxBucket; + +typedef struct { + uint64_t key; + uint32_t count; + uint32_t _pad; +} PairBucket; + +typedef struct { + int token_ctx_len; + int token_prefix_len; + int token_head; + uint16_t *token_ring; + + CtxBucket *token_ctx_tbl; + uint8_t *token_ctx_used; + size_t token_ctx_mask; + + PairBucket *token_pair_tbl; + uint8_t *token_pair_used; + size_t token_pair_mask; + + uint64_t within_hash; + uint32_t within_len; + + CtxBucket *within_ctx_tbl; + uint8_t *within_ctx_used; + size_t within_ctx_mask; + + PairBucket *within_pair_tbl; + uint8_t *within_pair_used; + size_t within_pair_mask; +} OnlineNgramState; + +static inline size_t mix_index(uint64_t key, size_t mask) { + return (size_t)((key * TABLE_MIX) & mask); +} + +static inline size_t find_ctx_slot( + CtxBucket *tbl, + uint8_t *used, + size_t mask, + uint64_t key, + int *found +) { + size_t idx = mix_index(key, mask); + for (size_t probe = 0; probe <= mask; ++probe) { + if (!used[idx]) { + *found = 0; + return idx; + } + if (tbl[idx].key == key) { + *found = 1; + return idx; + } + idx = (idx + 1U) & mask; + } + *found = -1; + return 0; +} + +static inline size_t find_pair_slot( + PairBucket *tbl, + uint8_t *used, + size_t mask, + uint64_t key, + int *found +) { + size_t idx = mix_index(key, mask); + for (size_t probe = 0; probe <= mask; ++probe) { + if (!used[idx]) { + *found = 0; + return idx; + } + if (tbl[idx].key == key) { + *found = 1; + return idx; + } + idx = (idx + 1U) & mask; + } + *found = -1; + return 0; +} + +static inline uint64_t token_pair_key(uint64_t ctx_key, uint16_t tok, int ctx_len) { + return (ctx_key * PAIR_MIX) ^ (((uint64_t)tok) * ROLLING_COEFFS[(size_t)ctx_len % COEFF_COUNT]); +} + +static inline uint64_t within_pair_key(uint64_t ctx_key, uint16_t tok) { + return (ctx_key * PAIR_MIX) ^ (((uint64_t)tok) * ROLLING_COEFFS[0]); +} + +static inline uint64_t extend_prefix_hash(uint64_t current_hash, uint16_t tok, uint32_t pos) { + return (current_hash * PREFIX_BASE) ^ (((uint64_t)tok + 1ULL) * ROLLING_COEFFS[(size_t)pos % COEFF_COUNT]); +} + +static inline uint32_t pair_increment( + PairBucket *tbl, + uint8_t *used, + size_t mask, + uint64_t key +) { + int found = 0; + size_t idx = find_pair_slot(tbl, used, mask, key, &found); + if (found < 0) { + return 0U; + } + if (!found) { + used[idx] = 1U; + tbl[idx].key = key; + tbl[idx].count = 1U; + return 1U; + } + tbl[idx].count += 1U; + return tbl[idx].count; +} + +static inline int ctx_increment( + CtxBucket *tbl, + uint8_t *used, + size_t mask, + uint64_t key, + uint16_t tok, + uint32_t pair_count +) { + int found = 0; + size_t idx = find_ctx_slot(tbl, used, mask, key, &found); + if (found < 0) { + return -1; + } + if (!found) { + used[idx] = 1U; + tbl[idx].key = key; + tbl[idx].total = 1U; + tbl[idx].top_count = pair_count; + tbl[idx].top_tok = tok; + return 0; + } + tbl[idx].total += 1U; + if (pair_count > tbl[idx].top_count) { + tbl[idx].top_count = pair_count; + tbl[idx].top_tok = tok; + } + return 0; +} + +static inline uint64_t token_context_hash(const OnlineNgramState *st) { + uint64_t h = 0ULL; + if (st->token_ctx_len <= 0) { + return h; + } + for (int j = 0; j < st->token_ctx_len; ++j) { + const int ring_idx = (st->token_head + j) % st->token_ctx_len; + h ^= ((uint64_t)st->token_ring[ring_idx]) * ROLLING_COEFFS[(size_t)j]; + } + return h; +} + +static inline void token_push(OnlineNgramState *st, uint16_t tok) { + if (st->token_ctx_len <= 0) { + return; + } + if (st->token_prefix_len < st->token_ctx_len) { + st->token_ring[st->token_prefix_len] = tok; + st->token_prefix_len += 1; + return; + } + st->token_ring[st->token_head] = tok; + st->token_head = (st->token_head + 1) % st->token_ctx_len; +} + +static void *xcalloc(size_t count, size_t size) { + if (count == 0 || size == 0) { + return NULL; + } + return calloc(count, size); +} + +static int alloc_tables( + size_t table_bits, + CtxBucket **ctx_tbl, + uint8_t **ctx_used, + size_t *ctx_mask, + PairBucket **pair_tbl, + uint8_t **pair_used, + size_t *pair_mask +) { + const size_t size = 1ULL << table_bits; + *ctx_tbl = (CtxBucket *)xcalloc(size, sizeof(CtxBucket)); + *ctx_used = (uint8_t *)xcalloc(size, sizeof(uint8_t)); + *pair_tbl = (PairBucket *)xcalloc(size, sizeof(PairBucket)); + *pair_used = (uint8_t *)xcalloc(size, sizeof(uint8_t)); + if (!*ctx_tbl || !*ctx_used || !*pair_tbl || !*pair_used) { + return -1; + } + *ctx_mask = size - 1U; + *pair_mask = size - 1U; + return 0; +} + +void *online_ngram_state_create( + int token_ctx_len, + int token_table_bits, + int within_table_bits +) { + if (token_ctx_len < 0 || token_table_bits <= 0 || within_table_bits <= 0) { + return NULL; + } + OnlineNgramState *st = (OnlineNgramState *)calloc(1, sizeof(OnlineNgramState)); + if (!st) { + return NULL; + } + st->token_ctx_len = token_ctx_len; + if (token_ctx_len > 0) { + st->token_ring = (uint16_t *)xcalloc((size_t)token_ctx_len, sizeof(uint16_t)); + if (!st->token_ring) { + free(st); + return NULL; + } + } + if (alloc_tables( + (size_t)token_table_bits, + &st->token_ctx_tbl, + &st->token_ctx_used, + &st->token_ctx_mask, + &st->token_pair_tbl, + &st->token_pair_used, + &st->token_pair_mask + ) != 0) { + free(st->token_ring); + free(st); + return NULL; + } + if (alloc_tables( + (size_t)within_table_bits, + &st->within_ctx_tbl, + &st->within_ctx_used, + &st->within_ctx_mask, + &st->within_pair_tbl, + &st->within_pair_used, + &st->within_pair_mask + ) != 0) { + free(st->token_pair_used); + free(st->token_pair_tbl); + free(st->token_ctx_used); + free(st->token_ctx_tbl); + free(st->token_ring); + free(st); + return NULL; + } + return (void *)st; +} + +void online_ngram_state_destroy(void *ptr) { + OnlineNgramState *st = (OnlineNgramState *)ptr; + if (!st) { + return; + } + free(st->within_pair_used); + free(st->within_pair_tbl); + free(st->within_ctx_used); + free(st->within_ctx_tbl); + free(st->token_pair_used); + free(st->token_pair_tbl); + free(st->token_ctx_used); + free(st->token_ctx_tbl); + free(st->token_ring); + free(st); +} + +void online_ngram_state_seed_prefix_token(void *ptr, uint16_t tok) { + OnlineNgramState *st = (OnlineNgramState *)ptr; + if (!st) { + return; + } + token_push(st, tok); +} + +int online_ngram_state_process_chunk( + void *ptr, + const uint16_t *tokens, + int64_t n_tokens, + const uint8_t *starts_new_word_lut, + const uint8_t *boundary_lut, + uint16_t *token_top_token, + float *token_top_prob, + uint16_t *within_top_token, + float *within_top_prob, + uint8_t *within_valid +) { + OnlineNgramState *st = (OnlineNgramState *)ptr; + if (!st || !tokens || n_tokens < 0) { + return -1; + } + for (int64_t i = 0; i < n_tokens; ++i) { + const uint16_t tok = tokens[i]; + const uint8_t is_boundary = boundary_lut[tok]; + const uint8_t is_new_word = starts_new_word_lut[tok]; + + uint64_t token_ctx_key = 0ULL; + if (st->token_ctx_len == 0 || st->token_prefix_len >= st->token_ctx_len) { + token_ctx_key = token_context_hash(st); + int found = 0; + size_t idx = find_ctx_slot( + st->token_ctx_tbl, + st->token_ctx_used, + st->token_ctx_mask, + token_ctx_key, + &found + ); + if (found > 0) { + token_top_token[i] = st->token_ctx_tbl[idx].top_tok; + token_top_prob[i] = + (float)st->token_ctx_tbl[idx].top_count / (float)st->token_ctx_tbl[idx].total; + } else { + token_top_token[i] = 0U; + token_top_prob[i] = 0.0f; + } + } else { + token_top_token[i] = 0U; + token_top_prob[i] = 0.0f; + } + + uint64_t within_ctx_key = 0ULL; + if (!is_boundary && !is_new_word && st->within_len > 0U) { + within_ctx_key = st->within_hash ^ ((uint64_t)st->within_len * LEN_MIX); + int found = 0; + size_t idx = find_ctx_slot( + st->within_ctx_tbl, + st->within_ctx_used, + st->within_ctx_mask, + within_ctx_key, + &found + ); + within_valid[i] = 1U; + if (found > 0) { + within_top_token[i] = st->within_ctx_tbl[idx].top_tok; + within_top_prob[i] = + (float)st->within_ctx_tbl[idx].top_count / (float)st->within_ctx_tbl[idx].total; + } else { + within_top_token[i] = 0U; + within_top_prob[i] = 0.0f; + } + } else { + within_valid[i] = 0U; + within_top_token[i] = 0U; + within_top_prob[i] = 0.0f; + } + + if (st->token_ctx_len == 0 || st->token_prefix_len >= st->token_ctx_len) { + const uint64_t pair_key = token_pair_key(token_ctx_key, tok, st->token_ctx_len); + const uint32_t pair_count = pair_increment( + st->token_pair_tbl, + st->token_pair_used, + st->token_pair_mask, + pair_key + ); + if (pair_count == 0U) { + return -2; + } + if (ctx_increment( + st->token_ctx_tbl, + st->token_ctx_used, + st->token_ctx_mask, + token_ctx_key, + tok, + pair_count + ) != 0) { + return -3; + } + } + token_push(st, tok); + + if (is_boundary) { + st->within_hash = 0ULL; + st->within_len = 0U; + continue; + } + if (is_new_word || st->within_len == 0U) { + st->within_hash = extend_prefix_hash(0ULL, tok, 0U); + st->within_len = 1U; + continue; + } + const uint32_t within_pair_count = pair_increment( + st->within_pair_tbl, + st->within_pair_used, + st->within_pair_mask, + within_pair_key(within_ctx_key, tok) + ); + if (within_pair_count == 0U) { + return -4; + } + if (ctx_increment( + st->within_ctx_tbl, + st->within_ctx_used, + st->within_ctx_mask, + within_ctx_key, + tok, + within_pair_count + ) != 0) { + return -5; + } + st->within_hash = extend_prefix_hash(st->within_hash, tok, st->within_len); + st->within_len += 1U; + } + return 0; +} diff --git a/records/track_10min_16mb/2026-04-03_SplitLR_NgramAgreement_FullGPTQ/requirements.txt b/records/track_10min_16mb/2026-04-03_SplitLR_NgramAgreement_FullGPTQ/requirements.txt new file mode 100644 index 0000000000..473b97d5da --- /dev/null +++ b/records/track_10min_16mb/2026-04-03_SplitLR_NgramAgreement_FullGPTQ/requirements.txt @@ -0,0 +1,5 @@ +torch>=2.9.0 +numpy +sentencepiece +brotli +flash_attn_3 diff --git a/records/track_10min_16mb/2026-04-03_SplitLR_NgramAgreement_FullGPTQ/submission.json b/records/track_10min_16mb/2026-04-03_SplitLR_NgramAgreement_FullGPTQ/submission.json new file mode 100644 index 0000000000..81526a33a8 --- /dev/null +++ b/records/track_10min_16mb/2026-04-03_SplitLR_NgramAgreement_FullGPTQ/submission.json @@ -0,0 +1,19 @@ +{ + "author": "vlivashkin", + "github_id": "vlivashkin", + "date": "2026-04-03", + "val_bpb": 1.1078, + "val_loss": 1.87521, + "bytes_total": 15857705, + "seeds": [1337, 42, 2025], + "seed_results": { + "1337": {"val_bpb": 1.1083, "sliding_bpb": 1.1110, "val_loss": 1.87595, "artifact_bytes": 15853466}, + "42": {"val_bpb": 1.1068, "sliding_bpb": 1.1095, "val_loss": 1.87342, "artifact_bytes": 15857705}, + "2025": {"val_bpb": 1.1085, "sliding_bpb": 1.1112, "val_loss": 1.87627, "artifact_bytes": 15846914} + }, + "val_bpb_std": 0.0009, + "delta_nats_vs_sota": 0.00697, + "sota_pr": 1019, + "hardware": "8xH100 SXM 80GB", + "blurb": "Split-LR + BigramHash(2816x160) + Full Hessian GPTQ + Soft-round QAT + Brotli + Online N-gram Agreement Eval" +} diff --git a/records/track_10min_16mb/2026-04-03_SplitLR_NgramAgreement_FullGPTQ/submission_ngram_seed1337.log b/records/track_10min_16mb/2026-04-03_SplitLR_NgramAgreement_FullGPTQ/submission_ngram_seed1337.log new file mode 100644 index 0000000000..53eb79e015 --- /dev/null +++ b/records/track_10min_16mb/2026-04-03_SplitLR_NgramAgreement_FullGPTQ/submission_ngram_seed1337.log @@ -0,0 +1,12 @@ +W0403 12:21:02.729000 89872 torch/distributed/run.py:803] +W0403 12:21:02.729000 89872 torch/distributed/run.py:803] ***************************************** +W0403 12:21:02.729000 89872 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0403 12:21:02.729000 89872 torch/distributed/run.py:803] ***************************************** +Loading checkpoint: checkpoints/final_model_seed1337.int6.ptz +Model loaded, running n-gram agreement eval... +eval-pass-online: using eager logits path +online_best_agree:start total_targets=62021632 seq_len=2048 stride=64 chunk_tokens=131072 batch_seqs=32 token_order=16 word_order=4 startup_max=0.00s +online_best_agree:done llm_bpb=1.11133698 best_agree_bpb=1.10826554 gain_bpb=0.00307144 startup_max=0.00s loop_total_max=444.72s state_max=201.89s input_max=13.04s forward_max=42.71s blend_max=189.87s +n-gram agreement BPB: 1.10826554 (elapsed: 448.7s) +LLM-only BPB: 1.11133698 +Gain: 0.00307144 diff --git a/records/track_10min_16mb/2026-04-03_SplitLR_NgramAgreement_FullGPTQ/submission_ngram_seed2025.log b/records/track_10min_16mb/2026-04-03_SplitLR_NgramAgreement_FullGPTQ/submission_ngram_seed2025.log new file mode 100644 index 0000000000..f40cc09959 --- /dev/null +++ b/records/track_10min_16mb/2026-04-03_SplitLR_NgramAgreement_FullGPTQ/submission_ngram_seed2025.log @@ -0,0 +1,12 @@ +W0403 13:03:58.864000 92171 torch/distributed/run.py:803] +W0403 13:03:58.864000 92171 torch/distributed/run.py:803] ***************************************** +W0403 13:03:58.864000 92171 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0403 13:03:58.864000 92171 torch/distributed/run.py:803] ***************************************** +Loading checkpoint: checkpoints/final_model_seed2025.int6.ptz +Model loaded, running n-gram agreement eval... +eval-pass-online: using eager logits path +online_best_agree:start total_targets=62021632 seq_len=2048 stride=64 chunk_tokens=131072 batch_seqs=32 token_order=16 word_order=4 startup_max=0.00s +online_best_agree:done llm_bpb=1.11151742 best_agree_bpb=1.10846117 gain_bpb=0.00305625 startup_max=0.00s loop_total_max=450.57s state_max=206.89s input_max=13.00s forward_max=43.63s blend_max=189.57s +n-gram agreement BPB: 1.10846117 (elapsed: 454.3s) +LLM-only BPB: 1.11151742 +Gain: 0.00305625 diff --git a/records/track_10min_16mb/2026-04-03_SplitLR_NgramAgreement_FullGPTQ/submission_ngram_seed42.log b/records/track_10min_16mb/2026-04-03_SplitLR_NgramAgreement_FullGPTQ/submission_ngram_seed42.log new file mode 100644 index 0000000000..9ea23ef5c9 --- /dev/null +++ b/records/track_10min_16mb/2026-04-03_SplitLR_NgramAgreement_FullGPTQ/submission_ngram_seed42.log @@ -0,0 +1,12 @@ +W0403 12:42:29.948000 91028 torch/distributed/run.py:803] +W0403 12:42:29.948000 91028 torch/distributed/run.py:803] ***************************************** +W0403 12:42:29.948000 91028 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0403 12:42:29.948000 91028 torch/distributed/run.py:803] ***************************************** +Loading checkpoint: checkpoints/final_model_seed42.int6.ptz +Model loaded, running n-gram agreement eval... +eval-pass-online: using eager logits path +online_best_agree:start total_targets=62021632 seq_len=2048 stride=64 chunk_tokens=131072 batch_seqs=32 token_order=16 word_order=4 startup_max=0.00s +online_best_agree:done llm_bpb=1.10984470 best_agree_bpb=1.10680980 gain_bpb=0.00303490 startup_max=0.00s loop_total_max=444.66s state_max=200.01s input_max=13.14s forward_max=42.63s blend_max=189.63s +n-gram agreement BPB: 1.10680980 (elapsed: 448.4s) +LLM-only BPB: 1.10984470 +Gain: 0.00303490 diff --git a/records/track_10min_16mb/2026-04-03_SplitLR_NgramAgreement_FullGPTQ/train_gpt.py b/records/track_10min_16mb/2026-04-03_SplitLR_NgramAgreement_FullGPTQ/train_gpt.py new file mode 100644 index 0000000000..9ec8f14e11 --- /dev/null +++ b/records/track_10min_16mb/2026-04-03_SplitLR_NgramAgreement_FullGPTQ/train_gpt.py @@ -0,0 +1,702 @@ +from __future__ import annotations +_i='passthrough_ctrl' +_h='passthrough_orig_dtypes' +_g='dtypes' +_f='scales' +_e='quantized' +_d='per_row' +_c='scheme' +_b='torch.' +_a='momentum' +_Z='shard_mom' +_Y='padded_grad' +_X='fineweb_train_*.bin' +_W='little' +_V='.scale' +_U='mlp_down_bank' +_T='mlp_up_bank' +_S='kv_bank' +_R='qo_bank' +_Q='X.size(-1) + if transposed:X=X.mT + X=X/(X.norm(dim=(-2,-1),keepdim=_B)+eps) + for _ in range(steps):A=X@X.mT;B=b*A+c*(A@A);X=a*X+B@X + if transposed:X=X.mT + if was_2d:X=X.squeeze(0) + return X +class Muon(torch.optim.Optimizer): + def __init__(self,params,lr,momentum,backend_steps,nesterov=_B,weight_decay=_E):super().__init__(params,dict(lr=lr,momentum=momentum,backend_steps=backend_steps,nesterov=nesterov,weight_decay=weight_decay));self._built=_C + def _build(self): + self._distributed=dist.is_available()and dist.is_initialized();self._world_size=dist.get_world_size()if self._distributed else 1;self._rank=dist.get_rank()if self._distributed else 0;ws=self._world_size;self._bank_meta=[] + for group in self.param_groups: + for p in group[_G]:B=p.shape[0];padded_B=(B+ws-1)//ws*ws;shard_B=padded_B//ws;tail=p.shape[1:];dev=p.device;self._bank_meta.append({'p':p,'B':B,_Y:torch.zeros(padded_B,*tail,device=dev,dtype=torch.bfloat16),_O:torch.zeros(shard_B,*tail,device=dev,dtype=torch.bfloat16),_Z:torch.zeros(shard_B,*tail,device=dev,dtype=torch.bfloat16),_J:torch.zeros(padded_B,*tail,device=dev,dtype=torch.bfloat16),_K:max(1,p.shape[-2]/p.shape[-1])**.5}) + self._bank_meta.sort(key=lambda m:-m['p'].numel());self._built=_B + def launch_reduce_scatters(self): + if not self._built:self._build() + if not self._distributed:return + self._rs_futures=[] + for m in self._bank_meta: + p=m['p'] + if p.grad is _A:self._rs_futures.append(_A);continue + pg=m[_Y];pg[:m['B']].copy_(p.grad.bfloat16()) + if pg.shape[0]>m['B']:pg[m['B']:].zero_() + fut=dist.reduce_scatter_tensor(m[_O],pg,op=dist.ReduceOp.AVG,async_op=_B);self._rs_futures.append(fut) + @torch.no_grad() + def step(self,closure=_A): + B='_rs_futures';A='momentum_buffer';loss=_A + if closure is not _A: + with torch.enable_grad():loss=closure() + if not self._built:self._build() + for group in self.param_groups: + lr=group[_H];momentum=group[_a];backend_steps=group['backend_steps'];nesterov=group['nesterov'];wd=group.get('weight_decay',_E);prev_ag_handle=_A;prev_m=_A;sharded=self._distributed and hasattr(self,B) + for(i,m)in enumerate(self._bank_meta): + p=m['p'] + if p.grad is _A:continue + if prev_ag_handle is not _A: + prev_ag_handle.wait();pp=prev_m['p'];upd=prev_m[_J][:prev_m['B']] + if wd>_E:pp.data.mul_(_D-lr*wd) + pp.add_(upd.to(dtype=pp.dtype),alpha=-lr*prev_m[_K]) + if sharded and self._rs_futures[i]is not _A:self._rs_futures[i].wait();g=m[_O];buf=m[_Z] + else: + g=p.grad.bfloat16();state=self.state[p] + if A not in state:state[A]=torch.zeros_like(g) + buf=state[A] + buf.mul_(momentum).add_(g) + if nesterov:update=g.add(buf,alpha=momentum) + else:update=buf + update=zeropower_via_newtonschulz5(update,steps=backend_steps) + if sharded:prev_ag_handle=dist.all_gather_into_tensor(m[_J],update,async_op=_B);prev_m=m + else: + if wd>_E:p.data.mul_(_D-lr*wd) + p.add_(update.to(dtype=p.dtype),alpha=-lr*m[_K]) + if prev_ag_handle is not _A: + prev_ag_handle.wait();pp=prev_m['p'];upd=prev_m[_J][:prev_m['B']] + if wd>_E:pp.data.mul_(_D-lr*wd) + pp.add_(upd.to(dtype=pp.dtype),alpha=-lr*prev_m[_K]) + if hasattr(self,B):del self._rs_futures + return loss +def build_sentencepiece_luts(sp,vocab_size,device): + sp_vocab_size=int(sp.vocab_size());table_size=max(sp_vocab_size,vocab_size);base_bytes_np=np.zeros((table_size,),dtype=np.int16);has_leading_space_np=np.zeros((table_size,),dtype=np.bool_);is_boundary_token_np=np.ones((table_size,),dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id)or sp.is_unknown(token_id)or sp.is_unused(token_id):continue + is_boundary_token_np[token_id]=_C + if sp.is_byte(token_id):base_bytes_np[token_id]=1;continue + piece=sp.id_to_piece(token_id) + if piece.startswith('▁'):has_leading_space_np[token_id]=_B;piece=piece[1:] + base_bytes_np[token_id]=len(piece.encode(_I)) + return torch.tensor(base_bytes_np,dtype=torch.int16,device=device),torch.tensor(has_leading_space_np,dtype=torch.bool,device=device),torch.tensor(is_boundary_token_np,dtype=torch.bool,device=device) +def load_validation_tokens(pattern,seq_len): + files=[Path(p)for p in sorted(glob.glob(pattern))] + if not files:raise FileNotFoundError(f"No files found for pattern: {pattern}") + tokens=torch.cat([load_data_shard(file)for file in files]).contiguous();usable=(tokens.numel()-1)//seq_len*seq_len + if usable<=0:raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[:usable+1] +def eval_val(args,model,rank,world_size,device,grad_accum_steps,val_tokens,base_bytes_lut,has_leading_space_lut,is_boundary_token_lut,eval_seq_len=_A): + seq_len=eval_seq_len or args.train_seq_len;local_batch_tokens=args.val_batch_size//(world_size*grad_accum_steps) + if local_batch_tokens0 else _D,dtype=torch.float32);q=torch.clamp(torch.round(torch.clamp(t32,-clip_abs,clip_abs)/scale),-127,127).to(torch.int8).contiguous();return q,scale +def quantize_state_dict_int8(state_dict): + F='baseline_tensor_bytes';E='num_nonfloat_tensors';D='num_float_tensors';C='num_tensors';B='param_count';A='int8_payload_bytes';quantized={};scales={};dtypes={};passthrough={};passthrough_orig_dtypes={};qmeta={};stats=dict.fromkeys((B,C,D,E,F,A),0) + for(name,tensor)in state_dict.items(): + t=tensor.detach().to(_P).contiguous();stats[B]+=int(t.numel());stats[C]+=1;stats[F]+=tensor_nbytes(t) + if not t.is_floating_point():stats[E]+=1;passthrough[name]=t;stats[A]+=tensor_nbytes(t);continue + if t.numel()<=INT8_KEEP_FLOAT_MAX_NUMEL:kept=keep_float_tensor(name,t,passthrough_orig_dtypes);passthrough[name]=kept;stats[A]+=tensor_nbytes(kept);continue + stats[D]+=1;q,s=quantize_float_tensor(t) + if s.ndim>0:qmeta[name]={_c:_d,'axis':0} + quantized[name]=q;scales[name]=s;dtypes[name]=str(t.dtype).removeprefix(_b);stats[A]+=tensor_nbytes(q)+tensor_nbytes(s) + obj={'__quant_format__':'int8_clean_per_row_v1',_e:quantized,_f:scales,_g:dtypes,_L:passthrough} + if qmeta:obj['qmeta']=qmeta + if passthrough_orig_dtypes:obj[_h]=passthrough_orig_dtypes + return obj,stats +def dequantize_state_dict_int8(obj): + out={};qmeta=obj.get('qmeta',{});passthrough_orig_dtypes=obj.get(_h,{}) + for(name,q)in obj[_e].items(): + dtype=getattr(torch,obj[_g][name]);s=obj[_f][name] + if qmeta.get(name,{}).get(_c)==_d or s.ndim>0:s=s.to(dtype=torch.float32);out[name]=(q.float()*s.view(q.shape[0],*[1]*(q.ndim-1))).to(dtype=dtype).contiguous() + else:scale=float(s.item());out[name]=(q.float()*scale).to(dtype=dtype).contiguous() + for(name,t)in obj[_L].items(): + out_t=t.detach().to(_P).contiguous();orig_dtype=passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype,str):out_t=out_t.to(dtype=getattr(torch,orig_dtype)).contiguous() + out[name]=out_t + return out +def load_data_shard(file): + header_bytes=256*np.dtype(_M).itemsize;token_bytes=np.dtype(_Q).itemsize;header=np.fromfile(file,dtype=_M,count=256) + if header.size!=256 or int(header[0])!=20240520 or int(header[1])!=1:raise ValueError(f"Unexpected shard header for {file}") + num_tokens=int(header[2]);expected_size=header_bytes+num_tokens*token_bytes + if file.stat().st_size!=expected_size:raise ValueError(f"Shard size mismatch for {file}: expected {expected_size} bytes") + tokens_np=np.fromfile(file,dtype=_Q,count=num_tokens,offset=header_bytes) + if tokens_np.size!=num_tokens:raise ValueError(f"Short read for {file}") + return torch.from_numpy(tokens_np.astype(np.uint16,copy=_C)) +_SHARD_HEADER_BYTES=256*np.dtype(_M).itemsize +_SHARD_NTOKENS_CACHE={} +_MMAP_CACHE={} +def _read_num_tokens(file): + key=str(file);cached=_SHARD_NTOKENS_CACHE.get(key) + if cached is not _A:return cached + header=np.fromfile(file,dtype=_M,count=256) + if header.size!=256 or int(header[0])!=20240520 or int(header[1])!=1:raise ValueError(f"Unexpected shard header for {file}") + n=int(header[2]);_SHARD_NTOKENS_CACHE[key]=n;return n +def _get_shard_memmap(file): + key=str(file);mm=_MMAP_CACHE.get(key) + if mm is not _A:return mm + n=_read_num_tokens(file);mm=np.memmap(file,mode='r',dtype=_Q,offset=_SHARD_HEADER_BYTES,shape=(n,));_MMAP_CACHE[key]=mm;return mm +class DistributedTokenLoader: + def __init__(self,pattern,rank,world_size,device): + self.rank=rank;self.world_size=world_size;self.device=device;self.files=[Path(p)for p in sorted(glob.glob(pattern))] + if not self.files:raise FileNotFoundError(f"No files found for pattern: {pattern}") + self._num_tokens=np.array([_read_num_tokens(f)for f in self.files],dtype=np.int64);seed=0 + for f in self.files: + for b in str(f).encode():seed=(seed^b)*1099511628211&0xffffffffffffffff + self._rng=np.random.Generator(np.random.PCG64(seed));self._cfg=_A;self._eligible_shards=_A;self._base_block_counts=_A;n=len(self.files);self._cursor_phase=np.zeros(n,dtype=np.int64);self._cursor_block_count=np.zeros(n,dtype=np.int64);self._cursor_next=np.zeros(n,dtype=np.int64);self._cursor_start=np.zeros(n,dtype=np.int64);self._cursor_stride=np.ones(n,dtype=np.int64);self._cursor_init=np.zeros(n,dtype=np.bool_);self._batches_built=0 + def _pick_coprime_stride(self,n): + if n<=1:return 1 + while _B: + s=int(self._rng.integers(1,n)) + if math.gcd(s,n)==1:return s + def _reset_cursor(self,si,seq_len):nt=int(self._num_tokens[si]);max_phase=min(seq_len-1,max(0,nt-seq_len-1));phase=int(self._rng.integers(max_phase+1))if max_phase>0 else 0;bc=(nt-1-phase)//seq_len;self._cursor_phase[si]=phase;self._cursor_block_count[si]=bc;self._cursor_next[si]=0;self._cursor_start[si]=int(self._rng.integers(bc))if bc>1 else 0;self._cursor_stride[si]=self._pick_coprime_stride(bc);self._cursor_init[si]=_B + def _ensure_cursor(self,si,seq_len): + if not self._cursor_init[si]or self._cursor_next[si]>=self._cursor_block_count[si]:self._reset_cursor(si,seq_len) + def _take_from_shard(self,si,seq_len,count,out): + rem=count + while rem>0: + self._ensure_cursor(si,seq_len);bc=int(self._cursor_block_count[si]);ni=int(self._cursor_next[si]);take=min(rem,bc-ni);phase=int(self._cursor_phase[si]);start=int(self._cursor_start[si]);stride=int(self._cursor_stride[si]) + for j in range(take):bi=(start+(ni+j)*stride)%bc;out.append((si,phase+bi*seq_len)) + self._cursor_next[si]=ni+take;rem-=take + def _init_pipeline(self,global_tokens,seq_len,grad_accum_steps):local_tokens=global_tokens//(self.world_size*grad_accum_steps);num_seqs=local_tokens//seq_len;global_num_seqs=num_seqs*self.world_size;self._cfg=local_tokens,seq_len,num_seqs,global_num_seqs;bbc=(self._num_tokens-1)//seq_len;eligible=bbc>0;self._eligible_shards=np.nonzero(eligible)[0].astype(np.int64);self._base_block_counts=bbc[self._eligible_shards].astype(np.int64) + def _sample_global_windows(self): + _,seq_len,_,gns=self._cfg;ec=int(self._eligible_shards.size);progress=min(self._batches_built/18e2,_D);remaining=np.empty(ec,dtype=np.float64) + for(i,si)in enumerate(self._eligible_shards.tolist()): + if self._cursor_init[si]:r=int(self._cursor_block_count[si])-int(self._cursor_next[si]);remaining[i]=float(max(r,1)) + else:remaining[i]=float(self._base_block_counts[i]) + alpha=.9-.4*progress;weights=np.power(remaining,alpha);ws=float(weights.sum()) + if not np.isfinite(ws)or ws<=_E:weights=np.ones(ec,dtype=np.float64);ws=float(weights.sum()) + probs=weights/ws;low=min(max(8,self.world_size),ec,gns);high=min(max(32,self.world_size*8),ec,gns);mix=max(1,min(int(round(low+progress*(high-low))),ec,gns));cp=self._rng.choice(ec,size=mix,replace=_C,p=probs);cs=self._eligible_shards[cp];cpr=probs[cp].copy();cpr/=cpr.sum();counts=np.ones(mix,dtype=np.int64);extra=gns-mix + if extra>0:counts+=self._rng.multinomial(extra,cpr).astype(np.int64) + perm=self._rng.permutation(mix);cs,counts=cs[perm],counts[perm];buckets=[] + for(si,cnt)in zip(cs.tolist(),counts.tolist()): + b=[];self._take_from_shard(int(si),seq_len,int(cnt),b) + if b: + if len(b)>1:bp=self._rng.permutation(len(b));b=[b[int(k)]for k in bp.tolist()] + buckets.append(b) + windows=[];active=[i for(i,bk)in enumerate(buckets)if bk] + while active: + order=self._rng.permutation(len(active));new_active=[] + for oi in order.tolist(): + bi=active[oi] + if buckets[bi]:windows.append(buckets[bi].pop()) + if buckets[bi]:new_active.append(bi) + active=new_active + return windows + def next_batch(self,global_tokens,seq_len,grad_accum_steps): + if self._cfg is _A:self._init_pipeline(global_tokens,seq_len,grad_accum_steps) + _,_,num_seqs,gns=self._cfg;gw=self._sample_global_windows();local_w=gw[self.rank::self.world_size];x=torch.empty((num_seqs,seq_len),dtype=torch.int64);y=torch.empty((num_seqs,seq_len),dtype=torch.int64) + for(slot,(si,pos))in enumerate(local_w):mm=_get_shard_memmap(self.files[si]);window=torch.as_tensor(np.array(mm[pos:pos+seq_len+1],dtype=np.int64));x[slot]=window[:-1];y[slot]=window[1:] + self._batches_built+=1;return x.to(self.device,non_blocking=_B),y.to(self.device,non_blocking=_B) +class RMSNorm(nn.Module): + def __init__(self,eps=_A):super().__init__();self.eps=eps + def forward(self,x):return F.rms_norm(x,(x.size(-1),),eps=self.eps) +class CastedLinear(nn.Linear): + _qat_enabled=_C;_qat_alpha=_D + def forward(self,x): + w=self.weight.to(x.dtype) + if CastedLinear._qat_enabled and self.training and w.ndim==2:w32=self.weight.float();row_max=w32.abs().amax(dim=1);s=(row_max/31.).clamp_min(_D/31.);scaled=w32/s[:,_A];alpha=CastedLinear._qat_alpha;frac=scaled-scaled.floor();soft_rounded=scaled.floor()+torch.sigmoid(alpha*(frac-.5));w_q=(torch.clamp(soft_rounded,-31,31)*s[:,_A]).to(x.dtype);w=w_q + bias=self.bias.to(x.dtype)if self.bias is not _A else _A;return F.linear(x,w,bias) +def restore_low_dim_params_to_fp32(module): + with torch.no_grad(): + for(name,param)in module.named_parameters(): + if(param.ndim<2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS))and param.dtype!=torch.float32:param.data=param.data.float() +class Rotary(nn.Module): + def __init__(self,dim,base=1e4,train_seq_len=1024,rope_dims=0):super().__init__();self.dim=dim;self.base=base;self.train_seq_len=train_seq_len;self.rope_dims=rope_dims if rope_dims>0 else dim;inv_freq=_D/base**(torch.arange(0,self.rope_dims,2,dtype=torch.float32)/self.rope_dims);self.register_buffer('inv_freq',inv_freq,persistent=_C);self._seq_len_cached=0;self._cos_cached=_A;self._sin_cached=_A + def forward(self,seq_len,device,dtype): + if self._cos_cached is _A or self._sin_cached is _A or self._seq_len_cached!=seq_len or self._cos_cached.device!=device: + rd=self.rope_dims + if seq_len>self.train_seq_len:scale=seq_len/self.train_seq_len;new_base=self.base*scale**(rd/(rd-2));inv_freq=_D/new_base**(torch.arange(0,rd,2,dtype=torch.float32,device=device)/rd) + else:inv_freq=self.inv_freq.to(device) + t=torch.arange(seq_len,device=device,dtype=inv_freq.dtype);freqs=torch.outer(t,inv_freq);self._cos_cached=freqs.cos()[_A,:,_A,:];self._sin_cached=freqs.sin()[_A,:,_A,:];self._seq_len_cached=seq_len + return self._cos_cached.to(dtype=dtype),self._sin_cached.to(dtype=dtype) +def apply_rotary_emb(x,cos,sin,rope_dims=0): + if rope_dims>0 and rope_dims0 else _A;self.smear=SmearGate(model_dim);self.num_encoder_layers=num_layers//2;self.num_decoder_layers=num_layers-self.num_encoder_layers;self.num_skip_weights=min(self.num_encoder_layers,self.num_decoder_layers);self.skip_weights=nn.Parameter(torch.ones(self.num_skip_weights,model_dim,dtype=torch.float32));self.skip_gates=nn.Parameter(torch.zeros(self.num_skip_weights,model_dim,dtype=torch.float32));head_dim=model_dim//num_heads;kv_dim=num_kv_heads*head_dim;mlp_dim=int(mlp_mult*model_dim);self.num_layers=num_layers;self.qo_bank=nn.Parameter(torch.empty(2*num_layers,model_dim,model_dim));self.kv_bank=nn.Parameter(torch.empty(2*num_layers,kv_dim,model_dim));self.mlp_up_bank=nn.Parameter(torch.empty(num_layers,mlp_dim,model_dim));self.mlp_down_bank=nn.Parameter(torch.empty(num_layers,model_dim,mlp_dim));self.blocks=nn.ModuleList([Block(model_dim,num_heads,num_kv_heads,mlp_mult,rope_base,qk_gain_init,layer_idx=i,ln_scale=ln_scale,neg_slope=neg_slope)for i in range(num_layers)]) + if rope_dims>0: + head_dim=model_dim//num_heads + for block in self.blocks:block.attn.rope_dims=rope_dims;block.attn.rotary=Rotary(head_dim,base=rope_base,train_seq_len=1024,rope_dims=rope_dims) + self.ve_layer_indices=[int(x)for x in ve_layers.split(',')if x.strip()]if ve_enabled else[];kv_dim_ve=self._ve_target_dim + if self.ve_layer_indices:self.ve_shared=ValueEmbedding(vocab_size,ve_dim,kv_dim_ve);self.ve_layer_scales=nn.ParameterList([nn.Parameter(torch.ones(1,dtype=torch.float32))for _ in self.ve_layer_indices]) + else:self.ve_shared=_A;self.ve_layer_scales=nn.ParameterList() + self.value_embeds=nn.ModuleList();self.final_norm=RMSNorm();self.lm_head=_A if tie_embeddings else CastedLinear(model_dim,vocab_size,bias=_C) + if self.lm_head is not _A:self.lm_head._zero_init=_B + if xsa_last_n>0: + for i in range(max(0,num_layers-xsa_last_n),num_layers):self.blocks[i].attn.use_xsa=_B + self._init_weights() + def _init_weights(self): + if self.tie_embeddings:nn.init.normal_(self.tok_emb.weight,mean=_E,std=self.tied_embed_init_std) + n=self.num_layers;proj_scale=_D/math.sqrt(2*n) + for i in range(n):nn.init.orthogonal_(self.qo_bank.data[i],gain=_D);nn.init.zeros_(self.qo_bank.data[n+i]);nn.init.orthogonal_(self.kv_bank.data[i],gain=_D);nn.init.orthogonal_(self.kv_bank.data[n+i],gain=_D);nn.init.orthogonal_(self.mlp_up_bank.data[i],gain=_D);nn.init.zeros_(self.mlp_down_bank.data[i]);self.qo_bank.data[n+i].mul_(proj_scale);self.mlp_down_bank.data[i].mul_(proj_scale) + for(name,module)in self.named_modules(): + if isinstance(module,nn.Linear): + if getattr(module,'_zero_init',_C):nn.init.zeros_(module.weight) + elif module.weight.ndim==2 and module.weight.shape[0]>=64 and module.weight.shape[1]>=64:nn.init.orthogonal_(module.weight,gain=_D) + def _get_ve(self,layer_idx,input_ids,ve_cache=_A): + A='ve' + if self.ve_shared is _A or layer_idx not in self.ve_layer_indices:return + if ve_cache is not _A and A not in ve_cache:ve_cache[A]=self.ve_shared(input_ids) + ve_base=ve_cache[A]if ve_cache is not _A else self.ve_shared(input_ids);ve_idx=self.ve_layer_indices.index(layer_idx);return ve_base*self.ve_layer_scales[ve_idx].to(dtype=ve_base.dtype) + def forward(self,input_ids,target_ids): + n=self.num_layers;x=self.tok_emb(input_ids) + if self.bigram is not _A:x=x+self.bigram(input_ids) + x=F.rms_norm(x,(x.size(-1),));x=self.smear(x);x0=x;skips=[];ve_cache={} + for i in range(self.num_encoder_layers):ve=self._get_ve(i,input_ids,ve_cache);x=self.blocks[i](x,x0,self.qo_bank[i],self.kv_bank[i],self.kv_bank[n+i],self.qo_bank[n+i],self.mlp_up_bank[i],self.mlp_down_bank[i],v_embed=ve);skips.append(x) + for i in range(self.num_decoder_layers): + bi=self.num_encoder_layers+i + if skips:g=torch.sigmoid(self.skip_gates[i].to(dtype=x.dtype))[_A,_A,:];scaled_skip=self.skip_weights[i].to(dtype=x.dtype)[_A,_A,:]*skips.pop();x=torch.lerp(scaled_skip,x,g) + ve=self._get_ve(bi,input_ids,ve_cache);x=self.blocks[bi](x,x0,self.qo_bank[bi],self.kv_bank[bi],self.kv_bank[n+bi],self.qo_bank[n+bi],self.mlp_up_bank[bi],self.mlp_down_bank[bi],v_embed=ve) + x=self.final_norm(x);x_flat=x.reshape(-1,x.size(-1));targets=target_ids.reshape(-1) + if self.tie_embeddings:logits_proj=F.linear(x_flat,self.tok_emb.weight) + else: + if self.lm_head is _A:raise RuntimeError('lm_head is required when tie_embeddings=False') + logits_proj=self.lm_head(x_flat) + logits=self.logit_softcap*torch.tanh(logits_proj/self.logit_softcap);return F.cross_entropy(logits.float(),targets,reduction='mean') + def forward_hidden(self,input_ids): + n=self.num_layers;x=self.tok_emb(input_ids) + if self.bigram is not _A:x=x+self.bigram(input_ids) + x=F.rms_norm(x,(x.size(-1),));x=self.smear(x);x0=x;skips=[];ve_cache={} + for i in range(self.num_encoder_layers):ve=self._get_ve(i,input_ids,ve_cache);x=self.blocks[i](x,x0,self.qo_bank[i],self.kv_bank[i],self.kv_bank[n+i],self.qo_bank[n+i],self.mlp_up_bank[i],self.mlp_down_bank[i],v_embed=ve);skips.append(x) + for i in range(self.num_decoder_layers): + bi=self.num_encoder_layers+i + if skips:g=torch.sigmoid(self.skip_gates[i].to(dtype=x.dtype))[_A,_A,:];scaled_skip=self.skip_weights[i].to(dtype=x.dtype)[_A,_A,:]*skips.pop();x=torch.lerp(scaled_skip,x,g) + ve=self._get_ve(bi,input_ids,ve_cache);x=self.blocks[bi](x,x0,self.qo_bank[bi],self.kv_bank[bi],self.kv_bank[n+bi],self.qo_bank[n+bi],self.mlp_up_bank[bi],self.mlp_down_bank[bi],v_embed=ve) + return self.final_norm(x) + def compute_logits(self,hidden): + if self.tie_embeddings:logits_proj=F.linear(hidden,self.tok_emb.weight) + else:logits_proj=self.lm_head(hidden) + return self.logit_softcap*torch.tanh(logits_proj/self.logit_softcap) + def forward_logits(self,input_ids):return self.compute_logits(self.forward_hidden(input_ids)) +def eval_val_sliding(args,base_model,rank,world_size,device,val_tokens,base_bytes_lut,has_leading_space_lut,is_boundary_token_lut,stride,batch_seqs=32,eval_seq_len=_A): + seq_len=eval_seq_len or args.train_seq_len;total_tokens=val_tokens.numel()-1;window_starts=[ws for ws in range(0,total_tokens,stride)if min(ws+seq_len,total_tokens)-ws>=1];total_windows=len(window_starts);my_s=total_windows*rank//world_size;my_e=total_windows*(rank+1)//world_size;my_windows=window_starts[my_s:my_e];loss_sum=torch.zeros((),device=device,dtype=torch.float64);token_count=torch.zeros((),device=device,dtype=torch.float64);byte_count=torch.zeros((),device=device,dtype=torch.float64);base_model.eval();use_slot=getattr(args,'slot_enabled',_C);compiled_logits=torch.compile(base_model.forward_logits,dynamic=_C,fullgraph=_B);compiled_hidden=torch.compile(base_model.forward_hidden,dynamic=_C,fullgraph=_B)if use_slot else _A + for bi in range(0,len(my_windows),batch_seqs): + batch_ws=my_windows[bi:bi+batch_seqs];bsz=len(batch_ws);x_batch=torch.zeros(bsz,seq_len,dtype=torch.int64,device=device);y_batch=torch.zeros(bsz,seq_len,dtype=torch.int64,device=device);wlens=[] + for(i,ws)in enumerate(batch_ws):end=min(ws+seq_len,total_tokens);wlen=end-ws;wlens.append(wlen);chunk=val_tokens[ws:end+1].to(dtype=torch.int64,device=device);x_batch[i,:wlen]=chunk[:-1];y_batch[i,:wlen]=chunk[1:] + if use_slot: + with torch.no_grad(),torch.autocast(device_type=_F,dtype=torch.bfloat16):H=compiled_hidden(x_batch) + H=H.detach().float();delta=torch.zeros(1,1,H.shape[-1],device=device,dtype=H.dtype,requires_grad=_B);slot_opt=torch.optim.AdamW([delta],lr=args.slot_lr,weight_decay=1e-08,eps=1e-05) + for _ in range(args.slot_steps):slot_opt.zero_grad();adapted=base_model.compute_logits((H+delta).to(torch.bfloat16)).float();slot_loss=F.cross_entropy(adapted[:,:-1].reshape(-1,adapted.size(-1)),y_batch[:,:seq_len-1].reshape(-1),reduction='mean');slot_loss.backward();slot_opt.step() + with torch.no_grad():logits=base_model.compute_logits((H+delta.detach()).to(torch.bfloat16)) + else: + with torch.inference_mode(),torch.autocast(device_type=_F,dtype=torch.bfloat16):logits=compiled_logits(x_batch) + with torch.no_grad(): + nll=F.cross_entropy(logits.reshape(-1,logits.size(-1)).float(),y_batch.reshape(-1),reduction='none').reshape(bsz,seq_len) + for(i,ws)in enumerate(batch_ws):wlen=wlens[i];s=0 if ws==0 else max(wlen-stride,0);scored_nll=nll[i,s:wlen].to(torch.float64);loss_sum+=scored_nll.sum();token_count+=float(wlen-s);tgt=y_batch[i,s:wlen];prev=x_batch[i,s:wlen];tb=base_bytes_lut[tgt].to(torch.float64);tb+=(has_leading_space_lut[tgt]&~is_boundary_token_lut[prev]).to(torch.float64);byte_count+=tb.sum() + if dist.is_available()and dist.is_initialized():dist.all_reduce(loss_sum,op=dist.ReduceOp.SUM);dist.all_reduce(token_count,op=dist.ReduceOp.SUM);dist.all_reduce(byte_count,op=dist.ReduceOp.SUM) + val_loss=(loss_sum/token_count).item();bits_per_token=val_loss/math.log(2.);tokens_per_byte=token_count.item()/byte_count.item();base_model.train();return val_loss,bits_per_token*tokens_per_byte +def eval_val_sliding_ttt(args,base_model,rank,world_size,device,val_tokens,base_bytes_lut,has_leading_space_lut,is_boundary_token_lut,stride,batch_seqs=32,log0=print): + seq_len=args.train_seq_len;total_tokens=val_tokens.numel()-1;ttt_chunk=args.ttt_chunk_tokens;window_starts=[ws for ws in range(0,total_tokens,stride)if min(ws+seq_len,total_tokens)-ws>=stride or ws==0];num_chunks=(total_tokens+ttt_chunk-1)//ttt_chunk;chunk_windows=[[]for _ in range(num_chunks)] + for ws in window_starts:end=min(ws+seq_len,total_tokens);wlen=end-ws;s=0 if ws==0 else max(wlen-stride,0);scored_start=ws+s;ci=min(scored_start//ttt_chunk,num_chunks-1);chunk_windows[ci].append(ws) + log0(f"ttt_sliding:start chunks={num_chunks} chunk_tokens={ttt_chunk} total_windows={len(window_starts)} stride={stride} ttt_lr={args.ttt_lr} ttt_epochs={args.ttt_epochs} freeze_blocks={args.ttt_freeze_blocks}");loss_sum=torch.zeros((),device=device,dtype=torch.float64);token_count=torch.zeros((),device=device,dtype=torch.float64);byte_count=torch.zeros((),device=device,dtype=torch.float64);frozen_block_ids=set(range(min(args.ttt_freeze_blocks,len(base_model.blocks))));ttt_params=[] + for(name,p)in base_model.named_parameters(): + freeze=_C + for bi in frozen_block_ids: + if f"blocks.{bi}."in name:freeze=_B;break + if freeze:p.requires_grad_(_C) + else:p.requires_grad_(_B);ttt_params.append(p) + log0(f"ttt_sliding:params unfrozen={sum(p.numel()for p in ttt_params)} frozen={sum(p.numel()for p in base_model.parameters()if not p.requires_grad)}");optimizer=torch.optim.SGD(ttt_params,lr=args.ttt_lr,momentum=args.ttt_momentum);t0=time.perf_counter() + for ci in range(num_chunks): + windows=chunk_windows[ci] + if not windows:continue + chunk_start=ci*ttt_chunk;chunk_end=min((ci+1)*ttt_chunk,total_tokens);my_s=len(windows)*rank//world_size;my_e=len(windows)*(rank+1)//world_size;my_windows=windows[my_s:my_e];base_model.eval() + with torch.inference_mode(): + for bi in range(0,len(my_windows),batch_seqs): + batch_ws=my_windows[bi:bi+batch_seqs];bsz=len(batch_ws);x_batch=torch.zeros(bsz,seq_len,dtype=torch.int64,device=device);y_batch=torch.zeros(bsz,seq_len,dtype=torch.int64,device=device);wlens=[] + for(i,ws)in enumerate(batch_ws):end=min(ws+seq_len,total_tokens);wlen=end-ws;wlens.append(wlen);chunk_tok=val_tokens[ws:end+1].to(dtype=torch.int64,device=device);x_batch[i,:wlen]=chunk_tok[:-1];y_batch[i,:wlen]=chunk_tok[1:] + with torch.autocast(device_type=_F,dtype=torch.bfloat16):logits=base_model.forward_logits(x_batch) + nll=F.cross_entropy(logits.reshape(-1,logits.size(-1)).float(),y_batch.reshape(-1),reduction='none').reshape(bsz,seq_len) + for(i,ws)in enumerate(batch_ws):wlen=wlens[i];s=0 if ws==0 else max(wlen-stride,0);scored_nll=nll[i,s:wlen].to(torch.float64);loss_sum+=scored_nll.sum();token_count+=float(wlen-s);tgt,prev=y_batch[i,s:wlen],x_batch[i,s:wlen];tb=base_bytes_lut[tgt].to(torch.float64);tb+=(has_leading_space_lut[tgt]&~is_boundary_token_lut[prev]).to(torch.float64);byte_count+=tb.sum() + is_last_chunk=ci==num_chunks-1 + if not is_last_chunk and args.ttt_epochs>0: + base_model.train();chunk_seqs=(chunk_end-chunk_start)//seq_len + if chunk_seqs>0: + cos_lr=args.ttt_lr*.5*(_D+math.cos(math.pi*ci/max(num_chunks-1,1))) + for pg in optimizer.param_groups:pg[_H]=cos_lr + my_seq_s=chunk_seqs*rank//world_size;my_seq_e=chunk_seqs*(rank+1)//world_size;my_chunk_seqs=my_seq_e-my_seq_s + for _ep in range(args.ttt_epochs): + for bs in range(0,my_chunk_seqs,args.ttt_batch_seqs): + be=min(bs+args.ttt_batch_seqs,my_chunk_seqs);actual_bs=my_seq_s+bs;start_tok=chunk_start+actual_bs*seq_len;end_tok=chunk_start+(my_seq_s+be)*seq_len+1 + if end_tok>val_tokens.numel():continue + local=val_tokens[start_tok:end_tok].to(device=device,dtype=torch.int64);x=local[:-1].reshape(-1,seq_len);y=local[1:].reshape(-1,seq_len);optimizer.zero_grad(set_to_none=_B) + with torch.autocast(device_type=_F,dtype=torch.bfloat16):loss=base_model(x,y) + loss.backward() + if world_size>1: + for p in ttt_params: + if p.grad is not _A:dist.all_reduce(p.grad,op=dist.ReduceOp.AVG) + torch.nn.utils.clip_grad_norm_(ttt_params,args.ttt_grad_clip);optimizer.step() + if rank==0 and(ci%10==0 or ci==num_chunks-1):elapsed=time.perf_counter()-t0;rl=loss_sum.item()/max(token_count.item(),1);rbpb=rl/math.log(2.)*(token_count.item()/max(byte_count.item(),1))if token_count.item()>0 else _E;log0(f" ttt_chunk [{ci+1}/{num_chunks}] bpb={rbpb:.6f} time={elapsed:.1f}s") + if dist.is_available()and dist.is_initialized():dist.all_reduce(loss_sum,op=dist.ReduceOp.SUM);dist.all_reduce(token_count,op=dist.ReduceOp.SUM);dist.all_reduce(byte_count,op=dist.ReduceOp.SUM) + val_loss=(loss_sum/token_count).item();val_bpb=val_loss/math.log(2.)*(token_count.item()/byte_count.item()) + for p in base_model.parameters():p.requires_grad_(_B) + base_model.eval();log0(f"ttt_sliding:done val_loss={val_loss:.6f}{ val_bpb=:.6f} elapsed={time.perf_counter()-t0:.1f}s");return val_loss,val_bpb +def _classify_param(name): + A='.mlp.' + if'tok_emb'in name or'lm_head'in name:return'embed' + if A in name:return'mlp' + if'.attn.'in name or'.proj.'in name and A not in name:return'attn' + return'other' +def quantize_int6_per_row(t,clip_range=31): + t32=t.float() + if t32.ndim==2: + best_q,best_s,best_err=_A,_A,float('inf') + for pct in[.999,.9995,.9999,.99999,_D]: + if pct<_D:row_clip=torch.quantile(t32.abs(),pct,dim=1) + else:row_clip=t32.abs().amax(dim=1) + s=(row_clip/clip_range).clamp_min(_D/clip_range).to(torch.float16);q=torch.clamp(torch.round(t32/s.float()[:,_A]),-clip_range,clip_range).to(torch.int8);recon=q.float()*s.float()[:,_A];err=(t32-recon).pow(2).mean().item() + if err0 else _D,dtype=torch.float16);q=torch.clamp(torch.round(t32/scale.float()),-clip_range,clip_range).to(torch.int8);return q,scale +def _unbank_state_dict(sd,num_layers): + out={};n=num_layers + for(name,tensor)in sd.items(): + if name==_R: + for i in range(n):out[f"blocks.{i}.attn.c_q.weight"]=tensor[i];out[f"blocks.{i}.attn.proj.weight"]=tensor[n+i] + elif name==_S: + for i in range(n):out[f"blocks.{i}.attn.c_k.weight"]=tensor[i];out[f"blocks.{i}.attn.c_v.weight"]=tensor[n+i] + elif name==_T: + for i in range(n):out[f"blocks.{i}.mlp.fc.weight"]=tensor[i] + elif name==_U: + for i in range(n):out[f"blocks.{i}.mlp.proj.weight"]=tensor[i] + else:out[name]=tensor + return out +def _rebank_state_dict(sd,num_layers,template_sd): + out={};n=num_layers;qo_slices=[_A]*(2*n);kv_slices=[_A]*(2*n);up_slices=[_A]*n;down_slices=[_A]*n;consumed=set() + for i in range(n): + qk=f"blocks.{i}.attn.c_q.weight" + if qk in sd:qo_slices[i]=sd[qk];consumed.add(qk) + ok=f"blocks.{i}.attn.proj.weight" + if ok in sd:qo_slices[n+i]=sd[ok];consumed.add(ok) + kk=f"blocks.{i}.attn.c_k.weight" + if kk in sd:kv_slices[i]=sd[kk];consumed.add(kk) + vk=f"blocks.{i}.attn.c_v.weight" + if vk in sd:kv_slices[n+i]=sd[vk];consumed.add(vk) + fk=f"blocks.{i}.mlp.fc.weight" + if fk in sd:up_slices[i]=sd[fk];consumed.add(fk) + dk=f"blocks.{i}.mlp.proj.weight" + if dk in sd:down_slices[i]=sd[dk];consumed.add(dk) + out[_R]=torch.stack(qo_slices).to(dtype=template_sd[_R].dtype);out[_S]=torch.stack(kv_slices).to(dtype=template_sd[_S].dtype);out[_T]=torch.stack(up_slices).to(dtype=template_sd[_T].dtype);out[_U]=torch.stack(down_slices).to(dtype=template_sd[_U].dtype) + for(name,tensor)in sd.items(): + if name not in consumed:out[name]=tensor + return out +def mixed_quantize_int6(state_dict,int6_cats,clip_range=31,hessians=_A): + A='type';num_layers_total=max((int(k.split('.')[1])for k in state_dict if k.startswith('blocks.')),default=0)+1;late_k_layers=set(range(num_layers_total-2,num_layers_total));result={};meta={};gptq_count,naive_count=0,0 + for(name,tensor)in state_dict.items(): + t=tensor.detach().cpu().contiguous();cat=_classify_param(name) + if not t.is_floating_point()or t.numel()<=65536:result[name]=t.to(torch.float16)if t.is_floating_point()else t;meta[name]=_L;continue + if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS):result[name]=t.float();meta[name]=_i;continue + if cat in int6_cats and t.ndim>=1: + H=hessians.get(name)if hessians else _A + if H is not _A and t.ndim==2:q,s=gptq_quantize_weight(t,H.cpu(),clip_range=clip_range);gptq_count+=1 + else:q,s=quantize_int6_per_row(t,clip_range=clip_range);naive_count+=1 + result[name+'.q']=q;result[name+_V]=s;meta[name]={A:'int6'} + else:q,s=quantize_float_tensor(t);result[name+'.q']=q;result[name+_V]=s;meta[name]={A:'int8'} + if hessians:print(f"gptq_quantize: {gptq_count} GPTQ layers, {naive_count} naive layers",flush=_B) + return result,meta +def dequantize_mixed_int6(result,meta,template_sd): + out={} + for(name,orig)in template_sd.items(): + info=meta.get(name) + if info is _A:continue + orig_dtype=orig.dtype + if info in(_L,_i,'passthrough_fp16'): + t=result[name] + if t.dtype==torch.float16 and orig_dtype in(torch.float32,torch.bfloat16):t=t.to(orig_dtype) + out[name]=t;continue + q,s=result[name+'.q'],result[name+_V] + if s.ndim>0:out[name]=(q.float()*s.float().view(q.shape[0],*[1]*(q.ndim-1))).to(orig_dtype) + else:out[name]=(q.float()*float(s.item())).to(orig_dtype) + return out +def gptq_quantize_weight(W,H,clip_range=31,block_size=128,percdamp=.01): + W_orig=W.float().clone();rows,cols=W_orig.shape;H=H.float().clone();dead=torch.diag(H)==0;H[dead,dead]=1;damp=percdamp*H.diag().mean();H.diagonal().add_(damp);perm=torch.argsort(H.diag(),descending=_B);invperm=torch.argsort(perm);W_perm=W_orig[:,perm].clone();W_perm[:,dead[perm]]=0;H=H[perm][:,perm] + try:Hinv=torch.cholesky_inverse(torch.linalg.cholesky(H));Hinv=torch.linalg.cholesky(Hinv,upper=_B) + except torch.linalg.LinAlgError:return quantize_int6_per_row(W_orig,clip_range) + best_q,best_scale,best_err=_A,_A,float('inf') + for pct in[.999,.9995,.9999,.99999,_D]: + if pct<_D:row_clip=torch.quantile(W_orig.abs(),pct,dim=1) + else:row_clip=W_orig.abs().amax(dim=1) + s=(row_clip/clip_range).clamp_min(_D/clip_range).to(torch.float16);sf=s.float();Q=torch.zeros(rows,cols,dtype=torch.int8);W_work=W_perm.clone() + for i1 in range(0,cols,block_size): + i2=min(i1+block_size,cols);W_block=W_work[:,i1:i2].clone();Hinv_block=Hinv[i1:i2,i1:i2];Err=torch.zeros(rows,i2-i1) + for j in range(i2-i1):w_col=W_block[:,j];d=Hinv_block[j,j];q_col=torch.clamp(torch.round(w_col/sf),-clip_range,clip_range);Q[:,i1+j]=q_col.to(torch.int8);err=(w_col-q_col.float()*sf)/d;Err[:,j]=err;W_block[:,j:]-=err.unsqueeze(1)*Hinv_block[j,j:].unsqueeze(0) + if i20 else args.train_seq_len;val_seq_len=max(args.train_seq_len,effective_eval_seq_len);val_tokens=load_validation_tokens(args.val_files,val_seq_len);base_bytes_lut,has_leading_space_lut,is_boundary_token_lut=build_sentencepiece_luts(sp,args.vocab_size,device);log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}");log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}");log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel()-1}");base_model=GPT(vocab_size=args.vocab_size,num_layers=args.num_layers,model_dim=args.model_dim,num_heads=args.num_heads,num_kv_heads=args.num_kv_heads,mlp_mult=args.mlp_mult,tie_embeddings=args.tie_embeddings,tied_embed_init_std=args.tied_embed_init_std,logit_softcap=args.logit_softcap,rope_base=args.rope_base,qk_gain_init=args.qk_gain_init,bigram_vocab_size=args.bigram_vocab_size,bigram_dim=args.bigram_dim,xsa_last_n=args.xsa_last_n,rope_dims=args.rope_dims,ln_scale=args.ln_scale,ve_enabled=args.ve_enabled,ve_dim=args.ve_dim,ve_layers=args.ve_layers,neg_slope=args.negative_slope).to(device).bfloat16();base_model.qo_bank.data=base_model.qo_bank.data.float();base_model.kv_bank.data=base_model.kv_bank.data.float();base_model.mlp_up_bank.data=base_model.mlp_up_bank.data.float();base_model.mlp_down_bank.data=base_model.mlp_down_bank.data.float() + for module in base_model.modules(): + if isinstance(module,CastedLinear):module.float() + restore_low_dim_params_to_fp32(base_model);compiled_model=torch.compile(base_model,dynamic=_C,fullgraph=_B);model=compiled_model;matrix_params=[base_model.qo_bank,base_model.kv_bank,base_model.mlp_up_bank,base_model.mlp_down_bank];block_named_params=list(base_model.blocks.named_parameters());scalar_params=[p for(name,p)in block_named_params if p.ndim<2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)] + if base_model.skip_weights.numel()>0:scalar_params.append(base_model.skip_weights);scalar_params.append(base_model.skip_gates) + scalar_params.append(base_model.smear.gate) + if base_model.bigram is not _A:scalar_params.append(base_model.bigram.scale) + token_lr=args.tied_embed_lr if args.tie_embeddings else args.embed_lr;tok_params=[{_G:[base_model.tok_emb.weight],_H:token_lr,A:token_lr}] + if base_model.bigram is not _A: + tok_params.append({_G:[base_model.bigram.embed.weight],_H:token_lr,A:token_lr}) + if base_model.bigram.proj is not _A:scalar_params.append(base_model.bigram.proj.weight) + if base_model.ve_shared is not _A: + tok_params.append({_G:[base_model.ve_shared.embed.weight],_H:token_lr,A:token_lr}) + if base_model.ve_shared.proj is not _A:scalar_params.append(base_model.ve_shared.proj.weight) + scalar_params.append(base_model.ve_shared.scale) + for s in base_model.ve_layer_scales:scalar_params.append(s) + optimizer_tok=torch.optim.AdamW(tok_params,betas=(args.beta1,args.beta2),eps=args.adam_eps,weight_decay=args.adam_wd,fused=_B);optimizer_muon=Muon(matrix_params,lr=args.matrix_lr,momentum=args.muon_momentum,backend_steps=args.muon_backend_steps,weight_decay=args.muon_wd) + for group in optimizer_muon.param_groups:group[A]=args.matrix_lr + optimizer_scalar=torch.optim.AdamW([{_G:scalar_params,_H:args.scalar_lr,A:args.scalar_lr}],betas=(args.beta1,args.beta2),eps=args.adam_eps,weight_decay=args.adam_wd,fused=_B);replicated_params=list(optimizer_tok.param_groups[0][_G]) + for pg in optimizer_tok.param_groups[1:]:replicated_params.extend(pg[_G]) + replicated_params.extend(scalar_params);optimizer_head=_A + if base_model.lm_head is not _A:optimizer_head=torch.optim.Adam([{_G:[base_model.lm_head.weight],_H:args.head_lr,A:args.head_lr}],betas=(args.beta1,args.beta2),eps=args.adam_eps,fused=_B);replicated_params.append(base_model.lm_head.weight) + optimizers=[optimizer_tok,optimizer_muon,optimizer_scalar] + if optimizer_head is not _A:optimizers.append(optimizer_head) + log0(f"model_params:{sum(p.numel()for p in base_model.parameters())}");xsa_layers=[i for(i,b)in enumerate(base_model.blocks)if b.attn.use_xsa];log0(f"XSA:last_{args.xsa_last_n} active_layers:{xsa_layers}");log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}");log0('sdp_backends:cudnn=False flash=True mem_efficient=False math=False');log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}");log0(f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} head_lr:{args.head_lr if base_model.lm_head is not _A else _E} matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}");log0(f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} iterations:{args.iterations} warmup_steps:{args.warmup_steps} max_wallclock_seconds:{args.max_wallclock_seconds:.3f}");log0(f"seed:{args.seed}");train_loader=DistributedTokenLoader(args.train_files,rank,world_size,device) + def zero_grad_all(): + for opt in optimizers:opt.zero_grad(set_to_none=_B) + max_wallclock_ms=1e3*args.max_wallclock_seconds if args.max_wallclock_seconds>0 else _A + if args.use_gptq and max_wallclock_ms is not _A:max_wallclock_ms-=args.gptq_reserve_ms;log0(f"gptq:reserving {args.gptq_reserve_ms:.0f}ms from training budget, effective={max_wallclock_ms:.0f}ms") + def lr_mul(step,elapsed_ms): + if args.warmdown_iters<=0:return _D + if max_wallclock_ms is _A:warmdown_start=max(args.iterations-args.warmdown_iters,0);return max((args.iterations-step)/max(args.warmdown_iters,1),_E)if warmdown_start<=step0: + initial_model_state={name:tensor.detach().cpu().clone()for(name,tensor)in base_model.state_dict().items()};initial_optimizer_states=[copy.deepcopy(opt.state_dict())for opt in optimizers];model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + x,y=train_loader.next_batch(args.train_batch_tokens,args.train_seq_len,grad_accum_steps) + with torch.autocast(device_type=_F,dtype=torch.bfloat16,enabled=_B):warmup_loss=model(x,y) + (warmup_loss*grad_scale).backward() + if distributed: + for p in base_model.parameters(): + if p.grad is not _A:dist.all_reduce(p.grad,op=dist.ReduceOp.AVG) + for opt in optimizers:opt.step() + zero_grad_all() + if args.warmup_steps<=20 or(warmup_step+1)%10==0 or warmup_step+1==args.warmup_steps:log0(f"warmup_step:{warmup_step+1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state,strict=_B) + for(opt,state)in zip(optimizers,initial_optimizer_states,strict=_B):opt.load_state_dict(state) + zero_grad_all();train_loader=DistributedTokenLoader(args.train_files,rank,world_size,device) + swa_state=_A;swa_count=0;ema_state={name:t.detach().float().clone()for(name,t)in base_model.state_dict().items()};ema_decay=.997;training_time_ms=_E;stop_after_step=_A;torch.cuda.synchronize();t0=time.perf_counter();step=0 + while _B: + last_step=step==args.iterations or stop_after_step is not _A and step>=stop_after_step;should_validate=last_step or args.val_loss_every>0 and step%args.val_loss_every==0 + if should_validate:torch.cuda.synchronize();training_time_ms+=1e3*(time.perf_counter()-t0);val_loss,val_bpb=eval_val(args,model,rank,world_size,device,grad_accum_steps,val_tokens,base_bytes_lut,has_leading_space_lut,is_boundary_token_lut);log0(f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step,1):.2f}ms");torch.cuda.synchronize();t0=time.perf_counter() + if last_step: + if stop_after_step is not _A and step0 else _D;muon_momentum=(1-frac)*args.muon_momentum_warmup_start+frac*args.muon_momentum + for group in optimizer_muon.param_groups:group[_a]=muon_momentum + for opt in optimizers: + for group in opt.param_groups:group[_H]=group[A]*scale + if args.grad_clip_norm>0:torch.nn.utils.clip_grad_norm_(base_model.parameters(),args.grad_clip_norm) + if args.matrix_lr_early!=args.matrix_lr or args.matrix_lr_late!=args.matrix_lr: + s=args.bank_split;n=args.num_layers;es=args.matrix_lr_early/args.matrix_lr;ls=args.matrix_lr_late/args.matrix_lr + with torch.no_grad(): + for bank in[base_model.qo_bank,base_model.kv_bank]: + if bank.grad is not _A:bank.grad[:s].mul_(es);bank.grad[s:n].mul_(ls);bank.grad[n:n+s].mul_(es);bank.grad[n+s:].mul_(ls) + for bank in[base_model.mlp_up_bank,base_model.mlp_down_bank]: + if bank.grad is not _A:bank.grad[:s].mul_(es);bank.grad[s:].mul_(ls) + optimizer_muon.launch_reduce_scatters() + if distributed: + for p in replicated_params: + if p.grad is not _A:dist.all_reduce(p.grad,op=dist.ReduceOp.AVG) + optimizer_tok.step();optimizer_scalar.step() + if optimizer_head is not _A:optimizer_head.step() + optimizer_muon.step();zero_grad_all() + with torch.no_grad(): + for(name,t)in base_model.state_dict().items():ema_state[name].mul_(ema_decay).add_(t.detach().float(),alpha=_D-ema_decay) + step+=1;approx_training_time_ms=training_time_ms+1e3*(time.perf_counter()-t0) + if args.late_qat_threshold>0 and scale=2000: + if not CastedLinear._qat_enabled:CastedLinear._qat_enabled=_B;CastedLinear._qat_start_step=step;log0(f"late_qat:enabled step:{step} scale:{scale:.4f}") + qat_progress=min((step-CastedLinear._qat_start_step)/max(500,1),_D);CastedLinear._qat_alpha=_D+15.*qat_progress + if args.swa_enabled and scale<.2 and step%args.swa_every==0: + if swa_state is _A:swa_state={name:t.detach().cpu().clone()for(name,t)in base_model.state_dict().items()};swa_count=1;log0(f"swa:start step:{step}") + else: + for(name,t)in base_model.state_dict().items():swa_state[name]+=t.detach().cpu() + swa_count+=1 + should_log_train=args.train_log_every>0 and(step<=10 or step%args.train_log_every==0 or stop_after_step is not _A) + if should_log_train:log0(f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/step:.2f}ms") + reached_cap=max_wallclock_ms is not _A and approx_training_time_ms>=max_wallclock_ms + if distributed and max_wallclock_ms is not _A:reached_cap_tensor=torch.tensor(int(reached_cap),device=device);dist.all_reduce(reached_cap_tensor,op=dist.ReduceOp.MAX);reached_cap=bool(reached_cap_tensor.item()) + if stop_after_step is _A and reached_cap:stop_after_step=step + log0(f"peak memory allocated: {torch.cuda.max_memory_allocated()//1024//1024} MiB reserved: {torch.cuda.max_memory_reserved()//1024//1024} MiB");log0('ema:applying EMA weights');current_state=base_model.state_dict();avg_state={name:t.to(dtype=current_state[name].dtype)for(name,t)in ema_state.items()};base_model.load_state_dict(avg_state,strict=_B);torch.cuda.synchronize();t_diag=time.perf_counter();diag_val_loss,diag_val_bpb=eval_val(args,compiled_model,rank,world_size,device,grad_accum_steps,val_tokens,base_bytes_lut,has_leading_space_lut,is_boundary_token_lut);torch.cuda.synchronize();log0(f"DIAGNOSTIC post_ema val_loss:{diag_val_loss:.4f} val_bpb:{diag_val_bpb:.4f} eval_time:{1e3*(time.perf_counter()-t_diag):.0f}ms");export_sd=base_model.state_dict() + if master_process:torch.save(export_sd,E);model_bytes=os.path.getsize(E);code_bytes=len(code.encode(_I));log0(f"Serialized model: {model_bytes} bytes");log0(f"Code size: {code_bytes} bytes") + sd_cpu={k:v.detach().cpu()for(k,v)in export_sd.items()};unbanked_sd=_unbank_state_dict(sd_cpu,args.num_layers);gptq_hessians=_A + if args.use_gptq:t_gptq=time.perf_counter();log0(f"gptq:calibrating with {args.gptq_calib_samples} batches (training data)...");calib_loader=DistributedTokenLoader(args.train_files,rank,world_size,device);gptq_hessians=gptq_collect_hessians(base_model,calib_loader,device,num_batches=args.gptq_calib_samples,batch_tokens=args.train_batch_tokens,seq_len=args.train_seq_len,grad_accum_steps=grad_accum_steps);del calib_loader;gptq_elapsed=time.perf_counter()-t_gptq;log0(f"gptq:calibrated {len(gptq_hessians)} layers in {gptq_elapsed:.1f}s");torch.cuda.empty_cache() + quant_result,quant_meta=mixed_quantize_int6(unbanked_sd,{'mlp','attn'},clip_range=args.quant_clip_range,hessians=gptq_hessians);quant_buf=io.BytesIO();torch.save({'w':quant_result,'m':quant_meta},quant_buf);quant_raw=quant_buf.getvalue();quant_blob=brotli.compress(_byte_shuffle(quant_raw),quality=11) + if master_process: + with open(F,'wb')as f:f.write(quant_blob) + quant_file_bytes=len(quant_blob);code_bytes=len(code.encode(_I));log0(f"Serialized model int6+brotli: {quant_file_bytes} bytes");log0(f"Total submission size int6+brotli: {quant_file_bytes+code_bytes} bytes") + if distributed:dist.barrier() + with open(F,'rb')as f:quant_blob_disk=f.read() + quant_state=torch.load(io.BytesIO(_byte_unshuffle(brotli.decompress(quant_blob_disk))),map_location=_P);deq_unbanked=dequantize_mixed_int6(quant_state['w'],quant_state['m'],unbanked_sd);deq_state=_rebank_state_dict(deq_unbanked,args.num_layers,sd_cpu);eval_model=GPT(vocab_size=args.vocab_size,num_layers=args.num_layers,model_dim=args.model_dim,num_heads=args.num_heads,num_kv_heads=args.num_kv_heads,mlp_mult=args.mlp_mult,tie_embeddings=args.tie_embeddings,tied_embed_init_std=args.tied_embed_init_std,logit_softcap=args.logit_softcap,rope_base=args.rope_base,qk_gain_init=args.qk_gain_init,bigram_vocab_size=args.bigram_vocab_size,bigram_dim=args.bigram_dim,xsa_last_n=args.xsa_last_n,rope_dims=args.rope_dims,ln_scale=args.ln_scale,ve_enabled=args.ve_enabled,ve_dim=args.ve_dim,ve_layers=args.ve_layers,neg_slope=args.negative_slope).to(device).bfloat16();eval_model.qo_bank.data=eval_model.qo_bank.data.float();eval_model.kv_bank.data=eval_model.kv_bank.data.float();eval_model.mlp_up_bank.data=eval_model.mlp_up_bank.data.float();eval_model.mlp_down_bank.data=eval_model.mlp_down_bank.data.float() + for m in eval_model.modules(): + if isinstance(m,CastedLinear):m.float() + restore_low_dim_params_to_fp32(eval_model);eval_model.load_state_dict(deq_state,strict=_B);compiled_eval=torch.compile(eval_model,dynamic=_C,fullgraph=_B);torch.cuda.synchronize();t_qeval=time.perf_counter();q_val_loss,q_val_bpb=eval_val(args,compiled_eval,rank,world_size,device,grad_accum_steps,val_tokens,base_bytes_lut,has_leading_space_lut,is_boundary_token_lut,eval_seq_len=effective_eval_seq_len);torch.cuda.synchronize();log0(f"final_int6_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} eval_time:{1e3*(time.perf_counter()-t_qeval):.0f}ms");log0(f"final_int6_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}");sw_seq_len=effective_eval_seq_len + if args.eval_stride>0 and args.eval_stride