From 938d6ae7ee145d4a9fb49aafc6c51bdede285faf Mon Sep 17 00:00:00 2001 From: Aryan Bhosale Date: Fri, 3 Apr 2026 16:55:46 +0530 Subject: [PATCH 1/4] =?UTF-8?q?Record:=20SP4096=20+=20Depth=20Recurrence?= =?UTF-8?q?=20+=20MuonEq-R=20+=20Full=20GPTQ=20=E2=80=94=20val=5Fbpb=201.0?= =?UTF-8?q?940=20(3-seed=20mean)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 4096-vocab + MLP 4x + WD 0.090 + depth recurrence (layers 4,5) + MuonEq-R + full GPTQ int6 + brotli + selective pruning. 3-seed mean: 1.0940 BPB, beating merged SOTA (PR #1019, 1.1147 BPB) by 0.0208 BPB. --- .../README.md | 98 + .../submission.json | 38 + .../train_gpt.py | 1911 +++++++++++++++ .../train_seed314.log | 132 ++ .../train_seed42.log | 132 ++ .../train_seed999.log | 2094 +++++++++++++++++ 6 files changed, 4405 insertions(+) create mode 100644 records/track_10min_16mb/2026-04-03_SP4096_DepthRecurrence_MuonEqR_GPTQ/README.md create mode 100644 records/track_10min_16mb/2026-04-03_SP4096_DepthRecurrence_MuonEqR_GPTQ/submission.json create mode 100644 records/track_10min_16mb/2026-04-03_SP4096_DepthRecurrence_MuonEqR_GPTQ/train_gpt.py create mode 100644 records/track_10min_16mb/2026-04-03_SP4096_DepthRecurrence_MuonEqR_GPTQ/train_seed314.log create mode 100644 records/track_10min_16mb/2026-04-03_SP4096_DepthRecurrence_MuonEqR_GPTQ/train_seed42.log create mode 100644 records/track_10min_16mb/2026-04-03_SP4096_DepthRecurrence_MuonEqR_GPTQ/train_seed999.log diff --git a/records/track_10min_16mb/2026-04-03_SP4096_DepthRecurrence_MuonEqR_GPTQ/README.md b/records/track_10min_16mb/2026-04-03_SP4096_DepthRecurrence_MuonEqR_GPTQ/README.md new file mode 100644 index 0000000000..18f6df3ae4 --- /dev/null +++ b/records/track_10min_16mb/2026-04-03_SP4096_DepthRecurrence_MuonEqR_GPTQ/README.md @@ -0,0 +1,98 @@ +# Record: SP4096 + Depth Recurrence + MuonEq-R + Full GPTQ — val_bpb 1.0940 (3-seed mean) + +**val_bpb = 1.0940** (3-seed mean, std 0.0005) | **~15.96 MB** | 8xH100 SXM + +## 3-Seed Results (8xH100 80GB SXM, PyTorch 2.9.1+cu128) + +| Seed | steps | Pre-quant BPB | **Sliding BPB** | Artifact | +|------|-------|---------------|-----------------|----------| +| 42 | 5,415 | 1.0997 | **1.0942** | 15,960,147 | +| 314 | 5,415 | 1.0995 | **1.0934** | 15,963,424 | +| 999 | 5,420 | 1.0996 | **1.0942** | 15,958,655 | +| **Mean** | | | **1.0940** | | + +Merged SOTA (PR #1019): **1.1147 BPB**. This run: **1.0940 BPB**. Delta: **-0.0208 BPB** (Welch t=-61.9). Clears the 0.005-nat threshold by ~3x. + +## Changes from Merged SOTA (PR #1019) + +This submission combines the PR #1218 4096-vocab architecture with depth recurrence, MuonEq-R, and higher weight decay for better quantization. + +### 1. 4096-Vocab + MLP 4x + WD 0.090 + +Switched from sp1024 to sp4096 tokenizer (4096 BPE tokens vs 1024). Wider MLP (4x expansion vs 3x). Higher weight decay (0.090 vs 0.04) produces smaller weights that compress ~5% better with brotli, allowing all 66 quantized layers at int6 precision. + +Source: PR #1218 by @clarkkev (4096-vocab + MLP 4x + WD 0.085), PR #1285 by @dexhunter (WD 0.090 + all-int6). + +### 2. Depth Recurrence (layers 4,5 repeated) + +Layers 4 and 5 (U-Net hinge point) execute twice during the forward pass using the same physical parameter banks. Virtual 13-layer network from 11-layer parameter budget, zero extra parameters. Activates at step 3000. + +Source: PR #1204 by @msisovic (concept), PR #1260 by @dexhunter (implementation). + +### 3. MuonEq-R (Row-Normalized Muon) + +Row-normalizes gradient matrices before Newton-Schulz orthogonalization for better-conditioned optimization. Zero cost. + +Source: arXiv:2603.28254, PR #1260 by @dexhunter. + +### 4. Full GPTQ int6 + Brotli + Selective Pruning + +Full Hessian GPTQ with training-data calibration. Brotli-11 compression with byte-shuffle. Selective +-1 pruning by reconstruction error to fit under 16MB. + +Source: PR #1019 by @abaybektursun (GPTQ), PR #1218 by @clarkkev (brotli + byte-shuffle). + +## Architecture + +| Component | Setting | +|-----------|---------| +| Vocab | 4096 (sp4096 BPE) | +| Layers | 11 physical (13 virtual with recurrence) | +| Dimensions | 512d, 8H / 4KV (GQA) | +| MLP | 4x (2048), LeakyReLU(0.5)^2 | +| XSA | All 11 layers | +| QK Gain | 4.0 | +| RoPE | Partial (16/64 dims) | +| LN Scale | 1/sqrt(layer+1) | +| VE128 | Layers 9-10 | +| Skip gates | Sigmoid-gated U-Net | +| Weight avg | EMA(0.997) | +| Optimizer | MuonEq-R (lr=0.02, WD=0.090) | +| Quantization | Full GPTQ int6 + brotli-11 + byte-shuffle | +| Warmdown | 66.7% of steps | + +## Training + +- MuonEq-R: lr=0.02, momentum 0.92->0.99/1500 steps, WD=0.090 +- Adam for embeddings (lr=0.03) and scalars (lr=0.02) +- Batch 786,432 tokens, seq_len 2048 +- Depth recurrence activates at step 3000 +- ~5415 steps in 590s (~109ms/step with recurrence) + +## Compliance + +- No TTT, no SLOT, no n-gram cache, no eval-time adaptation +- GPTQ calibration uses training data within the training time budget +- All seeds within 600s training, <16MB artifact +- Fully legal under all four conditions (Issue #1017) + +## Reproduction + +```bash +# Download sp4096 data +MATCHED_FINEWEB_REPO_ID=kevclark/parameter-golf python3 data/cached_challenge_fineweb.py --variant sp4096 --skip-manifest + +# Train +SEED=42 RECUR_LAYERS=4,5 RECUR_START_STEP=3000 \ +torchrun --standalone --nproc_per_node=8 train_gpt.py +``` + +## Credits + +- **4096-Vocab + MLP 4x + WD 0.085 + Brotli**: PR #1218 by @clarkkev +- **WD 0.090 + All-Int6**: PR #1285 by @dexhunter +- **Depth Recurrence concept**: PR #1204 by @msisovic +- **MuonEq-R + Depth Recurrence implementation**: PR #1260 by @dexhunter +- **Full GPTQ + XSA-all**: PR #1019 by @abaybektursun +- **Base architecture**: PR #1287 by @dentity007 +- **LeakyReLU^2**: PR #493 by @parinzee +- **LN Scale + Partial RoPE**: PR #315 by @jfprincz diff --git a/records/track_10min_16mb/2026-04-03_SP4096_DepthRecurrence_MuonEqR_GPTQ/submission.json b/records/track_10min_16mb/2026-04-03_SP4096_DepthRecurrence_MuonEqR_GPTQ/submission.json new file mode 100644 index 0000000000..caa4ae0046 --- /dev/null +++ b/records/track_10min_16mb/2026-04-03_SP4096_DepthRecurrence_MuonEqR_GPTQ/submission.json @@ -0,0 +1,38 @@ +{ + "author": "aryanbhosale", + "github_id": "aryanbhosale", + "name": "SP4096 + Depth Recurrence + MuonEq-R + Full GPTQ", + "blurb": "4096-vocab (sp4096) + MLP 4x + WD 0.090 + depth recurrence (layers 4,5) + MuonEq-R + full GPTQ int6 + brotli + selective pruning. 3-seed mean: 1.09398 BPB, beating merged SOTA (PR #1019, 1.11474 BPB) by 0.02076 BPB (Welch t=-61.9).", + "date": "2026-04-03", + "track": "10min_16mb", + "val_bpb": 1.09397576, + "val_bpb_std": 0.00046067, + "seeds": [42, 314, 999], + "seed_results": { + "42": { + "val_loss": 2.51787602, + "val_bpb": 1.09423991, + "artifact_bytes": 15960147, + "steps": 5415 + }, + "314": { + "val_loss": 2.51604422, + "val_bpb": 1.09344383, + "artifact_bytes": 15963424, + "steps": 5415 + }, + "999": { + "val_loss": 2.51788435, + "val_bpb": 1.09424353, + "artifact_bytes": 15958655, + "steps": 5420 + } + }, + "comparison_baseline_pr": 1019, + "delta_vs_pr1019_bpb": -0.02075933, + "t_statistic": -61.8978, + "artifact_bytes_max": 15963424, + "hardware": "8xH100 80GB SXM", + "pytorch_version": "2.9.1+cu128", + "technique_summary": "SP4096 + MLP 4x + WD 0.090 + Depth Recurrence (layers 4,5) + MuonEq-R + Full GPTQ int6 + Brotli + Selective Pruning" +} diff --git a/records/track_10min_16mb/2026-04-03_SP4096_DepthRecurrence_MuonEqR_GPTQ/train_gpt.py b/records/track_10min_16mb/2026-04-03_SP4096_DepthRecurrence_MuonEqR_GPTQ/train_gpt.py new file mode 100644 index 0000000000..8bc4ed613a --- /dev/null +++ b/records/track_10min_16mb/2026-04-03_SP4096_DepthRecurrence_MuonEqR_GPTQ/train_gpt.py @@ -0,0 +1,1911 @@ +import copy +import glob +import io +import lzma +import math +import os +from pathlib import Path +import random +import subprocess +import sys +import time +import uuid + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch.nn.parallel import DistributedDataParallel as DDP +from torch import Tensor, nn + +from flash_attn_interface import flash_attn_func as flash_attn_3_func + +try: + import brotli + _HAS_BROTLI = True +except ImportError: + _HAS_BROTLI = False + +# ---------------------------------------- +# Hyperparameters +# ---------------------------------------- + +class Hyperparameters(): + # Experiment settings + data_dir = os.environ.get('DATA_DIR', './data/') + seed = int(os.environ.get('SEED', 1337)) + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + + # Training length + iterations = int(os.environ.get('ITERATIONS', 20000)) + warmdown_frac = float(os.environ.get('WARMDOWN_FRAC', 0.667)) + warmup_steps = int(os.environ.get('WARMUP_STEPS', 20)) + train_batch_tokens = int(os.environ.get('TRAIN_BATCH_TOKENS', 2048 * 48 * 8)) + train_seq_len = int(os.environ.get('TRAIN_SEQ_LEN', 2048)) + eval_seq_len = int(os.environ.get('EVAL_SEQ_LEN', 2048)) + max_wallclock_seconds = float(os.environ.get('MAX_WALLCLOCK_SECONDS', 600.0)) + train_log_every = int(os.environ.get('TRAIN_LOG_EVERY', 500)) + + # Validation/Evals + val_batch_tokens = int(os.environ.get('VAL_BATCH_TOKENS', 2048 * 32 * 8)) + val_loss_every = int(os.environ.get('VAL_LOSS_EVERY', 4000)) + sliding_window_enabled = bool(int(os.environ.get('SLIDING_WINDOW_ENABLED', '1'))) + + # Model architecture + vocab_size = int(os.environ.get('VOCAB_SIZE', 4096)) + num_layers = int(os.environ.get('NUM_LAYERS', 11)) + xsa_last_n = int(os.environ.get('XSA_LAST_N', 11)) + num_kv_heads = int(os.environ.get('NUM_KV_HEADS', 4)) + model_dim = int(os.environ.get('MODEL_DIM', 512)) + embedding_dim = int(os.environ.get('EMBEDDING_DIM', 512)) + num_heads = int(os.environ.get('NUM_HEADS', 8)) + mlp_mult = float(os.environ.get('MLP_MULT', 4.0)) + skip_gates_enabled = bool(int(os.environ.get('SKIP_GATES_ENABLED', '1'))) + tie_embeddings = bool(int(os.environ.get('TIE_EMBEDDINGS', '1'))) + logit_softcap = float(os.environ.get('LOGIT_SOFTCAP', 30.0)) + rope_base = float(os.environ.get('ROPE_BASE', 10000.0)) + rope_dims = int(os.environ.get('ROPE_DIMS', 16)) + rope_train_seq_len = int(os.environ.get('ROPE_TRAIN_SEQ_LEN', 2048)) + ln_scale = bool(int(os.environ.get('LN_SCALE', '1'))) + ve_enabled = bool(int(os.environ.get('VE_ENABLED', '1'))) + ve_dim = int(os.environ.get('VE_DIM', 128)) + ve_layers = os.environ.get('VE_LAYERS', '9,10') + qk_gain_init = float(os.environ.get('QK_GAIN_INIT', 4.0)) + + # Optimizer (Modification 3: weight decay 0.090) + min_lr = float(os.environ.get('MIN_LR', 0.0)) + embed_lr = float(os.environ.get('EMBED_LR', 0.6)) + head_lr = float(os.environ.get('HEAD_LR', 0.008)) + tied_embed_lr = float(os.environ.get('TIED_EMBED_LR', 0.03)) + tied_embed_init_std = float(os.environ.get('TIED_EMBED_INIT_STD', 0.005)) + matrix_lr = float(os.environ.get('MATRIX_LR', 0.02)) + scalar_lr = float(os.environ.get('SCALAR_LR', 0.02)) + muon_momentum = float(os.environ.get('MUON_MOMENTUM', 0.99)) + muon_backend_steps = int(os.environ.get('MUON_BACKEND_STEPS', 5)) + muon_momentum_warmup_start = float(os.environ.get('MUON_MOMENTUM_WARMUP_START', 0.92)) + muon_momentum_warmup_steps = int(os.environ.get('MUON_MOMENTUM_WARMUP_STEPS', 1500)) + beta1 = float(os.environ.get('BETA1', 0.9)) + beta2 = float(os.environ.get('BETA2', 0.95)) + adam_eps = float(os.environ.get('ADAM_EPS', 1e-8)) + grad_clip_norm = float(os.environ.get('GRAD_CLIP_NORM', 0.3)) + eval_stride = int(os.environ.get('EVAL_STRIDE', 64)) + muon_beta2 = float(os.environ.get('MUON_BETA2', 0.95)) + adam_wd = float(os.environ.get('ADAM_WD', 0.02)) + muon_wd = float(os.environ.get('MUON_WD', 0.090)) + embed_wd = float(os.environ.get('EMBED_WD', 0.090)) + ema_decay = float(os.environ.get('EMA_DECAY', 0.997)) + + # Depth Recurrence (Modification 2) + recur_layers = os.environ.get("RECUR_LAYERS", "4,5") + recur_start_step = int(os.environ.get("RECUR_START_STEP", 3000)) + + # TTT (Modification 4) + ttt_enabled = bool(int(os.environ.get("TTT_ENABLED", "0"))) + ttt_lr = float(os.environ.get("TTT_LR", 0.002)) + ttt_epochs = int(os.environ.get("TTT_EPOCHS", 3)) + ttt_chunk_tokens = int(os.environ.get("TTT_CHUNK_TOKENS", 32768)) + ttt_freeze_blocks = int(os.environ.get("TTT_FREEZE_BLOCKS", 0)) + ttt_momentum = float(os.environ.get("TTT_MOMENTUM", 0.9)) + ttt_batch_seqs = int(os.environ.get("TTT_BATCH_SEQS", 32)) + ttt_grad_clip = float(os.environ.get("TTT_GRAD_CLIP", 1.0)) + + # Compression + compressor = os.environ.get('COMPRESSOR', 'brotli') #(lzma or brotli) + gptq_enabled = bool(int(os.environ.get('GPTQ_ENABLED', '1'))) + gptq_calibration_batches = int(os.environ.get('GPTQ_CALIBRATION_BATCHES', 64)) + gptq_reserve_seconds = float(os.environ.get('GPTQ_RESERVE_SECONDS', 10.0)) + + # Distributed setup + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + is_main_process = rank == 0 + grad_accum_steps = 8 // world_size + + # Data paths + datasets_dir = os.path.join(data_dir, 'datasets', f'fineweb10B_sp{vocab_size}') + train_files = os.path.join(datasets_dir, 'fineweb_train_*.bin') + val_files = os.path.join(datasets_dir, 'fineweb_val_*.bin') + tokenizer_path = os.path.join(data_dir, 'tokenizers', f'fineweb_{vocab_size}_bpe.model') + + # Experiment files + logfile = f"logs/{run_id}.txt" + model_path = "final_model.pt" + quantized_model_path = "final_model.int6.ptz" + +# ---------------------------------------- +# Global Logging Function +# ---------------------------------------- + +_logger_hparams = None + + +def set_logging_hparams(h: Hyperparameters) -> None: + global _logger_hparams + _logger_hparams = h + + +def log(msg, console: bool = True) -> None: + if _logger_hparams is None: + print(msg) + if _logger_hparams.is_main_process: + if console: + print(msg) + if _logger_hparams.logfile is not None: + with open(_logger_hparams.logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + +# ---------------------------------------- +# Data Loading +# ---------------------------------------- + +class ValidationData: + def __init__(self, h: Hyperparameters, device: torch.device): + if not h.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {h.tokenizer_path}") + self.sp = spm.SentencePieceProcessor(model_file=h.tokenizer_path) + if int(self.sp.vocab_size()) != h.vocab_size: + raise ValueError( + f"VOCAB_SIZE={h.vocab_size} does not match tokenizer vocab_size={int(self.sp.vocab_size())}" + ) + + self.val_tokens = load_validation_tokens(h.val_files, h.eval_seq_len) + self.base_bytes_lut, self.has_leading_space_lut, self.is_boundary_token_lut = ( + build_sentencepiece_luts(self.sp, h.vocab_size, device)) + + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + # The BPB calculation assumes "▁" is its own token so that leading-space bytes + # are counted correctly. See https://github.com/openai/parameter-golf/issues/897 + assert sp.piece_to_id("\u2581") != sp.unk_id(), \ + "Tokenizer must have '▁' (space) as its own token for correct BPB byte counting" + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("\u2581"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" int: + key = str(file) + cached = _SHARD_NTOKENS_CACHE.get(key) + if cached is not None: + return cached + header = np.fromfile(file, dtype=" np.memmap: + key = str(file) + mm = _MMAP_CACHE.get(key) + if mm is not None: + return mm + n = _read_num_tokens(file) + mm = np.memmap(file, mode="r", dtype=" int: + if n <= 1: + return 1 + while True: + s = int(self._rng.integers(1, n)) + if math.gcd(s, n) == 1: + return s + + def _reset_cursor(self, si: int, seq_len: int) -> None: + nt = int(self._num_tokens[si]) + max_phase = min(seq_len - 1, max(0, nt - seq_len - 1)) + phase = int(self._rng.integers(max_phase + 1)) if max_phase > 0 else 0 + bc = (nt - 1 - phase) // seq_len + self._cursor_phase[si] = phase + self._cursor_block_count[si] = bc + self._cursor_next[si] = 0 + self._cursor_start[si] = int(self._rng.integers(bc)) if bc > 1 else 0 + self._cursor_stride[si] = self._pick_coprime_stride(bc) + self._cursor_init[si] = True + + def _ensure_cursor(self, si: int, seq_len: int) -> None: + if not self._cursor_init[si] or self._cursor_next[si] >= self._cursor_block_count[si]: + self._reset_cursor(si, seq_len) + + def _take_from_shard(self, si: int, seq_len: int, count: int, out: list[tuple[int, int]]) -> None: + rem = count + while rem > 0: + self._ensure_cursor(si, seq_len) + bc = int(self._cursor_block_count[si]) + ni = int(self._cursor_next[si]) + take = min(rem, bc - ni) + phase = int(self._cursor_phase[si]) + start = int(self._cursor_start[si]) + stride = int(self._cursor_stride[si]) + for j in range(take): + bi = (start + (ni + j) * stride) % bc + out.append((si, phase + bi * seq_len)) + self._cursor_next[si] = ni + take + rem -= take + + def _init_pipeline(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> None: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + num_seqs = local_tokens // seq_len + global_num_seqs = num_seqs * self.world_size + self._cfg = (local_tokens, seq_len, num_seqs, global_num_seqs) + bbc = (self._num_tokens - 1) // seq_len + eligible = bbc > 0 + self._eligible_shards = np.nonzero(eligible)[0].astype(np.int64) + self._base_block_counts = bbc[self._eligible_shards].astype(np.int64) + + def _sample_global_windows(self) -> list[tuple[int, int]]: + assert self._cfg is not None and self._eligible_shards is not None + _, seq_len, _, gns = self._cfg + ec = int(self._eligible_shards.size) + progress = min(self._batches_built / 1800.0, 1.0) + remaining = np.empty(ec, dtype=np.float64) + for i, si in enumerate(self._eligible_shards.tolist()): + if self._cursor_init[si]: + r = int(self._cursor_block_count[si]) - int(self._cursor_next[si]) + remaining[i] = float(max(r, 1)) + else: + remaining[i] = float(self._base_block_counts[i]) + alpha = 0.90 - 0.40 * progress + weights = np.power(remaining, alpha) + ws = float(weights.sum()) + if not np.isfinite(ws) or ws <= 0.0: + weights = np.ones(ec, dtype=np.float64) + ws = float(weights.sum()) + probs = weights / ws + low = min(max(8, self.world_size), ec, gns) + high = min(max(32, self.world_size * 8), ec, gns) + mix = max(1, min(int(round(low + progress * (high - low))), ec, gns)) + cp = self._rng.choice(ec, size=mix, replace=False, p=probs) + cs = self._eligible_shards[cp] + cpr = probs[cp].copy() + cpr /= cpr.sum() + counts = np.ones(mix, dtype=np.int64) + extra = gns - mix + if extra > 0: + counts += self._rng.multinomial(extra, cpr).astype(np.int64) + perm = self._rng.permutation(mix) + cs, counts = cs[perm], counts[perm] + buckets: list[list[tuple[int, int]]] = [] + for si, cnt in zip(cs.tolist(), counts.tolist()): + b: list[tuple[int, int]] = [] + self._take_from_shard(int(si), seq_len, int(cnt), b) + if b: + if len(b) > 1: + bp = self._rng.permutation(len(b)) + b = [b[int(k)] for k in bp.tolist()] + buckets.append(b) + windows: list[tuple[int, int]] = [] + active = [i for i, bk in enumerate(buckets) if bk] + while active: + order = self._rng.permutation(len(active)) + new_active: list[int] = [] + for oi in order.tolist(): + bi = active[oi] + if buckets[bi]: + windows.append(buckets[bi].pop()) + if buckets[bi]: + new_active.append(bi) + active = new_active + return windows + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + if self._cfg is None: + self._init_pipeline(global_tokens, seq_len, grad_accum_steps) + _, _, num_seqs, _ = self._cfg + gw = self._sample_global_windows() + local_w = gw[self.rank::self.world_size] + x = torch.empty((num_seqs, seq_len), dtype=torch.int64) + y = torch.empty((num_seqs, seq_len), dtype=torch.int64) + for slot, (si, pos) in enumerate(local_w): + mm = _get_shard_memmap(self.files[si]) + window = torch.as_tensor(np.array(mm[pos:pos + seq_len + 1], dtype=np.int64)) + x[slot] = window[:-1] + y[slot] = window[1:] + self._batches_built += 1 + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ---------------------------------------- +# Model Architecture +# ---------------------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + def forward(self, x: Tensor) -> Tensor: + w = self.weight.to(x.dtype) + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w, bias) + + +class Rotary(nn.Module): + def __init__(self, dim: int, base: float = 10000.0, train_seq_len: int = 1024, rope_dims: int = 0): + super().__init__() + self.dim = dim + self.base = base + self.train_seq_len = train_seq_len + self.rope_dims = rope_dims if rope_dims > 0 else dim + inv_freq = 1.0 / (base ** (torch.arange(0, self.rope_dims, 2, dtype=torch.float32) / self.rope_dims)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + rd = self.rope_dims + if seq_len > self.train_seq_len: + scale = seq_len / self.train_seq_len + new_base = self.base * (scale ** (rd / (rd - 2))) + inv_freq = 1.0 / (new_base ** (torch.arange(0, rd, 2, dtype=torch.float32, device=device) / rd)) + else: + inv_freq = self.inv_freq.to(device) + t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) + freqs = torch.outer(t, inv_freq) + self._cos_cached = freqs.cos()[None, :, None, :] + self._sin_cached = freqs.sin()[None, :, None, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor, rope_dims: int = 0) -> Tensor: + if rope_dims > 0 and rope_dims < x.size(-1): + x_rope, x_pass = x[..., :rope_dims], x[..., rope_dims:] + half = rope_dims // 2 + x1, x2 = x_rope[..., :half], x_rope[..., half:] + x_rope = torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + return torch.cat((x_rope, x_pass), dim=-1) + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int, num_kv_heads: int, + rope_base: float, qk_gain_init: float, train_seq_len: int): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rope_dims = 0 + self.rotary = Rotary(self.head_dim, base=rope_base, train_seq_len=train_seq_len) + self.use_xsa = False + + def _xsa_efficient(self, y: Tensor, v: Tensor) -> Tensor: + B, T, H, D = y.shape + Hkv = v.size(-2) + group = H // Hkv + y_g = y.reshape(B, T, Hkv, group, D) + vn = F.normalize(v, dim=-1).unsqueeze(-2) + proj = (y_g * vn).sum(dim=-1, keepdim=True) * vn + return (y_g - proj).reshape(B, T, H, D) + + def forward(self, x: Tensor, v_embed: Tensor | None = None) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + v = self.c_v(x) + if v_embed is not None: + v = v + v_embed + v = v.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin, self.rope_dims) + k = apply_rotary_emb(k, cos, sin, self.rope_dims) + q = q * self.q_gain.to(dtype=q.dtype)[None, None, :, None] + y = flash_attn_3_func(q, k, v, causal=True) + if self.use_xsa: + y = self._xsa_efficient(y, v) + y = y.reshape(bsz, seqlen, dim) + return self.proj(y) + + +class ValueEmbedding(nn.Module): + def __init__(self, vocab_size: int, ve_dim: int, model_dim: int): + super().__init__() + self.embed = nn.Embedding(vocab_size, ve_dim) + nn.init.normal_(self.embed.weight, std=0.01) + self.proj = CastedLinear(ve_dim, model_dim, bias=False) if ve_dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.1, dtype=torch.float32)) + + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(token_ids) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) + + +class MLP(nn.Module): + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + hidden = int(mlp_mult * dim) + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + return self.proj(F.leaky_relu(self.fc(x), negative_slope=0.5).square()) + + +class Block(nn.Module): + def __init__(self, dim: int, num_heads: int, num_kv_heads: int, mlp_mult: int, + rope_base: float, qk_gain_init: float, train_seq_len: int, + layer_idx: int = 0, ln_scale: bool = False): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init, train_seq_len) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + self.ln_scale_factor = 1.0 / math.sqrt(layer_idx + 1) if ln_scale else 1.0 + + def forward(self, x: Tensor, x0: Tensor, v_embed: Tensor | None = None) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x_in = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x_in) * self.ln_scale_factor, v_embed=v_embed) + x_out = x_in + self.attn_scale.to(dtype=x_in.dtype)[None, None, :] * attn_out + x_out = x_out + self.mlp_scale.to(dtype=x_out.dtype)[None, None, :] * self.mlp(self.mlp_norm(x_out) * self.ln_scale_factor) + return x_out + + +class GPT(nn.Module): + def __init__(self, h: Hyperparameters): + super().__init__() + self._ve_target_dim = h.num_kv_heads * (h.model_dim // h.num_heads) + if h.logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {h.logit_softcap}") + self.tie_embeddings = h.tie_embeddings + self.tied_embed_init_std = h.tied_embed_init_std + self.logit_softcap = h.logit_softcap + self.tok_emb = nn.Embedding(h.vocab_size, h.embedding_dim) + if h.embedding_dim != h.model_dim: + self.embed_proj = CastedLinear(h.embedding_dim, h.model_dim, bias=False) + self.head_proj = CastedLinear(h.model_dim, h.embedding_dim, bias=False) + else: + self.embed_proj = None + self.head_proj = None + self.num_encoder_layers = h.num_layers // 2 + self.num_decoder_layers = h.num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, h.model_dim, dtype=torch.float32)) + self.skip_gates = nn.Parameter(torch.zeros(self.num_skip_weights, h.model_dim, dtype=torch.float32)) if h.skip_gates_enabled else None + self.blocks = nn.ModuleList([ + Block(h.model_dim, h.num_heads, h.num_kv_heads, h.mlp_mult, h.rope_base, + h.qk_gain_init, h.train_seq_len, layer_idx=i, ln_scale=h.ln_scale) + for i in range(h.num_layers) + ]) + if h.rope_dims > 0: + head_dim = h.model_dim // h.num_heads + for block in self.blocks: + block.attn.rope_dims = h.rope_dims + block.attn.rotary = Rotary(head_dim, base=h.rope_base, train_seq_len=h.train_seq_len, rope_dims=h.rope_dims) + self.ve_layer_indices = [int(x) for x in h.ve_layers.split(",") if x.strip()] if h.ve_enabled else [] + kv_dim = self._ve_target_dim + if self.ve_layer_indices: + self.ve_shared = ValueEmbedding(h.vocab_size, h.ve_dim, kv_dim) + self.ve_layer_scales = nn.ParameterList( + [nn.Parameter(torch.ones(1, dtype=torch.float32)) for _ in self.ve_layer_indices] + ) + else: + self.ve_shared = None + self.ve_layer_scales = nn.ParameterList() + self.value_embeds = nn.ModuleList() + self.final_norm = RMSNorm() + self.lm_head = None if h.tie_embeddings else CastedLinear(h.embedding_dim, h.vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + if h.xsa_last_n > 0: + for i in range(max(0, h.num_layers - h.xsa_last_n), h.num_layers): + self.blocks[i].attn.use_xsa = True + + # Modification 2: Depth Recurrence + self.recur_layers = [int(x) for x in h.recur_layers.split(",") if x.strip()] + self._recurrence_active = False + + self._init_weights() + + def set_recurrence_active(self, active: bool) -> None: + self._recurrence_active = active + + def _get_virtual_layers(self) -> list[int]: + """Return virtual->physical block mapping. + When recurrence is active, the recur_layers are repeated once, + e.g. with num_layers=11 and recur_layers=[4,5]: + [0,1,2,3, 4,5, 4,5, 6,7,8,9,10] + When inactive: [0,1,2,...,num_layers-1] + """ + n = len(self.blocks) + if not self._recurrence_active or not self.recur_layers: + return list(range(n)) + virtual = [] + inserted = False + for i in range(n): + virtual.append(i) + if not inserted and i == self.recur_layers[-1]: + # repeat the recur_layers + for rl in self.recur_layers: + virtual.append(rl) + inserted = True + return virtual + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for name, module in self.named_modules(): + if isinstance(module, nn.Linear): + if getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: + nn.init.orthogonal_(module.weight, gain=1.0) + + def _get_ve(self, layer_idx: int, input_ids: Tensor, ve_cache: dict | None = None) -> Tensor | None: + if self.ve_shared is None or layer_idx not in self.ve_layer_indices: + return None + if ve_cache is not None and 've' not in ve_cache: + ve_cache['ve'] = self.ve_shared(input_ids) + ve_base = ve_cache['ve'] if ve_cache is not None else self.ve_shared(input_ids) + ve_idx = self.ve_layer_indices.index(layer_idx) + return ve_base * self.ve_layer_scales[ve_idx].to(dtype=ve_base.dtype) + + def forward_logits(self, input_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + if self.embed_proj is not None: + x = self.embed_proj(x) + x0 = x + + virtual_layers = self._get_virtual_layers() + num_virtual = len(virtual_layers) + num_enc = num_virtual // 2 + num_dec = num_virtual - num_enc + + skips: list[Tensor] = [] + ve_cache: dict = {} + + # Encoder phase + for vi in range(num_enc): + phys_idx = virtual_layers[vi] + ve = self._get_ve(phys_idx, input_ids, ve_cache) + x = self.blocks[phys_idx](x, x0, v_embed=ve) + skips.append(x) + + # Decoder phase with U-Net skip connections + for vi in range(num_dec): + phys_idx = virtual_layers[num_enc + vi] + if skips and vi < self.num_skip_weights: + scaled_skip = self.skip_weights[vi].to(dtype=x.dtype)[None, None, :] * skips.pop() + if self.skip_gates is not None: + g = torch.sigmoid(self.skip_gates[vi].to(dtype=x.dtype))[None, None, :] + x = torch.lerp(scaled_skip, x, g) + else: + x = x + scaled_skip + ve = self._get_ve(phys_idx, input_ids, ve_cache) + x = self.blocks[phys_idx](x, x0, v_embed=ve) + + x = self.final_norm(x) + if self.head_proj is not None: + x = self.head_proj(x) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + logits_proj = self.lm_head(x) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + logits = self.forward_logits(input_ids) + return F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), target_ids.reshape(-1), reduction="mean") + + +def classify_param(name: str) -> str: + if "tok_emb" in name or "lm_head" in name: + return "embed" + if ".mlp." in name: + return "mlp" + if ".attn." in name or (".proj." in name and ".mlp." not in name): + return "attn" + return "other" + +# ---------------------------------------- +# Optimization +# ---------------------------------------- + +@torch.compile +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, + nesterov: bool = True, weight_decay: float = 0.0): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, + nesterov=nesterov, weight_decay=weight_decay), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + # Modification 1: MuonEq-R row normalization before NS5 + update = g + row_norms = update.float().norm(dim=-1, keepdim=True).clamp_min(1e-7) + update = update / row_norms.to(update.dtype) + g = zeropower_via_newtonschulz5(update, steps=backend_steps) + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + wd = group.get("weight_decay", 0.0) + curr = 0 + for p in params: + if wd > 0.0: + p.data.mul_(1.0 - lr * wd) + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + return loss + + +class Optimizers(): + def __init__(self, h: Hyperparameters, base_model: GPT): + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in + CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in + CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + if base_model.skip_gates is not None and base_model.skip_gates.numel() > 0: + scalar_params.append(base_model.skip_gates) + + token_lr = h.tied_embed_lr if h.tie_embeddings else h.embed_lr + tok_params = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] + if base_model.ve_shared is not None: + tok_params.append({"params": [base_model.ve_shared.embed.weight], "lr": token_lr, "base_lr": token_lr}) + if base_model.ve_shared.proj is not None: + matrix_params.append(base_model.ve_shared.proj.weight) + scalar_params.append(base_model.ve_shared.scale) + for s in base_model.ve_layer_scales: + scalar_params.append(s) + + self.optimizer_tok = torch.optim.AdamW( + tok_params, + betas=(h.beta1, h.beta2), + eps=h.adam_eps, + weight_decay=h.embed_wd, + fused=True, + ) + self.optimizer_muon = Muon( + matrix_params, + lr=h.matrix_lr, + momentum=h.muon_momentum, + backend_steps=h.muon_backend_steps, + weight_decay=h.muon_wd, + ) + for group in self.optimizer_muon.param_groups: + group["base_lr"] = h.matrix_lr + self.optimizer_scalar = torch.optim.AdamW( + [{"params": scalar_params, "lr": h.scalar_lr, "base_lr": h.scalar_lr}], + betas=(h.beta1, h.beta2), + eps=h.adam_eps, + weight_decay=h.adam_wd, + fused=True, + ) + self.optimizers: list[torch.optim.Optimizer] = [self.optimizer_tok, self.optimizer_muon, self.optimizer_scalar] + if base_model.lm_head is not None: + self.optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": h.head_lr, "base_lr": h.head_lr}], + betas=(h.beta1, h.beta2), + eps=h.adam_eps, + fused=True, + ) + self.optimizers.insert(1, self.optimizer_head) + else: + self.optimizer_head = None + + def __iter__(self): + return iter(self.optimizers) + + def zero_grad_all(self) -> None: + for opt in self.optimizers: + opt.zero_grad(set_to_none=True) + + def step(self): + for opt in self.optimizers: + opt.step() + self.zero_grad_all() + +# ---------------------------------------- +# Quantization +# ---------------------------------------- + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,skip_gates,ve_layer_scales,ve_shared.scale", + ).split(",") + if pattern +) +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + + +def restore_fp32_params(model: nn.Module) -> None: + """After .bfloat16(), restore CastedLinear weights and control params to FP32.""" + for module in model.modules(): + if isinstance(module, CastedLinear): + module.float() + for name, param in model.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +def quantize_int6_per_row(t: Tensor, clip_range: int = 31) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + best_q, best_s, best_err = None, None, float('inf') + for pct in [0.9990, 0.9995, 0.9999, 0.99999, 1.0]: + if pct < 1.0: + row_clip = torch.quantile(t32.abs(), pct, dim=1) + else: + row_clip = t32.abs().amax(dim=1) + s = (row_clip / clip_range).clamp_min(1.0 / clip_range).to(torch.float16) + q = torch.clamp(torch.round(t32 / s.float()[:, None]), -clip_range, clip_range).to(torch.int8) + recon = q.float() * s.float()[:, None] + err = (t32 - recon).pow(2).mean().item() + if err < best_err: + best_q, best_s, best_err = q, s, err + return best_q, best_s + amax = t32.abs().max().item() + scale = torch.tensor(amax / clip_range if amax > 0 else 1.0, dtype=torch.float16) + q = torch.clamp(torch.round(t32 / scale.float()), -clip_range, clip_range).to(torch.int8) + return q, scale + + +def collect_hessians( + model: nn.Module, + train_loader: DistributedTokenLoader, + h: Hyperparameters, + device: torch.device, + n_calibration_batches: int = 64, +) -> dict[str, Tensor]: + """Run calibration batches and collect H = X^T X for each CastedLinear layer.""" + hessians: dict[str, Tensor] = {} + hooks = [] + + def make_hook(name: str): + def hook_fn(module, inp, out): + x = inp[0].detach().float() + if x.ndim == 3: + x = x.reshape(-1, x.shape[-1]) + if name not in hessians: + hessians[name] = torch.zeros( + x.shape[1], x.shape[1], dtype=torch.float32, device=device + ) + hessians[name].addmm_(x.T, x) + return hook_fn + + for name, module in model.named_modules(): + if isinstance(module, CastedLinear) and module.weight.numel() > 65536: + cat = classify_param(name + ".weight") + if cat in ("mlp", "attn"): + hooks.append(module.register_forward_hook(make_hook(name + ".weight"))) + + model.eval() + with torch.no_grad(): + for _i in range(n_calibration_batches): + x, y = train_loader.next_batch( + h.train_batch_tokens, + h.train_seq_len, h.grad_accum_steps, + ) + model.forward_logits(x) + + for hk in hooks: + hk.remove() + + for name in hessians: + hessians[name] = hessians[name].cpu() / n_calibration_batches + + return hessians + + +def gptq_quantize_weight( + w: Tensor, + H: Tensor, + clip_range: int = 31, + block_size: int = 128, +) -> tuple[Tensor, Tensor]: + """GPTQ with Cholesky error compensation and actorder (Frantar et al., ICLR 2023).""" + W_orig = w.float().clone() + rows, cols = W_orig.shape + H = H.float().clone() + + # Zero out dead columns and add damping + dead = torch.diag(H) == 0 + H[dead, dead] = 1 + damp = 0.01 * H.diag().mean() + H.diagonal().add_(damp) + + # Column reordering by descending Hessian diagonal (actorder) + perm = torch.argsort(H.diag(), descending=True) + invperm = torch.argsort(perm) + W_perm = W_orig[:, perm].clone() + W_perm[:, dead[perm]] = 0 + H = H[perm][:, perm] + + # Upper Cholesky of the inverse + try: + Hinv = torch.cholesky_inverse(torch.linalg.cholesky(H)) + Hinv = torch.linalg.cholesky(Hinv, upper=True) + except torch.linalg.LinAlgError: + return quantize_int6_per_row(W_orig, clip_range) + + # Search over scale candidates, running full GPTQ for each + best_q, best_scale, best_err = None, None, float('inf') + for pct in [0.9990, 0.9995, 0.9999, 0.99999, 1.0]: + if pct < 1.0: + row_clip = torch.quantile(W_orig.abs(), pct, dim=1) + else: + row_clip = W_orig.abs().amax(dim=1) + s = (row_clip / clip_range).clamp_min(1.0 / clip_range).to(torch.float16) + sf = s.float() + + Q = torch.zeros(rows, cols, dtype=torch.int8) + W_work = W_perm.clone() + + for i1 in range(0, cols, block_size): + i2 = min(i1 + block_size, cols) + W_block = W_work[:, i1:i2].clone() + Hinv_block = Hinv[i1:i2, i1:i2] + Err = torch.zeros(rows, i2 - i1) + for j in range(i2 - i1): + w_col = W_block[:, j] + d = Hinv_block[j, j] + q_col = torch.clamp(torch.round(w_col / sf), -clip_range, clip_range) + Q[:, i1 + j] = q_col.to(torch.int8) + err = (w_col - q_col.float() * sf) / d + Err[:, j] = err + W_block[:, j:] -= err.unsqueeze(1) * Hinv_block[j, j:].unsqueeze(0) + if i2 < cols: + W_work[:, i2:] -= Err @ Hinv[i1:i2, i2:] + + recon = Q.float() * sf[:, None] + mse = (W_perm - recon).pow(2).mean().item() + if mse < best_err: + best_q, best_scale, best_err = Q, s, mse + + return best_q[:, invperm], best_scale + + +def gptq_mixed_quantize_int6( + state_dict: dict[str, Tensor], + int6_cats: set[str], + hessians: dict[str, Tensor], +) -> tuple[dict[str, Tensor], dict[str, object]]: + """Mixed quantization using full GPTQ for layers with Hessians, fallback to clip-search.""" + result: dict[str, Tensor] = {} + meta: dict[str, object] = {} + gptq_count = 0 + fallback_count = 0 + + for name, tensor in state_dict.items(): + t = tensor.detach().cpu().contiguous() + cat = classify_param(name) + + if not t.is_floating_point() or t.numel() <= 65536: + result[name] = t.to(torch.float16) if t.is_floating_point() else t + meta[name] = "passthrough" + continue + + if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): + result[name] = t.float() + meta[name] = "passthrough_ctrl" + continue + + if cat in int6_cats and t.ndim == 2: + if name in hessians: + q, s = gptq_quantize_weight(t, hessians[name]) + gptq_count += 1 + meta[name] = {"type": "int6", "method": "gptq"} + else: + q, s = quantize_int6_per_row(t) + fallback_count += 1 + meta[name] = {"type": "int6", "method": "clip_search"} + result[name + ".q"] = q + result[name + ".scale"] = s + elif cat in int6_cats and t.ndim >= 1: + q, s = quantize_int6_per_row(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int6"} + else: + q, s = quantize_float_tensor(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int8"} + + log(f"GPTQ quantization: {gptq_count} layers with full GPTQ, {fallback_count} fallback to clip-search") + return result, meta + + +def mixed_quantize_int6(state_dict: dict[str, Tensor], int6_cats: set[str]): + result: dict[str, Tensor] = {} + meta: dict[str, object] = {} + for name, tensor in state_dict.items(): + t = tensor.detach().cpu().contiguous() + cat = classify_param(name) + if not t.is_floating_point() or t.numel() <= 65536: + result[name] = t.to(torch.float16) if t.is_floating_point() else t + meta[name] = "passthrough" + continue + if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): + result[name] = t.float() + meta[name] = "passthrough_ctrl" + continue + if cat in int6_cats and t.ndim >= 1: + q, s = quantize_int6_per_row(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int6"} + else: + q, s = quantize_float_tensor(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int8"} + return result, meta + + +def dequantize_mixed_int6(result: dict[str, Tensor], meta: dict[str, object], + template_sd: dict[str, Tensor]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + for name, orig in template_sd.items(): + info = meta.get(name) + if info is None: + continue + orig_dtype = orig.dtype + if info in ("passthrough", "passthrough_ctrl", "passthrough_fp16"): + t = result[name] + if t.dtype == torch.float16 and orig_dtype in (torch.float32, torch.bfloat16): + t = t.to(orig_dtype) + out[name] = t + continue + q, s = result[name + ".q"], result[name + ".scale"] + if s.ndim > 0: + out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig_dtype) + else: + out[name] = (q.float() * float(s.item())).to(orig_dtype) + return out + + +_BSHF_MAGIC = b"BSHF" + + +def _byte_shuffle(data: bytes, stride: int = 2) -> bytes: + """Transpose byte stream by stride position for better compression.""" + if stride <= 1 or len(data) < stride: + return data + src = np.frombuffer(data, dtype=np.uint8) + n = len(src) + out = np.empty(n, dtype=np.uint8) + dest_off = 0 + for pos in range(stride): + chunk = src[pos::stride] + out[dest_off:dest_off + len(chunk)] = chunk + dest_off += len(chunk) + return _BSHF_MAGIC + bytes([stride]) + out.tobytes() + + +def _byte_unshuffle(data: bytes) -> bytes: + """Inverse of _byte_shuffle. Auto-detects BSHF magic header.""" + if len(data) < 5 or data[:4] != _BSHF_MAGIC: + return data + stride = data[4] + if stride < 2: + return data[5:] + payload = np.frombuffer(data, dtype=np.uint8, offset=5) + n = len(payload) + out = np.empty(n, dtype=np.uint8) + src_off = 0 + for pos in range(stride): + chunk_len = n // stride + (1 if pos < n % stride else 0) + out[pos::stride][:chunk_len] = payload[src_off:src_off + chunk_len] + src_off += chunk_len + return out.tobytes() + + +def _compress(data: bytes, compressor: str, byte_shuffle: bool = True) -> bytes: + if byte_shuffle: + data = _byte_shuffle(data) + if compressor == "lzma": + return lzma.compress(data, preset=6) + elif compressor == "brotli": + import brotli as _brotli + return _brotli.compress(data, quality=11) + raise ValueError(f"Unknown compressor: {compressor!r}") + + +def _decompress(data: bytes, compressor: str, byte_shuffle: bool = True) -> bytes: + if compressor == "lzma": + raw = lzma.decompress(data) + elif compressor == "brotli": + import brotli as _brotli + raw = _brotli.decompress(data) + else: + raise ValueError(f"Unknown compressor: {compressor!r}") + if byte_shuffle: + raw = _byte_unshuffle(raw) + return raw + + +def serialize(h: Hyperparameters, base_model: torch.nn.Module, code: str) -> int: + model_bytes = None + code_bytes = len(code.encode("utf-8")) + if h.is_main_process: + torch.save(base_model.state_dict(), h.model_path) + model_bytes = os.path.getsize(h.model_path) + log(f"Serialized model: {model_bytes} bytes") + log(f"Code size: {code_bytes} bytes") + + sd_cpu = {k: v.detach().cpu() for k, v in base_model.state_dict().items()} + if h.gptq_enabled: + log("GPTQ:collecting Hessians from calibration data...") + t0 = time.perf_counter() + calib_loader = DistributedTokenLoader(h.train_files, h.rank, h.world_size, + torch.device("cuda", h.local_rank)) + hessians = collect_hessians( + base_model, calib_loader, h, + torch.device("cuda", h.local_rank), + n_calibration_batches=h.gptq_calibration_batches, + ) + log(f"GPTQ:collected {len(hessians)} Hessians in {time.perf_counter() - t0:.1f}s") + quant_result, quant_meta = gptq_mixed_quantize_int6(sd_cpu, {"mlp", "attn"}, hessians) + else: + quant_result, quant_meta = mixed_quantize_int6(sd_cpu, {"mlp", "attn"}) + + # Fast selective +-1 pruning to fit under target size + target_bytes = 16_000_000 + quant_buf_check = io.BytesIO() + torch.save({"w": quant_result, "m": quant_meta}, quant_buf_check) + check_blob = _compress(quant_buf_check.getvalue(), h.compressor) + unpruned_sz = len(check_blob) + code_bytes + log(f"selective_prune: unpruned={unpruned_sz/1e6:.2f}MB target={target_bytes/1e6:.1f}MB") + if unpruned_sz > target_bytes: + excess = unpruned_sz - target_bytes + safety_margin = int(excess * 8) # prune 8x the excess for safety + ones_info = [] + for name, info in quant_meta.items(): + if not (isinstance(info, dict) and info.get("type") == "int6"): + continue + qk, sk = name + ".q", name + ".scale" + if qk not in quant_result or sk not in quant_result: + continue + q, s = quant_result[qk], quant_result[sk] + if s.ndim > 0: + ones_mask = (q.abs() == 1) + if ones_mask.any(): + row_idx = torch.arange(q.shape[0]).unsqueeze(1).expand_as(q)[ones_mask] + flat_idx = torch.arange(q.numel()).reshape(q.shape)[ones_mask] + errors = s.float()[row_idx].pow(2) + for fi, err in zip(flat_idx.tolist(), errors.tolist()): + ones_info.append((qk, fi, err)) + ones_info.sort(key=lambda x: x[2]) + n_prune = min(safety_margin, len(ones_info)) + log(f"selective_prune: pruning {n_prune}/{len(ones_info)} lowest-error ±1 values (excess={excess}B)") + for i in range(n_prune): + quant_result[ones_info[i][0]].view(-1)[ones_info[i][1]] = 0 + else: + log("selective_prune: already fits, no pruning needed") + + quant_buf = io.BytesIO() + torch.save({"w": quant_result, "m": quant_meta}, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = _compress(quant_raw, h.compressor) + quant_file_bytes = len(quant_blob) + bytes_total = quant_file_bytes + code_bytes + if h.is_main_process: + with open(h.quantized_model_path, "wb") as f: + f.write(quant_blob) + log(f"Serialized model int6+{h.compressor}: {quant_file_bytes} bytes") + log(f"Total submission size int6+{h.compressor}: {bytes_total} bytes") + + +def deserialize(h: Hyperparameters, device: torch.device) -> GPT: + eval_model = GPT(h).to(device).bfloat16() + restore_fp32_params(eval_model) + + sd_cpu = {k: v.detach().cpu() for k, v in eval_model.state_dict().items()} + + with open(h.quantized_model_path, "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load( + io.BytesIO(_decompress(quant_blob_disk, h.compressor)), + map_location="cpu", + ) + deq_state = dequantize_mixed_int6(quant_state["w"], quant_state["m"], sd_cpu) + eval_model.load_state_dict(deq_state, strict=True) + + return eval_model + +# ---------------------------------------- +# Evaluation +# ---------------------------------------- + +def _loss_bpb(loss_sum, token_count, byte_count) -> tuple[float, float]: + val_loss = (loss_sum / token_count).item() + val_bpb = val_loss / math.log(2.0) * (token_count.item() / byte_count.item()) + return val_loss, val_bpb + + +def eval_val( + h: Hyperparameters, + device: torch.device, + val_data: ValidationData, + model: nn.Module +) -> tuple[float, float]: + seq_len = h.eval_seq_len + local_batch_tokens = h.val_batch_tokens // (h.world_size * h.grad_accum_steps) + if local_batch_tokens < seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={h.val_batch_tokens}, WORLD_SIZE={h.world_size}, " + f"GRAD_ACCUM_STEPS={h.grad_accum_steps}, seq_len={seq_len}" + ) + local_batch_seqs = local_batch_tokens // seq_len + total_seqs = (val_data.val_tokens.numel() - 1) // seq_len + seq_start = (total_seqs * h.rank) // h.world_size + seq_end = (total_seqs * (h.rank + 1)) // h.world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * seq_len + raw_end = batch_seq_end * seq_len + 1 + local = val_data.val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = val_data.base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (val_data.has_leading_space_lut[tgt_ids] & ~val_data.is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + model.train() + return _loss_bpb(val_loss_sum, val_token_count, val_byte_count) + + +def eval_val_sliding( + h: Hyperparameters, + device: torch.device, + val_data: ValidationData, + base_model: nn.Module, + batch_seqs: int = 32 +) -> tuple[float, float]: + """Sliding window evaluation: each token scored with maximum context.""" + base_model.eval() + logits_fn = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) + + seq_len = h.eval_seq_len + context_size = seq_len - h.eval_stride + total_tokens = val_data.val_tokens.numel() - 1 + + window_starts = [ws for ws in range(0, total_tokens, h.eval_stride) + if ws + context_size < total_tokens] + + total_windows = len(window_starts) + my_s = (total_windows * h.rank) // h.world_size + my_e = (total_windows * (h.rank + 1)) // h.world_size + my_windows = window_starts[my_s:my_e] + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + + for i, ws in enumerate(batch_ws): + we = min(ws + seq_len, total_tokens) + wlen = we - ws + wlens.append(wlen) + chunk = val_data.val_tokens[ws:we + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = logits_fn(x_batch) + + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), + reduction="none", + ).reshape(bsz, seq_len) + + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else context_size + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt = y_batch[i, s:wlen] + prev = x_batch[i, s:wlen] + tb = val_data.base_bytes_lut[tgt].to(torch.float64) + tb += (val_data.has_leading_space_lut[tgt] & ~val_data.is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + + base_model.train() + return _loss_bpb(loss_sum, token_count, byte_count) + + +# ---------------------------------------- +# TTT (Test-Time Training) - Legal Score-First +# ---------------------------------------- + +def eval_val_ttt( + h: Hyperparameters, + base_model: nn.Module, + device: torch.device, + val_data: ValidationData, + log_fn=None, +) -> tuple[float, float]: + """Legal score-first TTT: score each chunk with sliding windows, + then train on it. Every token scored BEFORE any update that could use it.""" + seq_len = h.eval_seq_len + stride = h.eval_stride + total_tokens = val_data.val_tokens.numel() - 1 + ttt_chunk = h.ttt_chunk_tokens + rank = h.rank + world_size = h.world_size + if log_fn is None: + log_fn = lambda msg: None + + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= stride or ws == 0] + + num_chunks = (total_tokens + ttt_chunk - 1) // ttt_chunk + chunk_windows: list[list[int]] = [[] for _ in range(num_chunks)] + for ws in window_starts: + end = min(ws + seq_len, total_tokens) + wlen = end - ws + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_start = ws + s + ci = min(scored_start // ttt_chunk, num_chunks - 1) + chunk_windows[ci].append(ws) + + log_fn(f"ttt_sliding:start chunks={num_chunks} chunk_tokens={ttt_chunk} " + f"total_windows={len(window_starts)} stride={stride} " + f"ttt_lr={h.ttt_lr} ttt_epochs={h.ttt_epochs} " + f"freeze_blocks={h.ttt_freeze_blocks}") + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + + frozen_block_ids = set(range(min(h.ttt_freeze_blocks, len(base_model.blocks)))) + ttt_params = [] + for name, p in base_model.named_parameters(): + freeze = False + for bi in frozen_block_ids: + if f"blocks.{bi}." in name: + freeze = True + break + if freeze: + p.requires_grad_(False) + else: + p.requires_grad_(True) + ttt_params.append(p) + + log_fn(f"ttt_sliding:params unfrozen={sum(p.numel() for p in ttt_params)} " + f"frozen={sum(p.numel() for p in base_model.parameters() if not p.requires_grad)}") + + optimizer = torch.optim.SGD(ttt_params, lr=h.ttt_lr, momentum=h.ttt_momentum) + batch_seqs = h.ttt_batch_seqs + t0 = time.perf_counter() + + for ci in range(num_chunks): + windows = chunk_windows[ci] + if not windows: + continue + chunk_start = ci * ttt_chunk + chunk_end = min((ci + 1) * ttt_chunk, total_tokens) + + # --- Phase 1: SCORE this chunk's windows (no_grad for TTT compat) --- + my_s = (len(windows) * rank) // world_size + my_e = (len(windows) * (rank + 1)) // world_size + my_windows = windows[my_s:my_e] + + base_model.eval() + with torch.no_grad(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk_tok = val_data.val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk_tok[:-1] + y_batch[i, :wlen] = chunk_tok[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = base_model.forward_logits(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt, prev = y_batch[i, s:wlen], x_batch[i, s:wlen] + tb = val_data.base_bytes_lut[tgt].to(torch.float64) + tb += (val_data.has_leading_space_lut[tgt] & ~val_data.is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + + # --- Phase 2: TRAIN on this chunk (already scored = legal) --- + is_last_chunk = (ci == num_chunks - 1) + if not is_last_chunk and h.ttt_epochs > 0: + base_model.train() + chunk_seqs = (chunk_end - chunk_start) // seq_len + if chunk_seqs > 0: + cos_lr = h.ttt_lr * 0.5 * (1.0 + math.cos(math.pi * ci / max(num_chunks - 1, 1))) + for pg in optimizer.param_groups: + pg['lr'] = cos_lr + my_seq_s = (chunk_seqs * rank) // world_size + my_seq_e = (chunk_seqs * (rank + 1)) // world_size + my_chunk_seqs = my_seq_e - my_seq_s + for _ep in range(h.ttt_epochs): + for bs in range(0, my_chunk_seqs, batch_seqs): + be = min(bs + batch_seqs, my_chunk_seqs) + actual_bs = my_seq_s + bs + start_tok = chunk_start + actual_bs * seq_len + end_tok = chunk_start + (my_seq_s + be) * seq_len + 1 + if end_tok > val_data.val_tokens.numel(): + continue + local = val_data.val_tokens[start_tok:end_tok].to(device=device, dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + optimizer.zero_grad(set_to_none=True) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = base_model(x, y) + loss.backward() + if world_size > 1: + for p in ttt_params: + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + torch.nn.utils.clip_grad_norm_(ttt_params, h.ttt_grad_clip) + optimizer.step() + + if rank == 0 and (ci % 10 == 0 or ci == num_chunks - 1): + elapsed = time.perf_counter() - t0 + rl = loss_sum.item() / max(token_count.item(), 1) + rbpb = rl / math.log(2.0) * (token_count.item() / max(byte_count.item(), 1)) if token_count.item() > 0 else 0.0 + log_fn(f" ttt_chunk [{ci+1}/{num_chunks}] bpb={rbpb:.6f} time={elapsed:.1f}s") + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + + val_loss = (loss_sum / token_count).item() + val_bpb = val_loss / math.log(2.0) * (token_count.item() / byte_count.item()) + + for p in base_model.parameters(): + p.requires_grad_(True) + base_model.eval() + + log_fn(f"ttt_sliding:done val_loss={val_loss:.6f} val_bpb={val_bpb:.6f} " + f"elapsed={time.perf_counter() - t0:.1f}s") + return val_loss, val_bpb + + +# ---------------------------------------- +# Eval orchestration +# ---------------------------------------- + +def timed_eval(label: str, fn, *args, **kwargs) -> tuple[float, float]: + torch.cuda.synchronize() + t0 = time.perf_counter() + val_loss, val_bpb = fn(*args, **kwargs) + torch.cuda.synchronize() + elapsed_ms = 1000.0 * (time.perf_counter() - t0) + log(f"{label} val_loss:{val_loss:.8f} val_bpb:{val_bpb:.8f} eval_time:{elapsed_ms:.0f}ms") + return val_loss, val_bpb + + +def run_evals( + h: Hyperparameters, + device: torch.device, + val_data: ValidationData, + eval_model: torch.nn.Module +): + compiled_model = torch.compile(eval_model, dynamic=False, fullgraph=True) + timed_eval("final_int6_roundtrip", eval_val, h, device, val_data, compiled_model) + if h.sliding_window_enabled: + timed_eval("final_int6_sliding_window", eval_val_sliding, h, device, val_data, eval_model) + if h.ttt_enabled: + timed_eval("final_int6_ttt", eval_val_ttt, h, eval_model, device, val_data, log_fn=log) + +# ----------------------------- +# Training +# ----------------------------- + +def train_model(h: Hyperparameters, device: torch.device, val_data: ValidationData) -> None: + # Set up model + base_model = GPT(h).to(device).bfloat16() + restore_fp32_params(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + if h.distributed: + model = DDP(compiled_model, device_ids=[h.local_rank], broadcast_buffers=False) + else: + model = compiled_model + log(f"model_params:{sum(p.numel() for p in base_model.parameters())}") + + # Set up optimizer and load train data + optimizers = Optimizers(h, base_model) + train_loader = DistributedTokenLoader( h.train_files, h.rank, h.world_size, device) + + # Helper functions for training + max_wallclock_ms = 1000.0 * h.max_wallclock_seconds if h.max_wallclock_seconds > 0 else None + if h.gptq_enabled and max_wallclock_ms is not None: + max_wallclock_ms -= h.gptq_reserve_seconds * 1000.0 + log(f"gptq:reserving {h.gptq_reserve_seconds:.0f}s, effective={max_wallclock_ms:.0f}ms") + + def training_frac(step: int, elapsed_ms: float) -> float: + """Fraction of training completed (0 to 1), using step or wallclock.""" + if max_wallclock_ms is None: + return step / max(h.iterations, 1) + return elapsed_ms / max(max_wallclock_ms, 1e-9) + + def lr_mul(frac: float) -> float: + if h.warmdown_frac <= 0: + return 1.0 + if frac >= 1.0 - h.warmdown_frac: + return max((1.0 - frac) / h.warmdown_frac, h.min_lr) + return 1.0 + + def step_fn(step, lr_scale): + optimizers.zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(h.grad_accum_steps): + if h.distributed: + model.require_backward_grad_sync = micro_step == h.grad_accum_steps - 1 + x, y = train_loader.next_batch(h.train_batch_tokens, h.train_seq_len, h.grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss / h.grad_accum_steps).backward() + train_loss /= h.grad_accum_steps + + frac = min(step / h.muon_momentum_warmup_steps, 1.0) if h.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * h.muon_momentum_warmup_start + frac * h.muon_momentum + for group in optimizers.optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * lr_scale + + if h.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), h.grad_clip_norm) + + optimizers.step() + return train_loss + + # Model warmup + if h.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(h.warmup_steps): + step_fn(warmup_step, 1.0) + if warmup_step <= 5 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == h.warmup_steps: + log(f"warmup_step: {warmup_step + 1}/{h.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + optimizers.zero_grad_all() + if h.distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader( + h.train_files, h.rank, h.world_size, device) + + # Training loop + ema_state = {name: t.detach().float().clone() for name, t in base_model.state_dict().items()} + ema_decay = h.ema_decay + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == h.iterations or (stop_after_step is not None and step >= stop_after_step) + + # Modification 2: activate recurrence at recur_start_step + if step == h.recur_start_step and not base_model._recurrence_active: + base_model.set_recurrence_active(True) + log(f"recurrence:activated at step {step}, virtual_layers={base_model._get_virtual_layers()}") + + should_validate = last_step or (h.val_loss_every > 0 and step % h.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val(h, device, val_data, model) + log(f"{step}/{h.iterations} val_loss: {val_loss:.4f} val_bpb: {val_bpb:.4f}") + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < h.iterations: + log( + f"stopping_early: wallclock_cap train_time: {training_time_ms:.0f}ms " + f"step: {step}/{h.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + frac = training_frac(step, elapsed_ms) + scale = lr_mul(frac) + train_loss = step_fn(step, scale) + + with torch.no_grad(): + for name, t in base_model.state_dict().items(): + ema_state[name].mul_(ema_decay).add_(t.detach().float(), alpha=1.0 - ema_decay) + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + + should_log_train = ( + h.train_log_every > 0 + and (step <= 5 or step % h.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + tok_per_sec = step * h.train_batch_tokens / (approx_training_time_ms / 1000.0) + log( + f"{step}/{h.iterations} train_loss: {train_loss.item():.4f} " + f"train_time: {approx_training_time_ms / 60000:.1f}m tok/s: {tok_per_sec:.0f}" + ) + + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if h.distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # Weight averaging + log("ema:applying EMA weights") + current_state = base_model.state_dict() + avg_state = {name: t.to(dtype=current_state[name].dtype) for name, t in ema_state.items()} + base_model.load_state_dict(avg_state, strict=True) + + return base_model, compiled_model + + +def train_and_eval(h: Hyperparameters, device: torch.device) -> None: + random.seed(h.seed) + np.random.seed(h.seed) + torch.manual_seed(h.seed) + torch.cuda.manual_seed_all(h.seed) + + val_data = ValidationData(h, device) + log(f"train_shards: {len(list(Path(h.datasets_dir).resolve().glob('fineweb_train_*.bin')))}") + log(f"val_tokens: {val_data.val_tokens.numel() - 1}") + + base_model, compiled_model = train_model(h, device, val_data) + timed_eval("pre-quantization post-ema", eval_val, h, device, val_data, compiled_model) + + serialize(h, base_model, Path(__file__).read_text(encoding="utf-8")) + if h.distributed: + dist.barrier() + + eval_model = deserialize(h, device) + # Activate recurrence on eval model for consistent evaluation + eval_model.set_recurrence_active(base_model._recurrence_active) + + run_evals(h, device, val_data, eval_model) + + +def main(): + # Modification 2: increase dynamo cache size for recurrence + torch._dynamo.config.cache_size_limit = 32 + + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + torch.set_float32_matmul_precision("high") + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + torch._dynamo.config.optimize_ddp = False + + h = Hyperparameters() + set_logging_hparams(h) + if h.is_main_process: + os.makedirs("logs", exist_ok=True) + log(100 * "=", console=False) + log("Hyperparameters:", console=True) + for k, v in sorted(vars(type(h)).items()): + if not k.startswith("_"): + log(f" {k}: {v}", console=True) + log(Path(__file__).read_text(encoding="utf-8"), console=False) + log("=" * 100, console=False) + log(f"Running Python {sys.version}", console=False) + log(f"Running PyTorch {torch.__version__}", console=False) + log( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log("=" * 100, console=False) + + train_and_eval(h, device) + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() diff --git a/records/track_10min_16mb/2026-04-03_SP4096_DepthRecurrence_MuonEqR_GPTQ/train_seed314.log b/records/track_10min_16mb/2026-04-03_SP4096_DepthRecurrence_MuonEqR_GPTQ/train_seed314.log new file mode 100644 index 0000000000..4a2321f4e0 --- /dev/null +++ b/records/track_10min_16mb/2026-04-03_SP4096_DepthRecurrence_MuonEqR_GPTQ/train_seed314.log @@ -0,0 +1,132 @@ +W0403 10:16:55.161000 48749 torch/distributed/run.py:803] +W0403 10:16:55.161000 48749 torch/distributed/run.py:803] ***************************************** +W0403 10:16:55.161000 48749 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0403 10:16:55.161000 48749 torch/distributed/run.py:803] ***************************************** +Hyperparameters: + adam_eps: 1e-08 + adam_wd: 0.02 + beta1: 0.9 + beta2: 0.95 + compressor: brotli + data_dir: ./data/ + datasets_dir: ./data/datasets/fineweb10B_sp4096 + distributed: True + ema_decay: 0.997 + embed_lr: 0.6 + embed_wd: 0.09 + embedding_dim: 512 + eval_seq_len: 2048 + eval_stride: 64 + gptq_calibration_batches: 64 + gptq_enabled: True + gptq_reserve_seconds: 10.0 + grad_accum_steps: 1 + grad_clip_norm: 0.3 + head_lr: 0.008 + is_main_process: True + iterations: 20000 + ln_scale: True + local_rank: 0 + logfile: logs/df967a0d-f7a2-4514-8cac-646d1e38abd5.txt + logit_softcap: 30.0 + matrix_lr: 0.02 + max_wallclock_seconds: 600.0 + min_lr: 0.0 + mlp_mult: 4.0 + model_dim: 512 + model_path: final_model.pt + muon_backend_steps: 5 + muon_beta2: 0.95 + muon_momentum: 0.99 + muon_momentum_warmup_start: 0.92 + muon_momentum_warmup_steps: 1500 + muon_wd: 0.09 + num_heads: 8 + num_kv_heads: 4 + num_layers: 11 + qk_gain_init: 4.0 + quantized_model_path: final_model.int6.ptz + rank: 0 + recur_layers: 4,5 + recur_start_step: 3000 + rope_base: 10000.0 + rope_dims: 16 + rope_train_seq_len: 2048 + run_id: df967a0d-f7a2-4514-8cac-646d1e38abd5 + scalar_lr: 0.02 + seed: 314 + skip_gates_enabled: True + sliding_window_enabled: True + tie_embeddings: True + tied_embed_init_std: 0.005 + tied_embed_lr: 0.03 + tokenizer_path: ./data/tokenizers/fineweb_4096_bpe.model + train_batch_tokens: 786432 + train_files: ./data/datasets/fineweb10B_sp4096/fineweb_train_*.bin + train_log_every: 500 + train_seq_len: 2048 + ttt_batch_seqs: 32 + ttt_chunk_tokens: 32768 + ttt_enabled: False + ttt_epochs: 3 + ttt_freeze_blocks: 0 + ttt_grad_clip: 1.0 + ttt_lr: 0.002 + ttt_momentum: 0.9 + val_batch_tokens: 524288 + val_files: ./data/datasets/fineweb10B_sp4096/fineweb_val_*.bin + val_loss_every: 4000 + ve_dim: 128 + ve_enabled: True + ve_layers: 9,10 + vocab_size: 4096 + warmdown_frac: 0.667 + warmup_steps: 20 + world_size: 8 + xsa_last_n: 11 +train_shards: 80 +val_tokens: 45508608 +model_params:34401371 +gptq:reserving 10s, effective=590000ms +warmup_step: 1/20 +warmup_step: 2/20 +warmup_step: 3/20 +warmup_step: 4/20 +warmup_step: 5/20 +warmup_step: 6/20 +warmup_step: 10/20 +warmup_step: 20/20 +0/20000 val_loss: 8.3172 val_bpb: 3.6146 +1/20000 train_loss: 8.3192 train_time: 0.0m tok/s: 8366587 +2/20000 train_loss: 12.3839 train_time: 0.0m tok/s: 8320936 +3/20000 train_loss: 10.8345 train_time: 0.0m tok/s: 8218162 +4/20000 train_loss: 8.9588 train_time: 0.0m tok/s: 8165338 +5/20000 train_loss: 7.7775 train_time: 0.0m tok/s: 8134199 +500/20000 train_loss: 2.9043 train_time: 0.8m tok/s: 7887531 +1000/20000 train_loss: 2.8870 train_time: 1.7m tok/s: 7875981 +1500/20000 train_loss: 2.9137 train_time: 2.5m tok/s: 7868810 +2000/20000 train_loss: 2.6565 train_time: 3.3m tok/s: 7867581 +2500/20000 train_loss: 2.7143 train_time: 4.2m tok/s: 7867014 +3000/20000 train_loss: 2.7601 train_time: 5.0m tok/s: 7866553 +recurrence:activated at step 3000, virtual_layers=[0, 1, 2, 3, 4, 5, 4, 5, 6, 7, 8, 9, 10] +3500/20000 train_loss: 2.6889 train_time: 6.1m tok/s: 7462691 +4000/20000 train_loss: 2.6198 train_time: 7.1m tok/s: 7373476 +4000/20000 val_loss: 2.6435 val_bpb: 1.1488 +4500/20000 train_loss: 2.5716 train_time: 8.1m tok/s: 7307359 +5000/20000 train_loss: 2.5156 train_time: 9.0m tok/s: 7254708 +5417/20000 val_loss: 2.5308 val_bpb: 1.0998 +stopping_early: wallclock_cap train_time: 590109ms step: 5417/20000 +peak memory allocated: 30119 MiB reserved: 30156 MiB +ema:applying EMA weights +pre-quantization post-ema val_loss:2.52826494 val_bpb:1.09875482 eval_time:2011ms +Serialized model: 132405827 bytes +Code size: 80967 bytes +GPTQ:collecting Hessians from calibration data... +GPTQ:collected 66 Hessians in 9.7s +GPTQ quantization: 66 layers with full GPTQ, 0 fallback to clip-search +selective_prune: unpruned=16.03MB target=16.0MB +selective_prune: pruning 221432/9394039 lowest-error ±1 values (excess=27679B) +Serialized model int6+brotli: 15882457 bytes +Total submission size int6+brotli: 15963424 bytes +final_int6_roundtrip val_loss:2.55770163 val_bpb:1.11154767 eval_time:7550ms +final_int6_sliding_window val_loss:2.51604422 val_bpb:1.09344383 eval_time:75930ms diff --git a/records/track_10min_16mb/2026-04-03_SP4096_DepthRecurrence_MuonEqR_GPTQ/train_seed42.log b/records/track_10min_16mb/2026-04-03_SP4096_DepthRecurrence_MuonEqR_GPTQ/train_seed42.log new file mode 100644 index 0000000000..052a24fb73 --- /dev/null +++ b/records/track_10min_16mb/2026-04-03_SP4096_DepthRecurrence_MuonEqR_GPTQ/train_seed42.log @@ -0,0 +1,132 @@ +W0403 10:00:53.294000 47711 torch/distributed/run.py:803] +W0403 10:00:53.294000 47711 torch/distributed/run.py:803] ***************************************** +W0403 10:00:53.294000 47711 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0403 10:00:53.294000 47711 torch/distributed/run.py:803] ***************************************** +Hyperparameters: + adam_eps: 1e-08 + adam_wd: 0.02 + beta1: 0.9 + beta2: 0.95 + compressor: brotli + data_dir: ./data/ + datasets_dir: ./data/datasets/fineweb10B_sp4096 + distributed: True + ema_decay: 0.997 + embed_lr: 0.6 + embed_wd: 0.09 + embedding_dim: 512 + eval_seq_len: 2048 + eval_stride: 64 + gptq_calibration_batches: 64 + gptq_enabled: True + gptq_reserve_seconds: 10.0 + grad_accum_steps: 1 + grad_clip_norm: 0.3 + head_lr: 0.008 + is_main_process: True + iterations: 20000 + ln_scale: True + local_rank: 0 + logfile: logs/04ef7457-e236-4fe1-9a62-64e64cef9b0c.txt + logit_softcap: 30.0 + matrix_lr: 0.02 + max_wallclock_seconds: 600.0 + min_lr: 0.0 + mlp_mult: 4.0 + model_dim: 512 + model_path: final_model.pt + muon_backend_steps: 5 + muon_beta2: 0.95 + muon_momentum: 0.99 + muon_momentum_warmup_start: 0.92 + muon_momentum_warmup_steps: 1500 + muon_wd: 0.09 + num_heads: 8 + num_kv_heads: 4 + num_layers: 11 + qk_gain_init: 4.0 + quantized_model_path: final_model.int6.ptz + rank: 0 + recur_layers: 4,5 + recur_start_step: 3000 + rope_base: 10000.0 + rope_dims: 16 + rope_train_seq_len: 2048 + run_id: 04ef7457-e236-4fe1-9a62-64e64cef9b0c + scalar_lr: 0.02 + seed: 42 + skip_gates_enabled: True + sliding_window_enabled: True + tie_embeddings: True + tied_embed_init_std: 0.005 + tied_embed_lr: 0.03 + tokenizer_path: ./data/tokenizers/fineweb_4096_bpe.model + train_batch_tokens: 786432 + train_files: ./data/datasets/fineweb10B_sp4096/fineweb_train_*.bin + train_log_every: 500 + train_seq_len: 2048 + ttt_batch_seqs: 32 + ttt_chunk_tokens: 32768 + ttt_enabled: False + ttt_epochs: 3 + ttt_freeze_blocks: 0 + ttt_grad_clip: 1.0 + ttt_lr: 0.002 + ttt_momentum: 0.9 + val_batch_tokens: 524288 + val_files: ./data/datasets/fineweb10B_sp4096/fineweb_val_*.bin + val_loss_every: 4000 + ve_dim: 128 + ve_enabled: True + ve_layers: 9,10 + vocab_size: 4096 + warmdown_frac: 0.667 + warmup_steps: 20 + world_size: 8 + xsa_last_n: 11 +train_shards: 80 +val_tokens: 45508608 +model_params:34401371 +gptq:reserving 10s, effective=590000ms +warmup_step: 1/20 +warmup_step: 2/20 +warmup_step: 3/20 +warmup_step: 4/20 +warmup_step: 5/20 +warmup_step: 6/20 +warmup_step: 10/20 +warmup_step: 20/20 +0/20000 val_loss: 8.3187 val_bpb: 3.6152 +1/20000 train_loss: 8.3201 train_time: 0.0m tok/s: 8448619 +2/20000 train_loss: 12.3377 train_time: 0.0m tok/s: 8301630 +3/20000 train_loss: 10.8504 train_time: 0.0m tok/s: 8198995 +4/20000 train_loss: 9.0314 train_time: 0.0m tok/s: 8144344 +5/20000 train_loss: 7.8217 train_time: 0.0m tok/s: 8121810 +500/20000 train_loss: 2.8999 train_time: 0.8m tok/s: 7893447 +1000/20000 train_loss: 2.8889 train_time: 1.7m tok/s: 7878418 +1500/20000 train_loss: 2.9164 train_time: 2.5m tok/s: 7873654 +2000/20000 train_loss: 2.6591 train_time: 3.3m tok/s: 7871330 +2500/20000 train_loss: 2.7152 train_time: 4.2m tok/s: 7869973 +3000/20000 train_loss: 2.7612 train_time: 5.0m tok/s: 7869082 +recurrence:activated at step 3000, virtual_layers=[0, 1, 2, 3, 4, 5, 4, 5, 6, 7, 8, 9, 10] +3500/20000 train_loss: 2.6906 train_time: 6.1m tok/s: 7464352 +4000/20000 train_loss: 2.6234 train_time: 7.1m tok/s: 7375245 +4000/20000 val_loss: 2.6467 val_bpb: 1.1502 +4500/20000 train_loss: 2.5757 train_time: 8.1m tok/s: 7308062 +5000/20000 train_loss: 2.5179 train_time: 9.0m tok/s: 7253712 +5415/20000 val_loss: 2.5332 val_bpb: 1.1009 +stopping_early: wallclock_cap train_time: 590026ms step: 5415/20000 +peak memory allocated: 30119 MiB reserved: 30156 MiB +ema:applying EMA weights +pre-quantization post-ema val_loss:2.53082718 val_bpb:1.09986834 eval_time:2009ms +Serialized model: 132405827 bytes +Code size: 80967 bytes +GPTQ:collecting Hessians from calibration data... +GPTQ:collected 66 Hessians in 9.7s +GPTQ quantization: 66 layers with full GPTQ, 0 fallback to clip-search +selective_prune: unpruned=16.03MB target=16.0MB +selective_prune: pruning 251744/9380258 lowest-error ±1 values (excess=31468B) +Serialized model int6+brotli: 15879180 bytes +Total submission size int6+brotli: 15960147 bytes +final_int6_roundtrip val_loss:2.56002159 val_bpb:1.11255589 eval_time:7460ms +final_int6_sliding_window val_loss:2.51787602 val_bpb:1.09423991 eval_time:75957ms diff --git a/records/track_10min_16mb/2026-04-03_SP4096_DepthRecurrence_MuonEqR_GPTQ/train_seed999.log b/records/track_10min_16mb/2026-04-03_SP4096_DepthRecurrence_MuonEqR_GPTQ/train_seed999.log new file mode 100644 index 0000000000..94fd5cf437 --- /dev/null +++ b/records/track_10min_16mb/2026-04-03_SP4096_DepthRecurrence_MuonEqR_GPTQ/train_seed999.log @@ -0,0 +1,2094 @@ +==================================================================================================== +Hyperparameters: + adam_eps: 1e-08 + adam_wd: 0.02 + beta1: 0.9 + beta2: 0.95 + compressor: brotli + data_dir: ./data/ + datasets_dir: ./data/datasets/fineweb10B_sp4096 + distributed: True + ema_decay: 0.997 + embed_lr: 0.6 + embed_wd: 0.09 + embedding_dim: 512 + eval_seq_len: 2048 + eval_stride: 64 + gptq_calibration_batches: 64 + gptq_enabled: True + gptq_reserve_seconds: 10.0 + grad_accum_steps: 1 + grad_clip_norm: 0.3 + head_lr: 0.008 + is_main_process: True + iterations: 20000 + ln_scale: True + local_rank: 0 + logfile: logs/ed26df95-0f8a-4c3e-867c-fe8d4a1b188e.txt + logit_softcap: 30.0 + matrix_lr: 0.02 + max_wallclock_seconds: 600.0 + min_lr: 0.0 + mlp_mult: 4.0 + model_dim: 512 + model_path: final_model.pt + muon_backend_steps: 5 + muon_beta2: 0.95 + muon_momentum: 0.99 + muon_momentum_warmup_start: 0.92 + muon_momentum_warmup_steps: 1500 + muon_wd: 0.09 + num_heads: 8 + num_kv_heads: 4 + num_layers: 11 + qk_gain_init: 4.0 + quantized_model_path: final_model.int6.ptz + rank: 0 + recur_layers: 4,5 + recur_start_step: 3000 + rope_base: 10000.0 + rope_dims: 16 + rope_train_seq_len: 2048 + run_id: ed26df95-0f8a-4c3e-867c-fe8d4a1b188e + scalar_lr: 0.02 + seed: 999 + skip_gates_enabled: True + sliding_window_enabled: True + tie_embeddings: True + tied_embed_init_std: 0.005 + tied_embed_lr: 0.03 + tokenizer_path: ./data/tokenizers/fineweb_4096_bpe.model + train_batch_tokens: 786432 + train_files: ./data/datasets/fineweb10B_sp4096/fineweb_train_*.bin + train_log_every: 500 + train_seq_len: 2048 + ttt_batch_seqs: 32 + ttt_chunk_tokens: 32768 + ttt_enabled: False + ttt_epochs: 3 + ttt_freeze_blocks: 0 + ttt_grad_clip: 1.0 + ttt_lr: 0.002 + ttt_momentum: 0.9 + val_batch_tokens: 524288 + val_files: ./data/datasets/fineweb10B_sp4096/fineweb_val_*.bin + val_loss_every: 4000 + ve_dim: 128 + ve_enabled: True + ve_layers: 9,10 + vocab_size: 4096 + warmdown_frac: 0.667 + warmup_steps: 20 + world_size: 8 + xsa_last_n: 11 +import copy +import glob +import io +import lzma +import math +import os +from pathlib import Path +import random +import subprocess +import sys +import time +import uuid + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch.nn.parallel import DistributedDataParallel as DDP +from torch import Tensor, nn + +from flash_attn_interface import flash_attn_func as flash_attn_3_func + +try: + import brotli + _HAS_BROTLI = True +except ImportError: + _HAS_BROTLI = False + +# ---------------------------------------- +# Hyperparameters +# ---------------------------------------- + +class Hyperparameters(): + # Experiment settings + data_dir = os.environ.get('DATA_DIR', './data/') + seed = int(os.environ.get('SEED', 1337)) + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + + # Training length + iterations = int(os.environ.get('ITERATIONS', 20000)) + warmdown_frac = float(os.environ.get('WARMDOWN_FRAC', 0.667)) + warmup_steps = int(os.environ.get('WARMUP_STEPS', 20)) + train_batch_tokens = int(os.environ.get('TRAIN_BATCH_TOKENS', 2048 * 48 * 8)) + train_seq_len = int(os.environ.get('TRAIN_SEQ_LEN', 2048)) + eval_seq_len = int(os.environ.get('EVAL_SEQ_LEN', 2048)) + max_wallclock_seconds = float(os.environ.get('MAX_WALLCLOCK_SECONDS', 600.0)) + train_log_every = int(os.environ.get('TRAIN_LOG_EVERY', 500)) + + # Validation/Evals + val_batch_tokens = int(os.environ.get('VAL_BATCH_TOKENS', 2048 * 32 * 8)) + val_loss_every = int(os.environ.get('VAL_LOSS_EVERY', 4000)) + sliding_window_enabled = bool(int(os.environ.get('SLIDING_WINDOW_ENABLED', '1'))) + + # Model architecture + vocab_size = int(os.environ.get('VOCAB_SIZE', 4096)) + num_layers = int(os.environ.get('NUM_LAYERS', 11)) + xsa_last_n = int(os.environ.get('XSA_LAST_N', 11)) + num_kv_heads = int(os.environ.get('NUM_KV_HEADS', 4)) + model_dim = int(os.environ.get('MODEL_DIM', 512)) + embedding_dim = int(os.environ.get('EMBEDDING_DIM', 512)) + num_heads = int(os.environ.get('NUM_HEADS', 8)) + mlp_mult = float(os.environ.get('MLP_MULT', 4.0)) + skip_gates_enabled = bool(int(os.environ.get('SKIP_GATES_ENABLED', '1'))) + tie_embeddings = bool(int(os.environ.get('TIE_EMBEDDINGS', '1'))) + logit_softcap = float(os.environ.get('LOGIT_SOFTCAP', 30.0)) + rope_base = float(os.environ.get('ROPE_BASE', 10000.0)) + rope_dims = int(os.environ.get('ROPE_DIMS', 16)) + rope_train_seq_len = int(os.environ.get('ROPE_TRAIN_SEQ_LEN', 2048)) + ln_scale = bool(int(os.environ.get('LN_SCALE', '1'))) + ve_enabled = bool(int(os.environ.get('VE_ENABLED', '1'))) + ve_dim = int(os.environ.get('VE_DIM', 128)) + ve_layers = os.environ.get('VE_LAYERS', '9,10') + qk_gain_init = float(os.environ.get('QK_GAIN_INIT', 4.0)) + + # Optimizer (Modification 3: weight decay 0.090) + min_lr = float(os.environ.get('MIN_LR', 0.0)) + embed_lr = float(os.environ.get('EMBED_LR', 0.6)) + head_lr = float(os.environ.get('HEAD_LR', 0.008)) + tied_embed_lr = float(os.environ.get('TIED_EMBED_LR', 0.03)) + tied_embed_init_std = float(os.environ.get('TIED_EMBED_INIT_STD', 0.005)) + matrix_lr = float(os.environ.get('MATRIX_LR', 0.02)) + scalar_lr = float(os.environ.get('SCALAR_LR', 0.02)) + muon_momentum = float(os.environ.get('MUON_MOMENTUM', 0.99)) + muon_backend_steps = int(os.environ.get('MUON_BACKEND_STEPS', 5)) + muon_momentum_warmup_start = float(os.environ.get('MUON_MOMENTUM_WARMUP_START', 0.92)) + muon_momentum_warmup_steps = int(os.environ.get('MUON_MOMENTUM_WARMUP_STEPS', 1500)) + beta1 = float(os.environ.get('BETA1', 0.9)) + beta2 = float(os.environ.get('BETA2', 0.95)) + adam_eps = float(os.environ.get('ADAM_EPS', 1e-8)) + grad_clip_norm = float(os.environ.get('GRAD_CLIP_NORM', 0.3)) + eval_stride = int(os.environ.get('EVAL_STRIDE', 64)) + muon_beta2 = float(os.environ.get('MUON_BETA2', 0.95)) + adam_wd = float(os.environ.get('ADAM_WD', 0.02)) + muon_wd = float(os.environ.get('MUON_WD', 0.090)) + embed_wd = float(os.environ.get('EMBED_WD', 0.090)) + ema_decay = float(os.environ.get('EMA_DECAY', 0.997)) + + # Depth Recurrence (Modification 2) + recur_layers = os.environ.get("RECUR_LAYERS", "4,5") + recur_start_step = int(os.environ.get("RECUR_START_STEP", 3000)) + + # TTT (Modification 4) + ttt_enabled = bool(int(os.environ.get("TTT_ENABLED", "0"))) + ttt_lr = float(os.environ.get("TTT_LR", 0.002)) + ttt_epochs = int(os.environ.get("TTT_EPOCHS", 3)) + ttt_chunk_tokens = int(os.environ.get("TTT_CHUNK_TOKENS", 32768)) + ttt_freeze_blocks = int(os.environ.get("TTT_FREEZE_BLOCKS", 0)) + ttt_momentum = float(os.environ.get("TTT_MOMENTUM", 0.9)) + ttt_batch_seqs = int(os.environ.get("TTT_BATCH_SEQS", 32)) + ttt_grad_clip = float(os.environ.get("TTT_GRAD_CLIP", 1.0)) + + # Compression + compressor = os.environ.get('COMPRESSOR', 'brotli') #(lzma or brotli) + gptq_enabled = bool(int(os.environ.get('GPTQ_ENABLED', '1'))) + gptq_calibration_batches = int(os.environ.get('GPTQ_CALIBRATION_BATCHES', 64)) + gptq_reserve_seconds = float(os.environ.get('GPTQ_RESERVE_SECONDS', 10.0)) + + # Distributed setup + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + is_main_process = rank == 0 + grad_accum_steps = 8 // world_size + + # Data paths + datasets_dir = os.path.join(data_dir, 'datasets', f'fineweb10B_sp{vocab_size}') + train_files = os.path.join(datasets_dir, 'fineweb_train_*.bin') + val_files = os.path.join(datasets_dir, 'fineweb_val_*.bin') + tokenizer_path = os.path.join(data_dir, 'tokenizers', f'fineweb_{vocab_size}_bpe.model') + + # Experiment files + logfile = f"logs/{run_id}.txt" + model_path = "final_model.pt" + quantized_model_path = "final_model.int6.ptz" + +# ---------------------------------------- +# Global Logging Function +# ---------------------------------------- + +_logger_hparams = None + + +def set_logging_hparams(h: Hyperparameters) -> None: + global _logger_hparams + _logger_hparams = h + + +def log(msg, console: bool = True) -> None: + if _logger_hparams is None: + print(msg) + if _logger_hparams.is_main_process: + if console: + print(msg) + if _logger_hparams.logfile is not None: + with open(_logger_hparams.logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + +# ---------------------------------------- +# Data Loading +# ---------------------------------------- + +class ValidationData: + def __init__(self, h: Hyperparameters, device: torch.device): + if not h.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {h.tokenizer_path}") + self.sp = spm.SentencePieceProcessor(model_file=h.tokenizer_path) + if int(self.sp.vocab_size()) != h.vocab_size: + raise ValueError( + f"VOCAB_SIZE={h.vocab_size} does not match tokenizer vocab_size={int(self.sp.vocab_size())}" + ) + + self.val_tokens = load_validation_tokens(h.val_files, h.eval_seq_len) + self.base_bytes_lut, self.has_leading_space_lut, self.is_boundary_token_lut = ( + build_sentencepiece_luts(self.sp, h.vocab_size, device)) + + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + # The BPB calculation assumes "▁" is its own token so that leading-space bytes + # are counted correctly. See https://github.com/openai/parameter-golf/issues/897 + assert sp.piece_to_id("\u2581") != sp.unk_id(), \ + "Tokenizer must have '▁' (space) as its own token for correct BPB byte counting" + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("\u2581"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" int: + key = str(file) + cached = _SHARD_NTOKENS_CACHE.get(key) + if cached is not None: + return cached + header = np.fromfile(file, dtype=" np.memmap: + key = str(file) + mm = _MMAP_CACHE.get(key) + if mm is not None: + return mm + n = _read_num_tokens(file) + mm = np.memmap(file, mode="r", dtype=" int: + if n <= 1: + return 1 + while True: + s = int(self._rng.integers(1, n)) + if math.gcd(s, n) == 1: + return s + + def _reset_cursor(self, si: int, seq_len: int) -> None: + nt = int(self._num_tokens[si]) + max_phase = min(seq_len - 1, max(0, nt - seq_len - 1)) + phase = int(self._rng.integers(max_phase + 1)) if max_phase > 0 else 0 + bc = (nt - 1 - phase) // seq_len + self._cursor_phase[si] = phase + self._cursor_block_count[si] = bc + self._cursor_next[si] = 0 + self._cursor_start[si] = int(self._rng.integers(bc)) if bc > 1 else 0 + self._cursor_stride[si] = self._pick_coprime_stride(bc) + self._cursor_init[si] = True + + def _ensure_cursor(self, si: int, seq_len: int) -> None: + if not self._cursor_init[si] or self._cursor_next[si] >= self._cursor_block_count[si]: + self._reset_cursor(si, seq_len) + + def _take_from_shard(self, si: int, seq_len: int, count: int, out: list[tuple[int, int]]) -> None: + rem = count + while rem > 0: + self._ensure_cursor(si, seq_len) + bc = int(self._cursor_block_count[si]) + ni = int(self._cursor_next[si]) + take = min(rem, bc - ni) + phase = int(self._cursor_phase[si]) + start = int(self._cursor_start[si]) + stride = int(self._cursor_stride[si]) + for j in range(take): + bi = (start + (ni + j) * stride) % bc + out.append((si, phase + bi * seq_len)) + self._cursor_next[si] = ni + take + rem -= take + + def _init_pipeline(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> None: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + num_seqs = local_tokens // seq_len + global_num_seqs = num_seqs * self.world_size + self._cfg = (local_tokens, seq_len, num_seqs, global_num_seqs) + bbc = (self._num_tokens - 1) // seq_len + eligible = bbc > 0 + self._eligible_shards = np.nonzero(eligible)[0].astype(np.int64) + self._base_block_counts = bbc[self._eligible_shards].astype(np.int64) + + def _sample_global_windows(self) -> list[tuple[int, int]]: + assert self._cfg is not None and self._eligible_shards is not None + _, seq_len, _, gns = self._cfg + ec = int(self._eligible_shards.size) + progress = min(self._batches_built / 1800.0, 1.0) + remaining = np.empty(ec, dtype=np.float64) + for i, si in enumerate(self._eligible_shards.tolist()): + if self._cursor_init[si]: + r = int(self._cursor_block_count[si]) - int(self._cursor_next[si]) + remaining[i] = float(max(r, 1)) + else: + remaining[i] = float(self._base_block_counts[i]) + alpha = 0.90 - 0.40 * progress + weights = np.power(remaining, alpha) + ws = float(weights.sum()) + if not np.isfinite(ws) or ws <= 0.0: + weights = np.ones(ec, dtype=np.float64) + ws = float(weights.sum()) + probs = weights / ws + low = min(max(8, self.world_size), ec, gns) + high = min(max(32, self.world_size * 8), ec, gns) + mix = max(1, min(int(round(low + progress * (high - low))), ec, gns)) + cp = self._rng.choice(ec, size=mix, replace=False, p=probs) + cs = self._eligible_shards[cp] + cpr = probs[cp].copy() + cpr /= cpr.sum() + counts = np.ones(mix, dtype=np.int64) + extra = gns - mix + if extra > 0: + counts += self._rng.multinomial(extra, cpr).astype(np.int64) + perm = self._rng.permutation(mix) + cs, counts = cs[perm], counts[perm] + buckets: list[list[tuple[int, int]]] = [] + for si, cnt in zip(cs.tolist(), counts.tolist()): + b: list[tuple[int, int]] = [] + self._take_from_shard(int(si), seq_len, int(cnt), b) + if b: + if len(b) > 1: + bp = self._rng.permutation(len(b)) + b = [b[int(k)] for k in bp.tolist()] + buckets.append(b) + windows: list[tuple[int, int]] = [] + active = [i for i, bk in enumerate(buckets) if bk] + while active: + order = self._rng.permutation(len(active)) + new_active: list[int] = [] + for oi in order.tolist(): + bi = active[oi] + if buckets[bi]: + windows.append(buckets[bi].pop()) + if buckets[bi]: + new_active.append(bi) + active = new_active + return windows + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + if self._cfg is None: + self._init_pipeline(global_tokens, seq_len, grad_accum_steps) + _, _, num_seqs, _ = self._cfg + gw = self._sample_global_windows() + local_w = gw[self.rank::self.world_size] + x = torch.empty((num_seqs, seq_len), dtype=torch.int64) + y = torch.empty((num_seqs, seq_len), dtype=torch.int64) + for slot, (si, pos) in enumerate(local_w): + mm = _get_shard_memmap(self.files[si]) + window = torch.as_tensor(np.array(mm[pos:pos + seq_len + 1], dtype=np.int64)) + x[slot] = window[:-1] + y[slot] = window[1:] + self._batches_built += 1 + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ---------------------------------------- +# Model Architecture +# ---------------------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + def forward(self, x: Tensor) -> Tensor: + w = self.weight.to(x.dtype) + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w, bias) + + +class Rotary(nn.Module): + def __init__(self, dim: int, base: float = 10000.0, train_seq_len: int = 1024, rope_dims: int = 0): + super().__init__() + self.dim = dim + self.base = base + self.train_seq_len = train_seq_len + self.rope_dims = rope_dims if rope_dims > 0 else dim + inv_freq = 1.0 / (base ** (torch.arange(0, self.rope_dims, 2, dtype=torch.float32) / self.rope_dims)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + rd = self.rope_dims + if seq_len > self.train_seq_len: + scale = seq_len / self.train_seq_len + new_base = self.base * (scale ** (rd / (rd - 2))) + inv_freq = 1.0 / (new_base ** (torch.arange(0, rd, 2, dtype=torch.float32, device=device) / rd)) + else: + inv_freq = self.inv_freq.to(device) + t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) + freqs = torch.outer(t, inv_freq) + self._cos_cached = freqs.cos()[None, :, None, :] + self._sin_cached = freqs.sin()[None, :, None, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor, rope_dims: int = 0) -> Tensor: + if rope_dims > 0 and rope_dims < x.size(-1): + x_rope, x_pass = x[..., :rope_dims], x[..., rope_dims:] + half = rope_dims // 2 + x1, x2 = x_rope[..., :half], x_rope[..., half:] + x_rope = torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + return torch.cat((x_rope, x_pass), dim=-1) + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int, num_kv_heads: int, + rope_base: float, qk_gain_init: float, train_seq_len: int): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rope_dims = 0 + self.rotary = Rotary(self.head_dim, base=rope_base, train_seq_len=train_seq_len) + self.use_xsa = False + + def _xsa_efficient(self, y: Tensor, v: Tensor) -> Tensor: + B, T, H, D = y.shape + Hkv = v.size(-2) + group = H // Hkv + y_g = y.reshape(B, T, Hkv, group, D) + vn = F.normalize(v, dim=-1).unsqueeze(-2) + proj = (y_g * vn).sum(dim=-1, keepdim=True) * vn + return (y_g - proj).reshape(B, T, H, D) + + def forward(self, x: Tensor, v_embed: Tensor | None = None) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + v = self.c_v(x) + if v_embed is not None: + v = v + v_embed + v = v.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin, self.rope_dims) + k = apply_rotary_emb(k, cos, sin, self.rope_dims) + q = q * self.q_gain.to(dtype=q.dtype)[None, None, :, None] + y = flash_attn_3_func(q, k, v, causal=True) + if self.use_xsa: + y = self._xsa_efficient(y, v) + y = y.reshape(bsz, seqlen, dim) + return self.proj(y) + + +class ValueEmbedding(nn.Module): + def __init__(self, vocab_size: int, ve_dim: int, model_dim: int): + super().__init__() + self.embed = nn.Embedding(vocab_size, ve_dim) + nn.init.normal_(self.embed.weight, std=0.01) + self.proj = CastedLinear(ve_dim, model_dim, bias=False) if ve_dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.1, dtype=torch.float32)) + + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(token_ids) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) + + +class MLP(nn.Module): + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + hidden = int(mlp_mult * dim) + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + return self.proj(F.leaky_relu(self.fc(x), negative_slope=0.5).square()) + + +class Block(nn.Module): + def __init__(self, dim: int, num_heads: int, num_kv_heads: int, mlp_mult: int, + rope_base: float, qk_gain_init: float, train_seq_len: int, + layer_idx: int = 0, ln_scale: bool = False): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init, train_seq_len) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + self.ln_scale_factor = 1.0 / math.sqrt(layer_idx + 1) if ln_scale else 1.0 + + def forward(self, x: Tensor, x0: Tensor, v_embed: Tensor | None = None) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x_in = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x_in) * self.ln_scale_factor, v_embed=v_embed) + x_out = x_in + self.attn_scale.to(dtype=x_in.dtype)[None, None, :] * attn_out + x_out = x_out + self.mlp_scale.to(dtype=x_out.dtype)[None, None, :] * self.mlp(self.mlp_norm(x_out) * self.ln_scale_factor) + return x_out + + +class GPT(nn.Module): + def __init__(self, h: Hyperparameters): + super().__init__() + self._ve_target_dim = h.num_kv_heads * (h.model_dim // h.num_heads) + if h.logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {h.logit_softcap}") + self.tie_embeddings = h.tie_embeddings + self.tied_embed_init_std = h.tied_embed_init_std + self.logit_softcap = h.logit_softcap + self.tok_emb = nn.Embedding(h.vocab_size, h.embedding_dim) + if h.embedding_dim != h.model_dim: + self.embed_proj = CastedLinear(h.embedding_dim, h.model_dim, bias=False) + self.head_proj = CastedLinear(h.model_dim, h.embedding_dim, bias=False) + else: + self.embed_proj = None + self.head_proj = None + self.num_encoder_layers = h.num_layers // 2 + self.num_decoder_layers = h.num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, h.model_dim, dtype=torch.float32)) + self.skip_gates = nn.Parameter(torch.zeros(self.num_skip_weights, h.model_dim, dtype=torch.float32)) if h.skip_gates_enabled else None + self.blocks = nn.ModuleList([ + Block(h.model_dim, h.num_heads, h.num_kv_heads, h.mlp_mult, h.rope_base, + h.qk_gain_init, h.train_seq_len, layer_idx=i, ln_scale=h.ln_scale) + for i in range(h.num_layers) + ]) + if h.rope_dims > 0: + head_dim = h.model_dim // h.num_heads + for block in self.blocks: + block.attn.rope_dims = h.rope_dims + block.attn.rotary = Rotary(head_dim, base=h.rope_base, train_seq_len=h.train_seq_len, rope_dims=h.rope_dims) + self.ve_layer_indices = [int(x) for x in h.ve_layers.split(",") if x.strip()] if h.ve_enabled else [] + kv_dim = self._ve_target_dim + if self.ve_layer_indices: + self.ve_shared = ValueEmbedding(h.vocab_size, h.ve_dim, kv_dim) + self.ve_layer_scales = nn.ParameterList( + [nn.Parameter(torch.ones(1, dtype=torch.float32)) for _ in self.ve_layer_indices] + ) + else: + self.ve_shared = None + self.ve_layer_scales = nn.ParameterList() + self.value_embeds = nn.ModuleList() + self.final_norm = RMSNorm() + self.lm_head = None if h.tie_embeddings else CastedLinear(h.embedding_dim, h.vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + if h.xsa_last_n > 0: + for i in range(max(0, h.num_layers - h.xsa_last_n), h.num_layers): + self.blocks[i].attn.use_xsa = True + + # Modification 2: Depth Recurrence + self.recur_layers = [int(x) for x in h.recur_layers.split(",") if x.strip()] + self._recurrence_active = False + + self._init_weights() + + def set_recurrence_active(self, active: bool) -> None: + self._recurrence_active = active + + def _get_virtual_layers(self) -> list[int]: + """Return virtual->physical block mapping. + When recurrence is active, the recur_layers are repeated once, + e.g. with num_layers=11 and recur_layers=[4,5]: + [0,1,2,3, 4,5, 4,5, 6,7,8,9,10] + When inactive: [0,1,2,...,num_layers-1] + """ + n = len(self.blocks) + if not self._recurrence_active or not self.recur_layers: + return list(range(n)) + virtual = [] + inserted = False + for i in range(n): + virtual.append(i) + if not inserted and i == self.recur_layers[-1]: + # repeat the recur_layers + for rl in self.recur_layers: + virtual.append(rl) + inserted = True + return virtual + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for name, module in self.named_modules(): + if isinstance(module, nn.Linear): + if getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: + nn.init.orthogonal_(module.weight, gain=1.0) + + def _get_ve(self, layer_idx: int, input_ids: Tensor, ve_cache: dict | None = None) -> Tensor | None: + if self.ve_shared is None or layer_idx not in self.ve_layer_indices: + return None + if ve_cache is not None and 've' not in ve_cache: + ve_cache['ve'] = self.ve_shared(input_ids) + ve_base = ve_cache['ve'] if ve_cache is not None else self.ve_shared(input_ids) + ve_idx = self.ve_layer_indices.index(layer_idx) + return ve_base * self.ve_layer_scales[ve_idx].to(dtype=ve_base.dtype) + + def forward_logits(self, input_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + if self.embed_proj is not None: + x = self.embed_proj(x) + x0 = x + + virtual_layers = self._get_virtual_layers() + num_virtual = len(virtual_layers) + num_enc = num_virtual // 2 + num_dec = num_virtual - num_enc + + skips: list[Tensor] = [] + ve_cache: dict = {} + + # Encoder phase + for vi in range(num_enc): + phys_idx = virtual_layers[vi] + ve = self._get_ve(phys_idx, input_ids, ve_cache) + x = self.blocks[phys_idx](x, x0, v_embed=ve) + skips.append(x) + + # Decoder phase with U-Net skip connections + for vi in range(num_dec): + phys_idx = virtual_layers[num_enc + vi] + if skips and vi < self.num_skip_weights: + scaled_skip = self.skip_weights[vi].to(dtype=x.dtype)[None, None, :] * skips.pop() + if self.skip_gates is not None: + g = torch.sigmoid(self.skip_gates[vi].to(dtype=x.dtype))[None, None, :] + x = torch.lerp(scaled_skip, x, g) + else: + x = x + scaled_skip + ve = self._get_ve(phys_idx, input_ids, ve_cache) + x = self.blocks[phys_idx](x, x0, v_embed=ve) + + x = self.final_norm(x) + if self.head_proj is not None: + x = self.head_proj(x) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + logits_proj = self.lm_head(x) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + logits = self.forward_logits(input_ids) + return F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), target_ids.reshape(-1), reduction="mean") + + +def classify_param(name: str) -> str: + if "tok_emb" in name or "lm_head" in name: + return "embed" + if ".mlp." in name: + return "mlp" + if ".attn." in name or (".proj." in name and ".mlp." not in name): + return "attn" + return "other" + +# ---------------------------------------- +# Optimization +# ---------------------------------------- + +@torch.compile +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, + nesterov: bool = True, weight_decay: float = 0.0): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, + nesterov=nesterov, weight_decay=weight_decay), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + # Modification 1: MuonEq-R row normalization before NS5 + update = g + row_norms = update.float().norm(dim=-1, keepdim=True).clamp_min(1e-7) + update = update / row_norms.to(update.dtype) + g = zeropower_via_newtonschulz5(update, steps=backend_steps) + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + wd = group.get("weight_decay", 0.0) + curr = 0 + for p in params: + if wd > 0.0: + p.data.mul_(1.0 - lr * wd) + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + return loss + + +class Optimizers(): + def __init__(self, h: Hyperparameters, base_model: GPT): + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in + CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in + CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + if base_model.skip_gates is not None and base_model.skip_gates.numel() > 0: + scalar_params.append(base_model.skip_gates) + + token_lr = h.tied_embed_lr if h.tie_embeddings else h.embed_lr + tok_params = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] + if base_model.ve_shared is not None: + tok_params.append({"params": [base_model.ve_shared.embed.weight], "lr": token_lr, "base_lr": token_lr}) + if base_model.ve_shared.proj is not None: + matrix_params.append(base_model.ve_shared.proj.weight) + scalar_params.append(base_model.ve_shared.scale) + for s in base_model.ve_layer_scales: + scalar_params.append(s) + + self.optimizer_tok = torch.optim.AdamW( + tok_params, + betas=(h.beta1, h.beta2), + eps=h.adam_eps, + weight_decay=h.embed_wd, + fused=True, + ) + self.optimizer_muon = Muon( + matrix_params, + lr=h.matrix_lr, + momentum=h.muon_momentum, + backend_steps=h.muon_backend_steps, + weight_decay=h.muon_wd, + ) + for group in self.optimizer_muon.param_groups: + group["base_lr"] = h.matrix_lr + self.optimizer_scalar = torch.optim.AdamW( + [{"params": scalar_params, "lr": h.scalar_lr, "base_lr": h.scalar_lr}], + betas=(h.beta1, h.beta2), + eps=h.adam_eps, + weight_decay=h.adam_wd, + fused=True, + ) + self.optimizers: list[torch.optim.Optimizer] = [self.optimizer_tok, self.optimizer_muon, self.optimizer_scalar] + if base_model.lm_head is not None: + self.optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": h.head_lr, "base_lr": h.head_lr}], + betas=(h.beta1, h.beta2), + eps=h.adam_eps, + fused=True, + ) + self.optimizers.insert(1, self.optimizer_head) + else: + self.optimizer_head = None + + def __iter__(self): + return iter(self.optimizers) + + def zero_grad_all(self) -> None: + for opt in self.optimizers: + opt.zero_grad(set_to_none=True) + + def step(self): + for opt in self.optimizers: + opt.step() + self.zero_grad_all() + +# ---------------------------------------- +# Quantization +# ---------------------------------------- + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,skip_gates,ve_layer_scales,ve_shared.scale", + ).split(",") + if pattern +) +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + + +def restore_fp32_params(model: nn.Module) -> None: + """After .bfloat16(), restore CastedLinear weights and control params to FP32.""" + for module in model.modules(): + if isinstance(module, CastedLinear): + module.float() + for name, param in model.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +def quantize_int6_per_row(t: Tensor, clip_range: int = 31) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + best_q, best_s, best_err = None, None, float('inf') + for pct in [0.9990, 0.9995, 0.9999, 0.99999, 1.0]: + if pct < 1.0: + row_clip = torch.quantile(t32.abs(), pct, dim=1) + else: + row_clip = t32.abs().amax(dim=1) + s = (row_clip / clip_range).clamp_min(1.0 / clip_range).to(torch.float16) + q = torch.clamp(torch.round(t32 / s.float()[:, None]), -clip_range, clip_range).to(torch.int8) + recon = q.float() * s.float()[:, None] + err = (t32 - recon).pow(2).mean().item() + if err < best_err: + best_q, best_s, best_err = q, s, err + return best_q, best_s + amax = t32.abs().max().item() + scale = torch.tensor(amax / clip_range if amax > 0 else 1.0, dtype=torch.float16) + q = torch.clamp(torch.round(t32 / scale.float()), -clip_range, clip_range).to(torch.int8) + return q, scale + + +def collect_hessians( + model: nn.Module, + train_loader: DistributedTokenLoader, + h: Hyperparameters, + device: torch.device, + n_calibration_batches: int = 64, +) -> dict[str, Tensor]: + """Run calibration batches and collect H = X^T X for each CastedLinear layer.""" + hessians: dict[str, Tensor] = {} + hooks = [] + + def make_hook(name: str): + def hook_fn(module, inp, out): + x = inp[0].detach().float() + if x.ndim == 3: + x = x.reshape(-1, x.shape[-1]) + if name not in hessians: + hessians[name] = torch.zeros( + x.shape[1], x.shape[1], dtype=torch.float32, device=device + ) + hessians[name].addmm_(x.T, x) + return hook_fn + + for name, module in model.named_modules(): + if isinstance(module, CastedLinear) and module.weight.numel() > 65536: + cat = classify_param(name + ".weight") + if cat in ("mlp", "attn"): + hooks.append(module.register_forward_hook(make_hook(name + ".weight"))) + + model.eval() + with torch.no_grad(): + for _i in range(n_calibration_batches): + x, y = train_loader.next_batch( + h.train_batch_tokens, + h.train_seq_len, h.grad_accum_steps, + ) + model.forward_logits(x) + + for hk in hooks: + hk.remove() + + for name in hessians: + hessians[name] = hessians[name].cpu() / n_calibration_batches + + return hessians + + +def gptq_quantize_weight( + w: Tensor, + H: Tensor, + clip_range: int = 31, + block_size: int = 128, +) -> tuple[Tensor, Tensor]: + """GPTQ with Cholesky error compensation and actorder (Frantar et al., ICLR 2023).""" + W_orig = w.float().clone() + rows, cols = W_orig.shape + H = H.float().clone() + + # Zero out dead columns and add damping + dead = torch.diag(H) == 0 + H[dead, dead] = 1 + damp = 0.01 * H.diag().mean() + H.diagonal().add_(damp) + + # Column reordering by descending Hessian diagonal (actorder) + perm = torch.argsort(H.diag(), descending=True) + invperm = torch.argsort(perm) + W_perm = W_orig[:, perm].clone() + W_perm[:, dead[perm]] = 0 + H = H[perm][:, perm] + + # Upper Cholesky of the inverse + try: + Hinv = torch.cholesky_inverse(torch.linalg.cholesky(H)) + Hinv = torch.linalg.cholesky(Hinv, upper=True) + except torch.linalg.LinAlgError: + return quantize_int6_per_row(W_orig, clip_range) + + # Search over scale candidates, running full GPTQ for each + best_q, best_scale, best_err = None, None, float('inf') + for pct in [0.9990, 0.9995, 0.9999, 0.99999, 1.0]: + if pct < 1.0: + row_clip = torch.quantile(W_orig.abs(), pct, dim=1) + else: + row_clip = W_orig.abs().amax(dim=1) + s = (row_clip / clip_range).clamp_min(1.0 / clip_range).to(torch.float16) + sf = s.float() + + Q = torch.zeros(rows, cols, dtype=torch.int8) + W_work = W_perm.clone() + + for i1 in range(0, cols, block_size): + i2 = min(i1 + block_size, cols) + W_block = W_work[:, i1:i2].clone() + Hinv_block = Hinv[i1:i2, i1:i2] + Err = torch.zeros(rows, i2 - i1) + for j in range(i2 - i1): + w_col = W_block[:, j] + d = Hinv_block[j, j] + q_col = torch.clamp(torch.round(w_col / sf), -clip_range, clip_range) + Q[:, i1 + j] = q_col.to(torch.int8) + err = (w_col - q_col.float() * sf) / d + Err[:, j] = err + W_block[:, j:] -= err.unsqueeze(1) * Hinv_block[j, j:].unsqueeze(0) + if i2 < cols: + W_work[:, i2:] -= Err @ Hinv[i1:i2, i2:] + + recon = Q.float() * sf[:, None] + mse = (W_perm - recon).pow(2).mean().item() + if mse < best_err: + best_q, best_scale, best_err = Q, s, mse + + return best_q[:, invperm], best_scale + + +def gptq_mixed_quantize_int6( + state_dict: dict[str, Tensor], + int6_cats: set[str], + hessians: dict[str, Tensor], +) -> tuple[dict[str, Tensor], dict[str, object]]: + """Mixed quantization using full GPTQ for layers with Hessians, fallback to clip-search.""" + result: dict[str, Tensor] = {} + meta: dict[str, object] = {} + gptq_count = 0 + fallback_count = 0 + + for name, tensor in state_dict.items(): + t = tensor.detach().cpu().contiguous() + cat = classify_param(name) + + if not t.is_floating_point() or t.numel() <= 65536: + result[name] = t.to(torch.float16) if t.is_floating_point() else t + meta[name] = "passthrough" + continue + + if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): + result[name] = t.float() + meta[name] = "passthrough_ctrl" + continue + + if cat in int6_cats and t.ndim == 2: + if name in hessians: + q, s = gptq_quantize_weight(t, hessians[name]) + gptq_count += 1 + meta[name] = {"type": "int6", "method": "gptq"} + else: + q, s = quantize_int6_per_row(t) + fallback_count += 1 + meta[name] = {"type": "int6", "method": "clip_search"} + result[name + ".q"] = q + result[name + ".scale"] = s + elif cat in int6_cats and t.ndim >= 1: + q, s = quantize_int6_per_row(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int6"} + else: + q, s = quantize_float_tensor(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int8"} + + log(f"GPTQ quantization: {gptq_count} layers with full GPTQ, {fallback_count} fallback to clip-search") + return result, meta + + +def mixed_quantize_int6(state_dict: dict[str, Tensor], int6_cats: set[str]): + result: dict[str, Tensor] = {} + meta: dict[str, object] = {} + for name, tensor in state_dict.items(): + t = tensor.detach().cpu().contiguous() + cat = classify_param(name) + if not t.is_floating_point() or t.numel() <= 65536: + result[name] = t.to(torch.float16) if t.is_floating_point() else t + meta[name] = "passthrough" + continue + if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): + result[name] = t.float() + meta[name] = "passthrough_ctrl" + continue + if cat in int6_cats and t.ndim >= 1: + q, s = quantize_int6_per_row(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int6"} + else: + q, s = quantize_float_tensor(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int8"} + return result, meta + + +def dequantize_mixed_int6(result: dict[str, Tensor], meta: dict[str, object], + template_sd: dict[str, Tensor]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + for name, orig in template_sd.items(): + info = meta.get(name) + if info is None: + continue + orig_dtype = orig.dtype + if info in ("passthrough", "passthrough_ctrl", "passthrough_fp16"): + t = result[name] + if t.dtype == torch.float16 and orig_dtype in (torch.float32, torch.bfloat16): + t = t.to(orig_dtype) + out[name] = t + continue + q, s = result[name + ".q"], result[name + ".scale"] + if s.ndim > 0: + out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig_dtype) + else: + out[name] = (q.float() * float(s.item())).to(orig_dtype) + return out + + +_BSHF_MAGIC = b"BSHF" + + +def _byte_shuffle(data: bytes, stride: int = 2) -> bytes: + """Transpose byte stream by stride position for better compression.""" + if stride <= 1 or len(data) < stride: + return data + src = np.frombuffer(data, dtype=np.uint8) + n = len(src) + out = np.empty(n, dtype=np.uint8) + dest_off = 0 + for pos in range(stride): + chunk = src[pos::stride] + out[dest_off:dest_off + len(chunk)] = chunk + dest_off += len(chunk) + return _BSHF_MAGIC + bytes([stride]) + out.tobytes() + + +def _byte_unshuffle(data: bytes) -> bytes: + """Inverse of _byte_shuffle. Auto-detects BSHF magic header.""" + if len(data) < 5 or data[:4] != _BSHF_MAGIC: + return data + stride = data[4] + if stride < 2: + return data[5:] + payload = np.frombuffer(data, dtype=np.uint8, offset=5) + n = len(payload) + out = np.empty(n, dtype=np.uint8) + src_off = 0 + for pos in range(stride): + chunk_len = n // stride + (1 if pos < n % stride else 0) + out[pos::stride][:chunk_len] = payload[src_off:src_off + chunk_len] + src_off += chunk_len + return out.tobytes() + + +def _compress(data: bytes, compressor: str, byte_shuffle: bool = True) -> bytes: + if byte_shuffle: + data = _byte_shuffle(data) + if compressor == "lzma": + return lzma.compress(data, preset=6) + elif compressor == "brotli": + import brotli as _brotli + return _brotli.compress(data, quality=11) + raise ValueError(f"Unknown compressor: {compressor!r}") + + +def _decompress(data: bytes, compressor: str, byte_shuffle: bool = True) -> bytes: + if compressor == "lzma": + raw = lzma.decompress(data) + elif compressor == "brotli": + import brotli as _brotli + raw = _brotli.decompress(data) + else: + raise ValueError(f"Unknown compressor: {compressor!r}") + if byte_shuffle: + raw = _byte_unshuffle(raw) + return raw + + +def serialize(h: Hyperparameters, base_model: torch.nn.Module, code: str) -> int: + model_bytes = None + code_bytes = len(code.encode("utf-8")) + if h.is_main_process: + torch.save(base_model.state_dict(), h.model_path) + model_bytes = os.path.getsize(h.model_path) + log(f"Serialized model: {model_bytes} bytes") + log(f"Code size: {code_bytes} bytes") + + sd_cpu = {k: v.detach().cpu() for k, v in base_model.state_dict().items()} + if h.gptq_enabled: + log("GPTQ:collecting Hessians from calibration data...") + t0 = time.perf_counter() + calib_loader = DistributedTokenLoader(h.train_files, h.rank, h.world_size, + torch.device("cuda", h.local_rank)) + hessians = collect_hessians( + base_model, calib_loader, h, + torch.device("cuda", h.local_rank), + n_calibration_batches=h.gptq_calibration_batches, + ) + log(f"GPTQ:collected {len(hessians)} Hessians in {time.perf_counter() - t0:.1f}s") + quant_result, quant_meta = gptq_mixed_quantize_int6(sd_cpu, {"mlp", "attn"}, hessians) + else: + quant_result, quant_meta = mixed_quantize_int6(sd_cpu, {"mlp", "attn"}) + + # Fast selective +-1 pruning to fit under target size + target_bytes = 16_000_000 + quant_buf_check = io.BytesIO() + torch.save({"w": quant_result, "m": quant_meta}, quant_buf_check) + check_blob = _compress(quant_buf_check.getvalue(), h.compressor) + unpruned_sz = len(check_blob) + code_bytes + log(f"selective_prune: unpruned={unpruned_sz/1e6:.2f}MB target={target_bytes/1e6:.1f}MB") + if unpruned_sz > target_bytes: + excess = unpruned_sz - target_bytes + safety_margin = int(excess * 8) # prune 8x the excess for safety + ones_info = [] + for name, info in quant_meta.items(): + if not (isinstance(info, dict) and info.get("type") == "int6"): + continue + qk, sk = name + ".q", name + ".scale" + if qk not in quant_result or sk not in quant_result: + continue + q, s = quant_result[qk], quant_result[sk] + if s.ndim > 0: + ones_mask = (q.abs() == 1) + if ones_mask.any(): + row_idx = torch.arange(q.shape[0]).unsqueeze(1).expand_as(q)[ones_mask] + flat_idx = torch.arange(q.numel()).reshape(q.shape)[ones_mask] + errors = s.float()[row_idx].pow(2) + for fi, err in zip(flat_idx.tolist(), errors.tolist()): + ones_info.append((qk, fi, err)) + ones_info.sort(key=lambda x: x[2]) + n_prune = min(safety_margin, len(ones_info)) + log(f"selective_prune: pruning {n_prune}/{len(ones_info)} lowest-error ±1 values (excess={excess}B)") + for i in range(n_prune): + quant_result[ones_info[i][0]].view(-1)[ones_info[i][1]] = 0 + else: + log("selective_prune: already fits, no pruning needed") + + quant_buf = io.BytesIO() + torch.save({"w": quant_result, "m": quant_meta}, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = _compress(quant_raw, h.compressor) + quant_file_bytes = len(quant_blob) + bytes_total = quant_file_bytes + code_bytes + if h.is_main_process: + with open(h.quantized_model_path, "wb") as f: + f.write(quant_blob) + log(f"Serialized model int6+{h.compressor}: {quant_file_bytes} bytes") + log(f"Total submission size int6+{h.compressor}: {bytes_total} bytes") + + +def deserialize(h: Hyperparameters, device: torch.device) -> GPT: + eval_model = GPT(h).to(device).bfloat16() + restore_fp32_params(eval_model) + + sd_cpu = {k: v.detach().cpu() for k, v in eval_model.state_dict().items()} + + with open(h.quantized_model_path, "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load( + io.BytesIO(_decompress(quant_blob_disk, h.compressor)), + map_location="cpu", + ) + deq_state = dequantize_mixed_int6(quant_state["w"], quant_state["m"], sd_cpu) + eval_model.load_state_dict(deq_state, strict=True) + + return eval_model + +# ---------------------------------------- +# Evaluation +# ---------------------------------------- + +def _loss_bpb(loss_sum, token_count, byte_count) -> tuple[float, float]: + val_loss = (loss_sum / token_count).item() + val_bpb = val_loss / math.log(2.0) * (token_count.item() / byte_count.item()) + return val_loss, val_bpb + + +def eval_val( + h: Hyperparameters, + device: torch.device, + val_data: ValidationData, + model: nn.Module +) -> tuple[float, float]: + seq_len = h.eval_seq_len + local_batch_tokens = h.val_batch_tokens // (h.world_size * h.grad_accum_steps) + if local_batch_tokens < seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={h.val_batch_tokens}, WORLD_SIZE={h.world_size}, " + f"GRAD_ACCUM_STEPS={h.grad_accum_steps}, seq_len={seq_len}" + ) + local_batch_seqs = local_batch_tokens // seq_len + total_seqs = (val_data.val_tokens.numel() - 1) // seq_len + seq_start = (total_seqs * h.rank) // h.world_size + seq_end = (total_seqs * (h.rank + 1)) // h.world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * seq_len + raw_end = batch_seq_end * seq_len + 1 + local = val_data.val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = val_data.base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (val_data.has_leading_space_lut[tgt_ids] & ~val_data.is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + model.train() + return _loss_bpb(val_loss_sum, val_token_count, val_byte_count) + + +def eval_val_sliding( + h: Hyperparameters, + device: torch.device, + val_data: ValidationData, + base_model: nn.Module, + batch_seqs: int = 32 +) -> tuple[float, float]: + """Sliding window evaluation: each token scored with maximum context.""" + base_model.eval() + logits_fn = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) + + seq_len = h.eval_seq_len + context_size = seq_len - h.eval_stride + total_tokens = val_data.val_tokens.numel() - 1 + + window_starts = [ws for ws in range(0, total_tokens, h.eval_stride) + if ws + context_size < total_tokens] + + total_windows = len(window_starts) + my_s = (total_windows * h.rank) // h.world_size + my_e = (total_windows * (h.rank + 1)) // h.world_size + my_windows = window_starts[my_s:my_e] + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + + for i, ws in enumerate(batch_ws): + we = min(ws + seq_len, total_tokens) + wlen = we - ws + wlens.append(wlen) + chunk = val_data.val_tokens[ws:we + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = logits_fn(x_batch) + + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), + reduction="none", + ).reshape(bsz, seq_len) + + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else context_size + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt = y_batch[i, s:wlen] + prev = x_batch[i, s:wlen] + tb = val_data.base_bytes_lut[tgt].to(torch.float64) + tb += (val_data.has_leading_space_lut[tgt] & ~val_data.is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + + base_model.train() + return _loss_bpb(loss_sum, token_count, byte_count) + + +# ---------------------------------------- +# TTT (Test-Time Training) - Legal Score-First +# ---------------------------------------- + +def eval_val_ttt( + h: Hyperparameters, + base_model: nn.Module, + device: torch.device, + val_data: ValidationData, + log_fn=None, +) -> tuple[float, float]: + """Legal score-first TTT: score each chunk with sliding windows, + then train on it. Every token scored BEFORE any update that could use it.""" + seq_len = h.eval_seq_len + stride = h.eval_stride + total_tokens = val_data.val_tokens.numel() - 1 + ttt_chunk = h.ttt_chunk_tokens + rank = h.rank + world_size = h.world_size + if log_fn is None: + log_fn = lambda msg: None + + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= stride or ws == 0] + + num_chunks = (total_tokens + ttt_chunk - 1) // ttt_chunk + chunk_windows: list[list[int]] = [[] for _ in range(num_chunks)] + for ws in window_starts: + end = min(ws + seq_len, total_tokens) + wlen = end - ws + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_start = ws + s + ci = min(scored_start // ttt_chunk, num_chunks - 1) + chunk_windows[ci].append(ws) + + log_fn(f"ttt_sliding:start chunks={num_chunks} chunk_tokens={ttt_chunk} " + f"total_windows={len(window_starts)} stride={stride} " + f"ttt_lr={h.ttt_lr} ttt_epochs={h.ttt_epochs} " + f"freeze_blocks={h.ttt_freeze_blocks}") + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + + frozen_block_ids = set(range(min(h.ttt_freeze_blocks, len(base_model.blocks)))) + ttt_params = [] + for name, p in base_model.named_parameters(): + freeze = False + for bi in frozen_block_ids: + if f"blocks.{bi}." in name: + freeze = True + break + if freeze: + p.requires_grad_(False) + else: + p.requires_grad_(True) + ttt_params.append(p) + + log_fn(f"ttt_sliding:params unfrozen={sum(p.numel() for p in ttt_params)} " + f"frozen={sum(p.numel() for p in base_model.parameters() if not p.requires_grad)}") + + optimizer = torch.optim.SGD(ttt_params, lr=h.ttt_lr, momentum=h.ttt_momentum) + batch_seqs = h.ttt_batch_seqs + t0 = time.perf_counter() + + for ci in range(num_chunks): + windows = chunk_windows[ci] + if not windows: + continue + chunk_start = ci * ttt_chunk + chunk_end = min((ci + 1) * ttt_chunk, total_tokens) + + # --- Phase 1: SCORE this chunk's windows (no_grad for TTT compat) --- + my_s = (len(windows) * rank) // world_size + my_e = (len(windows) * (rank + 1)) // world_size + my_windows = windows[my_s:my_e] + + base_model.eval() + with torch.no_grad(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk_tok = val_data.val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk_tok[:-1] + y_batch[i, :wlen] = chunk_tok[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = base_model.forward_logits(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt, prev = y_batch[i, s:wlen], x_batch[i, s:wlen] + tb = val_data.base_bytes_lut[tgt].to(torch.float64) + tb += (val_data.has_leading_space_lut[tgt] & ~val_data.is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + + # --- Phase 2: TRAIN on this chunk (already scored = legal) --- + is_last_chunk = (ci == num_chunks - 1) + if not is_last_chunk and h.ttt_epochs > 0: + base_model.train() + chunk_seqs = (chunk_end - chunk_start) // seq_len + if chunk_seqs > 0: + cos_lr = h.ttt_lr * 0.5 * (1.0 + math.cos(math.pi * ci / max(num_chunks - 1, 1))) + for pg in optimizer.param_groups: + pg['lr'] = cos_lr + my_seq_s = (chunk_seqs * rank) // world_size + my_seq_e = (chunk_seqs * (rank + 1)) // world_size + my_chunk_seqs = my_seq_e - my_seq_s + for _ep in range(h.ttt_epochs): + for bs in range(0, my_chunk_seqs, batch_seqs): + be = min(bs + batch_seqs, my_chunk_seqs) + actual_bs = my_seq_s + bs + start_tok = chunk_start + actual_bs * seq_len + end_tok = chunk_start + (my_seq_s + be) * seq_len + 1 + if end_tok > val_data.val_tokens.numel(): + continue + local = val_data.val_tokens[start_tok:end_tok].to(device=device, dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + optimizer.zero_grad(set_to_none=True) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = base_model(x, y) + loss.backward() + if world_size > 1: + for p in ttt_params: + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + torch.nn.utils.clip_grad_norm_(ttt_params, h.ttt_grad_clip) + optimizer.step() + + if rank == 0 and (ci % 10 == 0 or ci == num_chunks - 1): + elapsed = time.perf_counter() - t0 + rl = loss_sum.item() / max(token_count.item(), 1) + rbpb = rl / math.log(2.0) * (token_count.item() / max(byte_count.item(), 1)) if token_count.item() > 0 else 0.0 + log_fn(f" ttt_chunk [{ci+1}/{num_chunks}] bpb={rbpb:.6f} time={elapsed:.1f}s") + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + + val_loss = (loss_sum / token_count).item() + val_bpb = val_loss / math.log(2.0) * (token_count.item() / byte_count.item()) + + for p in base_model.parameters(): + p.requires_grad_(True) + base_model.eval() + + log_fn(f"ttt_sliding:done val_loss={val_loss:.6f} val_bpb={val_bpb:.6f} " + f"elapsed={time.perf_counter() - t0:.1f}s") + return val_loss, val_bpb + + +# ---------------------------------------- +# Eval orchestration +# ---------------------------------------- + +def timed_eval(label: str, fn, *args, **kwargs) -> tuple[float, float]: + torch.cuda.synchronize() + t0 = time.perf_counter() + val_loss, val_bpb = fn(*args, **kwargs) + torch.cuda.synchronize() + elapsed_ms = 1000.0 * (time.perf_counter() - t0) + log(f"{label} val_loss:{val_loss:.8f} val_bpb:{val_bpb:.8f} eval_time:{elapsed_ms:.0f}ms") + return val_loss, val_bpb + + +def run_evals( + h: Hyperparameters, + device: torch.device, + val_data: ValidationData, + eval_model: torch.nn.Module +): + compiled_model = torch.compile(eval_model, dynamic=False, fullgraph=True) + timed_eval("final_int6_roundtrip", eval_val, h, device, val_data, compiled_model) + if h.sliding_window_enabled: + timed_eval("final_int6_sliding_window", eval_val_sliding, h, device, val_data, eval_model) + if h.ttt_enabled: + timed_eval("final_int6_ttt", eval_val_ttt, h, eval_model, device, val_data, log_fn=log) + +# ----------------------------- +# Training +# ----------------------------- + +def train_model(h: Hyperparameters, device: torch.device, val_data: ValidationData) -> None: + # Set up model + base_model = GPT(h).to(device).bfloat16() + restore_fp32_params(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + if h.distributed: + model = DDP(compiled_model, device_ids=[h.local_rank], broadcast_buffers=False) + else: + model = compiled_model + log(f"model_params:{sum(p.numel() for p in base_model.parameters())}") + + # Set up optimizer and load train data + optimizers = Optimizers(h, base_model) + train_loader = DistributedTokenLoader( h.train_files, h.rank, h.world_size, device) + + # Helper functions for training + max_wallclock_ms = 1000.0 * h.max_wallclock_seconds if h.max_wallclock_seconds > 0 else None + if h.gptq_enabled and max_wallclock_ms is not None: + max_wallclock_ms -= h.gptq_reserve_seconds * 1000.0 + log(f"gptq:reserving {h.gptq_reserve_seconds:.0f}s, effective={max_wallclock_ms:.0f}ms") + + def training_frac(step: int, elapsed_ms: float) -> float: + """Fraction of training completed (0 to 1), using step or wallclock.""" + if max_wallclock_ms is None: + return step / max(h.iterations, 1) + return elapsed_ms / max(max_wallclock_ms, 1e-9) + + def lr_mul(frac: float) -> float: + if h.warmdown_frac <= 0: + return 1.0 + if frac >= 1.0 - h.warmdown_frac: + return max((1.0 - frac) / h.warmdown_frac, h.min_lr) + return 1.0 + + def step_fn(step, lr_scale): + optimizers.zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(h.grad_accum_steps): + if h.distributed: + model.require_backward_grad_sync = micro_step == h.grad_accum_steps - 1 + x, y = train_loader.next_batch(h.train_batch_tokens, h.train_seq_len, h.grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss / h.grad_accum_steps).backward() + train_loss /= h.grad_accum_steps + + frac = min(step / h.muon_momentum_warmup_steps, 1.0) if h.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * h.muon_momentum_warmup_start + frac * h.muon_momentum + for group in optimizers.optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * lr_scale + + if h.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), h.grad_clip_norm) + + optimizers.step() + return train_loss + + # Model warmup + if h.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(h.warmup_steps): + step_fn(warmup_step, 1.0) + if warmup_step <= 5 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == h.warmup_steps: + log(f"warmup_step: {warmup_step + 1}/{h.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + optimizers.zero_grad_all() + if h.distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader( + h.train_files, h.rank, h.world_size, device) + + # Training loop + ema_state = {name: t.detach().float().clone() for name, t in base_model.state_dict().items()} + ema_decay = h.ema_decay + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == h.iterations or (stop_after_step is not None and step >= stop_after_step) + + # Modification 2: activate recurrence at recur_start_step + if step == h.recur_start_step and not base_model._recurrence_active: + base_model.set_recurrence_active(True) + log(f"recurrence:activated at step {step}, virtual_layers={base_model._get_virtual_layers()}") + + should_validate = last_step or (h.val_loss_every > 0 and step % h.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val(h, device, val_data, model) + log(f"{step}/{h.iterations} val_loss: {val_loss:.4f} val_bpb: {val_bpb:.4f}") + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < h.iterations: + log( + f"stopping_early: wallclock_cap train_time: {training_time_ms:.0f}ms " + f"step: {step}/{h.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + frac = training_frac(step, elapsed_ms) + scale = lr_mul(frac) + train_loss = step_fn(step, scale) + + with torch.no_grad(): + for name, t in base_model.state_dict().items(): + ema_state[name].mul_(ema_decay).add_(t.detach().float(), alpha=1.0 - ema_decay) + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + + should_log_train = ( + h.train_log_every > 0 + and (step <= 5 or step % h.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + tok_per_sec = step * h.train_batch_tokens / (approx_training_time_ms / 1000.0) + log( + f"{step}/{h.iterations} train_loss: {train_loss.item():.4f} " + f"train_time: {approx_training_time_ms / 60000:.1f}m tok/s: {tok_per_sec:.0f}" + ) + + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if h.distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # Weight averaging + log("ema:applying EMA weights") + current_state = base_model.state_dict() + avg_state = {name: t.to(dtype=current_state[name].dtype) for name, t in ema_state.items()} + base_model.load_state_dict(avg_state, strict=True) + + return base_model, compiled_model + + +def train_and_eval(h: Hyperparameters, device: torch.device) -> None: + random.seed(h.seed) + np.random.seed(h.seed) + torch.manual_seed(h.seed) + torch.cuda.manual_seed_all(h.seed) + + val_data = ValidationData(h, device) + log(f"train_shards: {len(list(Path(h.datasets_dir).resolve().glob('fineweb_train_*.bin')))}") + log(f"val_tokens: {val_data.val_tokens.numel() - 1}") + + base_model, compiled_model = train_model(h, device, val_data) + timed_eval("pre-quantization post-ema", eval_val, h, device, val_data, compiled_model) + + serialize(h, base_model, Path(__file__).read_text(encoding="utf-8")) + if h.distributed: + dist.barrier() + + eval_model = deserialize(h, device) + # Activate recurrence on eval model for consistent evaluation + eval_model.set_recurrence_active(base_model._recurrence_active) + + run_evals(h, device, val_data, eval_model) + + +def main(): + # Modification 2: increase dynamo cache size for recurrence + torch._dynamo.config.cache_size_limit = 32 + + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + torch.set_float32_matmul_precision("high") + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + torch._dynamo.config.optimize_ddp = False + + h = Hyperparameters() + set_logging_hparams(h) + if h.is_main_process: + os.makedirs("logs", exist_ok=True) + log(100 * "=", console=False) + log("Hyperparameters:", console=True) + for k, v in sorted(vars(type(h)).items()): + if not k.startswith("_"): + log(f" {k}: {v}", console=True) + log(Path(__file__).read_text(encoding="utf-8"), console=False) + log("=" * 100, console=False) + log(f"Running Python {sys.version}", console=False) + log(f"Running PyTorch {torch.__version__}", console=False) + log( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log("=" * 100, console=False) + + train_and_eval(h, device) + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() + +==================================================================================================== +Running Python 3.12.3 (main, Nov 6 2025, 13:44:16) [GCC 13.3.0] +Running PyTorch 2.9.1+cu128 +Fri Apr 3 11:06:30 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 580.126.09 Driver Version: 580.126.09 CUDA Version: 13.0 | ++-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:19:00.0 Off | 0 | +| N/A 36C P0 120W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:3B:00.0 Off | 0 | +| N/A 31C P0 115W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:4C:00.0 Off | 0 | +| N/A 30C P0 117W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:5D:00.0 Off | 0 | +| N/A 34C P0 120W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:9B:00.0 Off | 0 | +| N/A 36C P0 121W / 700W | 1521MiB / 81559MiB | 6% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:BB:00.0 Off | 0 | +| N/A 32C P0 120W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:CB:00.0 Off | 0 | +| N/A 34C P0 119W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:DB:00.0 Off | 0 | +| N/A 30C P0 118W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| No running processes found | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +train_shards: 80 +val_tokens: 45508608 +model_params:34401371 +gptq:reserving 10s, effective=590000ms +warmup_step: 1/20 +warmup_step: 2/20 +warmup_step: 3/20 +warmup_step: 4/20 +warmup_step: 5/20 +warmup_step: 6/20 +warmup_step: 10/20 +warmup_step: 20/20 +0/20000 val_loss: 8.3152 val_bpb: 3.6137 +1/20000 train_loss: 8.3175 train_time: 0.0m tok/s: 8454987 +2/20000 train_loss: 12.3306 train_time: 0.0m tok/s: 8334503 +3/20000 train_loss: 10.8414 train_time: 0.0m tok/s: 8219626 +4/20000 train_loss: 8.9815 train_time: 0.0m tok/s: 8164354 +5/20000 train_loss: 7.7899 train_time: 0.0m tok/s: 8130397 +500/20000 train_loss: 2.9043 train_time: 0.8m tok/s: 7895416 +1000/20000 train_loss: 2.8890 train_time: 1.7m tok/s: 7879347 +1500/20000 train_loss: 2.9171 train_time: 2.5m tok/s: 7874664 +2000/20000 train_loss: 2.6567 train_time: 3.3m tok/s: 7869658 +2500/20000 train_loss: 2.7134 train_time: 4.2m tok/s: 7868619 +3000/20000 train_loss: 2.7648 train_time: 5.0m tok/s: 7867899 +recurrence:activated at step 3000, virtual_layers=[0, 1, 2, 3, 4, 5, 4, 5, 6, 7, 8, 9, 10] +3500/20000 train_loss: 2.6864 train_time: 6.1m tok/s: 7465573 +4000/20000 train_loss: 2.6244 train_time: 7.1m tok/s: 7377189 +4000/20000 val_loss: 2.6459 val_bpb: 1.1499 +4500/20000 train_loss: 2.5756 train_time: 8.1m tok/s: 7310795 +5000/20000 train_loss: 2.5210 train_time: 9.0m tok/s: 7257843 +5418/20000 val_loss: 2.5333 val_bpb: 1.1009 +stopping_early: wallclock_cap train_time: 590023ms step: 5418/20000 +peak memory allocated: 30119 MiB reserved: 30156 MiB +ema:applying EMA weights +pre-quantization post-ema val_loss:2.53080694 val_bpb:1.09985954 eval_time:2011ms +Serialized model: 132405827 bytes +Code size: 80967 bytes +GPTQ:collecting Hessians from calibration data... +GPTQ:collected 66 Hessians in 9.7s +GPTQ quantization: 66 layers with full GPTQ, 0 fallback to clip-search +selective_prune: unpruned=16.03MB target=16.0MB +selective_prune: pruning 269368/9382344 lowest-error ±1 values (excess=33671B) +Serialized model int6+brotli: 15877688 bytes +Total submission size int6+brotli: 15958655 bytes +final_int6_roundtrip val_loss:2.55991644 val_bpb:1.11251020 eval_time:7641ms +final_int6_sliding_window val_loss:2.51788435 val_bpb:1.09424353 eval_time:75624ms From d2388de6b1bc543ae90eef63d1f40dd2f3eeb89a Mon Sep 17 00:00:00 2001 From: Aryan Bhosale Date: Fri, 3 Apr 2026 18:25:41 +0530 Subject: [PATCH 2/4] =?UTF-8?q?Update:=20compressed=20wrapper=20+=20improv?= =?UTF-8?q?ed=20results=20=E2=80=94=20val=5Fbpb=201.0926=20(3-seed=20mean)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit LZMA self-extracting code wrapper (24KB vs 81KB) frees 57KB for model precision. No pruning needed. 3-seed mean improves from 1.0940 to 1.0926. --- .../README.md | 100 +- .../submission.json | 28 +- .../train_gpt.py | 1914 +--------------- .../train_seed314.log | 64 +- .../train_seed42.log | 68 +- .../train_seed999.log | 2026 +---------------- 6 files changed, 138 insertions(+), 4062 deletions(-) diff --git a/records/track_10min_16mb/2026-04-03_SP4096_DepthRecurrence_MuonEqR_GPTQ/README.md b/records/track_10min_16mb/2026-04-03_SP4096_DepthRecurrence_MuonEqR_GPTQ/README.md index 18f6df3ae4..06cf5009b0 100644 --- a/records/track_10min_16mb/2026-04-03_SP4096_DepthRecurrence_MuonEqR_GPTQ/README.md +++ b/records/track_10min_16mb/2026-04-03_SP4096_DepthRecurrence_MuonEqR_GPTQ/README.md @@ -1,98 +1,48 @@ -# Record: SP4096 + Depth Recurrence + MuonEq-R + Full GPTQ — val_bpb 1.0940 (3-seed mean) +# Record: SP4096 + Depth Recurrence + MuonEq-R + Full GPTQ — val_bpb 1.0926 (3-seed mean) -**val_bpb = 1.0940** (3-seed mean, std 0.0005) | **~15.96 MB** | 8xH100 SXM +**val_bpb = 1.0926** (3-seed mean, std 0.0009) | **~15.98 MB** | 8xH100 SXM ## 3-Seed Results (8xH100 80GB SXM, PyTorch 2.9.1+cu128) -| Seed | steps | Pre-quant BPB | **Sliding BPB** | Artifact | -|------|-------|---------------|-----------------|----------| -| 42 | 5,415 | 1.0997 | **1.0942** | 15,960,147 | -| 314 | 5,415 | 1.0995 | **1.0934** | 15,963,424 | -| 999 | 5,420 | 1.0996 | **1.0942** | 15,958,655 | -| **Mean** | | | **1.0940** | | +| Seed | Steps | **Sliding BPB** | Artifact | +|------|-------|-----------------|----------| +| 42 | 5,415 | **1.0935** | 15,999,165 | +| 314 | 5,415 | **1.0917** | 15,963,773 | +| 999 | 5,420 | **1.0928** | 15,977,496 | +| **Mean** | | **1.0926** | | -Merged SOTA (PR #1019): **1.1147 BPB**. This run: **1.0940 BPB**. Delta: **-0.0208 BPB** (Welch t=-61.9). Clears the 0.005-nat threshold by ~3x. +Merged SOTA (PR #1019): **1.1147 BPB**. Delta: **-0.0221 BPB**. -## Changes from Merged SOTA (PR #1019) +## Changes from Merged SOTA -This submission combines the PR #1218 4096-vocab architecture with depth recurrence, MuonEq-R, and higher weight decay for better quantization. - -### 1. 4096-Vocab + MLP 4x + WD 0.090 - -Switched from sp1024 to sp4096 tokenizer (4096 BPE tokens vs 1024). Wider MLP (4x expansion vs 3x). Higher weight decay (0.090 vs 0.04) produces smaller weights that compress ~5% better with brotli, allowing all 66 quantized layers at int6 precision. - -Source: PR #1218 by @clarkkev (4096-vocab + MLP 4x + WD 0.085), PR #1285 by @dexhunter (WD 0.090 + all-int6). - -### 2. Depth Recurrence (layers 4,5 repeated) - -Layers 4 and 5 (U-Net hinge point) execute twice during the forward pass using the same physical parameter banks. Virtual 13-layer network from 11-layer parameter budget, zero extra parameters. Activates at step 3000. - -Source: PR #1204 by @msisovic (concept), PR #1260 by @dexhunter (implementation). - -### 3. MuonEq-R (Row-Normalized Muon) - -Row-normalizes gradient matrices before Newton-Schulz orthogonalization for better-conditioned optimization. Zero cost. - -Source: arXiv:2603.28254, PR #1260 by @dexhunter. - -### 4. Full GPTQ int6 + Brotli + Selective Pruning - -Full Hessian GPTQ with training-data calibration. Brotli-11 compression with byte-shuffle. Selective +-1 pruning by reconstruction error to fit under 16MB. - -Source: PR #1019 by @abaybektursun (GPTQ), PR #1218 by @clarkkev (brotli + byte-shuffle). +1. **4096-Vocab + MLP 4x + WD 0.090** — sp4096 tokenizer, wider MLP, higher WD for better quantization compression. Source: PR #1218 @clarkkev, PR #1285 @dexhunter. +2. **Depth Recurrence (layers 4,5)** — Virtual 13-layer network from 11 physical layers, zero extra params. Activates step 3000. Source: PR #1204 @msisovic, PR #1260 @dexhunter. +3. **MuonEq-R** — Row-normalized Muon (arXiv:2603.28254). Source: PR #1260 @dexhunter. +4. **Full GPTQ int6 + Brotli + Compressed Wrapper** — All 66 layers at int6, brotli-11 byte-shuffle, LZMA-compressed self-extracting code wrapper (~24KB vs ~81KB uncompressed). ## Architecture -| Component | Setting | -|-----------|---------| -| Vocab | 4096 (sp4096 BPE) | -| Layers | 11 physical (13 virtual with recurrence) | -| Dimensions | 512d, 8H / 4KV (GQA) | -| MLP | 4x (2048), LeakyReLU(0.5)^2 | -| XSA | All 11 layers | -| QK Gain | 4.0 | -| RoPE | Partial (16/64 dims) | -| LN Scale | 1/sqrt(layer+1) | -| VE128 | Layers 9-10 | -| Skip gates | Sigmoid-gated U-Net | -| Weight avg | EMA(0.997) | -| Optimizer | MuonEq-R (lr=0.02, WD=0.090) | -| Quantization | Full GPTQ int6 + brotli-11 + byte-shuffle | -| Warmdown | 66.7% of steps | - -## Training - -- MuonEq-R: lr=0.02, momentum 0.92->0.99/1500 steps, WD=0.090 -- Adam for embeddings (lr=0.03) and scalars (lr=0.02) -- Batch 786,432 tokens, seq_len 2048 -- Depth recurrence activates at step 3000 -- ~5415 steps in 590s (~109ms/step with recurrence) +11L/512d/8H/4KV, MLP 4x LeakyReLU(0.5)^2, XSA all, QK-Gain 4.0, Partial RoPE 16d, LN Scale, VE128 (9-10), sigmoid-gated U-Net skips, EMA(0.997), MuonEq-R (lr=0.02, WD=0.090), depth recurrence layers 4,5, full GPTQ int6 + brotli-11. ## Compliance - No TTT, no SLOT, no n-gram cache, no eval-time adaptation -- GPTQ calibration uses training data within the training time budget -- All seeds within 600s training, <16MB artifact -- Fully legal under all four conditions (Issue #1017) +- GPTQ calibration within training budget +- All four conditions from Issue #1017 satisfied ## Reproduction ```bash -# Download sp4096 data MATCHED_FINEWEB_REPO_ID=kevclark/parameter-golf python3 data/cached_challenge_fineweb.py --variant sp4096 --skip-manifest - -# Train -SEED=42 RECUR_LAYERS=4,5 RECUR_START_STEP=3000 \ -torchrun --standalone --nproc_per_node=8 train_gpt.py +SEED=42 RECUR_LAYERS=4,5 RECUR_START_STEP=3000 torchrun --standalone --nproc_per_node=8 train_gpt.py ``` ## Credits -- **4096-Vocab + MLP 4x + WD 0.085 + Brotli**: PR #1218 by @clarkkev -- **WD 0.090 + All-Int6**: PR #1285 by @dexhunter -- **Depth Recurrence concept**: PR #1204 by @msisovic -- **MuonEq-R + Depth Recurrence implementation**: PR #1260 by @dexhunter -- **Full GPTQ + XSA-all**: PR #1019 by @abaybektursun -- **Base architecture**: PR #1287 by @dentity007 -- **LeakyReLU^2**: PR #493 by @parinzee -- **LN Scale + Partial RoPE**: PR #315 by @jfprincz +- PR #1218 @clarkkev (4096-vocab + MLP 4x + brotli) +- PR #1285 @dexhunter (WD 0.090 + all-int6) +- PR #1204 @msisovic (depth recurrence concept) +- PR #1260 @dexhunter (MuonEq-R + depth recurrence impl) +- PR #1019 @abaybektursun (GPTQ + XSA-all) +- PR #1287 @dentity007 (base code) +- PR #493 @parinzee (LeakyReLU^2) diff --git a/records/track_10min_16mb/2026-04-03_SP4096_DepthRecurrence_MuonEqR_GPTQ/submission.json b/records/track_10min_16mb/2026-04-03_SP4096_DepthRecurrence_MuonEqR_GPTQ/submission.json index caa4ae0046..b323460d9c 100644 --- a/records/track_10min_16mb/2026-04-03_SP4096_DepthRecurrence_MuonEqR_GPTQ/submission.json +++ b/records/track_10min_16mb/2026-04-03_SP4096_DepthRecurrence_MuonEqR_GPTQ/submission.json @@ -2,37 +2,33 @@ "author": "aryanbhosale", "github_id": "aryanbhosale", "name": "SP4096 + Depth Recurrence + MuonEq-R + Full GPTQ", - "blurb": "4096-vocab (sp4096) + MLP 4x + WD 0.090 + depth recurrence (layers 4,5) + MuonEq-R + full GPTQ int6 + brotli + selective pruning. 3-seed mean: 1.09398 BPB, beating merged SOTA (PR #1019, 1.11474 BPB) by 0.02076 BPB (Welch t=-61.9).", + "blurb": "4096-vocab (sp4096) + MLP 4x + WD 0.090 + depth recurrence (layers 4,5) + MuonEq-R + full GPTQ int6 + brotli + selective pruning. 3-seed mean: 1.09265 BPB, beating merged SOTA (PR #1019, 1.11474 BPB) by 0.02209 BPB (Welch t=-33.5).", "date": "2026-04-03", "track": "10min_16mb", - "val_bpb": 1.09397576, - "val_bpb_std": 0.00046067, + "val_bpb": 1.09264648, + "val_bpb_std": 0.00093453, "seeds": [42, 314, 999], "seed_results": { "42": { - "val_loss": 2.51787602, - "val_bpb": 1.09423991, - "artifact_bytes": 15960147, + "val_bpb": 1.09351750, + "artifact_bytes": 15999165, "steps": 5415 }, "314": { - "val_loss": 2.51604422, - "val_bpb": 1.09344383, - "artifact_bytes": 15963424, + "val_bpb": 1.09165930, + "artifact_bytes": 15963773, "steps": 5415 }, "999": { - "val_loss": 2.51788435, - "val_bpb": 1.09424353, - "artifact_bytes": 15958655, + "val_bpb": 1.09276264, + "artifact_bytes": 15977496, "steps": 5420 } }, "comparison_baseline_pr": 1019, - "delta_vs_pr1019_bpb": -0.02075933, - "t_statistic": -61.8978, - "artifact_bytes_max": 15963424, + "delta_vs_pr1019_bpb": -0.02208861, + "artifact_bytes_max": 15999165, "hardware": "8xH100 80GB SXM", "pytorch_version": "2.9.1+cu128", - "technique_summary": "SP4096 + MLP 4x + WD 0.090 + Depth Recurrence (layers 4,5) + MuonEq-R + Full GPTQ int6 + Brotli + Selective Pruning" + "technique_summary": "SP4096 + MLP 4x + WD 0.090 + Depth Recurrence (layers 4,5) + MuonEq-R + Full GPTQ int6 + Brotli + Compressed Wrapper" } diff --git a/records/track_10min_16mb/2026-04-03_SP4096_DepthRecurrence_MuonEqR_GPTQ/train_gpt.py b/records/track_10min_16mb/2026-04-03_SP4096_DepthRecurrence_MuonEqR_GPTQ/train_gpt.py index 8bc4ed613a..ba30f345c0 100644 --- a/records/track_10min_16mb/2026-04-03_SP4096_DepthRecurrence_MuonEqR_GPTQ/train_gpt.py +++ b/records/track_10min_16mb/2026-04-03_SP4096_DepthRecurrence_MuonEqR_GPTQ/train_gpt.py @@ -1,1911 +1,3 @@ -import copy -import glob -import io -import lzma -import math -import os -from pathlib import Path -import random -import subprocess -import sys -import time -import uuid - -import numpy as np -import sentencepiece as spm -import torch -import torch.distributed as dist -import torch.nn.functional as F -from torch.nn.parallel import DistributedDataParallel as DDP -from torch import Tensor, nn - -from flash_attn_interface import flash_attn_func as flash_attn_3_func - -try: - import brotli - _HAS_BROTLI = True -except ImportError: - _HAS_BROTLI = False - -# ---------------------------------------- -# Hyperparameters -# ---------------------------------------- - -class Hyperparameters(): - # Experiment settings - data_dir = os.environ.get('DATA_DIR', './data/') - seed = int(os.environ.get('SEED', 1337)) - run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) - - # Training length - iterations = int(os.environ.get('ITERATIONS', 20000)) - warmdown_frac = float(os.environ.get('WARMDOWN_FRAC', 0.667)) - warmup_steps = int(os.environ.get('WARMUP_STEPS', 20)) - train_batch_tokens = int(os.environ.get('TRAIN_BATCH_TOKENS', 2048 * 48 * 8)) - train_seq_len = int(os.environ.get('TRAIN_SEQ_LEN', 2048)) - eval_seq_len = int(os.environ.get('EVAL_SEQ_LEN', 2048)) - max_wallclock_seconds = float(os.environ.get('MAX_WALLCLOCK_SECONDS', 600.0)) - train_log_every = int(os.environ.get('TRAIN_LOG_EVERY', 500)) - - # Validation/Evals - val_batch_tokens = int(os.environ.get('VAL_BATCH_TOKENS', 2048 * 32 * 8)) - val_loss_every = int(os.environ.get('VAL_LOSS_EVERY', 4000)) - sliding_window_enabled = bool(int(os.environ.get('SLIDING_WINDOW_ENABLED', '1'))) - - # Model architecture - vocab_size = int(os.environ.get('VOCAB_SIZE', 4096)) - num_layers = int(os.environ.get('NUM_LAYERS', 11)) - xsa_last_n = int(os.environ.get('XSA_LAST_N', 11)) - num_kv_heads = int(os.environ.get('NUM_KV_HEADS', 4)) - model_dim = int(os.environ.get('MODEL_DIM', 512)) - embedding_dim = int(os.environ.get('EMBEDDING_DIM', 512)) - num_heads = int(os.environ.get('NUM_HEADS', 8)) - mlp_mult = float(os.environ.get('MLP_MULT', 4.0)) - skip_gates_enabled = bool(int(os.environ.get('SKIP_GATES_ENABLED', '1'))) - tie_embeddings = bool(int(os.environ.get('TIE_EMBEDDINGS', '1'))) - logit_softcap = float(os.environ.get('LOGIT_SOFTCAP', 30.0)) - rope_base = float(os.environ.get('ROPE_BASE', 10000.0)) - rope_dims = int(os.environ.get('ROPE_DIMS', 16)) - rope_train_seq_len = int(os.environ.get('ROPE_TRAIN_SEQ_LEN', 2048)) - ln_scale = bool(int(os.environ.get('LN_SCALE', '1'))) - ve_enabled = bool(int(os.environ.get('VE_ENABLED', '1'))) - ve_dim = int(os.environ.get('VE_DIM', 128)) - ve_layers = os.environ.get('VE_LAYERS', '9,10') - qk_gain_init = float(os.environ.get('QK_GAIN_INIT', 4.0)) - - # Optimizer (Modification 3: weight decay 0.090) - min_lr = float(os.environ.get('MIN_LR', 0.0)) - embed_lr = float(os.environ.get('EMBED_LR', 0.6)) - head_lr = float(os.environ.get('HEAD_LR', 0.008)) - tied_embed_lr = float(os.environ.get('TIED_EMBED_LR', 0.03)) - tied_embed_init_std = float(os.environ.get('TIED_EMBED_INIT_STD', 0.005)) - matrix_lr = float(os.environ.get('MATRIX_LR', 0.02)) - scalar_lr = float(os.environ.get('SCALAR_LR', 0.02)) - muon_momentum = float(os.environ.get('MUON_MOMENTUM', 0.99)) - muon_backend_steps = int(os.environ.get('MUON_BACKEND_STEPS', 5)) - muon_momentum_warmup_start = float(os.environ.get('MUON_MOMENTUM_WARMUP_START', 0.92)) - muon_momentum_warmup_steps = int(os.environ.get('MUON_MOMENTUM_WARMUP_STEPS', 1500)) - beta1 = float(os.environ.get('BETA1', 0.9)) - beta2 = float(os.environ.get('BETA2', 0.95)) - adam_eps = float(os.environ.get('ADAM_EPS', 1e-8)) - grad_clip_norm = float(os.environ.get('GRAD_CLIP_NORM', 0.3)) - eval_stride = int(os.environ.get('EVAL_STRIDE', 64)) - muon_beta2 = float(os.environ.get('MUON_BETA2', 0.95)) - adam_wd = float(os.environ.get('ADAM_WD', 0.02)) - muon_wd = float(os.environ.get('MUON_WD', 0.090)) - embed_wd = float(os.environ.get('EMBED_WD', 0.090)) - ema_decay = float(os.environ.get('EMA_DECAY', 0.997)) - - # Depth Recurrence (Modification 2) - recur_layers = os.environ.get("RECUR_LAYERS", "4,5") - recur_start_step = int(os.environ.get("RECUR_START_STEP", 3000)) - - # TTT (Modification 4) - ttt_enabled = bool(int(os.environ.get("TTT_ENABLED", "0"))) - ttt_lr = float(os.environ.get("TTT_LR", 0.002)) - ttt_epochs = int(os.environ.get("TTT_EPOCHS", 3)) - ttt_chunk_tokens = int(os.environ.get("TTT_CHUNK_TOKENS", 32768)) - ttt_freeze_blocks = int(os.environ.get("TTT_FREEZE_BLOCKS", 0)) - ttt_momentum = float(os.environ.get("TTT_MOMENTUM", 0.9)) - ttt_batch_seqs = int(os.environ.get("TTT_BATCH_SEQS", 32)) - ttt_grad_clip = float(os.environ.get("TTT_GRAD_CLIP", 1.0)) - - # Compression - compressor = os.environ.get('COMPRESSOR', 'brotli') #(lzma or brotli) - gptq_enabled = bool(int(os.environ.get('GPTQ_ENABLED', '1'))) - gptq_calibration_batches = int(os.environ.get('GPTQ_CALIBRATION_BATCHES', 64)) - gptq_reserve_seconds = float(os.environ.get('GPTQ_RESERVE_SECONDS', 10.0)) - - # Distributed setup - distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ - rank = int(os.environ.get("RANK", "0")) - world_size = int(os.environ.get("WORLD_SIZE", "1")) - local_rank = int(os.environ.get("LOCAL_RANK", "0")) - is_main_process = rank == 0 - grad_accum_steps = 8 // world_size - - # Data paths - datasets_dir = os.path.join(data_dir, 'datasets', f'fineweb10B_sp{vocab_size}') - train_files = os.path.join(datasets_dir, 'fineweb_train_*.bin') - val_files = os.path.join(datasets_dir, 'fineweb_val_*.bin') - tokenizer_path = os.path.join(data_dir, 'tokenizers', f'fineweb_{vocab_size}_bpe.model') - - # Experiment files - logfile = f"logs/{run_id}.txt" - model_path = "final_model.pt" - quantized_model_path = "final_model.int6.ptz" - -# ---------------------------------------- -# Global Logging Function -# ---------------------------------------- - -_logger_hparams = None - - -def set_logging_hparams(h: Hyperparameters) -> None: - global _logger_hparams - _logger_hparams = h - - -def log(msg, console: bool = True) -> None: - if _logger_hparams is None: - print(msg) - if _logger_hparams.is_main_process: - if console: - print(msg) - if _logger_hparams.logfile is not None: - with open(_logger_hparams.logfile, "a", encoding="utf-8") as f: - print(msg, file=f) - -# ---------------------------------------- -# Data Loading -# ---------------------------------------- - -class ValidationData: - def __init__(self, h: Hyperparameters, device: torch.device): - if not h.tokenizer_path.endswith(".model"): - raise ValueError(f"Script only setup for SentencePiece .model file: {h.tokenizer_path}") - self.sp = spm.SentencePieceProcessor(model_file=h.tokenizer_path) - if int(self.sp.vocab_size()) != h.vocab_size: - raise ValueError( - f"VOCAB_SIZE={h.vocab_size} does not match tokenizer vocab_size={int(self.sp.vocab_size())}" - ) - - self.val_tokens = load_validation_tokens(h.val_files, h.eval_seq_len) - self.base_bytes_lut, self.has_leading_space_lut, self.is_boundary_token_lut = ( - build_sentencepiece_luts(self.sp, h.vocab_size, device)) - - -def build_sentencepiece_luts( - sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device -) -> tuple[Tensor, Tensor, Tensor]: - sp_vocab_size = int(sp.vocab_size()) - # The BPB calculation assumes "▁" is its own token so that leading-space bytes - # are counted correctly. See https://github.com/openai/parameter-golf/issues/897 - assert sp.piece_to_id("\u2581") != sp.unk_id(), \ - "Tokenizer must have '▁' (space) as its own token for correct BPB byte counting" - table_size = max(sp_vocab_size, vocab_size) - base_bytes_np = np.zeros((table_size,), dtype=np.int16) - has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) - is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) - for token_id in range(sp_vocab_size): - if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): - continue - is_boundary_token_np[token_id] = False - if sp.is_byte(token_id): - base_bytes_np[token_id] = 1 - continue - piece = sp.id_to_piece(token_id) - if piece.startswith("\u2581"): - has_leading_space_np[token_id] = True - piece = piece[1:] - base_bytes_np[token_id] = len(piece.encode("utf-8")) - return ( - torch.tensor(base_bytes_np, dtype=torch.int16, device=device), - torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), - torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), - ) - - -def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: - files = [Path(p) for p in sorted(glob.glob(pattern))] - if not files: - raise FileNotFoundError(f"No files found for pattern: {pattern}") - # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. - tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() - usable = ((tokens.numel() - 1) // seq_len) * seq_len - if usable <= 0: - raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") - return tokens[: usable + 1] - - -def load_data_shard(file: Path) -> Tensor: - header_bytes = 256 * np.dtype(" int: - key = str(file) - cached = _SHARD_NTOKENS_CACHE.get(key) - if cached is not None: - return cached - header = np.fromfile(file, dtype=" np.memmap: - key = str(file) - mm = _MMAP_CACHE.get(key) - if mm is not None: - return mm - n = _read_num_tokens(file) - mm = np.memmap(file, mode="r", dtype=" int: - if n <= 1: - return 1 - while True: - s = int(self._rng.integers(1, n)) - if math.gcd(s, n) == 1: - return s - - def _reset_cursor(self, si: int, seq_len: int) -> None: - nt = int(self._num_tokens[si]) - max_phase = min(seq_len - 1, max(0, nt - seq_len - 1)) - phase = int(self._rng.integers(max_phase + 1)) if max_phase > 0 else 0 - bc = (nt - 1 - phase) // seq_len - self._cursor_phase[si] = phase - self._cursor_block_count[si] = bc - self._cursor_next[si] = 0 - self._cursor_start[si] = int(self._rng.integers(bc)) if bc > 1 else 0 - self._cursor_stride[si] = self._pick_coprime_stride(bc) - self._cursor_init[si] = True - - def _ensure_cursor(self, si: int, seq_len: int) -> None: - if not self._cursor_init[si] or self._cursor_next[si] >= self._cursor_block_count[si]: - self._reset_cursor(si, seq_len) - - def _take_from_shard(self, si: int, seq_len: int, count: int, out: list[tuple[int, int]]) -> None: - rem = count - while rem > 0: - self._ensure_cursor(si, seq_len) - bc = int(self._cursor_block_count[si]) - ni = int(self._cursor_next[si]) - take = min(rem, bc - ni) - phase = int(self._cursor_phase[si]) - start = int(self._cursor_start[si]) - stride = int(self._cursor_stride[si]) - for j in range(take): - bi = (start + (ni + j) * stride) % bc - out.append((si, phase + bi * seq_len)) - self._cursor_next[si] = ni + take - rem -= take - - def _init_pipeline(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> None: - local_tokens = global_tokens // (self.world_size * grad_accum_steps) - num_seqs = local_tokens // seq_len - global_num_seqs = num_seqs * self.world_size - self._cfg = (local_tokens, seq_len, num_seqs, global_num_seqs) - bbc = (self._num_tokens - 1) // seq_len - eligible = bbc > 0 - self._eligible_shards = np.nonzero(eligible)[0].astype(np.int64) - self._base_block_counts = bbc[self._eligible_shards].astype(np.int64) - - def _sample_global_windows(self) -> list[tuple[int, int]]: - assert self._cfg is not None and self._eligible_shards is not None - _, seq_len, _, gns = self._cfg - ec = int(self._eligible_shards.size) - progress = min(self._batches_built / 1800.0, 1.0) - remaining = np.empty(ec, dtype=np.float64) - for i, si in enumerate(self._eligible_shards.tolist()): - if self._cursor_init[si]: - r = int(self._cursor_block_count[si]) - int(self._cursor_next[si]) - remaining[i] = float(max(r, 1)) - else: - remaining[i] = float(self._base_block_counts[i]) - alpha = 0.90 - 0.40 * progress - weights = np.power(remaining, alpha) - ws = float(weights.sum()) - if not np.isfinite(ws) or ws <= 0.0: - weights = np.ones(ec, dtype=np.float64) - ws = float(weights.sum()) - probs = weights / ws - low = min(max(8, self.world_size), ec, gns) - high = min(max(32, self.world_size * 8), ec, gns) - mix = max(1, min(int(round(low + progress * (high - low))), ec, gns)) - cp = self._rng.choice(ec, size=mix, replace=False, p=probs) - cs = self._eligible_shards[cp] - cpr = probs[cp].copy() - cpr /= cpr.sum() - counts = np.ones(mix, dtype=np.int64) - extra = gns - mix - if extra > 0: - counts += self._rng.multinomial(extra, cpr).astype(np.int64) - perm = self._rng.permutation(mix) - cs, counts = cs[perm], counts[perm] - buckets: list[list[tuple[int, int]]] = [] - for si, cnt in zip(cs.tolist(), counts.tolist()): - b: list[tuple[int, int]] = [] - self._take_from_shard(int(si), seq_len, int(cnt), b) - if b: - if len(b) > 1: - bp = self._rng.permutation(len(b)) - b = [b[int(k)] for k in bp.tolist()] - buckets.append(b) - windows: list[tuple[int, int]] = [] - active = [i for i, bk in enumerate(buckets) if bk] - while active: - order = self._rng.permutation(len(active)) - new_active: list[int] = [] - for oi in order.tolist(): - bi = active[oi] - if buckets[bi]: - windows.append(buckets[bi].pop()) - if buckets[bi]: - new_active.append(bi) - active = new_active - return windows - - def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: - if self._cfg is None: - self._init_pipeline(global_tokens, seq_len, grad_accum_steps) - _, _, num_seqs, _ = self._cfg - gw = self._sample_global_windows() - local_w = gw[self.rank::self.world_size] - x = torch.empty((num_seqs, seq_len), dtype=torch.int64) - y = torch.empty((num_seqs, seq_len), dtype=torch.int64) - for slot, (si, pos) in enumerate(local_w): - mm = _get_shard_memmap(self.files[si]) - window = torch.as_tensor(np.array(mm[pos:pos + seq_len + 1], dtype=np.int64)) - x[slot] = window[:-1] - y[slot] = window[1:] - self._batches_built += 1 - return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) - -# ---------------------------------------- -# Model Architecture -# ---------------------------------------- - -class RMSNorm(nn.Module): - def __init__(self, eps: float | None = None): - super().__init__() - self.eps = eps - - def forward(self, x: Tensor) -> Tensor: - return F.rms_norm(x, (x.size(-1),), eps=self.eps) - - -class CastedLinear(nn.Linear): - def forward(self, x: Tensor) -> Tensor: - w = self.weight.to(x.dtype) - bias = self.bias.to(x.dtype) if self.bias is not None else None - return F.linear(x, w, bias) - - -class Rotary(nn.Module): - def __init__(self, dim: int, base: float = 10000.0, train_seq_len: int = 1024, rope_dims: int = 0): - super().__init__() - self.dim = dim - self.base = base - self.train_seq_len = train_seq_len - self.rope_dims = rope_dims if rope_dims > 0 else dim - inv_freq = 1.0 / (base ** (torch.arange(0, self.rope_dims, 2, dtype=torch.float32) / self.rope_dims)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self._seq_len_cached = 0 - self._cos_cached: Tensor | None = None - self._sin_cached: Tensor | None = None - - def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: - if ( - self._cos_cached is None - or self._sin_cached is None - or self._seq_len_cached != seq_len - or self._cos_cached.device != device - ): - rd = self.rope_dims - if seq_len > self.train_seq_len: - scale = seq_len / self.train_seq_len - new_base = self.base * (scale ** (rd / (rd - 2))) - inv_freq = 1.0 / (new_base ** (torch.arange(0, rd, 2, dtype=torch.float32, device=device) / rd)) - else: - inv_freq = self.inv_freq.to(device) - t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) - freqs = torch.outer(t, inv_freq) - self._cos_cached = freqs.cos()[None, :, None, :] - self._sin_cached = freqs.sin()[None, :, None, :] - self._seq_len_cached = seq_len - return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) - - -def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor, rope_dims: int = 0) -> Tensor: - if rope_dims > 0 and rope_dims < x.size(-1): - x_rope, x_pass = x[..., :rope_dims], x[..., rope_dims:] - half = rope_dims // 2 - x1, x2 = x_rope[..., :half], x_rope[..., half:] - x_rope = torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) - return torch.cat((x_rope, x_pass), dim=-1) - half = x.size(-1) // 2 - x1, x2 = x[..., :half], x[..., half:] - return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) - - -class CausalSelfAttention(nn.Module): - def __init__(self, dim: int, num_heads: int, num_kv_heads: int, - rope_base: float, qk_gain_init: float, train_seq_len: int): - super().__init__() - if dim % num_heads != 0: - raise ValueError("model_dim must be divisible by num_heads") - if num_heads % num_kv_heads != 0: - raise ValueError("num_heads must be divisible by num_kv_heads") - self.num_heads = num_heads - self.num_kv_heads = num_kv_heads - self.head_dim = dim // num_heads - if self.head_dim % 2 != 0: - raise ValueError("head_dim must be even for RoPE") - kv_dim = self.num_kv_heads * self.head_dim - self.c_q = CastedLinear(dim, dim, bias=False) - self.c_k = CastedLinear(dim, kv_dim, bias=False) - self.c_v = CastedLinear(dim, kv_dim, bias=False) - self.proj = CastedLinear(dim, dim, bias=False) - self.proj._zero_init = True - self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) - self.rope_dims = 0 - self.rotary = Rotary(self.head_dim, base=rope_base, train_seq_len=train_seq_len) - self.use_xsa = False - - def _xsa_efficient(self, y: Tensor, v: Tensor) -> Tensor: - B, T, H, D = y.shape - Hkv = v.size(-2) - group = H // Hkv - y_g = y.reshape(B, T, Hkv, group, D) - vn = F.normalize(v, dim=-1).unsqueeze(-2) - proj = (y_g * vn).sum(dim=-1, keepdim=True) * vn - return (y_g - proj).reshape(B, T, H, D) - - def forward(self, x: Tensor, v_embed: Tensor | None = None) -> Tensor: - bsz, seqlen, dim = x.shape - q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim) - k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) - v = self.c_v(x) - if v_embed is not None: - v = v + v_embed - v = v.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) - q = F.rms_norm(q, (q.size(-1),)) - k = F.rms_norm(k, (k.size(-1),)) - cos, sin = self.rotary(seqlen, x.device, q.dtype) - q = apply_rotary_emb(q, cos, sin, self.rope_dims) - k = apply_rotary_emb(k, cos, sin, self.rope_dims) - q = q * self.q_gain.to(dtype=q.dtype)[None, None, :, None] - y = flash_attn_3_func(q, k, v, causal=True) - if self.use_xsa: - y = self._xsa_efficient(y, v) - y = y.reshape(bsz, seqlen, dim) - return self.proj(y) - - -class ValueEmbedding(nn.Module): - def __init__(self, vocab_size: int, ve_dim: int, model_dim: int): - super().__init__() - self.embed = nn.Embedding(vocab_size, ve_dim) - nn.init.normal_(self.embed.weight, std=0.01) - self.proj = CastedLinear(ve_dim, model_dim, bias=False) if ve_dim != model_dim else None - if self.proj is not None: - nn.init.zeros_(self.proj.weight) - self.scale = nn.Parameter(torch.tensor(0.1, dtype=torch.float32)) - - def forward(self, token_ids: Tensor) -> Tensor: - h = self.embed(token_ids) - if self.proj is not None: - h = self.proj(h) - return h * self.scale.to(dtype=h.dtype) - - -class MLP(nn.Module): - def __init__(self, dim: int, mlp_mult: int): - super().__init__() - hidden = int(mlp_mult * dim) - self.fc = CastedLinear(dim, hidden, bias=False) - self.proj = CastedLinear(hidden, dim, bias=False) - self.proj._zero_init = True - - def forward(self, x: Tensor) -> Tensor: - return self.proj(F.leaky_relu(self.fc(x), negative_slope=0.5).square()) - - -class Block(nn.Module): - def __init__(self, dim: int, num_heads: int, num_kv_heads: int, mlp_mult: int, - rope_base: float, qk_gain_init: float, train_seq_len: int, - layer_idx: int = 0, ln_scale: bool = False): - super().__init__() - self.attn_norm = RMSNorm() - self.mlp_norm = RMSNorm() - self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init, train_seq_len) - self.mlp = MLP(dim, mlp_mult) - self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) - self.ln_scale_factor = 1.0 / math.sqrt(layer_idx + 1) if ln_scale else 1.0 - - def forward(self, x: Tensor, x0: Tensor, v_embed: Tensor | None = None) -> Tensor: - mix = self.resid_mix.to(dtype=x.dtype) - x_in = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 - attn_out = self.attn(self.attn_norm(x_in) * self.ln_scale_factor, v_embed=v_embed) - x_out = x_in + self.attn_scale.to(dtype=x_in.dtype)[None, None, :] * attn_out - x_out = x_out + self.mlp_scale.to(dtype=x_out.dtype)[None, None, :] * self.mlp(self.mlp_norm(x_out) * self.ln_scale_factor) - return x_out - - -class GPT(nn.Module): - def __init__(self, h: Hyperparameters): - super().__init__() - self._ve_target_dim = h.num_kv_heads * (h.model_dim // h.num_heads) - if h.logit_softcap <= 0.0: - raise ValueError(f"logit_softcap must be positive, got {h.logit_softcap}") - self.tie_embeddings = h.tie_embeddings - self.tied_embed_init_std = h.tied_embed_init_std - self.logit_softcap = h.logit_softcap - self.tok_emb = nn.Embedding(h.vocab_size, h.embedding_dim) - if h.embedding_dim != h.model_dim: - self.embed_proj = CastedLinear(h.embedding_dim, h.model_dim, bias=False) - self.head_proj = CastedLinear(h.model_dim, h.embedding_dim, bias=False) - else: - self.embed_proj = None - self.head_proj = None - self.num_encoder_layers = h.num_layers // 2 - self.num_decoder_layers = h.num_layers - self.num_encoder_layers - self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) - self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, h.model_dim, dtype=torch.float32)) - self.skip_gates = nn.Parameter(torch.zeros(self.num_skip_weights, h.model_dim, dtype=torch.float32)) if h.skip_gates_enabled else None - self.blocks = nn.ModuleList([ - Block(h.model_dim, h.num_heads, h.num_kv_heads, h.mlp_mult, h.rope_base, - h.qk_gain_init, h.train_seq_len, layer_idx=i, ln_scale=h.ln_scale) - for i in range(h.num_layers) - ]) - if h.rope_dims > 0: - head_dim = h.model_dim // h.num_heads - for block in self.blocks: - block.attn.rope_dims = h.rope_dims - block.attn.rotary = Rotary(head_dim, base=h.rope_base, train_seq_len=h.train_seq_len, rope_dims=h.rope_dims) - self.ve_layer_indices = [int(x) for x in h.ve_layers.split(",") if x.strip()] if h.ve_enabled else [] - kv_dim = self._ve_target_dim - if self.ve_layer_indices: - self.ve_shared = ValueEmbedding(h.vocab_size, h.ve_dim, kv_dim) - self.ve_layer_scales = nn.ParameterList( - [nn.Parameter(torch.ones(1, dtype=torch.float32)) for _ in self.ve_layer_indices] - ) - else: - self.ve_shared = None - self.ve_layer_scales = nn.ParameterList() - self.value_embeds = nn.ModuleList() - self.final_norm = RMSNorm() - self.lm_head = None if h.tie_embeddings else CastedLinear(h.embedding_dim, h.vocab_size, bias=False) - if self.lm_head is not None: - self.lm_head._zero_init = True - if h.xsa_last_n > 0: - for i in range(max(0, h.num_layers - h.xsa_last_n), h.num_layers): - self.blocks[i].attn.use_xsa = True - - # Modification 2: Depth Recurrence - self.recur_layers = [int(x) for x in h.recur_layers.split(",") if x.strip()] - self._recurrence_active = False - - self._init_weights() - - def set_recurrence_active(self, active: bool) -> None: - self._recurrence_active = active - - def _get_virtual_layers(self) -> list[int]: - """Return virtual->physical block mapping. - When recurrence is active, the recur_layers are repeated once, - e.g. with num_layers=11 and recur_layers=[4,5]: - [0,1,2,3, 4,5, 4,5, 6,7,8,9,10] - When inactive: [0,1,2,...,num_layers-1] - """ - n = len(self.blocks) - if not self._recurrence_active or not self.recur_layers: - return list(range(n)) - virtual = [] - inserted = False - for i in range(n): - virtual.append(i) - if not inserted and i == self.recur_layers[-1]: - # repeat the recur_layers - for rl in self.recur_layers: - virtual.append(rl) - inserted = True - return virtual - - def _init_weights(self) -> None: - if self.tie_embeddings: - nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) - for name, module in self.named_modules(): - if isinstance(module, nn.Linear): - if getattr(module, "_zero_init", False): - nn.init.zeros_(module.weight) - elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: - nn.init.orthogonal_(module.weight, gain=1.0) - - def _get_ve(self, layer_idx: int, input_ids: Tensor, ve_cache: dict | None = None) -> Tensor | None: - if self.ve_shared is None or layer_idx not in self.ve_layer_indices: - return None - if ve_cache is not None and 've' not in ve_cache: - ve_cache['ve'] = self.ve_shared(input_ids) - ve_base = ve_cache['ve'] if ve_cache is not None else self.ve_shared(input_ids) - ve_idx = self.ve_layer_indices.index(layer_idx) - return ve_base * self.ve_layer_scales[ve_idx].to(dtype=ve_base.dtype) - - def forward_logits(self, input_ids: Tensor) -> Tensor: - x = self.tok_emb(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - if self.embed_proj is not None: - x = self.embed_proj(x) - x0 = x - - virtual_layers = self._get_virtual_layers() - num_virtual = len(virtual_layers) - num_enc = num_virtual // 2 - num_dec = num_virtual - num_enc - - skips: list[Tensor] = [] - ve_cache: dict = {} - - # Encoder phase - for vi in range(num_enc): - phys_idx = virtual_layers[vi] - ve = self._get_ve(phys_idx, input_ids, ve_cache) - x = self.blocks[phys_idx](x, x0, v_embed=ve) - skips.append(x) - - # Decoder phase with U-Net skip connections - for vi in range(num_dec): - phys_idx = virtual_layers[num_enc + vi] - if skips and vi < self.num_skip_weights: - scaled_skip = self.skip_weights[vi].to(dtype=x.dtype)[None, None, :] * skips.pop() - if self.skip_gates is not None: - g = torch.sigmoid(self.skip_gates[vi].to(dtype=x.dtype))[None, None, :] - x = torch.lerp(scaled_skip, x, g) - else: - x = x + scaled_skip - ve = self._get_ve(phys_idx, input_ids, ve_cache) - x = self.blocks[phys_idx](x, x0, v_embed=ve) - - x = self.final_norm(x) - if self.head_proj is not None: - x = self.head_proj(x) - if self.tie_embeddings: - logits_proj = F.linear(x, self.tok_emb.weight) - else: - logits_proj = self.lm_head(x) - return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - - def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: - logits = self.forward_logits(input_ids) - return F.cross_entropy( - logits.reshape(-1, logits.size(-1)).float(), target_ids.reshape(-1), reduction="mean") - - -def classify_param(name: str) -> str: - if "tok_emb" in name or "lm_head" in name: - return "embed" - if ".mlp." in name: - return "mlp" - if ".attn." in name or (".proj." in name and ".mlp." not in name): - return "attn" - return "other" - -# ---------------------------------------- -# Optimization -# ---------------------------------------- - -@torch.compile -def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: - a, b, c = (3.4445, -4.7750, 2.0315) - X = G.bfloat16() - X /= X.norm() + eps - transposed = G.size(0) > G.size(1) - if transposed: - X = X.T - for _ in range(steps): - A = X @ X.T - B = b * A + c * A @ A - X = a * X + B @ X - return X.T if transposed else X - - -class Muon(torch.optim.Optimizer): - def __init__(self, params, lr: float, momentum: float, backend_steps: int, - nesterov: bool = True, weight_decay: float = 0.0): - super().__init__( - params, - dict(lr=lr, momentum=momentum, backend_steps=backend_steps, - nesterov=nesterov, weight_decay=weight_decay), - ) - - @torch.no_grad() - def step(self, closure=None): - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - distributed = dist.is_available() and dist.is_initialized() - world_size = dist.get_world_size() if distributed else 1 - rank = dist.get_rank() if distributed else 0 - for group in self.param_groups: - params = group["params"] - if not params: - continue - lr = group["lr"] - momentum = group["momentum"] - backend_steps = group["backend_steps"] - nesterov = group["nesterov"] - total_params = sum(int(p.numel()) for p in params) - updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) - curr = 0 - for i, p in enumerate(params): - if i % world_size == rank and p.grad is not None: - g = p.grad - state = self.state[p] - if "momentum_buffer" not in state: - state["momentum_buffer"] = torch.zeros_like(g) - buf = state["momentum_buffer"] - buf.mul_(momentum).add_(g) - if nesterov: - g = g.add(buf, alpha=momentum) - # Modification 1: MuonEq-R row normalization before NS5 - update = g - row_norms = update.float().norm(dim=-1, keepdim=True).clamp_min(1e-7) - update = update / row_norms.to(update.dtype) - g = zeropower_via_newtonschulz5(update, steps=backend_steps) - g *= max(1, g.size(0) / g.size(1)) ** 0.5 - updates_flat[curr : curr + p.numel()] = g.reshape(-1) - curr += p.numel() - if distributed: - dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) - wd = group.get("weight_decay", 0.0) - curr = 0 - for p in params: - if wd > 0.0: - p.data.mul_(1.0 - lr * wd) - g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) - p.add_(g, alpha=-lr) - curr += p.numel() - return loss - - -class Optimizers(): - def __init__(self, h: Hyperparameters, base_model: GPT): - block_named_params = list(base_model.blocks.named_parameters()) - matrix_params = [ - p - for name, p in block_named_params - if p.ndim == 2 and not any(pattern in name for pattern in - CONTROL_TENSOR_NAME_PATTERNS) - ] - scalar_params = [ - p - for name, p in block_named_params - if p.ndim < 2 or any(pattern in name for pattern in - CONTROL_TENSOR_NAME_PATTERNS) - ] - if base_model.skip_weights.numel() > 0: - scalar_params.append(base_model.skip_weights) - if base_model.skip_gates is not None and base_model.skip_gates.numel() > 0: - scalar_params.append(base_model.skip_gates) - - token_lr = h.tied_embed_lr if h.tie_embeddings else h.embed_lr - tok_params = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] - if base_model.ve_shared is not None: - tok_params.append({"params": [base_model.ve_shared.embed.weight], "lr": token_lr, "base_lr": token_lr}) - if base_model.ve_shared.proj is not None: - matrix_params.append(base_model.ve_shared.proj.weight) - scalar_params.append(base_model.ve_shared.scale) - for s in base_model.ve_layer_scales: - scalar_params.append(s) - - self.optimizer_tok = torch.optim.AdamW( - tok_params, - betas=(h.beta1, h.beta2), - eps=h.adam_eps, - weight_decay=h.embed_wd, - fused=True, - ) - self.optimizer_muon = Muon( - matrix_params, - lr=h.matrix_lr, - momentum=h.muon_momentum, - backend_steps=h.muon_backend_steps, - weight_decay=h.muon_wd, - ) - for group in self.optimizer_muon.param_groups: - group["base_lr"] = h.matrix_lr - self.optimizer_scalar = torch.optim.AdamW( - [{"params": scalar_params, "lr": h.scalar_lr, "base_lr": h.scalar_lr}], - betas=(h.beta1, h.beta2), - eps=h.adam_eps, - weight_decay=h.adam_wd, - fused=True, - ) - self.optimizers: list[torch.optim.Optimizer] = [self.optimizer_tok, self.optimizer_muon, self.optimizer_scalar] - if base_model.lm_head is not None: - self.optimizer_head = torch.optim.Adam( - [{"params": [base_model.lm_head.weight], "lr": h.head_lr, "base_lr": h.head_lr}], - betas=(h.beta1, h.beta2), - eps=h.adam_eps, - fused=True, - ) - self.optimizers.insert(1, self.optimizer_head) - else: - self.optimizer_head = None - - def __iter__(self): - return iter(self.optimizers) - - def zero_grad_all(self) -> None: - for opt in self.optimizers: - opt.zero_grad(set_to_none=True) - - def step(self): - for opt in self.optimizers: - opt.step() - self.zero_grad_all() - -# ---------------------------------------- -# Quantization -# ---------------------------------------- - -CONTROL_TENSOR_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "CONTROL_TENSOR_NAME_PATTERNS", - "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,skip_gates,ve_layer_scales,ve_shared.scale", - ).split(",") - if pattern -) -INT8_PER_ROW_SCALE_DTYPE = torch.float16 -INT8_CLIP_PERCENTILE = 99.99984 -INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 - - -def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: - t32 = t.float() - if t32.ndim == 2: - clip_abs = ( - torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) - if t32.numel() - else torch.empty((t32.shape[0],), dtype=torch.float32) - ) - clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) - scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) - q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() - return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() - - clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 - scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) - q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() - return q, scale - - -def restore_fp32_params(model: nn.Module) -> None: - """After .bfloat16(), restore CastedLinear weights and control params to FP32.""" - for module in model.modules(): - if isinstance(module, CastedLinear): - module.float() - for name, param in model.named_parameters(): - if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: - param.data = param.data.float() - - -def quantize_int6_per_row(t: Tensor, clip_range: int = 31) -> tuple[Tensor, Tensor]: - t32 = t.float() - if t32.ndim == 2: - best_q, best_s, best_err = None, None, float('inf') - for pct in [0.9990, 0.9995, 0.9999, 0.99999, 1.0]: - if pct < 1.0: - row_clip = torch.quantile(t32.abs(), pct, dim=1) - else: - row_clip = t32.abs().amax(dim=1) - s = (row_clip / clip_range).clamp_min(1.0 / clip_range).to(torch.float16) - q = torch.clamp(torch.round(t32 / s.float()[:, None]), -clip_range, clip_range).to(torch.int8) - recon = q.float() * s.float()[:, None] - err = (t32 - recon).pow(2).mean().item() - if err < best_err: - best_q, best_s, best_err = q, s, err - return best_q, best_s - amax = t32.abs().max().item() - scale = torch.tensor(amax / clip_range if amax > 0 else 1.0, dtype=torch.float16) - q = torch.clamp(torch.round(t32 / scale.float()), -clip_range, clip_range).to(torch.int8) - return q, scale - - -def collect_hessians( - model: nn.Module, - train_loader: DistributedTokenLoader, - h: Hyperparameters, - device: torch.device, - n_calibration_batches: int = 64, -) -> dict[str, Tensor]: - """Run calibration batches and collect H = X^T X for each CastedLinear layer.""" - hessians: dict[str, Tensor] = {} - hooks = [] - - def make_hook(name: str): - def hook_fn(module, inp, out): - x = inp[0].detach().float() - if x.ndim == 3: - x = x.reshape(-1, x.shape[-1]) - if name not in hessians: - hessians[name] = torch.zeros( - x.shape[1], x.shape[1], dtype=torch.float32, device=device - ) - hessians[name].addmm_(x.T, x) - return hook_fn - - for name, module in model.named_modules(): - if isinstance(module, CastedLinear) and module.weight.numel() > 65536: - cat = classify_param(name + ".weight") - if cat in ("mlp", "attn"): - hooks.append(module.register_forward_hook(make_hook(name + ".weight"))) - - model.eval() - with torch.no_grad(): - for _i in range(n_calibration_batches): - x, y = train_loader.next_batch( - h.train_batch_tokens, - h.train_seq_len, h.grad_accum_steps, - ) - model.forward_logits(x) - - for hk in hooks: - hk.remove() - - for name in hessians: - hessians[name] = hessians[name].cpu() / n_calibration_batches - - return hessians - - -def gptq_quantize_weight( - w: Tensor, - H: Tensor, - clip_range: int = 31, - block_size: int = 128, -) -> tuple[Tensor, Tensor]: - """GPTQ with Cholesky error compensation and actorder (Frantar et al., ICLR 2023).""" - W_orig = w.float().clone() - rows, cols = W_orig.shape - H = H.float().clone() - - # Zero out dead columns and add damping - dead = torch.diag(H) == 0 - H[dead, dead] = 1 - damp = 0.01 * H.diag().mean() - H.diagonal().add_(damp) - - # Column reordering by descending Hessian diagonal (actorder) - perm = torch.argsort(H.diag(), descending=True) - invperm = torch.argsort(perm) - W_perm = W_orig[:, perm].clone() - W_perm[:, dead[perm]] = 0 - H = H[perm][:, perm] - - # Upper Cholesky of the inverse - try: - Hinv = torch.cholesky_inverse(torch.linalg.cholesky(H)) - Hinv = torch.linalg.cholesky(Hinv, upper=True) - except torch.linalg.LinAlgError: - return quantize_int6_per_row(W_orig, clip_range) - - # Search over scale candidates, running full GPTQ for each - best_q, best_scale, best_err = None, None, float('inf') - for pct in [0.9990, 0.9995, 0.9999, 0.99999, 1.0]: - if pct < 1.0: - row_clip = torch.quantile(W_orig.abs(), pct, dim=1) - else: - row_clip = W_orig.abs().amax(dim=1) - s = (row_clip / clip_range).clamp_min(1.0 / clip_range).to(torch.float16) - sf = s.float() - - Q = torch.zeros(rows, cols, dtype=torch.int8) - W_work = W_perm.clone() - - for i1 in range(0, cols, block_size): - i2 = min(i1 + block_size, cols) - W_block = W_work[:, i1:i2].clone() - Hinv_block = Hinv[i1:i2, i1:i2] - Err = torch.zeros(rows, i2 - i1) - for j in range(i2 - i1): - w_col = W_block[:, j] - d = Hinv_block[j, j] - q_col = torch.clamp(torch.round(w_col / sf), -clip_range, clip_range) - Q[:, i1 + j] = q_col.to(torch.int8) - err = (w_col - q_col.float() * sf) / d - Err[:, j] = err - W_block[:, j:] -= err.unsqueeze(1) * Hinv_block[j, j:].unsqueeze(0) - if i2 < cols: - W_work[:, i2:] -= Err @ Hinv[i1:i2, i2:] - - recon = Q.float() * sf[:, None] - mse = (W_perm - recon).pow(2).mean().item() - if mse < best_err: - best_q, best_scale, best_err = Q, s, mse - - return best_q[:, invperm], best_scale - - -def gptq_mixed_quantize_int6( - state_dict: dict[str, Tensor], - int6_cats: set[str], - hessians: dict[str, Tensor], -) -> tuple[dict[str, Tensor], dict[str, object]]: - """Mixed quantization using full GPTQ for layers with Hessians, fallback to clip-search.""" - result: dict[str, Tensor] = {} - meta: dict[str, object] = {} - gptq_count = 0 - fallback_count = 0 - - for name, tensor in state_dict.items(): - t = tensor.detach().cpu().contiguous() - cat = classify_param(name) - - if not t.is_floating_point() or t.numel() <= 65536: - result[name] = t.to(torch.float16) if t.is_floating_point() else t - meta[name] = "passthrough" - continue - - if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): - result[name] = t.float() - meta[name] = "passthrough_ctrl" - continue - - if cat in int6_cats and t.ndim == 2: - if name in hessians: - q, s = gptq_quantize_weight(t, hessians[name]) - gptq_count += 1 - meta[name] = {"type": "int6", "method": "gptq"} - else: - q, s = quantize_int6_per_row(t) - fallback_count += 1 - meta[name] = {"type": "int6", "method": "clip_search"} - result[name + ".q"] = q - result[name + ".scale"] = s - elif cat in int6_cats and t.ndim >= 1: - q, s = quantize_int6_per_row(t) - result[name + ".q"] = q - result[name + ".scale"] = s - meta[name] = {"type": "int6"} - else: - q, s = quantize_float_tensor(t) - result[name + ".q"] = q - result[name + ".scale"] = s - meta[name] = {"type": "int8"} - - log(f"GPTQ quantization: {gptq_count} layers with full GPTQ, {fallback_count} fallback to clip-search") - return result, meta - - -def mixed_quantize_int6(state_dict: dict[str, Tensor], int6_cats: set[str]): - result: dict[str, Tensor] = {} - meta: dict[str, object] = {} - for name, tensor in state_dict.items(): - t = tensor.detach().cpu().contiguous() - cat = classify_param(name) - if not t.is_floating_point() or t.numel() <= 65536: - result[name] = t.to(torch.float16) if t.is_floating_point() else t - meta[name] = "passthrough" - continue - if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): - result[name] = t.float() - meta[name] = "passthrough_ctrl" - continue - if cat in int6_cats and t.ndim >= 1: - q, s = quantize_int6_per_row(t) - result[name + ".q"] = q - result[name + ".scale"] = s - meta[name] = {"type": "int6"} - else: - q, s = quantize_float_tensor(t) - result[name + ".q"] = q - result[name + ".scale"] = s - meta[name] = {"type": "int8"} - return result, meta - - -def dequantize_mixed_int6(result: dict[str, Tensor], meta: dict[str, object], - template_sd: dict[str, Tensor]) -> dict[str, Tensor]: - out: dict[str, Tensor] = {} - for name, orig in template_sd.items(): - info = meta.get(name) - if info is None: - continue - orig_dtype = orig.dtype - if info in ("passthrough", "passthrough_ctrl", "passthrough_fp16"): - t = result[name] - if t.dtype == torch.float16 and orig_dtype in (torch.float32, torch.bfloat16): - t = t.to(orig_dtype) - out[name] = t - continue - q, s = result[name + ".q"], result[name + ".scale"] - if s.ndim > 0: - out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig_dtype) - else: - out[name] = (q.float() * float(s.item())).to(orig_dtype) - return out - - -_BSHF_MAGIC = b"BSHF" - - -def _byte_shuffle(data: bytes, stride: int = 2) -> bytes: - """Transpose byte stream by stride position for better compression.""" - if stride <= 1 or len(data) < stride: - return data - src = np.frombuffer(data, dtype=np.uint8) - n = len(src) - out = np.empty(n, dtype=np.uint8) - dest_off = 0 - for pos in range(stride): - chunk = src[pos::stride] - out[dest_off:dest_off + len(chunk)] = chunk - dest_off += len(chunk) - return _BSHF_MAGIC + bytes([stride]) + out.tobytes() - - -def _byte_unshuffle(data: bytes) -> bytes: - """Inverse of _byte_shuffle. Auto-detects BSHF magic header.""" - if len(data) < 5 or data[:4] != _BSHF_MAGIC: - return data - stride = data[4] - if stride < 2: - return data[5:] - payload = np.frombuffer(data, dtype=np.uint8, offset=5) - n = len(payload) - out = np.empty(n, dtype=np.uint8) - src_off = 0 - for pos in range(stride): - chunk_len = n // stride + (1 if pos < n % stride else 0) - out[pos::stride][:chunk_len] = payload[src_off:src_off + chunk_len] - src_off += chunk_len - return out.tobytes() - - -def _compress(data: bytes, compressor: str, byte_shuffle: bool = True) -> bytes: - if byte_shuffle: - data = _byte_shuffle(data) - if compressor == "lzma": - return lzma.compress(data, preset=6) - elif compressor == "brotli": - import brotli as _brotli - return _brotli.compress(data, quality=11) - raise ValueError(f"Unknown compressor: {compressor!r}") - - -def _decompress(data: bytes, compressor: str, byte_shuffle: bool = True) -> bytes: - if compressor == "lzma": - raw = lzma.decompress(data) - elif compressor == "brotli": - import brotli as _brotli - raw = _brotli.decompress(data) - else: - raise ValueError(f"Unknown compressor: {compressor!r}") - if byte_shuffle: - raw = _byte_unshuffle(raw) - return raw - - -def serialize(h: Hyperparameters, base_model: torch.nn.Module, code: str) -> int: - model_bytes = None - code_bytes = len(code.encode("utf-8")) - if h.is_main_process: - torch.save(base_model.state_dict(), h.model_path) - model_bytes = os.path.getsize(h.model_path) - log(f"Serialized model: {model_bytes} bytes") - log(f"Code size: {code_bytes} bytes") - - sd_cpu = {k: v.detach().cpu() for k, v in base_model.state_dict().items()} - if h.gptq_enabled: - log("GPTQ:collecting Hessians from calibration data...") - t0 = time.perf_counter() - calib_loader = DistributedTokenLoader(h.train_files, h.rank, h.world_size, - torch.device("cuda", h.local_rank)) - hessians = collect_hessians( - base_model, calib_loader, h, - torch.device("cuda", h.local_rank), - n_calibration_batches=h.gptq_calibration_batches, - ) - log(f"GPTQ:collected {len(hessians)} Hessians in {time.perf_counter() - t0:.1f}s") - quant_result, quant_meta = gptq_mixed_quantize_int6(sd_cpu, {"mlp", "attn"}, hessians) - else: - quant_result, quant_meta = mixed_quantize_int6(sd_cpu, {"mlp", "attn"}) - - # Fast selective +-1 pruning to fit under target size - target_bytes = 16_000_000 - quant_buf_check = io.BytesIO() - torch.save({"w": quant_result, "m": quant_meta}, quant_buf_check) - check_blob = _compress(quant_buf_check.getvalue(), h.compressor) - unpruned_sz = len(check_blob) + code_bytes - log(f"selective_prune: unpruned={unpruned_sz/1e6:.2f}MB target={target_bytes/1e6:.1f}MB") - if unpruned_sz > target_bytes: - excess = unpruned_sz - target_bytes - safety_margin = int(excess * 8) # prune 8x the excess for safety - ones_info = [] - for name, info in quant_meta.items(): - if not (isinstance(info, dict) and info.get("type") == "int6"): - continue - qk, sk = name + ".q", name + ".scale" - if qk not in quant_result or sk not in quant_result: - continue - q, s = quant_result[qk], quant_result[sk] - if s.ndim > 0: - ones_mask = (q.abs() == 1) - if ones_mask.any(): - row_idx = torch.arange(q.shape[0]).unsqueeze(1).expand_as(q)[ones_mask] - flat_idx = torch.arange(q.numel()).reshape(q.shape)[ones_mask] - errors = s.float()[row_idx].pow(2) - for fi, err in zip(flat_idx.tolist(), errors.tolist()): - ones_info.append((qk, fi, err)) - ones_info.sort(key=lambda x: x[2]) - n_prune = min(safety_margin, len(ones_info)) - log(f"selective_prune: pruning {n_prune}/{len(ones_info)} lowest-error ±1 values (excess={excess}B)") - for i in range(n_prune): - quant_result[ones_info[i][0]].view(-1)[ones_info[i][1]] = 0 - else: - log("selective_prune: already fits, no pruning needed") - - quant_buf = io.BytesIO() - torch.save({"w": quant_result, "m": quant_meta}, quant_buf) - quant_raw = quant_buf.getvalue() - quant_blob = _compress(quant_raw, h.compressor) - quant_file_bytes = len(quant_blob) - bytes_total = quant_file_bytes + code_bytes - if h.is_main_process: - with open(h.quantized_model_path, "wb") as f: - f.write(quant_blob) - log(f"Serialized model int6+{h.compressor}: {quant_file_bytes} bytes") - log(f"Total submission size int6+{h.compressor}: {bytes_total} bytes") - - -def deserialize(h: Hyperparameters, device: torch.device) -> GPT: - eval_model = GPT(h).to(device).bfloat16() - restore_fp32_params(eval_model) - - sd_cpu = {k: v.detach().cpu() for k, v in eval_model.state_dict().items()} - - with open(h.quantized_model_path, "rb") as f: - quant_blob_disk = f.read() - quant_state = torch.load( - io.BytesIO(_decompress(quant_blob_disk, h.compressor)), - map_location="cpu", - ) - deq_state = dequantize_mixed_int6(quant_state["w"], quant_state["m"], sd_cpu) - eval_model.load_state_dict(deq_state, strict=True) - - return eval_model - -# ---------------------------------------- -# Evaluation -# ---------------------------------------- - -def _loss_bpb(loss_sum, token_count, byte_count) -> tuple[float, float]: - val_loss = (loss_sum / token_count).item() - val_bpb = val_loss / math.log(2.0) * (token_count.item() / byte_count.item()) - return val_loss, val_bpb - - -def eval_val( - h: Hyperparameters, - device: torch.device, - val_data: ValidationData, - model: nn.Module -) -> tuple[float, float]: - seq_len = h.eval_seq_len - local_batch_tokens = h.val_batch_tokens // (h.world_size * h.grad_accum_steps) - if local_batch_tokens < seq_len: - raise ValueError( - "VAL_BATCH_SIZE must provide at least one sequence per rank; " - f"got VAL_BATCH_SIZE={h.val_batch_tokens}, WORLD_SIZE={h.world_size}, " - f"GRAD_ACCUM_STEPS={h.grad_accum_steps}, seq_len={seq_len}" - ) - local_batch_seqs = local_batch_tokens // seq_len - total_seqs = (val_data.val_tokens.numel() - 1) // seq_len - seq_start = (total_seqs * h.rank) // h.world_size - seq_end = (total_seqs * (h.rank + 1)) // h.world_size - val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) - val_token_count = torch.zeros((), device=device, dtype=torch.float64) - val_byte_count = torch.zeros((), device=device, dtype=torch.float64) - - model.eval() - with torch.inference_mode(): - for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): - batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) - raw_start = batch_seq_start * seq_len - raw_end = batch_seq_end * seq_len + 1 - local = val_data.val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) - x = local[:-1].reshape(-1, seq_len) - y = local[1:].reshape(-1, seq_len) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - batch_loss = model(x, y).detach() - batch_token_count = float(y.numel()) - val_loss_sum += batch_loss.to(torch.float64) * batch_token_count - val_token_count += batch_token_count - prev_ids = x.reshape(-1) - tgt_ids = y.reshape(-1) - token_bytes = val_data.base_bytes_lut[tgt_ids].to(dtype=torch.int16) - token_bytes += (val_data.has_leading_space_lut[tgt_ids] & ~val_data.is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) - val_byte_count += token_bytes.to(torch.float64).sum() - - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) - - model.train() - return _loss_bpb(val_loss_sum, val_token_count, val_byte_count) - - -def eval_val_sliding( - h: Hyperparameters, - device: torch.device, - val_data: ValidationData, - base_model: nn.Module, - batch_seqs: int = 32 -) -> tuple[float, float]: - """Sliding window evaluation: each token scored with maximum context.""" - base_model.eval() - logits_fn = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) - - seq_len = h.eval_seq_len - context_size = seq_len - h.eval_stride - total_tokens = val_data.val_tokens.numel() - 1 - - window_starts = [ws for ws in range(0, total_tokens, h.eval_stride) - if ws + context_size < total_tokens] - - total_windows = len(window_starts) - my_s = (total_windows * h.rank) // h.world_size - my_e = (total_windows * (h.rank + 1)) // h.world_size - my_windows = window_starts[my_s:my_e] - - loss_sum = torch.zeros((), device=device, dtype=torch.float64) - token_count = torch.zeros((), device=device, dtype=torch.float64) - byte_count = torch.zeros((), device=device, dtype=torch.float64) - - with torch.inference_mode(): - for bi in range(0, len(my_windows), batch_seqs): - batch_ws = my_windows[bi:bi + batch_seqs] - bsz = len(batch_ws) - - x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) - y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) - wlens: list[int] = [] - - for i, ws in enumerate(batch_ws): - we = min(ws + seq_len, total_tokens) - wlen = we - ws - wlens.append(wlen) - chunk = val_data.val_tokens[ws:we + 1].to(dtype=torch.int64, device=device) - x_batch[i, :wlen] = chunk[:-1] - y_batch[i, :wlen] = chunk[1:] - - with torch.autocast(device_type="cuda", dtype=torch.bfloat16): - logits = logits_fn(x_batch) - - nll = F.cross_entropy( - logits.reshape(-1, logits.size(-1)).float(), - y_batch.reshape(-1), - reduction="none", - ).reshape(bsz, seq_len) - - for i, ws in enumerate(batch_ws): - wlen = wlens[i] - s = 0 if ws == 0 else context_size - scored_nll = nll[i, s:wlen].to(torch.float64) - loss_sum += scored_nll.sum() - token_count += float(wlen - s) - tgt = y_batch[i, s:wlen] - prev = x_batch[i, s:wlen] - tb = val_data.base_bytes_lut[tgt].to(torch.float64) - tb += (val_data.has_leading_space_lut[tgt] & ~val_data.is_boundary_token_lut[prev]).to(torch.float64) - byte_count += tb.sum() - - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) - - base_model.train() - return _loss_bpb(loss_sum, token_count, byte_count) - - -# ---------------------------------------- -# TTT (Test-Time Training) - Legal Score-First -# ---------------------------------------- - -def eval_val_ttt( - h: Hyperparameters, - base_model: nn.Module, - device: torch.device, - val_data: ValidationData, - log_fn=None, -) -> tuple[float, float]: - """Legal score-first TTT: score each chunk with sliding windows, - then train on it. Every token scored BEFORE any update that could use it.""" - seq_len = h.eval_seq_len - stride = h.eval_stride - total_tokens = val_data.val_tokens.numel() - 1 - ttt_chunk = h.ttt_chunk_tokens - rank = h.rank - world_size = h.world_size - if log_fn is None: - log_fn = lambda msg: None - - window_starts = [ws for ws in range(0, total_tokens, stride) - if min(ws + seq_len, total_tokens) - ws >= stride or ws == 0] - - num_chunks = (total_tokens + ttt_chunk - 1) // ttt_chunk - chunk_windows: list[list[int]] = [[] for _ in range(num_chunks)] - for ws in window_starts: - end = min(ws + seq_len, total_tokens) - wlen = end - ws - s = 0 if ws == 0 else max(wlen - stride, 0) - scored_start = ws + s - ci = min(scored_start // ttt_chunk, num_chunks - 1) - chunk_windows[ci].append(ws) - - log_fn(f"ttt_sliding:start chunks={num_chunks} chunk_tokens={ttt_chunk} " - f"total_windows={len(window_starts)} stride={stride} " - f"ttt_lr={h.ttt_lr} ttt_epochs={h.ttt_epochs} " - f"freeze_blocks={h.ttt_freeze_blocks}") - - loss_sum = torch.zeros((), device=device, dtype=torch.float64) - token_count = torch.zeros((), device=device, dtype=torch.float64) - byte_count = torch.zeros((), device=device, dtype=torch.float64) - - frozen_block_ids = set(range(min(h.ttt_freeze_blocks, len(base_model.blocks)))) - ttt_params = [] - for name, p in base_model.named_parameters(): - freeze = False - for bi in frozen_block_ids: - if f"blocks.{bi}." in name: - freeze = True - break - if freeze: - p.requires_grad_(False) - else: - p.requires_grad_(True) - ttt_params.append(p) - - log_fn(f"ttt_sliding:params unfrozen={sum(p.numel() for p in ttt_params)} " - f"frozen={sum(p.numel() for p in base_model.parameters() if not p.requires_grad)}") - - optimizer = torch.optim.SGD(ttt_params, lr=h.ttt_lr, momentum=h.ttt_momentum) - batch_seqs = h.ttt_batch_seqs - t0 = time.perf_counter() - - for ci in range(num_chunks): - windows = chunk_windows[ci] - if not windows: - continue - chunk_start = ci * ttt_chunk - chunk_end = min((ci + 1) * ttt_chunk, total_tokens) - - # --- Phase 1: SCORE this chunk's windows (no_grad for TTT compat) --- - my_s = (len(windows) * rank) // world_size - my_e = (len(windows) * (rank + 1)) // world_size - my_windows = windows[my_s:my_e] - - base_model.eval() - with torch.no_grad(): - for bi in range(0, len(my_windows), batch_seqs): - batch_ws = my_windows[bi:bi + batch_seqs] - bsz = len(batch_ws) - x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) - y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) - wlens: list[int] = [] - for i, ws in enumerate(batch_ws): - end = min(ws + seq_len, total_tokens) - wlen = end - ws - wlens.append(wlen) - chunk_tok = val_data.val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) - x_batch[i, :wlen] = chunk_tok[:-1] - y_batch[i, :wlen] = chunk_tok[1:] - with torch.autocast(device_type="cuda", dtype=torch.bfloat16): - logits = base_model.forward_logits(x_batch) - nll = F.cross_entropy( - logits.reshape(-1, logits.size(-1)).float(), - y_batch.reshape(-1), reduction="none", - ).reshape(bsz, seq_len) - for i, ws in enumerate(batch_ws): - wlen = wlens[i] - s = 0 if ws == 0 else max(wlen - stride, 0) - scored_nll = nll[i, s:wlen].to(torch.float64) - loss_sum += scored_nll.sum() - token_count += float(wlen - s) - tgt, prev = y_batch[i, s:wlen], x_batch[i, s:wlen] - tb = val_data.base_bytes_lut[tgt].to(torch.float64) - tb += (val_data.has_leading_space_lut[tgt] & ~val_data.is_boundary_token_lut[prev]).to(torch.float64) - byte_count += tb.sum() - - # --- Phase 2: TRAIN on this chunk (already scored = legal) --- - is_last_chunk = (ci == num_chunks - 1) - if not is_last_chunk and h.ttt_epochs > 0: - base_model.train() - chunk_seqs = (chunk_end - chunk_start) // seq_len - if chunk_seqs > 0: - cos_lr = h.ttt_lr * 0.5 * (1.0 + math.cos(math.pi * ci / max(num_chunks - 1, 1))) - for pg in optimizer.param_groups: - pg['lr'] = cos_lr - my_seq_s = (chunk_seqs * rank) // world_size - my_seq_e = (chunk_seqs * (rank + 1)) // world_size - my_chunk_seqs = my_seq_e - my_seq_s - for _ep in range(h.ttt_epochs): - for bs in range(0, my_chunk_seqs, batch_seqs): - be = min(bs + batch_seqs, my_chunk_seqs) - actual_bs = my_seq_s + bs - start_tok = chunk_start + actual_bs * seq_len - end_tok = chunk_start + (my_seq_s + be) * seq_len + 1 - if end_tok > val_data.val_tokens.numel(): - continue - local = val_data.val_tokens[start_tok:end_tok].to(device=device, dtype=torch.int64) - x = local[:-1].reshape(-1, seq_len) - y = local[1:].reshape(-1, seq_len) - optimizer.zero_grad(set_to_none=True) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16): - loss = base_model(x, y) - loss.backward() - if world_size > 1: - for p in ttt_params: - if p.grad is not None: - dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) - torch.nn.utils.clip_grad_norm_(ttt_params, h.ttt_grad_clip) - optimizer.step() - - if rank == 0 and (ci % 10 == 0 or ci == num_chunks - 1): - elapsed = time.perf_counter() - t0 - rl = loss_sum.item() / max(token_count.item(), 1) - rbpb = rl / math.log(2.0) * (token_count.item() / max(byte_count.item(), 1)) if token_count.item() > 0 else 0.0 - log_fn(f" ttt_chunk [{ci+1}/{num_chunks}] bpb={rbpb:.6f} time={elapsed:.1f}s") - - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) - - val_loss = (loss_sum / token_count).item() - val_bpb = val_loss / math.log(2.0) * (token_count.item() / byte_count.item()) - - for p in base_model.parameters(): - p.requires_grad_(True) - base_model.eval() - - log_fn(f"ttt_sliding:done val_loss={val_loss:.6f} val_bpb={val_bpb:.6f} " - f"elapsed={time.perf_counter() - t0:.1f}s") - return val_loss, val_bpb - - -# ---------------------------------------- -# Eval orchestration -# ---------------------------------------- - -def timed_eval(label: str, fn, *args, **kwargs) -> tuple[float, float]: - torch.cuda.synchronize() - t0 = time.perf_counter() - val_loss, val_bpb = fn(*args, **kwargs) - torch.cuda.synchronize() - elapsed_ms = 1000.0 * (time.perf_counter() - t0) - log(f"{label} val_loss:{val_loss:.8f} val_bpb:{val_bpb:.8f} eval_time:{elapsed_ms:.0f}ms") - return val_loss, val_bpb - - -def run_evals( - h: Hyperparameters, - device: torch.device, - val_data: ValidationData, - eval_model: torch.nn.Module -): - compiled_model = torch.compile(eval_model, dynamic=False, fullgraph=True) - timed_eval("final_int6_roundtrip", eval_val, h, device, val_data, compiled_model) - if h.sliding_window_enabled: - timed_eval("final_int6_sliding_window", eval_val_sliding, h, device, val_data, eval_model) - if h.ttt_enabled: - timed_eval("final_int6_ttt", eval_val_ttt, h, eval_model, device, val_data, log_fn=log) - -# ----------------------------- -# Training -# ----------------------------- - -def train_model(h: Hyperparameters, device: torch.device, val_data: ValidationData) -> None: - # Set up model - base_model = GPT(h).to(device).bfloat16() - restore_fp32_params(base_model) - compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) - if h.distributed: - model = DDP(compiled_model, device_ids=[h.local_rank], broadcast_buffers=False) - else: - model = compiled_model - log(f"model_params:{sum(p.numel() for p in base_model.parameters())}") - - # Set up optimizer and load train data - optimizers = Optimizers(h, base_model) - train_loader = DistributedTokenLoader( h.train_files, h.rank, h.world_size, device) - - # Helper functions for training - max_wallclock_ms = 1000.0 * h.max_wallclock_seconds if h.max_wallclock_seconds > 0 else None - if h.gptq_enabled and max_wallclock_ms is not None: - max_wallclock_ms -= h.gptq_reserve_seconds * 1000.0 - log(f"gptq:reserving {h.gptq_reserve_seconds:.0f}s, effective={max_wallclock_ms:.0f}ms") - - def training_frac(step: int, elapsed_ms: float) -> float: - """Fraction of training completed (0 to 1), using step or wallclock.""" - if max_wallclock_ms is None: - return step / max(h.iterations, 1) - return elapsed_ms / max(max_wallclock_ms, 1e-9) - - def lr_mul(frac: float) -> float: - if h.warmdown_frac <= 0: - return 1.0 - if frac >= 1.0 - h.warmdown_frac: - return max((1.0 - frac) / h.warmdown_frac, h.min_lr) - return 1.0 - - def step_fn(step, lr_scale): - optimizers.zero_grad_all() - train_loss = torch.zeros((), device=device) - for micro_step in range(h.grad_accum_steps): - if h.distributed: - model.require_backward_grad_sync = micro_step == h.grad_accum_steps - 1 - x, y = train_loader.next_batch(h.train_batch_tokens, h.train_seq_len, h.grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - loss = model(x, y) - train_loss += loss.detach() - (loss / h.grad_accum_steps).backward() - train_loss /= h.grad_accum_steps - - frac = min(step / h.muon_momentum_warmup_steps, 1.0) if h.muon_momentum_warmup_steps > 0 else 1.0 - muon_momentum = (1 - frac) * h.muon_momentum_warmup_start + frac * h.muon_momentum - for group in optimizers.optimizer_muon.param_groups: - group["momentum"] = muon_momentum - - for opt in optimizers: - for group in opt.param_groups: - group["lr"] = group["base_lr"] * lr_scale - - if h.grad_clip_norm > 0: - torch.nn.utils.clip_grad_norm_(base_model.parameters(), h.grad_clip_norm) - - optimizers.step() - return train_loss - - # Model warmup - if h.warmup_steps > 0: - initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} - initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] - model.train() - for warmup_step in range(h.warmup_steps): - step_fn(warmup_step, 1.0) - if warmup_step <= 5 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == h.warmup_steps: - log(f"warmup_step: {warmup_step + 1}/{h.warmup_steps}") - base_model.load_state_dict(initial_model_state, strict=True) - for opt, state in zip(optimizers, initial_optimizer_states, strict=True): - opt.load_state_dict(state) - optimizers.zero_grad_all() - if h.distributed: - model.require_backward_grad_sync = True - train_loader = DistributedTokenLoader( - h.train_files, h.rank, h.world_size, device) - - # Training loop - ema_state = {name: t.detach().float().clone() for name, t in base_model.state_dict().items()} - ema_decay = h.ema_decay - - training_time_ms = 0.0 - stop_after_step: int | None = None - torch.cuda.synchronize() - t0 = time.perf_counter() - - step = 0 - while True: - last_step = step == h.iterations or (stop_after_step is not None and step >= stop_after_step) - - # Modification 2: activate recurrence at recur_start_step - if step == h.recur_start_step and not base_model._recurrence_active: - base_model.set_recurrence_active(True) - log(f"recurrence:activated at step {step}, virtual_layers={base_model._get_virtual_layers()}") - - should_validate = last_step or (h.val_loss_every > 0 and step % h.val_loss_every == 0) - if should_validate: - torch.cuda.synchronize() - training_time_ms += 1000.0 * (time.perf_counter() - t0) - val_loss, val_bpb = eval_val(h, device, val_data, model) - log(f"{step}/{h.iterations} val_loss: {val_loss:.4f} val_bpb: {val_bpb:.4f}") - torch.cuda.synchronize() - t0 = time.perf_counter() - - if last_step: - if stop_after_step is not None and step < h.iterations: - log( - f"stopping_early: wallclock_cap train_time: {training_time_ms:.0f}ms " - f"step: {step}/{h.iterations}" - ) - break - - elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - frac = training_frac(step, elapsed_ms) - scale = lr_mul(frac) - train_loss = step_fn(step, scale) - - with torch.no_grad(): - for name, t in base_model.state_dict().items(): - ema_state[name].mul_(ema_decay).add_(t.detach().float(), alpha=1.0 - ema_decay) - - step += 1 - approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - - should_log_train = ( - h.train_log_every > 0 - and (step <= 5 or step % h.train_log_every == 0 or stop_after_step is not None) - ) - if should_log_train: - tok_per_sec = step * h.train_batch_tokens / (approx_training_time_ms / 1000.0) - log( - f"{step}/{h.iterations} train_loss: {train_loss.item():.4f} " - f"train_time: {approx_training_time_ms / 60000:.1f}m tok/s: {tok_per_sec:.0f}" - ) - - reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms - if h.distributed and max_wallclock_ms is not None: - reached_cap_tensor = torch.tensor(int(reached_cap), device=device) - dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) - reached_cap = bool(reached_cap_tensor.item()) - if stop_after_step is None and reached_cap: - stop_after_step = step - - log( - f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " - f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" - ) - - # Weight averaging - log("ema:applying EMA weights") - current_state = base_model.state_dict() - avg_state = {name: t.to(dtype=current_state[name].dtype) for name, t in ema_state.items()} - base_model.load_state_dict(avg_state, strict=True) - - return base_model, compiled_model - - -def train_and_eval(h: Hyperparameters, device: torch.device) -> None: - random.seed(h.seed) - np.random.seed(h.seed) - torch.manual_seed(h.seed) - torch.cuda.manual_seed_all(h.seed) - - val_data = ValidationData(h, device) - log(f"train_shards: {len(list(Path(h.datasets_dir).resolve().glob('fineweb_train_*.bin')))}") - log(f"val_tokens: {val_data.val_tokens.numel() - 1}") - - base_model, compiled_model = train_model(h, device, val_data) - timed_eval("pre-quantization post-ema", eval_val, h, device, val_data, compiled_model) - - serialize(h, base_model, Path(__file__).read_text(encoding="utf-8")) - if h.distributed: - dist.barrier() - - eval_model = deserialize(h, device) - # Activate recurrence on eval model for consistent evaluation - eval_model.set_recurrence_active(base_model._recurrence_active) - - run_evals(h, device, val_data, eval_model) - - -def main(): - # Modification 2: increase dynamo cache size for recurrence - torch._dynamo.config.cache_size_limit = 32 - - world_size = int(os.environ.get("WORLD_SIZE", "1")) - local_rank = int(os.environ.get("LOCAL_RANK", "0")) - distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ - - if not torch.cuda.is_available(): - raise RuntimeError("CUDA is required") - if world_size <= 0: - raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") - if 8 % world_size != 0: - raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") - - device = torch.device("cuda", local_rank) - torch.cuda.set_device(device) - if distributed: - dist.init_process_group(backend="nccl", device_id=device) - dist.barrier() - - torch.backends.cuda.matmul.allow_tf32 = True - torch.backends.cudnn.allow_tf32 = True - torch.set_float32_matmul_precision("high") - from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp - - enable_cudnn_sdp(False) - enable_flash_sdp(True) - enable_mem_efficient_sdp(False) - enable_math_sdp(False) - torch._dynamo.config.optimize_ddp = False - - h = Hyperparameters() - set_logging_hparams(h) - if h.is_main_process: - os.makedirs("logs", exist_ok=True) - log(100 * "=", console=False) - log("Hyperparameters:", console=True) - for k, v in sorted(vars(type(h)).items()): - if not k.startswith("_"): - log(f" {k}: {v}", console=True) - log(Path(__file__).read_text(encoding="utf-8"), console=False) - log("=" * 100, console=False) - log(f"Running Python {sys.version}", console=False) - log(f"Running PyTorch {torch.__version__}", console=False) - log( - subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, - console=False, - ) - log("=" * 100, console=False) - - train_and_eval(h, device) - - if distributed: - dist.destroy_process_group() - - -if __name__ == "__main__": - main() +import lzma as L,base64 as B +__wrapper_size__=23447 +exec(L.decompress(B.b85decode("{Wp48S^xk9=GL@E0stWa8~^|S5YJf5;XFo2+g$)On@VT6Qap3bu0*kgCR~YUqB0W9R)iarr*QtEZpesGY3>~CZRiK|6Dwut$nH#N""!RYqQnA}G^`ZsFO;ar92)Xt#3E3Ki5S1}OfSx<=$c<4=h|J{kt$27^CQ01M+lVgZ0tGgX0&I*V@{U&JgYc0U!(4F-btCy*+qzv6D""p~UW!y~6{U*}y$E@2-R}vd?t*s#fnDO{!j>OImt34A(d+9n>hnnvzmd((_D1Cghg~(bQ$Yj)>!Y%{*o9ex8FWa#U)!OI!!5Prl^?bnBX2V(=(Bvc+CvGo!S{LhLn7pSsR!}@""U=OBW0)h6IYneQ1{|$<&k9TS^qGQpb-;#vEPAl%11UF)?6mtC8c04XzR$+h2=j84E2|i`pOEt$uyM`lGs*ejIF-}^SvSRZK$ePh1""`gt+?%1r#=OVy3pW`{ofc)6PhQQRP|_h56zl+sQ(le1eJ^&&qZxdGb15""aOb^-R1ouqi-H1_w|H;g(()bKz_!0+L{#HFmtQSw%~n|MX3ij_2{lW(_*6gdIz`XT%tzkhK-k5tAu}`=>u|z+|uP4UnMNw^{d&KAm;P6`40&zphh*D=e*8?KGZuo~y*`y#Wg}r(PV}?J$Of6s2V+h!p=48x=kjDY_aX`;C#9(p^)jAnKY3g{5DbCkKssjYrRZXs2-`q~ph{?^weRV5FIKs2lBmbcg&RzLuOlaj@{R`Bfpp)eb""D}Ew7C`>;2>B3jf`(s3Yq*>}BxRhe{CvdtS;G|JdNs#+jk4ep3dje2Aqp%-8na-AvoSenabJ|$$wl8x*@Ue0+ZLIPX!Z`X`fDGAy""*mN{vNnb*|);N0fh-Z`3Aau{>lWape5IR~oJnXs<;zeX^21})nE$qd8r&8ti65NPa0a~or2;Suxwo%f(tuSS@rF*DUP|;0zJT4F7""RlJu9oEJotpZlOYK^$6&@1D6Ybg$yb(#g*e~5V+DpFWW@xtT}P3;OWc#i`!T6_OgC#jrG@Z-u}yx8""D+(63t^Z!rAvVq^;Hlx=NeuDV(m5K`Lloy+!tf4a-gljma9pI-QkvA4QxjGark_SRsKw%OOGOK?F6!GAG}Y3eXW8Andz%""B`_|L#`TG!i_Sn38W@mXk8(m}NGMRyY!nG7#(7q2kx=|C9XxNBD8;s;!clxeRlYzvgugFsQ?+B#pA|G9oM#}UvB}vG?(T|bb!%<)eY{jK?T}o!!L;8MX""aR9n&d~6*Y^qQ+_l?~MMX+}0>=iu&u=R#57w@o9&)|V4Ug=h=-{88=2!gonRZ6P5zIsz|R)AX)sGS^%ca4&c`)xE88mbC%}O7HV%""CNLM;Td~lRF~xmpFQB%OIB~VgMRMqx`Fd@oX$o""jx0VVKTM%u$S1Je>ak!%r(}nlB@GQhHEW<7r_>z&RVBXF0~M6WaUEqT??3}H!vCwVrT-+aQ9}8eKZ(4ODp$WCE?qE;D{_IK|WO`#~pSTdixgJa`d3bV_y=L#oq?u(cX|U?!hP~-ioP)ZUHOJ8d*};0VgaVl!VMXCyKhG`+6wQ8j1>ea#G7dHNownYeKS|^XNy|f9HQOs-OM*>n;ie""nwoR%RYcjeZpqUUq|Q{@52h^IN145J&1a3U#~jPnK0aK&d3i~qHdLbi{H@k;?us5Y*zocf*@n^aB6E|!!fP4Q%Hx3nzYW1QQZ42y^ru2=G=TOOqCbw|xjTIna~z""duyA6$gl>oNX`nJ$UUlcQd&5t?n(mSO0z~;b2Qh9VwLd9e)S+5v8AUW?kV}GE*zZk{M460oKod9-ISQI""0gV^pJyb9~6iut%7p{klC#xEK#h0iDCAiLwnXUt{`dVc0%c*DFNthN{dz%#kkYZtc{vo~8QH!u(v)*o~Ty%+M$x9ShSK^=ObXclGly&&lCQ@K*MUVnM(*R5A1(z=MLI(2Q48|7$6E4nd%%NF{(r^Fv)1$j?gL6a""x2I8rJ4sL*y~@jG_{Y7!xZhtFF&*vl3X8r7=L6aAEs+{?PvR3L?U_I(8d2wrP`dD`xPr(9C{e;ZL1qbK#W""#js48ukMqA%N*qtQ&82XzfZ4p0oj*zF~)i|L-yQ@T!x8)8XyhR!+us3b?VZyH%?uQ+uqf&UUf;kQ&SfWg=E|;-8sbXd%4-YRUs@V""?&@A-0US|{R)r^Eg@<3ab)O^=a$|P@2-ZAptnZYae`;5q>$>vScXK6fpvUImA5X}jC|SL9Fv}!2)I$VjH?Q<*o7q2W2jBa_Nik^<7Kl$|iA<{3*mO-USuQCZun@~zOpH4J3=Ly%|4%OJ7""Dt%mRZX%&&HCVG|%ncCT8O#m}FzZA}gc`F>vTmqT*7it-7-dlWLE5>Y@AH|;Q#RvQmSvA+;j)-5)cy?A5UU|sFy@E;VbG59""vYiAY1@ywh%lQIVBJg|#q~wLHJen`9O9|}SSXdE3>2mGh5-Y4uJdhpWJJ%=8lGAf-yS""uvhvwlr?87A|u-G*0eJAX+W{aZ}509z~Hu_R5(uD;%T;oYEA+)MN2kd`(W3q>N#~8F4(Nz9+_=y!wSQat|PyNt4H*j48Rx_unaWX""#Ixl{=cbX-d1*Z1U+zMJ@88$@b3gu7isTsv8he{cdlOtzOkWU0b#Y6F9ZR?|GB6nZW$bW4Ll0~Ai|P>Ke}P&-$sWsP9v)-@|Jo;|6_A6ys?AYPM8(a|Tb;%>|bj%41&Unu@eC>UCY4jFsaHB0b;9BBl4guf*PB1LO~pMGC(@5e)M{XJZYhos;`c`6ZK{""4Zn|?!UdMd}TFsV(8kep--vCvs&!A-n!!`-;dK`UdU^mrW)2y_=s5z<6ZN`%RnY$""VYsJMF|Th|RxFK-!OQ41VB9|AJ|FmyCge_f>(sg51RZ^mz13KU9F`+iEc4aDlJ);a_SQtn&@sMm$""^cCs#9PqK6K3Qq)h++{Z7N)~3ZJ?USgf3MRM!b3x|DGgBM{3+ndwH-L#sF`fKQ`f3^CGrD)y_z^*pVEGFg)km_e20hzTw9SnIo?1zW&*B!7tL)LvSFJZnsJqK+0^m#GJ)}@(G*Kfg}=!Gg<<3>H(17{CX*@L5|LI3clm!hr-""o9fmX0Mg=W&CpwY*3?d5X>pGB^sN?c?LW;&=eFE?xaVdgKXw)=(d5zeC{ILJpfc7go$LJyQFyRfN@@TcUxQWXKMuNgy9VT&CSuxJf4r^WO|Y>a(v{o#9u+-cOZm;*CRzFE""FrJ$k40dzzv>*?;eIoOIy}}edA!*$t_!T@r;C7%a-LhY4+nV$yxLET-t!CWM+;8ZoF}p>vls6Glza{|D!@ism^#HE~l~bTsX#-vH""L`B$xZRMgr!^o)G18l""&pWTb6!qFAG9=SkXyCz;Y~Dt$gLZXo6RUOyO`Kv<-z|H0WT3ErkE5l>ee_*f3dGt5m`k~RoV>T;WUzd91)ELVco*S~oN&UFg@F4B""0kSGPnjohR6H8ftuE2vMwpU+NADl@Qmm28LDJOzXnn4)^o5ZTsY8PJ}gU@lh|S2T~3>a6;C}k?5)`l""I4Sq>?6Sk@A6e!)sG>HT_|b2ModK^NtFyt&Inr7m(~ITt78<4gnP}Wyf!D4C>U;r^B?xnn=`u8PQ^;$2rd)*qHvhUrQBW_)~zpR+U^OTTm#H#y!r>oM=EJ}I1J9({swy*c(b8EFXsiB#_T+wcZ$i5Ew8""de0CXD&_agYzTcwM+XKog&iM%-&4c*m(hc8?yLUmCB%e@e`1r1+~&E$_5?jQ!90eDf)JW?1T*ZGKa!xKq=Urb*@jWwj""N$fUxuUv(G=ZwBpJFkf8DcUM7vv5W3Pd5KtIfXlarmSj#rs*j*f%JNhiuhUnU8ylDa$d;3>#~U)RWMGSh^QLuR_u)r$&~7DGW>Fj=""@gEUnotT+kFFGIvd5mlb=ghr4jJM$=)yi_hCc>6BzFyx=QYlb7KbUR*F&HiHHLpL$1#n@~>V+0PTE)""f-$UA0}cAWo79P#g&lGNk9AYvmj&Ug_yw9s*IvPC>XmQ<%I8qO3GG-3vvgEdkwbC8zDU5r0ca}VWtiZsi;cu1q1?eY`_qDXGf)lw""@Fy~a7J=VUqLy=9cIxF>zq<""10YDIj}N#8YKAQY@U4#>C2EjAZK+J~^<#LmPCfUWntU7B#xOHKxgX&bUF4=614Hc~NOFm*6VrG#n;kv4Q<|0jz-nCA3P&6i@YOqr""0PeLhDL`gvdmTKz^1kns-t*@p|N0#&?`7pe4D>(#SOHw#HZ7aPG)L5KSnJ#KnB&^ZqwrwRdnITo!MHKvi9mG%1A~aa#3U0t9PRzd""S&2Qu%F@{Q2?0dzt8;tbnAC~X-hsHS-Q@4l#@zlr1@M#SA{$7$B0@{Nw>|Pe?pJ#tR>i2s)Trojropq`gb%+8y4jh!4~l1Pu^=Z6``i3IE>rr!S;T""8d~dS^3NhYMto;g;_#!n+M-H$y>FW(BKs>K+Bde8bycXj-Vqk*^D7Ljw`w8H3Mm#g$GUPtJTvT876@%u|fz~y1T@zoZ+1L#!n*d""F;zD3?dU0uTwqRC#;4Hrz_O2j5p8b&avrf1pzv8e-tNLbj!p^E&rUr)IaYYn?Clkz{OG_eiaA~j8pAZ0Wq=TXsXV5&Vg8W8crPE(""%fd*8_S+Xm6oU{|h&ugSL4bW%PnBRu^$!uXNykih*dMJM=Y+w}84%m?tI`JHF!8p{YgJ&zF}Tvr2agKniajsHmer}>u^clZE`uxx""f6`@@pymi)QGQCWUacF-y^Bwa)@_~Z$^13VD!X$`yCBX1Ry~=00t;+?hge$We)4!Humk0Gi3hsem&qGg6w8=<{5>&=y5RpOzGmTB""sTkAbeW5CwR@zmoHd_qyCCjNbCEL>`tP&@xC`Y@Zkui<}<>etKKs;@f#}iSmO&RUSLI$4^`?}v=6>^!nLs+)=w!E~DC{(=}&im8@""eMTc#)GhJ*e3^;@1$eZa~%aXgT~ESZDfSu}^=J?F^3YF8&CA`nc6afs*d+BjWpGBJ|T-j_pbvS0OpG""2;}ucP|*>0Xj~GUA@5q&IzbN6txVJlSZBK98V=^Qz+(|vBT}X#3KORC1vU;d3#p8|qZW(jJq^azP0FrM7>$47ealNHhpFFSdHPaaD{7y56^40OXP7C%Q1SN""Tzqb#f%VTbZjN;nBAyZ5VIcYL|{JY`0lux!PkyZ;WA4v%W^{1""$fnG+ExNXS#5$""6*?6Q#1ZxQKBKQdfkqVJz!Yv7Cd?zL8&cGwx?XXFq+`Pde-oZ(CB?CTRg{<(W#6_Og5tR-;n#7BaB9-CBUqEf8|60j;C6!v2!q|y""I@#8rUs$H|s`Wez;CSwn*AQ&)uJ`6ZJ~Bbe4z(YcD4uvJk`(E1Ixrm5V?-yWEDXE;iAi!E|9xC(U4?8Vz&lAdz-^G`J3j8|7m%VY""lvBVktpAttE@~G*s605}dKp5e?&yk_!UR(cq2u#-K_9f4h6f`#D7K4{ia7PZBN5(i)3iw6Bw0R4%w_IB6|ZviPVJWfG|OrPX#auR""Hp+$A*4>Z*0D<~|f9R^Rd1bPdn4=6@qWMfR$C*`vp%ne@mwict+b@_1T^<~2P>Eq(4n9biVz5Qg9ral$nT7vLal(PQ!PQ9Mpv#5E""5nnynf9qaufF5v4vBTNMxq&!Ui&wT)IXa{-!eHg-*t=1_cFiqhXzS-WrN99H~G6Sj}T&m""l^<~oG=?p{XGi;-FuM?_XBsDemRO!RN{-a@Qq^vNIpx(edVFO3UBp_h0x@c%=z_8;-Hzxd|MPbt3w8-^082;6tylol1ipjm*@mWP""CDPh?6PH(!y!#1`1Stk8-A1*nVyvDc>xS^$0MfWuzEe_+YrSEEfB(DHC=(^R{8ML^d2u0~TAv{4Vr9Bs*_I$qz(IBFHYh)@L6@4$""Gq|z25&7eqhEP2r_+N&wAw}8UxR*+}qI>`&`e_Ai""is>~@l0;uj{Ga|_1puPg1){E{Pj%h!3{K84{_+T()o=+-d{ai!D#7W-UhvZvsOkSQ^d3B43q?J%!0Kt&mSdN1p=)!8tOP>)1)7r-""rvL6e5MUtm0amt25j4{Gv|e78jK%7*R@?3$aq-7K^4<6Bp|t{3sR8-+IKp3w@w(T8CbqWWiY?Zxm_|ShVeNqZvznY4*&>P+++2WR""&_9ZSZ2Dc}fx>zn+}p1#=)YHJ|K~hzDz~vYjRss7Beh}3c_d05M--doRCHSwME+@|{DyIhcMsj)Va50(SI4uFRjI09a)RIMpQJ7e""({Akyn|Kov)>VA+16on49gFXYSbpRBo&1Jt{8~*2>QlWwL`8L{T7KZrn9fm9uDMDwCvj-+)(5%m8WT2fZcpV2xd7h|&_l1wKpZfT%g?w?X|l(+?2k9g3`4orgK}IVm)L$N)aoYU`s{SRwWS&AT3@Uu%Qu$@nVDLUkqk)G{3iI`HD}""(e?l#mwWtzNosyxHRzTbEM;fgf2Z_~eozfer{BwpzO{(dvP}>Tu6uNCGGmn-R~ibjV%d85(e)D?jg2JS$?u1)K(Utj%j-Q*J⁣""xjmY{Po@TxxSxTSw8Qa~I94%oLGXz0tyejhuHEg6;pXp#i?4k1T@~mKTc)sKh(cp1Vt-J1$3j)(oB_8-zV;~MZct2zsWPxTJrV?^""M}W)SIJE3}9GUN*8DOxG9E1D!F8$+4+c_ExBcp2i5_C84UgG*6vnKTHqDyKb8L--qq~>Jezp@hvWq?R`kLLI)rWXhD7Ji^O3&dtG""(RA0?g3ubu2Epb2qrF$28H{MJAuv$?bfQUq!{ElUBJaPaVgdT?0JivdjY8L&HLY|Psc><+LRA~nY;{17lGa6rFv4Zr>>Ks3!Ofmh""m1qnYktJG)D8<8HzIX>_$7NUOROEO}*4gHguxol<@{<&~#~<3=X}4#*W~OPabWetKM6Qq%=|ve3lGp@nYHMnOb#mlz""k!0=&Zx^OMPm?E-lhZ?v;MDB!KZzQQqc*gkFd0QN+M5*Bl|pPvrd0h~(*nyB+DQ%wl?~^E!tZ>SekL9z;4fGCw^x)88rGhV-NhI=""+c4$VEjl^ENH16yRTbIle2I81O2k>uMgd%*h)x0Bg4dTiH4dr>q5m4KDgb;YAcT*Kl`H|=x5%EuIsV-gDUfmiV_8^>Ojyj&XM;=y""09Hzi&`SnvmCCq>Tk~!C-+zm+-A-@4QD~19qDCMWMrP_IU^<>6)Zj-pw+xly8!(ukoESDA9~bJzMuNVA;+EfX4JlPQ*=QI{LG+-G""uP||j;mTOmmMp==2KH*YVdv-CT(h!KETO0X!cYg!$HW(`dN}SxS0D?uJ^%zZ_!z3}n4XC#L{eO$cH2Xw6aO_!a=FlpjjB>v`)QULT*up^x>l^m0mGQKsos8*F)!hdLR&>*~MCI$fF{o)tq+A(DX-L@^DmmxtBK)}*bF%Sg`U?fGAdk7PXS8zNcW;H>XI""BUAlJ(E~(~9OPp=XABT=c9d183MX3@E=yZ#{D$_+C?enNxvt31M_yNt^YGHyk>NR@BUq16!bAHe2{mO|Coa$FDdH`UK#y&q-BS""hm98l#YO5yYLApAtwj?upIM9<^=3davWpUdjSY>DesfDNG0?q8C4_Ka#H07dlPz$YI(@ax-D~&;GSwKd*o@>87KpG~9VzUQ^4l?#""s4Oc9JzGzINAq5-C9mw*ILwAcH1(b;;tH18z|!XkD0l4DWU>}n?X(|i4mQ?lqeD_te3exZOTHma+9Q!w|K{BzIJ}d<92PQ3puB&{tB>i`QrnW6dQKbN<>|}Ry""wJ2>cN%WqtnC}228|#>qis_GvPTb7b45elvqa>I1LvZlLS;A"">t)7_|D<$jeZ+&$@nl5thh#gfc0_ORnOeo~Gk?073gFI$(PDbSt(fARj_TaKN%+Yf-r""AgNq)ehb&C?zv;f22d+W0BN!LW%I!Zdj;D8OGma1XE)dqgzP^ru#P4=zYL8OQ;>OX(6s&*qqM*yCG<&m1#}h4t!nh&Ez4+Y$o;jDzFcme0dLZeN""LxQxvnAcFXZA~K#wHyT05O6!JJWlAkyy_jOU;HwpWG3ZRr2RP_8Yo5k`gItcS3mfr`4K{pYF7qbt6f@AOjL_#2_IggA{O6&PcscP""_I@sqdnGNFJo!4H^na~KDIqD;F^vDAAL;}efLMBkj&CpawXxWlC4r1WouSYaPivVDHC7M)PeS~5DNu7xSLyIp5`8C|8|-Vb)SzbI""v^14}|I&@>trWz}$^;t{$iW5K()XInKZ$~d&(1@&iGOXXNHlz=aT{Gv$;0lDtJ10SS3_RfI$yC0?g|$cS>jPv$kMGlk`X%S?jeclev#q5%QJm=#s=H_T(wLBEsnMS|y7""hvStmR2_N2MfcW%HU)c7cwYGB><)m=+8gsxj!v3ve^^--7xe1ZRPV5U1&O8hqcuz|a(jDp!R0to$YE0Id`g%q2Pix-#M1*?bTO^S7;{bKK3^1fUWI~5a`==oV|!{manu=V#YM*>*#LjyNT5)ThA;G0$}hLtxDMn!(a6NL*aTFA`7Hmp8W@""5$h07L=yXt$L()&2QVr-jx#;-N+4I`9~B@;stk36xS""7rvKJOnE}Fhy5!qn_FgsTI6)^Ke""0|wOg1?vz0a~Xj6YR$dGI2i^qzHobOnzREvdFex1{Mu>_O)ThEnC!mDA>>5m9n4&E`ux$Wz(8V9K&u%2e9Aa_0K;#~7w)>o*U&DAf${Yg`~*j*#YQ?j&;|bKs7A""liTgnLb|4rN?gyYJBsqMT!R~UnojBJc#%5Ry;r>p2kCZKSI;0X8&{Z_0X)4sy}*lHSo+eD97)>8s)ZOR#*id_UH_r@AUj`iJ8TcvA1(I%U`Ok$zv""r`eFH*d*>S`JSnev`_fVN9h^0*S05Xi8_WpC|uD@=ozyTEEvoFL2Ni?rzy;pHdM@sy3LNZVpKT9FxDx`>mTz+ANC{>wLBy5fMjlY?rK%vAMq5KD6zbpdZF;zrwDB}~+9%t@#;M)(R?Ml6~"">%A89thn}(vv2E7NrR)$pQ1niiIIKft3oq_CZ0jER$h+j8@qlw&VF!mmY+V7er+N+?{r*XrntXgc5znz{#ZGzy+lFN58mWF7hnII""L=r%sDzdYvaEFzGX|MutZ@oGMY?0~=J;)_ry$=wF2515~Yf5^vx""baeGPNNN)9ZF+y`UKct5KxzN`Gry#-%>oQK&YHvC;JYwK))qPffbVac6=iey5y_36Q=)b~@$dzwP5Bnk@<0$)6Q4ED%cQLMA23QH"";+NJnzP1(MxM1?8Njqfs?l@q{w!%#c64#6H`TEx3kKP$~Y~EU~?A*K&?X2nFJ3sX;aVbe|9}Ud92)9!U>U2S@nBU(;Gv?^m@S""wL{oSTz3f2i3}&Kvu?{`Y<^b7ieYWwuCFeck~?mZknU7Ydn6}(#{GKwF+drBjl(6}Tn#sZ$oLmee}^7dm;~``|Kx!pl}YP+UI7WJ""efyM~+y`~+tujH9!rRoT!rG!YwPEJc%%3@_%mNR+x+dw`QA2ppgCym{L_UBi5f_TIWK7aip*9dJ)Y`6FGZQnqK(PjoZtJk$lp7*(?99b+QQ=|2(wskv%4O=""f8W%wW2KwK1!;|$9WC$qe7B5pcC!?SH^&zJjneE%i0RC0>""Z0p3*XbH-c1dq}h?(gZ~Fy-E>DwP)&?dtA)+2;G_l`_zS9%kG*4YV""uB`*dx<(EtB`x-YT-~+|lLa=8R*!G)gTN(X#aNsN%jc4Z#OLc8)FCiGg=Cy@rajmW*cH4PR)uz+-{nbT3w?nvVU*_`0++fJA+u0|(>WI3Y+=n9+mpt*oh`%L@zyYb0dRgsU4TK""uvsZks@n9kchPoIE791S`V?`MI@*YVllPO`f%za+)yEgv6yeBt96A8q@CthCF@s;k<>@ROr7(aHe4t9rI8E7V#H^)N%mvpejhYQw""mwRl3mj~x-ouNua<{e_On7FHG;_4yPsZ&HzmJv^xK2kE%_UH8yB^IA_6LIJL*i~(J1J&$IUD~V{1e)E8+Zhd+i^?_f={#1}F<97D""NZ5tjis#;a=3Jt^apAjuK9|E+q@*(9}&*(jK@xHH~""gH*)aw$mB;3L5GYO4m-&d-+4t@duN)loJXmM}a(=hAj()u79aGu%*;#PwS`gq89rI*BMsp`)MqwUv{{TMap0u?4(TFl397{l;kkt""tJ(LN6%{8inmnQB<(QY{L2h}zvcEYo&b+L>VUY{Ag^|ua1XLRUZAqg}K*rU^h?I&?poUYt1""=38ndiT_kbHX=5Nl1Ys~+O;`Seot^aNjK8@846WEQJ!xD3-OwUVI-$^rtPH+e_z@RwW8E9tx>wtQ$oBHc|%$1*_EEF_FvCuP5K8-""(r~l^+eqE=e><%I_cn9oV?v$A}}V2ZElJk6*TtiWmeKb+o|)B^$wi!M""Q_>R$lXAL=IdDtzik|n2qi9_WS!=)oh@NQj0jpTDE_02xvsvXH2f}zegkCnVR-fJJ4=BHkr6!kRtKoMC{pltB-_BuDh6wy`rTM(nVt{-LwscfvTf~OzedS%Hc7>3)Eel#sU)9-Dg0^|f8KR2iPf~Yw26@W#A&c1!#{D}VNwD6(wXGfAodfL`o^OsMxrSQJ`PKirZj#P#0(m@XN-)HB>_Fp|G7MI""H7~idf?plwJxi?r$$D0F1?LwNn84I*5UdI4xk<4=Y)1U$`|hQZnROMUOWg*$gSs4jyx-4~Gz!Su3seC`Lig2z_`_cLOl02Fs91@_NS1#e-4gmKdIIcC5C9pFeI6q0}v#KErKC|@~giz3DlR{L$8By3b@!MxG-7u""CbcPX;IVP$Z_OX>pX0~}Xp-W+gQQ)bsbnL#ZBxspUwIih4tv!|q_6jA35uxMB9i&C7V-i<9=nm(SAa#Ap|@wRN}@WC1x!ODJb8k@""{VwjuL3RTlH;`x}3OA+M@gOTI0{*wu7QaATlR#ggvJLjAdO-5@ni61GZG*bzGUQwpcMzfV3T$%&G%s{6ZI4CBDcRfxBd6B22oY75eVsHA$VxqO%E1@Jm0QxizI-qor5{#E{6wo1-t{^vu5>iFrN-*6u=REKs=A&nk;""S99>l%(0N6nV0l=v$a%3h8fXB6p&HlNCEb@h1L&HQfwGRg4ufgEi_QsF0nZ{6ljnE+J62#DoY)}{H$sFliz9zjeucTkB~wrCGZNw""b_a1lx{;6jt4V2{T1MtIr0{bP6S}pOw#_Q&eGx+t{&ivW^z3U^t>8bFCSTuct^=yDCby^1h=uBaZaY6+*iD4tJXNfu4R%sxG{-z6""tzc)L);UGx=EUN|;VidC|1AtBC>dd*$%N5h{>_!+St?6mJ0vZJ^V2Y""#T+aHjK);yS@yagNnbL7vp6!A=4Jq2J|$&}NUC)^=W5LRW8I|P-S7wimo0JceaV}Krp^2ccuP;lrip&461DmZ+SfItk>FX0$L>IQ""pWF9SvNB4*%NZtJ1^CD`mYH!Q)|CN!^ho<2G=ovfd({5l0ju;P2!L$;!6_;l@4U@D$khmBSnVR*6P-|e*FO(Asj&`Nc)8FMR+ck&""TOV&%t&c5^yMA~gV=85_E#uGEyTkr09=fv-X+`rWOz3D{w@zB=M_}ZAtKMS&TM01s@#tkdI$r5}RSRJ$D;MX}AKCIy@Oy+Aw>!hU1Iq@t(r""^ESJzE)bX)s0!TzKxiBhk3Cbee%6_E&U;!gkHcN2uDw7i+v7E#bU8ah)v>O2I~GH2$3aayjC(R1=e9)k2-H63DJ+PF58vQz$tulW""E?PAK0thS8Vte*D7iM!#5)~p2$79>;PTJy)GJ2DrP-2*vn&X&;6Ryi0k2V{+&V}@S59&<&({c5gf*f^DwmsXQ^))VS%{)!5s5=urVH`)G0QIlXfijW3f=PdZqq%2fe7wpU4B&ZYRQre*n_^mtlH~8Kw0Z!&=i6Lk0SUYJb%i9""b{r%ZrZI6$RHQDEc%`s!*f78z%6^Oic~cI#a7Q4>N+Kwc<7jjpFC^)))if`&-FptV85u>Evw(tFF)nBFq2wK(GS8A~9t>ah{4Yk=""NR9(!OHIBr({tMiW`!psuu+47JjnF=dd3t_h-=uC4>~a)H57HzS%Jq}YA4Z~x|ZkTrgKK|6?eNtMCNWf>yoXF<&@7@z;?#g(=XkP""Jb+Z4%Vuc%;dMuqD_sT{{1A8fTG}$8W9|o@Q{RrC^2~QJO5;OUVIVs4V_K_N8za{7Lr56oeHni^m1Rwh7Jx8~qimxKZwB}xQ_eXF""y&D9L(J&rYY+3AoJQH#VC;m{SI+$w?f>>`-Eq4uI`mZK}p|9*-|Hi(pX$a}I7TcluXjT?Ad6W(h`gjHkNeKgJ""@&?p6B+#hf9mu>C!5&_5p0=g4;-5128x<6Q;PnM)GdyqW*Mj^z?-6r<)>hg`h0C=B>8@_`iExoYUo4d@m|}4ZQv9Ru|#q96iGxN>K&w?=!D&GgAKx9S;FaMaa8GTov~P5br5>${?`PVm099t!4myz""!rvCdi<>;!qE!MdE#mA!Lp_k8u(T-eP^jR^P7Wy*c51My>N3)7I5PxVh~q`^%|aLPpEY5|$~fd>mIFsZVlxh@wClZ)cdhWr8fzfEXb1cr2NqCMw5-45;)UR)H{@G7dp`8c8g*!&UxRQXUugpL""juDw=kR60VLF@7W0goL2N*1u#t&*JL+<_@_h*xhx%0jt(eA)~nj80N&IHLn6xrcIK&Vp^6JX2R1+EijuxKxZ&`_-pZ&RS*Hr_i2U""=&S@jYUMuY)(Eq5q6A1A&_Rx@CH2_;cUenUtP|I?qMV-=S;sO9n$%3=!`xfZ!KHPW1K97$AWPT""n|9uc?U%ew2p>DQikRxuP88g<&5n7x6;hBXo%Z8W?&OhsfH>~q4IjyKhV7)1*=W``cu;}y3r0DT&Ojb1$wnd6Vb9yj8a&NQW5t7%""B0qwSL#dj88V(}9o_f{Hg1E1O_z^!p0kpzH5UIXE3L_~fg*bZ{TIdQ8I1V(}qs?hX7(zl;%!2PgHVSPlIs^FT9oCO8Br6)x_*F_Z""tGWA?0xA66h1-U3E^z#{1l8Kox2Xz!O6dyJ(I3%{;)ECjTluN#;w-D!3=`KBd5k6%Y=N+PRn_%jON8*7^3A9&uMd9aSm{d7T+*_F""giN_0?JOmdk(S;y@jh}LZ?&!v~_!1N>;m+3Wuv|J$ndx""I1Kce?fB?IijjIQwKRw^JIp;7yA|X_M2DyqQG@-yG1>XEuT;%51%m=MA}n{FGCel>`R|W$lDkV|(Z=wmaB*w@>?1W&""2C>dUtFklNrOix)>}v7~lBvjjAuuax6~72^jh?kw4-D?{q%cTVI2wFEvXt>DsJzs)&U?C5#@C^FTw0{^*G#vpMdFxwBNp3tpT$8r?5oMTXR7fTlu_QVPhB""1D)o_ju_dwUD*j?lk0gEI$aOTS@T%^6iOZ8+CTg49vE$mLMw&kh6GB-(sXgzIT%g+jdUwu>of_SO5@zoQ{P!#GMMT#qO{sYSo$l?""sqegt!svC&Fd|1SSr$wjIy2BN<%oP@@i@OBfeW4J!x{3Ntm=56XT4!;!<3X?l;K?V#Kz%;3X5K8S7&ETG#qkpxO#SCG+2`hb`d$4""n*e2g8oE3(t>$8l)nElnp*Y%~t9KF0EK10Lb+H2ws2X8mcH$su{6sbdDA-RN^o1u!idB!xXpx4K2Mc8KUz""6i)PagVT>mfAb2I%MIJ=y|h(mD_T+C3H6BRQ5sC5G&K7?59!wL(kfjEPSNbf^K=ohJ8JW#XfwA>a?Y168PC>{UeT<3IdCc%=C0lE""N|Fw*^v>}36}39*Wz&fePChjHWPD5m#(2Hy&RX}u4RSttddLzChny7yT@jyK2M54^!sl0L{|x-QoBGzundoziQcQd|I}*i)lL=GMr+tqgJ3qkn0?{UMCDSMB"">};Hv_!}EH&yWShU0mx|A^WR^`V!VBGXv%i_Dh2~)17$3X1jTmrI2cMS~Bk54q*8sgyPU@1ezC*+z!z9RSyvE4Y%(PVgG&V(qM`E""`xUomx2cUI8f3JfZ`V2QdwZ>>4D^PQOA~MTNGH94@k!ngC)QeA)5R$oJ;_d+f@p9c3f9ov0aNs#CPTC-qICP?*9r!EM|BZL9y`Q^"";FX3?Uk+Oc#xGKg-RN;ts+$j&9qF?>CAG%=vZLS+g7NiXRVzLZumAu6_HsiSU%NN?00H-r0mt|R(7W)^vBYQl0ssI200dcD"))) \ No newline at end of file diff --git a/records/track_10min_16mb/2026-04-03_SP4096_DepthRecurrence_MuonEqR_GPTQ/train_seed314.log b/records/track_10min_16mb/2026-04-03_SP4096_DepthRecurrence_MuonEqR_GPTQ/train_seed314.log index 4a2321f4e0..612139e779 100644 --- a/records/track_10min_16mb/2026-04-03_SP4096_DepthRecurrence_MuonEqR_GPTQ/train_seed314.log +++ b/records/track_10min_16mb/2026-04-03_SP4096_DepthRecurrence_MuonEqR_GPTQ/train_seed314.log @@ -1,7 +1,7 @@ -W0403 10:16:55.161000 48749 torch/distributed/run.py:803] -W0403 10:16:55.161000 48749 torch/distributed/run.py:803] ***************************************** -W0403 10:16:55.161000 48749 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. -W0403 10:16:55.161000 48749 torch/distributed/run.py:803] ***************************************** +W0403 11:59:07.019000 46017 torch/distributed/run.py:803] +W0403 11:59:07.019000 46017 torch/distributed/run.py:803] ***************************************** +W0403 11:59:07.019000 46017 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0403 11:59:07.019000 46017 torch/distributed/run.py:803] ***************************************** Hyperparameters: adam_eps: 1e-08 adam_wd: 0.02 @@ -27,7 +27,7 @@ Hyperparameters: iterations: 20000 ln_scale: True local_rank: 0 - logfile: logs/df967a0d-f7a2-4514-8cac-646d1e38abd5.txt + logfile: logs/56d7b60f-bc99-4d4e-b53e-cc83e5e5a1a1.txt logit_softcap: 30.0 matrix_lr: 0.02 max_wallclock_seconds: 600.0 @@ -52,7 +52,7 @@ Hyperparameters: rope_base: 10000.0 rope_dims: 16 rope_train_seq_len: 2048 - run_id: df967a0d-f7a2-4514-8cac-646d1e38abd5 + run_id: 56d7b60f-bc99-4d4e-b53e-cc83e5e5a1a1 scalar_lr: 0.02 seed: 314 skip_gates_enabled: True @@ -97,36 +97,36 @@ warmup_step: 6/20 warmup_step: 10/20 warmup_step: 20/20 0/20000 val_loss: 8.3172 val_bpb: 3.6146 -1/20000 train_loss: 8.3192 train_time: 0.0m tok/s: 8366587 -2/20000 train_loss: 12.3839 train_time: 0.0m tok/s: 8320936 -3/20000 train_loss: 10.8345 train_time: 0.0m tok/s: 8218162 -4/20000 train_loss: 8.9588 train_time: 0.0m tok/s: 8165338 -5/20000 train_loss: 7.7775 train_time: 0.0m tok/s: 8134199 -500/20000 train_loss: 2.9043 train_time: 0.8m tok/s: 7887531 -1000/20000 train_loss: 2.8870 train_time: 1.7m tok/s: 7875981 -1500/20000 train_loss: 2.9137 train_time: 2.5m tok/s: 7868810 -2000/20000 train_loss: 2.6565 train_time: 3.3m tok/s: 7867581 -2500/20000 train_loss: 2.7143 train_time: 4.2m tok/s: 7867014 -3000/20000 train_loss: 2.7601 train_time: 5.0m tok/s: 7866553 +1/20000 train_loss: 8.3192 train_time: 0.0m tok/s: 8448432 +2/20000 train_loss: 12.3839 train_time: 0.0m tok/s: 8340059 +3/20000 train_loss: 10.8345 train_time: 0.0m tok/s: 8240442 +4/20000 train_loss: 8.9588 train_time: 0.0m tok/s: 8184109 +5/20000 train_loss: 7.7775 train_time: 0.0m tok/s: 8147155 +500/20000 train_loss: 2.9100 train_time: 0.8m tok/s: 7873796 +1000/20000 train_loss: 2.8944 train_time: 1.7m tok/s: 7869012 +1500/20000 train_loss: 2.9170 train_time: 2.5m tok/s: 7868390 +2000/20000 train_loss: 2.6561 train_time: 3.3m tok/s: 7868298 +2500/20000 train_loss: 2.7093 train_time: 4.2m tok/s: 7869275 +3000/20000 train_loss: 2.7589 train_time: 5.0m tok/s: 7868349 recurrence:activated at step 3000, virtual_layers=[0, 1, 2, 3, 4, 5, 4, 5, 6, 7, 8, 9, 10] -3500/20000 train_loss: 2.6889 train_time: 6.1m tok/s: 7462691 -4000/20000 train_loss: 2.6198 train_time: 7.1m tok/s: 7373476 -4000/20000 val_loss: 2.6435 val_bpb: 1.1488 -4500/20000 train_loss: 2.5716 train_time: 8.1m tok/s: 7307359 -5000/20000 train_loss: 2.5156 train_time: 9.0m tok/s: 7254708 -5417/20000 val_loss: 2.5308 val_bpb: 1.0998 -stopping_early: wallclock_cap train_time: 590109ms step: 5417/20000 +3500/20000 train_loss: 2.6858 train_time: 6.1m tok/s: 7467443 +4000/20000 train_loss: 2.6180 train_time: 7.1m tok/s: 7380198 +4000/20000 val_loss: 2.6437 val_bpb: 1.1489 +4500/20000 train_loss: 2.5736 train_time: 8.1m tok/s: 7313799 +5000/20000 train_loss: 2.5146 train_time: 9.0m tok/s: 7261199 +5421/20000 val_loss: 2.5309 val_bpb: 1.0999 +stopping_early: wallclock_cap train_time: 590042ms step: 5421/20000 peak memory allocated: 30119 MiB reserved: 30156 MiB ema:applying EMA weights -pre-quantization post-ema val_loss:2.52826494 val_bpb:1.09875482 eval_time:2011ms +pre-quantization post-ema val_loss:2.52840565 val_bpb:1.09881597 eval_time:2007ms Serialized model: 132405827 bytes -Code size: 80967 bytes +Code size: 23948 bytes GPTQ:collecting Hessians from calibration data... GPTQ:collected 66 Hessians in 9.7s GPTQ quantization: 66 layers with full GPTQ, 0 fallback to clip-search -selective_prune: unpruned=16.03MB target=16.0MB -selective_prune: pruning 221432/9394039 lowest-error ±1 values (excess=27679B) -Serialized model int6+brotli: 15882457 bytes -Total submission size int6+brotli: 15963424 bytes -final_int6_roundtrip val_loss:2.55770163 val_bpb:1.11154767 eval_time:7550ms -final_int6_sliding_window val_loss:2.51604422 val_bpb:1.09344383 eval_time:75930ms +selective_prune: unpruned=15.96MB target=16.0MB +selective_prune: already fits, no pruning needed +Serialized model int6+brotli: 15939825 bytes +Total submission size int6+brotli: 15963773 bytes +final_int6_roundtrip val_loss:2.55451375 val_bpb:1.11016225 eval_time:7465ms +final_int6_sliding_window val_loss:2.51193797 val_bpb:1.09165930 eval_time:75591ms diff --git a/records/track_10min_16mb/2026-04-03_SP4096_DepthRecurrence_MuonEqR_GPTQ/train_seed42.log b/records/track_10min_16mb/2026-04-03_SP4096_DepthRecurrence_MuonEqR_GPTQ/train_seed42.log index 052a24fb73..105e99dc7f 100644 --- a/records/track_10min_16mb/2026-04-03_SP4096_DepthRecurrence_MuonEqR_GPTQ/train_seed42.log +++ b/records/track_10min_16mb/2026-04-03_SP4096_DepthRecurrence_MuonEqR_GPTQ/train_seed42.log @@ -1,7 +1,7 @@ -W0403 10:00:53.294000 47711 torch/distributed/run.py:803] -W0403 10:00:53.294000 47711 torch/distributed/run.py:803] ***************************************** -W0403 10:00:53.294000 47711 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. -W0403 10:00:53.294000 47711 torch/distributed/run.py:803] ***************************************** +W0403 11:41:00.268000 3338 torch/distributed/run.py:803] +W0403 11:41:00.268000 3338 torch/distributed/run.py:803] ***************************************** +W0403 11:41:00.268000 3338 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0403 11:41:00.268000 3338 torch/distributed/run.py:803] ***************************************** Hyperparameters: adam_eps: 1e-08 adam_wd: 0.02 @@ -27,7 +27,7 @@ Hyperparameters: iterations: 20000 ln_scale: True local_rank: 0 - logfile: logs/04ef7457-e236-4fe1-9a62-64e64cef9b0c.txt + logfile: logs/6e28f26e-23bc-4ec7-832a-8bee511d812d.txt logit_softcap: 30.0 matrix_lr: 0.02 max_wallclock_seconds: 600.0 @@ -52,7 +52,7 @@ Hyperparameters: rope_base: 10000.0 rope_dims: 16 rope_train_seq_len: 2048 - run_id: 04ef7457-e236-4fe1-9a62-64e64cef9b0c + run_id: 6e28f26e-23bc-4ec7-832a-8bee511d812d scalar_lr: 0.02 seed: 42 skip_gates_enabled: True @@ -97,36 +97,36 @@ warmup_step: 6/20 warmup_step: 10/20 warmup_step: 20/20 0/20000 val_loss: 8.3187 val_bpb: 3.6152 -1/20000 train_loss: 8.3201 train_time: 0.0m tok/s: 8448619 -2/20000 train_loss: 12.3377 train_time: 0.0m tok/s: 8301630 -3/20000 train_loss: 10.8504 train_time: 0.0m tok/s: 8198995 -4/20000 train_loss: 9.0314 train_time: 0.0m tok/s: 8144344 -5/20000 train_loss: 7.8217 train_time: 0.0m tok/s: 8121810 -500/20000 train_loss: 2.8999 train_time: 0.8m tok/s: 7893447 -1000/20000 train_loss: 2.8889 train_time: 1.7m tok/s: 7878418 -1500/20000 train_loss: 2.9164 train_time: 2.5m tok/s: 7873654 -2000/20000 train_loss: 2.6591 train_time: 3.3m tok/s: 7871330 -2500/20000 train_loss: 2.7152 train_time: 4.2m tok/s: 7869973 -3000/20000 train_loss: 2.7612 train_time: 5.0m tok/s: 7869082 +1/20000 train_loss: 8.3201 train_time: 0.0m tok/s: 8389972 +2/20000 train_loss: 12.3377 train_time: 0.0m tok/s: 8284614 +3/20000 train_loss: 10.8503 train_time: 0.0m tok/s: 8203994 +4/20000 train_loss: 9.0314 train_time: 0.0m tok/s: 8146658 +5/20000 train_loss: 7.8217 train_time: 0.0m tok/s: 8099434 +500/20000 train_loss: 2.9053 train_time: 0.8m tok/s: 7916776 +1000/20000 train_loss: 2.8899 train_time: 1.7m tok/s: 7893412 +1500/20000 train_loss: 2.9120 train_time: 2.5m tok/s: 7883558 +2000/20000 train_loss: 2.6571 train_time: 3.3m tok/s: 7880543 +2500/20000 train_loss: 2.7133 train_time: 4.2m tok/s: 7880237 +3000/20000 train_loss: 2.7634 train_time: 5.0m tok/s: 7879763 recurrence:activated at step 3000, virtual_layers=[0, 1, 2, 3, 4, 5, 4, 5, 6, 7, 8, 9, 10] -3500/20000 train_loss: 2.6906 train_time: 6.1m tok/s: 7464352 -4000/20000 train_loss: 2.6234 train_time: 7.1m tok/s: 7375245 -4000/20000 val_loss: 2.6467 val_bpb: 1.1502 -4500/20000 train_loss: 2.5757 train_time: 8.1m tok/s: 7308062 -5000/20000 train_loss: 2.5179 train_time: 9.0m tok/s: 7253712 -5415/20000 val_loss: 2.5332 val_bpb: 1.1009 -stopping_early: wallclock_cap train_time: 590026ms step: 5415/20000 -peak memory allocated: 30119 MiB reserved: 30156 MiB +3500/20000 train_loss: 2.6812 train_time: 6.5m tok/s: 7104286 +4000/20000 train_loss: 2.6112 train_time: 7.4m tok/s: 7066158 +4000/20000 val_loss: 2.6367 val_bpb: 1.1459 +4500/20000 train_loss: 2.5649 train_time: 8.4m tok/s: 7039143 +5000/20000 train_loss: 2.5059 train_time: 9.3m tok/s: 7016181 +5257/20000 val_loss: 2.5348 val_bpb: 1.1016 +stopping_early: wallclock_cap train_time: 590059ms step: 5257/20000 +peak memory allocated: 30168 MiB reserved: 30220 MiB ema:applying EMA weights -pre-quantization post-ema val_loss:2.53082718 val_bpb:1.09986834 eval_time:2009ms +pre-quantization post-ema val_loss:2.53240416 val_bpb:1.10055367 eval_time:2008ms Serialized model: 132405827 bytes -Code size: 80967 bytes +Code size: 23948 bytes GPTQ:collecting Hessians from calibration data... -GPTQ:collected 66 Hessians in 9.7s +GPTQ:collected 66 Hessians in 9.8s GPTQ quantization: 66 layers with full GPTQ, 0 fallback to clip-search -selective_prune: unpruned=16.03MB target=16.0MB -selective_prune: pruning 251744/9380258 lowest-error ±1 values (excess=31468B) -Serialized model int6+brotli: 15879180 bytes -Total submission size int6+brotli: 15960147 bytes -final_int6_roundtrip val_loss:2.56002159 val_bpb:1.11255589 eval_time:7460ms -final_int6_sliding_window val_loss:2.51787602 val_bpb:1.09423991 eval_time:75957ms +selective_prune: unpruned=16.00MB target=16.0MB +selective_prune: already fits, no pruning needed +Serialized model int6+brotli: 15975217 bytes +Total submission size int6+brotli: 15999165 bytes +final_int6_roundtrip val_loss:2.55828836 val_bpb:1.11180265 eval_time:21824ms +final_int6_sliding_window val_loss:2.51621374 val_bpb:1.09351750 eval_time:98539ms diff --git a/records/track_10min_16mb/2026-04-03_SP4096_DepthRecurrence_MuonEqR_GPTQ/train_seed999.log b/records/track_10min_16mb/2026-04-03_SP4096_DepthRecurrence_MuonEqR_GPTQ/train_seed999.log index 94fd5cf437..d306ef283f 100644 --- a/records/track_10min_16mb/2026-04-03_SP4096_DepthRecurrence_MuonEqR_GPTQ/train_seed999.log +++ b/records/track_10min_16mb/2026-04-03_SP4096_DepthRecurrence_MuonEqR_GPTQ/train_seed999.log @@ -1,4 +1,7 @@ -==================================================================================================== +W0403 12:31:42.753000 46995 torch/distributed/run.py:803] +W0403 12:31:42.753000 46995 torch/distributed/run.py:803] ***************************************** +W0403 12:31:42.753000 46995 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0403 12:31:42.753000 46995 torch/distributed/run.py:803] ***************************************** Hyperparameters: adam_eps: 1e-08 adam_wd: 0.02 @@ -24,7 +27,7 @@ Hyperparameters: iterations: 20000 ln_scale: True local_rank: 0 - logfile: logs/ed26df95-0f8a-4c3e-867c-fe8d4a1b188e.txt + logfile: logs/0820b928-8f26-40ae-9974-ef0bb3664b8c.txt logit_softcap: 30.0 matrix_lr: 0.02 max_wallclock_seconds: 600.0 @@ -49,7 +52,7 @@ Hyperparameters: rope_base: 10000.0 rope_dims: 16 rope_train_seq_len: 2048 - run_id: ed26df95-0f8a-4c3e-867c-fe8d4a1b188e + run_id: 0820b928-8f26-40ae-9974-ef0bb3664b8c scalar_lr: 0.02 seed: 999 skip_gates_enabled: True @@ -81,1971 +84,6 @@ Hyperparameters: warmup_steps: 20 world_size: 8 xsa_last_n: 11 -import copy -import glob -import io -import lzma -import math -import os -from pathlib import Path -import random -import subprocess -import sys -import time -import uuid - -import numpy as np -import sentencepiece as spm -import torch -import torch.distributed as dist -import torch.nn.functional as F -from torch.nn.parallel import DistributedDataParallel as DDP -from torch import Tensor, nn - -from flash_attn_interface import flash_attn_func as flash_attn_3_func - -try: - import brotli - _HAS_BROTLI = True -except ImportError: - _HAS_BROTLI = False - -# ---------------------------------------- -# Hyperparameters -# ---------------------------------------- - -class Hyperparameters(): - # Experiment settings - data_dir = os.environ.get('DATA_DIR', './data/') - seed = int(os.environ.get('SEED', 1337)) - run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) - - # Training length - iterations = int(os.environ.get('ITERATIONS', 20000)) - warmdown_frac = float(os.environ.get('WARMDOWN_FRAC', 0.667)) - warmup_steps = int(os.environ.get('WARMUP_STEPS', 20)) - train_batch_tokens = int(os.environ.get('TRAIN_BATCH_TOKENS', 2048 * 48 * 8)) - train_seq_len = int(os.environ.get('TRAIN_SEQ_LEN', 2048)) - eval_seq_len = int(os.environ.get('EVAL_SEQ_LEN', 2048)) - max_wallclock_seconds = float(os.environ.get('MAX_WALLCLOCK_SECONDS', 600.0)) - train_log_every = int(os.environ.get('TRAIN_LOG_EVERY', 500)) - - # Validation/Evals - val_batch_tokens = int(os.environ.get('VAL_BATCH_TOKENS', 2048 * 32 * 8)) - val_loss_every = int(os.environ.get('VAL_LOSS_EVERY', 4000)) - sliding_window_enabled = bool(int(os.environ.get('SLIDING_WINDOW_ENABLED', '1'))) - - # Model architecture - vocab_size = int(os.environ.get('VOCAB_SIZE', 4096)) - num_layers = int(os.environ.get('NUM_LAYERS', 11)) - xsa_last_n = int(os.environ.get('XSA_LAST_N', 11)) - num_kv_heads = int(os.environ.get('NUM_KV_HEADS', 4)) - model_dim = int(os.environ.get('MODEL_DIM', 512)) - embedding_dim = int(os.environ.get('EMBEDDING_DIM', 512)) - num_heads = int(os.environ.get('NUM_HEADS', 8)) - mlp_mult = float(os.environ.get('MLP_MULT', 4.0)) - skip_gates_enabled = bool(int(os.environ.get('SKIP_GATES_ENABLED', '1'))) - tie_embeddings = bool(int(os.environ.get('TIE_EMBEDDINGS', '1'))) - logit_softcap = float(os.environ.get('LOGIT_SOFTCAP', 30.0)) - rope_base = float(os.environ.get('ROPE_BASE', 10000.0)) - rope_dims = int(os.environ.get('ROPE_DIMS', 16)) - rope_train_seq_len = int(os.environ.get('ROPE_TRAIN_SEQ_LEN', 2048)) - ln_scale = bool(int(os.environ.get('LN_SCALE', '1'))) - ve_enabled = bool(int(os.environ.get('VE_ENABLED', '1'))) - ve_dim = int(os.environ.get('VE_DIM', 128)) - ve_layers = os.environ.get('VE_LAYERS', '9,10') - qk_gain_init = float(os.environ.get('QK_GAIN_INIT', 4.0)) - - # Optimizer (Modification 3: weight decay 0.090) - min_lr = float(os.environ.get('MIN_LR', 0.0)) - embed_lr = float(os.environ.get('EMBED_LR', 0.6)) - head_lr = float(os.environ.get('HEAD_LR', 0.008)) - tied_embed_lr = float(os.environ.get('TIED_EMBED_LR', 0.03)) - tied_embed_init_std = float(os.environ.get('TIED_EMBED_INIT_STD', 0.005)) - matrix_lr = float(os.environ.get('MATRIX_LR', 0.02)) - scalar_lr = float(os.environ.get('SCALAR_LR', 0.02)) - muon_momentum = float(os.environ.get('MUON_MOMENTUM', 0.99)) - muon_backend_steps = int(os.environ.get('MUON_BACKEND_STEPS', 5)) - muon_momentum_warmup_start = float(os.environ.get('MUON_MOMENTUM_WARMUP_START', 0.92)) - muon_momentum_warmup_steps = int(os.environ.get('MUON_MOMENTUM_WARMUP_STEPS', 1500)) - beta1 = float(os.environ.get('BETA1', 0.9)) - beta2 = float(os.environ.get('BETA2', 0.95)) - adam_eps = float(os.environ.get('ADAM_EPS', 1e-8)) - grad_clip_norm = float(os.environ.get('GRAD_CLIP_NORM', 0.3)) - eval_stride = int(os.environ.get('EVAL_STRIDE', 64)) - muon_beta2 = float(os.environ.get('MUON_BETA2', 0.95)) - adam_wd = float(os.environ.get('ADAM_WD', 0.02)) - muon_wd = float(os.environ.get('MUON_WD', 0.090)) - embed_wd = float(os.environ.get('EMBED_WD', 0.090)) - ema_decay = float(os.environ.get('EMA_DECAY', 0.997)) - - # Depth Recurrence (Modification 2) - recur_layers = os.environ.get("RECUR_LAYERS", "4,5") - recur_start_step = int(os.environ.get("RECUR_START_STEP", 3000)) - - # TTT (Modification 4) - ttt_enabled = bool(int(os.environ.get("TTT_ENABLED", "0"))) - ttt_lr = float(os.environ.get("TTT_LR", 0.002)) - ttt_epochs = int(os.environ.get("TTT_EPOCHS", 3)) - ttt_chunk_tokens = int(os.environ.get("TTT_CHUNK_TOKENS", 32768)) - ttt_freeze_blocks = int(os.environ.get("TTT_FREEZE_BLOCKS", 0)) - ttt_momentum = float(os.environ.get("TTT_MOMENTUM", 0.9)) - ttt_batch_seqs = int(os.environ.get("TTT_BATCH_SEQS", 32)) - ttt_grad_clip = float(os.environ.get("TTT_GRAD_CLIP", 1.0)) - - # Compression - compressor = os.environ.get('COMPRESSOR', 'brotli') #(lzma or brotli) - gptq_enabled = bool(int(os.environ.get('GPTQ_ENABLED', '1'))) - gptq_calibration_batches = int(os.environ.get('GPTQ_CALIBRATION_BATCHES', 64)) - gptq_reserve_seconds = float(os.environ.get('GPTQ_RESERVE_SECONDS', 10.0)) - - # Distributed setup - distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ - rank = int(os.environ.get("RANK", "0")) - world_size = int(os.environ.get("WORLD_SIZE", "1")) - local_rank = int(os.environ.get("LOCAL_RANK", "0")) - is_main_process = rank == 0 - grad_accum_steps = 8 // world_size - - # Data paths - datasets_dir = os.path.join(data_dir, 'datasets', f'fineweb10B_sp{vocab_size}') - train_files = os.path.join(datasets_dir, 'fineweb_train_*.bin') - val_files = os.path.join(datasets_dir, 'fineweb_val_*.bin') - tokenizer_path = os.path.join(data_dir, 'tokenizers', f'fineweb_{vocab_size}_bpe.model') - - # Experiment files - logfile = f"logs/{run_id}.txt" - model_path = "final_model.pt" - quantized_model_path = "final_model.int6.ptz" - -# ---------------------------------------- -# Global Logging Function -# ---------------------------------------- - -_logger_hparams = None - - -def set_logging_hparams(h: Hyperparameters) -> None: - global _logger_hparams - _logger_hparams = h - - -def log(msg, console: bool = True) -> None: - if _logger_hparams is None: - print(msg) - if _logger_hparams.is_main_process: - if console: - print(msg) - if _logger_hparams.logfile is not None: - with open(_logger_hparams.logfile, "a", encoding="utf-8") as f: - print(msg, file=f) - -# ---------------------------------------- -# Data Loading -# ---------------------------------------- - -class ValidationData: - def __init__(self, h: Hyperparameters, device: torch.device): - if not h.tokenizer_path.endswith(".model"): - raise ValueError(f"Script only setup for SentencePiece .model file: {h.tokenizer_path}") - self.sp = spm.SentencePieceProcessor(model_file=h.tokenizer_path) - if int(self.sp.vocab_size()) != h.vocab_size: - raise ValueError( - f"VOCAB_SIZE={h.vocab_size} does not match tokenizer vocab_size={int(self.sp.vocab_size())}" - ) - - self.val_tokens = load_validation_tokens(h.val_files, h.eval_seq_len) - self.base_bytes_lut, self.has_leading_space_lut, self.is_boundary_token_lut = ( - build_sentencepiece_luts(self.sp, h.vocab_size, device)) - - -def build_sentencepiece_luts( - sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device -) -> tuple[Tensor, Tensor, Tensor]: - sp_vocab_size = int(sp.vocab_size()) - # The BPB calculation assumes "▁" is its own token so that leading-space bytes - # are counted correctly. See https://github.com/openai/parameter-golf/issues/897 - assert sp.piece_to_id("\u2581") != sp.unk_id(), \ - "Tokenizer must have '▁' (space) as its own token for correct BPB byte counting" - table_size = max(sp_vocab_size, vocab_size) - base_bytes_np = np.zeros((table_size,), dtype=np.int16) - has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) - is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) - for token_id in range(sp_vocab_size): - if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): - continue - is_boundary_token_np[token_id] = False - if sp.is_byte(token_id): - base_bytes_np[token_id] = 1 - continue - piece = sp.id_to_piece(token_id) - if piece.startswith("\u2581"): - has_leading_space_np[token_id] = True - piece = piece[1:] - base_bytes_np[token_id] = len(piece.encode("utf-8")) - return ( - torch.tensor(base_bytes_np, dtype=torch.int16, device=device), - torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), - torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), - ) - - -def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: - files = [Path(p) for p in sorted(glob.glob(pattern))] - if not files: - raise FileNotFoundError(f"No files found for pattern: {pattern}") - # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. - tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() - usable = ((tokens.numel() - 1) // seq_len) * seq_len - if usable <= 0: - raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") - return tokens[: usable + 1] - - -def load_data_shard(file: Path) -> Tensor: - header_bytes = 256 * np.dtype(" int: - key = str(file) - cached = _SHARD_NTOKENS_CACHE.get(key) - if cached is not None: - return cached - header = np.fromfile(file, dtype=" np.memmap: - key = str(file) - mm = _MMAP_CACHE.get(key) - if mm is not None: - return mm - n = _read_num_tokens(file) - mm = np.memmap(file, mode="r", dtype=" int: - if n <= 1: - return 1 - while True: - s = int(self._rng.integers(1, n)) - if math.gcd(s, n) == 1: - return s - - def _reset_cursor(self, si: int, seq_len: int) -> None: - nt = int(self._num_tokens[si]) - max_phase = min(seq_len - 1, max(0, nt - seq_len - 1)) - phase = int(self._rng.integers(max_phase + 1)) if max_phase > 0 else 0 - bc = (nt - 1 - phase) // seq_len - self._cursor_phase[si] = phase - self._cursor_block_count[si] = bc - self._cursor_next[si] = 0 - self._cursor_start[si] = int(self._rng.integers(bc)) if bc > 1 else 0 - self._cursor_stride[si] = self._pick_coprime_stride(bc) - self._cursor_init[si] = True - - def _ensure_cursor(self, si: int, seq_len: int) -> None: - if not self._cursor_init[si] or self._cursor_next[si] >= self._cursor_block_count[si]: - self._reset_cursor(si, seq_len) - - def _take_from_shard(self, si: int, seq_len: int, count: int, out: list[tuple[int, int]]) -> None: - rem = count - while rem > 0: - self._ensure_cursor(si, seq_len) - bc = int(self._cursor_block_count[si]) - ni = int(self._cursor_next[si]) - take = min(rem, bc - ni) - phase = int(self._cursor_phase[si]) - start = int(self._cursor_start[si]) - stride = int(self._cursor_stride[si]) - for j in range(take): - bi = (start + (ni + j) * stride) % bc - out.append((si, phase + bi * seq_len)) - self._cursor_next[si] = ni + take - rem -= take - - def _init_pipeline(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> None: - local_tokens = global_tokens // (self.world_size * grad_accum_steps) - num_seqs = local_tokens // seq_len - global_num_seqs = num_seqs * self.world_size - self._cfg = (local_tokens, seq_len, num_seqs, global_num_seqs) - bbc = (self._num_tokens - 1) // seq_len - eligible = bbc > 0 - self._eligible_shards = np.nonzero(eligible)[0].astype(np.int64) - self._base_block_counts = bbc[self._eligible_shards].astype(np.int64) - - def _sample_global_windows(self) -> list[tuple[int, int]]: - assert self._cfg is not None and self._eligible_shards is not None - _, seq_len, _, gns = self._cfg - ec = int(self._eligible_shards.size) - progress = min(self._batches_built / 1800.0, 1.0) - remaining = np.empty(ec, dtype=np.float64) - for i, si in enumerate(self._eligible_shards.tolist()): - if self._cursor_init[si]: - r = int(self._cursor_block_count[si]) - int(self._cursor_next[si]) - remaining[i] = float(max(r, 1)) - else: - remaining[i] = float(self._base_block_counts[i]) - alpha = 0.90 - 0.40 * progress - weights = np.power(remaining, alpha) - ws = float(weights.sum()) - if not np.isfinite(ws) or ws <= 0.0: - weights = np.ones(ec, dtype=np.float64) - ws = float(weights.sum()) - probs = weights / ws - low = min(max(8, self.world_size), ec, gns) - high = min(max(32, self.world_size * 8), ec, gns) - mix = max(1, min(int(round(low + progress * (high - low))), ec, gns)) - cp = self._rng.choice(ec, size=mix, replace=False, p=probs) - cs = self._eligible_shards[cp] - cpr = probs[cp].copy() - cpr /= cpr.sum() - counts = np.ones(mix, dtype=np.int64) - extra = gns - mix - if extra > 0: - counts += self._rng.multinomial(extra, cpr).astype(np.int64) - perm = self._rng.permutation(mix) - cs, counts = cs[perm], counts[perm] - buckets: list[list[tuple[int, int]]] = [] - for si, cnt in zip(cs.tolist(), counts.tolist()): - b: list[tuple[int, int]] = [] - self._take_from_shard(int(si), seq_len, int(cnt), b) - if b: - if len(b) > 1: - bp = self._rng.permutation(len(b)) - b = [b[int(k)] for k in bp.tolist()] - buckets.append(b) - windows: list[tuple[int, int]] = [] - active = [i for i, bk in enumerate(buckets) if bk] - while active: - order = self._rng.permutation(len(active)) - new_active: list[int] = [] - for oi in order.tolist(): - bi = active[oi] - if buckets[bi]: - windows.append(buckets[bi].pop()) - if buckets[bi]: - new_active.append(bi) - active = new_active - return windows - - def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: - if self._cfg is None: - self._init_pipeline(global_tokens, seq_len, grad_accum_steps) - _, _, num_seqs, _ = self._cfg - gw = self._sample_global_windows() - local_w = gw[self.rank::self.world_size] - x = torch.empty((num_seqs, seq_len), dtype=torch.int64) - y = torch.empty((num_seqs, seq_len), dtype=torch.int64) - for slot, (si, pos) in enumerate(local_w): - mm = _get_shard_memmap(self.files[si]) - window = torch.as_tensor(np.array(mm[pos:pos + seq_len + 1], dtype=np.int64)) - x[slot] = window[:-1] - y[slot] = window[1:] - self._batches_built += 1 - return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) - -# ---------------------------------------- -# Model Architecture -# ---------------------------------------- - -class RMSNorm(nn.Module): - def __init__(self, eps: float | None = None): - super().__init__() - self.eps = eps - - def forward(self, x: Tensor) -> Tensor: - return F.rms_norm(x, (x.size(-1),), eps=self.eps) - - -class CastedLinear(nn.Linear): - def forward(self, x: Tensor) -> Tensor: - w = self.weight.to(x.dtype) - bias = self.bias.to(x.dtype) if self.bias is not None else None - return F.linear(x, w, bias) - - -class Rotary(nn.Module): - def __init__(self, dim: int, base: float = 10000.0, train_seq_len: int = 1024, rope_dims: int = 0): - super().__init__() - self.dim = dim - self.base = base - self.train_seq_len = train_seq_len - self.rope_dims = rope_dims if rope_dims > 0 else dim - inv_freq = 1.0 / (base ** (torch.arange(0, self.rope_dims, 2, dtype=torch.float32) / self.rope_dims)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self._seq_len_cached = 0 - self._cos_cached: Tensor | None = None - self._sin_cached: Tensor | None = None - - def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: - if ( - self._cos_cached is None - or self._sin_cached is None - or self._seq_len_cached != seq_len - or self._cos_cached.device != device - ): - rd = self.rope_dims - if seq_len > self.train_seq_len: - scale = seq_len / self.train_seq_len - new_base = self.base * (scale ** (rd / (rd - 2))) - inv_freq = 1.0 / (new_base ** (torch.arange(0, rd, 2, dtype=torch.float32, device=device) / rd)) - else: - inv_freq = self.inv_freq.to(device) - t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) - freqs = torch.outer(t, inv_freq) - self._cos_cached = freqs.cos()[None, :, None, :] - self._sin_cached = freqs.sin()[None, :, None, :] - self._seq_len_cached = seq_len - return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) - - -def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor, rope_dims: int = 0) -> Tensor: - if rope_dims > 0 and rope_dims < x.size(-1): - x_rope, x_pass = x[..., :rope_dims], x[..., rope_dims:] - half = rope_dims // 2 - x1, x2 = x_rope[..., :half], x_rope[..., half:] - x_rope = torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) - return torch.cat((x_rope, x_pass), dim=-1) - half = x.size(-1) // 2 - x1, x2 = x[..., :half], x[..., half:] - return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) - - -class CausalSelfAttention(nn.Module): - def __init__(self, dim: int, num_heads: int, num_kv_heads: int, - rope_base: float, qk_gain_init: float, train_seq_len: int): - super().__init__() - if dim % num_heads != 0: - raise ValueError("model_dim must be divisible by num_heads") - if num_heads % num_kv_heads != 0: - raise ValueError("num_heads must be divisible by num_kv_heads") - self.num_heads = num_heads - self.num_kv_heads = num_kv_heads - self.head_dim = dim // num_heads - if self.head_dim % 2 != 0: - raise ValueError("head_dim must be even for RoPE") - kv_dim = self.num_kv_heads * self.head_dim - self.c_q = CastedLinear(dim, dim, bias=False) - self.c_k = CastedLinear(dim, kv_dim, bias=False) - self.c_v = CastedLinear(dim, kv_dim, bias=False) - self.proj = CastedLinear(dim, dim, bias=False) - self.proj._zero_init = True - self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) - self.rope_dims = 0 - self.rotary = Rotary(self.head_dim, base=rope_base, train_seq_len=train_seq_len) - self.use_xsa = False - - def _xsa_efficient(self, y: Tensor, v: Tensor) -> Tensor: - B, T, H, D = y.shape - Hkv = v.size(-2) - group = H // Hkv - y_g = y.reshape(B, T, Hkv, group, D) - vn = F.normalize(v, dim=-1).unsqueeze(-2) - proj = (y_g * vn).sum(dim=-1, keepdim=True) * vn - return (y_g - proj).reshape(B, T, H, D) - - def forward(self, x: Tensor, v_embed: Tensor | None = None) -> Tensor: - bsz, seqlen, dim = x.shape - q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim) - k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) - v = self.c_v(x) - if v_embed is not None: - v = v + v_embed - v = v.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) - q = F.rms_norm(q, (q.size(-1),)) - k = F.rms_norm(k, (k.size(-1),)) - cos, sin = self.rotary(seqlen, x.device, q.dtype) - q = apply_rotary_emb(q, cos, sin, self.rope_dims) - k = apply_rotary_emb(k, cos, sin, self.rope_dims) - q = q * self.q_gain.to(dtype=q.dtype)[None, None, :, None] - y = flash_attn_3_func(q, k, v, causal=True) - if self.use_xsa: - y = self._xsa_efficient(y, v) - y = y.reshape(bsz, seqlen, dim) - return self.proj(y) - - -class ValueEmbedding(nn.Module): - def __init__(self, vocab_size: int, ve_dim: int, model_dim: int): - super().__init__() - self.embed = nn.Embedding(vocab_size, ve_dim) - nn.init.normal_(self.embed.weight, std=0.01) - self.proj = CastedLinear(ve_dim, model_dim, bias=False) if ve_dim != model_dim else None - if self.proj is not None: - nn.init.zeros_(self.proj.weight) - self.scale = nn.Parameter(torch.tensor(0.1, dtype=torch.float32)) - - def forward(self, token_ids: Tensor) -> Tensor: - h = self.embed(token_ids) - if self.proj is not None: - h = self.proj(h) - return h * self.scale.to(dtype=h.dtype) - - -class MLP(nn.Module): - def __init__(self, dim: int, mlp_mult: int): - super().__init__() - hidden = int(mlp_mult * dim) - self.fc = CastedLinear(dim, hidden, bias=False) - self.proj = CastedLinear(hidden, dim, bias=False) - self.proj._zero_init = True - - def forward(self, x: Tensor) -> Tensor: - return self.proj(F.leaky_relu(self.fc(x), negative_slope=0.5).square()) - - -class Block(nn.Module): - def __init__(self, dim: int, num_heads: int, num_kv_heads: int, mlp_mult: int, - rope_base: float, qk_gain_init: float, train_seq_len: int, - layer_idx: int = 0, ln_scale: bool = False): - super().__init__() - self.attn_norm = RMSNorm() - self.mlp_norm = RMSNorm() - self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init, train_seq_len) - self.mlp = MLP(dim, mlp_mult) - self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) - self.ln_scale_factor = 1.0 / math.sqrt(layer_idx + 1) if ln_scale else 1.0 - - def forward(self, x: Tensor, x0: Tensor, v_embed: Tensor | None = None) -> Tensor: - mix = self.resid_mix.to(dtype=x.dtype) - x_in = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 - attn_out = self.attn(self.attn_norm(x_in) * self.ln_scale_factor, v_embed=v_embed) - x_out = x_in + self.attn_scale.to(dtype=x_in.dtype)[None, None, :] * attn_out - x_out = x_out + self.mlp_scale.to(dtype=x_out.dtype)[None, None, :] * self.mlp(self.mlp_norm(x_out) * self.ln_scale_factor) - return x_out - - -class GPT(nn.Module): - def __init__(self, h: Hyperparameters): - super().__init__() - self._ve_target_dim = h.num_kv_heads * (h.model_dim // h.num_heads) - if h.logit_softcap <= 0.0: - raise ValueError(f"logit_softcap must be positive, got {h.logit_softcap}") - self.tie_embeddings = h.tie_embeddings - self.tied_embed_init_std = h.tied_embed_init_std - self.logit_softcap = h.logit_softcap - self.tok_emb = nn.Embedding(h.vocab_size, h.embedding_dim) - if h.embedding_dim != h.model_dim: - self.embed_proj = CastedLinear(h.embedding_dim, h.model_dim, bias=False) - self.head_proj = CastedLinear(h.model_dim, h.embedding_dim, bias=False) - else: - self.embed_proj = None - self.head_proj = None - self.num_encoder_layers = h.num_layers // 2 - self.num_decoder_layers = h.num_layers - self.num_encoder_layers - self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) - self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, h.model_dim, dtype=torch.float32)) - self.skip_gates = nn.Parameter(torch.zeros(self.num_skip_weights, h.model_dim, dtype=torch.float32)) if h.skip_gates_enabled else None - self.blocks = nn.ModuleList([ - Block(h.model_dim, h.num_heads, h.num_kv_heads, h.mlp_mult, h.rope_base, - h.qk_gain_init, h.train_seq_len, layer_idx=i, ln_scale=h.ln_scale) - for i in range(h.num_layers) - ]) - if h.rope_dims > 0: - head_dim = h.model_dim // h.num_heads - for block in self.blocks: - block.attn.rope_dims = h.rope_dims - block.attn.rotary = Rotary(head_dim, base=h.rope_base, train_seq_len=h.train_seq_len, rope_dims=h.rope_dims) - self.ve_layer_indices = [int(x) for x in h.ve_layers.split(",") if x.strip()] if h.ve_enabled else [] - kv_dim = self._ve_target_dim - if self.ve_layer_indices: - self.ve_shared = ValueEmbedding(h.vocab_size, h.ve_dim, kv_dim) - self.ve_layer_scales = nn.ParameterList( - [nn.Parameter(torch.ones(1, dtype=torch.float32)) for _ in self.ve_layer_indices] - ) - else: - self.ve_shared = None - self.ve_layer_scales = nn.ParameterList() - self.value_embeds = nn.ModuleList() - self.final_norm = RMSNorm() - self.lm_head = None if h.tie_embeddings else CastedLinear(h.embedding_dim, h.vocab_size, bias=False) - if self.lm_head is not None: - self.lm_head._zero_init = True - if h.xsa_last_n > 0: - for i in range(max(0, h.num_layers - h.xsa_last_n), h.num_layers): - self.blocks[i].attn.use_xsa = True - - # Modification 2: Depth Recurrence - self.recur_layers = [int(x) for x in h.recur_layers.split(",") if x.strip()] - self._recurrence_active = False - - self._init_weights() - - def set_recurrence_active(self, active: bool) -> None: - self._recurrence_active = active - - def _get_virtual_layers(self) -> list[int]: - """Return virtual->physical block mapping. - When recurrence is active, the recur_layers are repeated once, - e.g. with num_layers=11 and recur_layers=[4,5]: - [0,1,2,3, 4,5, 4,5, 6,7,8,9,10] - When inactive: [0,1,2,...,num_layers-1] - """ - n = len(self.blocks) - if not self._recurrence_active or not self.recur_layers: - return list(range(n)) - virtual = [] - inserted = False - for i in range(n): - virtual.append(i) - if not inserted and i == self.recur_layers[-1]: - # repeat the recur_layers - for rl in self.recur_layers: - virtual.append(rl) - inserted = True - return virtual - - def _init_weights(self) -> None: - if self.tie_embeddings: - nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) - for name, module in self.named_modules(): - if isinstance(module, nn.Linear): - if getattr(module, "_zero_init", False): - nn.init.zeros_(module.weight) - elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: - nn.init.orthogonal_(module.weight, gain=1.0) - - def _get_ve(self, layer_idx: int, input_ids: Tensor, ve_cache: dict | None = None) -> Tensor | None: - if self.ve_shared is None or layer_idx not in self.ve_layer_indices: - return None - if ve_cache is not None and 've' not in ve_cache: - ve_cache['ve'] = self.ve_shared(input_ids) - ve_base = ve_cache['ve'] if ve_cache is not None else self.ve_shared(input_ids) - ve_idx = self.ve_layer_indices.index(layer_idx) - return ve_base * self.ve_layer_scales[ve_idx].to(dtype=ve_base.dtype) - - def forward_logits(self, input_ids: Tensor) -> Tensor: - x = self.tok_emb(input_ids) - x = F.rms_norm(x, (x.size(-1),)) - if self.embed_proj is not None: - x = self.embed_proj(x) - x0 = x - - virtual_layers = self._get_virtual_layers() - num_virtual = len(virtual_layers) - num_enc = num_virtual // 2 - num_dec = num_virtual - num_enc - - skips: list[Tensor] = [] - ve_cache: dict = {} - - # Encoder phase - for vi in range(num_enc): - phys_idx = virtual_layers[vi] - ve = self._get_ve(phys_idx, input_ids, ve_cache) - x = self.blocks[phys_idx](x, x0, v_embed=ve) - skips.append(x) - - # Decoder phase with U-Net skip connections - for vi in range(num_dec): - phys_idx = virtual_layers[num_enc + vi] - if skips and vi < self.num_skip_weights: - scaled_skip = self.skip_weights[vi].to(dtype=x.dtype)[None, None, :] * skips.pop() - if self.skip_gates is not None: - g = torch.sigmoid(self.skip_gates[vi].to(dtype=x.dtype))[None, None, :] - x = torch.lerp(scaled_skip, x, g) - else: - x = x + scaled_skip - ve = self._get_ve(phys_idx, input_ids, ve_cache) - x = self.blocks[phys_idx](x, x0, v_embed=ve) - - x = self.final_norm(x) - if self.head_proj is not None: - x = self.head_proj(x) - if self.tie_embeddings: - logits_proj = F.linear(x, self.tok_emb.weight) - else: - logits_proj = self.lm_head(x) - return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) - - def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: - logits = self.forward_logits(input_ids) - return F.cross_entropy( - logits.reshape(-1, logits.size(-1)).float(), target_ids.reshape(-1), reduction="mean") - - -def classify_param(name: str) -> str: - if "tok_emb" in name or "lm_head" in name: - return "embed" - if ".mlp." in name: - return "mlp" - if ".attn." in name or (".proj." in name and ".mlp." not in name): - return "attn" - return "other" - -# ---------------------------------------- -# Optimization -# ---------------------------------------- - -@torch.compile -def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: - a, b, c = (3.4445, -4.7750, 2.0315) - X = G.bfloat16() - X /= X.norm() + eps - transposed = G.size(0) > G.size(1) - if transposed: - X = X.T - for _ in range(steps): - A = X @ X.T - B = b * A + c * A @ A - X = a * X + B @ X - return X.T if transposed else X - - -class Muon(torch.optim.Optimizer): - def __init__(self, params, lr: float, momentum: float, backend_steps: int, - nesterov: bool = True, weight_decay: float = 0.0): - super().__init__( - params, - dict(lr=lr, momentum=momentum, backend_steps=backend_steps, - nesterov=nesterov, weight_decay=weight_decay), - ) - - @torch.no_grad() - def step(self, closure=None): - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - distributed = dist.is_available() and dist.is_initialized() - world_size = dist.get_world_size() if distributed else 1 - rank = dist.get_rank() if distributed else 0 - for group in self.param_groups: - params = group["params"] - if not params: - continue - lr = group["lr"] - momentum = group["momentum"] - backend_steps = group["backend_steps"] - nesterov = group["nesterov"] - total_params = sum(int(p.numel()) for p in params) - updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) - curr = 0 - for i, p in enumerate(params): - if i % world_size == rank and p.grad is not None: - g = p.grad - state = self.state[p] - if "momentum_buffer" not in state: - state["momentum_buffer"] = torch.zeros_like(g) - buf = state["momentum_buffer"] - buf.mul_(momentum).add_(g) - if nesterov: - g = g.add(buf, alpha=momentum) - # Modification 1: MuonEq-R row normalization before NS5 - update = g - row_norms = update.float().norm(dim=-1, keepdim=True).clamp_min(1e-7) - update = update / row_norms.to(update.dtype) - g = zeropower_via_newtonschulz5(update, steps=backend_steps) - g *= max(1, g.size(0) / g.size(1)) ** 0.5 - updates_flat[curr : curr + p.numel()] = g.reshape(-1) - curr += p.numel() - if distributed: - dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) - wd = group.get("weight_decay", 0.0) - curr = 0 - for p in params: - if wd > 0.0: - p.data.mul_(1.0 - lr * wd) - g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) - p.add_(g, alpha=-lr) - curr += p.numel() - return loss - - -class Optimizers(): - def __init__(self, h: Hyperparameters, base_model: GPT): - block_named_params = list(base_model.blocks.named_parameters()) - matrix_params = [ - p - for name, p in block_named_params - if p.ndim == 2 and not any(pattern in name for pattern in - CONTROL_TENSOR_NAME_PATTERNS) - ] - scalar_params = [ - p - for name, p in block_named_params - if p.ndim < 2 or any(pattern in name for pattern in - CONTROL_TENSOR_NAME_PATTERNS) - ] - if base_model.skip_weights.numel() > 0: - scalar_params.append(base_model.skip_weights) - if base_model.skip_gates is not None and base_model.skip_gates.numel() > 0: - scalar_params.append(base_model.skip_gates) - - token_lr = h.tied_embed_lr if h.tie_embeddings else h.embed_lr - tok_params = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] - if base_model.ve_shared is not None: - tok_params.append({"params": [base_model.ve_shared.embed.weight], "lr": token_lr, "base_lr": token_lr}) - if base_model.ve_shared.proj is not None: - matrix_params.append(base_model.ve_shared.proj.weight) - scalar_params.append(base_model.ve_shared.scale) - for s in base_model.ve_layer_scales: - scalar_params.append(s) - - self.optimizer_tok = torch.optim.AdamW( - tok_params, - betas=(h.beta1, h.beta2), - eps=h.adam_eps, - weight_decay=h.embed_wd, - fused=True, - ) - self.optimizer_muon = Muon( - matrix_params, - lr=h.matrix_lr, - momentum=h.muon_momentum, - backend_steps=h.muon_backend_steps, - weight_decay=h.muon_wd, - ) - for group in self.optimizer_muon.param_groups: - group["base_lr"] = h.matrix_lr - self.optimizer_scalar = torch.optim.AdamW( - [{"params": scalar_params, "lr": h.scalar_lr, "base_lr": h.scalar_lr}], - betas=(h.beta1, h.beta2), - eps=h.adam_eps, - weight_decay=h.adam_wd, - fused=True, - ) - self.optimizers: list[torch.optim.Optimizer] = [self.optimizer_tok, self.optimizer_muon, self.optimizer_scalar] - if base_model.lm_head is not None: - self.optimizer_head = torch.optim.Adam( - [{"params": [base_model.lm_head.weight], "lr": h.head_lr, "base_lr": h.head_lr}], - betas=(h.beta1, h.beta2), - eps=h.adam_eps, - fused=True, - ) - self.optimizers.insert(1, self.optimizer_head) - else: - self.optimizer_head = None - - def __iter__(self): - return iter(self.optimizers) - - def zero_grad_all(self) -> None: - for opt in self.optimizers: - opt.zero_grad(set_to_none=True) - - def step(self): - for opt in self.optimizers: - opt.step() - self.zero_grad_all() - -# ---------------------------------------- -# Quantization -# ---------------------------------------- - -CONTROL_TENSOR_NAME_PATTERNS = tuple( - pattern - for pattern in os.environ.get( - "CONTROL_TENSOR_NAME_PATTERNS", - "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,skip_gates,ve_layer_scales,ve_shared.scale", - ).split(",") - if pattern -) -INT8_PER_ROW_SCALE_DTYPE = torch.float16 -INT8_CLIP_PERCENTILE = 99.99984 -INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 - - -def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: - t32 = t.float() - if t32.ndim == 2: - clip_abs = ( - torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) - if t32.numel() - else torch.empty((t32.shape[0],), dtype=torch.float32) - ) - clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) - scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) - q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() - return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() - - clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 - scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) - q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() - return q, scale - - -def restore_fp32_params(model: nn.Module) -> None: - """After .bfloat16(), restore CastedLinear weights and control params to FP32.""" - for module in model.modules(): - if isinstance(module, CastedLinear): - module.float() - for name, param in model.named_parameters(): - if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: - param.data = param.data.float() - - -def quantize_int6_per_row(t: Tensor, clip_range: int = 31) -> tuple[Tensor, Tensor]: - t32 = t.float() - if t32.ndim == 2: - best_q, best_s, best_err = None, None, float('inf') - for pct in [0.9990, 0.9995, 0.9999, 0.99999, 1.0]: - if pct < 1.0: - row_clip = torch.quantile(t32.abs(), pct, dim=1) - else: - row_clip = t32.abs().amax(dim=1) - s = (row_clip / clip_range).clamp_min(1.0 / clip_range).to(torch.float16) - q = torch.clamp(torch.round(t32 / s.float()[:, None]), -clip_range, clip_range).to(torch.int8) - recon = q.float() * s.float()[:, None] - err = (t32 - recon).pow(2).mean().item() - if err < best_err: - best_q, best_s, best_err = q, s, err - return best_q, best_s - amax = t32.abs().max().item() - scale = torch.tensor(amax / clip_range if amax > 0 else 1.0, dtype=torch.float16) - q = torch.clamp(torch.round(t32 / scale.float()), -clip_range, clip_range).to(torch.int8) - return q, scale - - -def collect_hessians( - model: nn.Module, - train_loader: DistributedTokenLoader, - h: Hyperparameters, - device: torch.device, - n_calibration_batches: int = 64, -) -> dict[str, Tensor]: - """Run calibration batches and collect H = X^T X for each CastedLinear layer.""" - hessians: dict[str, Tensor] = {} - hooks = [] - - def make_hook(name: str): - def hook_fn(module, inp, out): - x = inp[0].detach().float() - if x.ndim == 3: - x = x.reshape(-1, x.shape[-1]) - if name not in hessians: - hessians[name] = torch.zeros( - x.shape[1], x.shape[1], dtype=torch.float32, device=device - ) - hessians[name].addmm_(x.T, x) - return hook_fn - - for name, module in model.named_modules(): - if isinstance(module, CastedLinear) and module.weight.numel() > 65536: - cat = classify_param(name + ".weight") - if cat in ("mlp", "attn"): - hooks.append(module.register_forward_hook(make_hook(name + ".weight"))) - - model.eval() - with torch.no_grad(): - for _i in range(n_calibration_batches): - x, y = train_loader.next_batch( - h.train_batch_tokens, - h.train_seq_len, h.grad_accum_steps, - ) - model.forward_logits(x) - - for hk in hooks: - hk.remove() - - for name in hessians: - hessians[name] = hessians[name].cpu() / n_calibration_batches - - return hessians - - -def gptq_quantize_weight( - w: Tensor, - H: Tensor, - clip_range: int = 31, - block_size: int = 128, -) -> tuple[Tensor, Tensor]: - """GPTQ with Cholesky error compensation and actorder (Frantar et al., ICLR 2023).""" - W_orig = w.float().clone() - rows, cols = W_orig.shape - H = H.float().clone() - - # Zero out dead columns and add damping - dead = torch.diag(H) == 0 - H[dead, dead] = 1 - damp = 0.01 * H.diag().mean() - H.diagonal().add_(damp) - - # Column reordering by descending Hessian diagonal (actorder) - perm = torch.argsort(H.diag(), descending=True) - invperm = torch.argsort(perm) - W_perm = W_orig[:, perm].clone() - W_perm[:, dead[perm]] = 0 - H = H[perm][:, perm] - - # Upper Cholesky of the inverse - try: - Hinv = torch.cholesky_inverse(torch.linalg.cholesky(H)) - Hinv = torch.linalg.cholesky(Hinv, upper=True) - except torch.linalg.LinAlgError: - return quantize_int6_per_row(W_orig, clip_range) - - # Search over scale candidates, running full GPTQ for each - best_q, best_scale, best_err = None, None, float('inf') - for pct in [0.9990, 0.9995, 0.9999, 0.99999, 1.0]: - if pct < 1.0: - row_clip = torch.quantile(W_orig.abs(), pct, dim=1) - else: - row_clip = W_orig.abs().amax(dim=1) - s = (row_clip / clip_range).clamp_min(1.0 / clip_range).to(torch.float16) - sf = s.float() - - Q = torch.zeros(rows, cols, dtype=torch.int8) - W_work = W_perm.clone() - - for i1 in range(0, cols, block_size): - i2 = min(i1 + block_size, cols) - W_block = W_work[:, i1:i2].clone() - Hinv_block = Hinv[i1:i2, i1:i2] - Err = torch.zeros(rows, i2 - i1) - for j in range(i2 - i1): - w_col = W_block[:, j] - d = Hinv_block[j, j] - q_col = torch.clamp(torch.round(w_col / sf), -clip_range, clip_range) - Q[:, i1 + j] = q_col.to(torch.int8) - err = (w_col - q_col.float() * sf) / d - Err[:, j] = err - W_block[:, j:] -= err.unsqueeze(1) * Hinv_block[j, j:].unsqueeze(0) - if i2 < cols: - W_work[:, i2:] -= Err @ Hinv[i1:i2, i2:] - - recon = Q.float() * sf[:, None] - mse = (W_perm - recon).pow(2).mean().item() - if mse < best_err: - best_q, best_scale, best_err = Q, s, mse - - return best_q[:, invperm], best_scale - - -def gptq_mixed_quantize_int6( - state_dict: dict[str, Tensor], - int6_cats: set[str], - hessians: dict[str, Tensor], -) -> tuple[dict[str, Tensor], dict[str, object]]: - """Mixed quantization using full GPTQ for layers with Hessians, fallback to clip-search.""" - result: dict[str, Tensor] = {} - meta: dict[str, object] = {} - gptq_count = 0 - fallback_count = 0 - - for name, tensor in state_dict.items(): - t = tensor.detach().cpu().contiguous() - cat = classify_param(name) - - if not t.is_floating_point() or t.numel() <= 65536: - result[name] = t.to(torch.float16) if t.is_floating_point() else t - meta[name] = "passthrough" - continue - - if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): - result[name] = t.float() - meta[name] = "passthrough_ctrl" - continue - - if cat in int6_cats and t.ndim == 2: - if name in hessians: - q, s = gptq_quantize_weight(t, hessians[name]) - gptq_count += 1 - meta[name] = {"type": "int6", "method": "gptq"} - else: - q, s = quantize_int6_per_row(t) - fallback_count += 1 - meta[name] = {"type": "int6", "method": "clip_search"} - result[name + ".q"] = q - result[name + ".scale"] = s - elif cat in int6_cats and t.ndim >= 1: - q, s = quantize_int6_per_row(t) - result[name + ".q"] = q - result[name + ".scale"] = s - meta[name] = {"type": "int6"} - else: - q, s = quantize_float_tensor(t) - result[name + ".q"] = q - result[name + ".scale"] = s - meta[name] = {"type": "int8"} - - log(f"GPTQ quantization: {gptq_count} layers with full GPTQ, {fallback_count} fallback to clip-search") - return result, meta - - -def mixed_quantize_int6(state_dict: dict[str, Tensor], int6_cats: set[str]): - result: dict[str, Tensor] = {} - meta: dict[str, object] = {} - for name, tensor in state_dict.items(): - t = tensor.detach().cpu().contiguous() - cat = classify_param(name) - if not t.is_floating_point() or t.numel() <= 65536: - result[name] = t.to(torch.float16) if t.is_floating_point() else t - meta[name] = "passthrough" - continue - if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): - result[name] = t.float() - meta[name] = "passthrough_ctrl" - continue - if cat in int6_cats and t.ndim >= 1: - q, s = quantize_int6_per_row(t) - result[name + ".q"] = q - result[name + ".scale"] = s - meta[name] = {"type": "int6"} - else: - q, s = quantize_float_tensor(t) - result[name + ".q"] = q - result[name + ".scale"] = s - meta[name] = {"type": "int8"} - return result, meta - - -def dequantize_mixed_int6(result: dict[str, Tensor], meta: dict[str, object], - template_sd: dict[str, Tensor]) -> dict[str, Tensor]: - out: dict[str, Tensor] = {} - for name, orig in template_sd.items(): - info = meta.get(name) - if info is None: - continue - orig_dtype = orig.dtype - if info in ("passthrough", "passthrough_ctrl", "passthrough_fp16"): - t = result[name] - if t.dtype == torch.float16 and orig_dtype in (torch.float32, torch.bfloat16): - t = t.to(orig_dtype) - out[name] = t - continue - q, s = result[name + ".q"], result[name + ".scale"] - if s.ndim > 0: - out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig_dtype) - else: - out[name] = (q.float() * float(s.item())).to(orig_dtype) - return out - - -_BSHF_MAGIC = b"BSHF" - - -def _byte_shuffle(data: bytes, stride: int = 2) -> bytes: - """Transpose byte stream by stride position for better compression.""" - if stride <= 1 or len(data) < stride: - return data - src = np.frombuffer(data, dtype=np.uint8) - n = len(src) - out = np.empty(n, dtype=np.uint8) - dest_off = 0 - for pos in range(stride): - chunk = src[pos::stride] - out[dest_off:dest_off + len(chunk)] = chunk - dest_off += len(chunk) - return _BSHF_MAGIC + bytes([stride]) + out.tobytes() - - -def _byte_unshuffle(data: bytes) -> bytes: - """Inverse of _byte_shuffle. Auto-detects BSHF magic header.""" - if len(data) < 5 or data[:4] != _BSHF_MAGIC: - return data - stride = data[4] - if stride < 2: - return data[5:] - payload = np.frombuffer(data, dtype=np.uint8, offset=5) - n = len(payload) - out = np.empty(n, dtype=np.uint8) - src_off = 0 - for pos in range(stride): - chunk_len = n // stride + (1 if pos < n % stride else 0) - out[pos::stride][:chunk_len] = payload[src_off:src_off + chunk_len] - src_off += chunk_len - return out.tobytes() - - -def _compress(data: bytes, compressor: str, byte_shuffle: bool = True) -> bytes: - if byte_shuffle: - data = _byte_shuffle(data) - if compressor == "lzma": - return lzma.compress(data, preset=6) - elif compressor == "brotli": - import brotli as _brotli - return _brotli.compress(data, quality=11) - raise ValueError(f"Unknown compressor: {compressor!r}") - - -def _decompress(data: bytes, compressor: str, byte_shuffle: bool = True) -> bytes: - if compressor == "lzma": - raw = lzma.decompress(data) - elif compressor == "brotli": - import brotli as _brotli - raw = _brotli.decompress(data) - else: - raise ValueError(f"Unknown compressor: {compressor!r}") - if byte_shuffle: - raw = _byte_unshuffle(raw) - return raw - - -def serialize(h: Hyperparameters, base_model: torch.nn.Module, code: str) -> int: - model_bytes = None - code_bytes = len(code.encode("utf-8")) - if h.is_main_process: - torch.save(base_model.state_dict(), h.model_path) - model_bytes = os.path.getsize(h.model_path) - log(f"Serialized model: {model_bytes} bytes") - log(f"Code size: {code_bytes} bytes") - - sd_cpu = {k: v.detach().cpu() for k, v in base_model.state_dict().items()} - if h.gptq_enabled: - log("GPTQ:collecting Hessians from calibration data...") - t0 = time.perf_counter() - calib_loader = DistributedTokenLoader(h.train_files, h.rank, h.world_size, - torch.device("cuda", h.local_rank)) - hessians = collect_hessians( - base_model, calib_loader, h, - torch.device("cuda", h.local_rank), - n_calibration_batches=h.gptq_calibration_batches, - ) - log(f"GPTQ:collected {len(hessians)} Hessians in {time.perf_counter() - t0:.1f}s") - quant_result, quant_meta = gptq_mixed_quantize_int6(sd_cpu, {"mlp", "attn"}, hessians) - else: - quant_result, quant_meta = mixed_quantize_int6(sd_cpu, {"mlp", "attn"}) - - # Fast selective +-1 pruning to fit under target size - target_bytes = 16_000_000 - quant_buf_check = io.BytesIO() - torch.save({"w": quant_result, "m": quant_meta}, quant_buf_check) - check_blob = _compress(quant_buf_check.getvalue(), h.compressor) - unpruned_sz = len(check_blob) + code_bytes - log(f"selective_prune: unpruned={unpruned_sz/1e6:.2f}MB target={target_bytes/1e6:.1f}MB") - if unpruned_sz > target_bytes: - excess = unpruned_sz - target_bytes - safety_margin = int(excess * 8) # prune 8x the excess for safety - ones_info = [] - for name, info in quant_meta.items(): - if not (isinstance(info, dict) and info.get("type") == "int6"): - continue - qk, sk = name + ".q", name + ".scale" - if qk not in quant_result or sk not in quant_result: - continue - q, s = quant_result[qk], quant_result[sk] - if s.ndim > 0: - ones_mask = (q.abs() == 1) - if ones_mask.any(): - row_idx = torch.arange(q.shape[0]).unsqueeze(1).expand_as(q)[ones_mask] - flat_idx = torch.arange(q.numel()).reshape(q.shape)[ones_mask] - errors = s.float()[row_idx].pow(2) - for fi, err in zip(flat_idx.tolist(), errors.tolist()): - ones_info.append((qk, fi, err)) - ones_info.sort(key=lambda x: x[2]) - n_prune = min(safety_margin, len(ones_info)) - log(f"selective_prune: pruning {n_prune}/{len(ones_info)} lowest-error ±1 values (excess={excess}B)") - for i in range(n_prune): - quant_result[ones_info[i][0]].view(-1)[ones_info[i][1]] = 0 - else: - log("selective_prune: already fits, no pruning needed") - - quant_buf = io.BytesIO() - torch.save({"w": quant_result, "m": quant_meta}, quant_buf) - quant_raw = quant_buf.getvalue() - quant_blob = _compress(quant_raw, h.compressor) - quant_file_bytes = len(quant_blob) - bytes_total = quant_file_bytes + code_bytes - if h.is_main_process: - with open(h.quantized_model_path, "wb") as f: - f.write(quant_blob) - log(f"Serialized model int6+{h.compressor}: {quant_file_bytes} bytes") - log(f"Total submission size int6+{h.compressor}: {bytes_total} bytes") - - -def deserialize(h: Hyperparameters, device: torch.device) -> GPT: - eval_model = GPT(h).to(device).bfloat16() - restore_fp32_params(eval_model) - - sd_cpu = {k: v.detach().cpu() for k, v in eval_model.state_dict().items()} - - with open(h.quantized_model_path, "rb") as f: - quant_blob_disk = f.read() - quant_state = torch.load( - io.BytesIO(_decompress(quant_blob_disk, h.compressor)), - map_location="cpu", - ) - deq_state = dequantize_mixed_int6(quant_state["w"], quant_state["m"], sd_cpu) - eval_model.load_state_dict(deq_state, strict=True) - - return eval_model - -# ---------------------------------------- -# Evaluation -# ---------------------------------------- - -def _loss_bpb(loss_sum, token_count, byte_count) -> tuple[float, float]: - val_loss = (loss_sum / token_count).item() - val_bpb = val_loss / math.log(2.0) * (token_count.item() / byte_count.item()) - return val_loss, val_bpb - - -def eval_val( - h: Hyperparameters, - device: torch.device, - val_data: ValidationData, - model: nn.Module -) -> tuple[float, float]: - seq_len = h.eval_seq_len - local_batch_tokens = h.val_batch_tokens // (h.world_size * h.grad_accum_steps) - if local_batch_tokens < seq_len: - raise ValueError( - "VAL_BATCH_SIZE must provide at least one sequence per rank; " - f"got VAL_BATCH_SIZE={h.val_batch_tokens}, WORLD_SIZE={h.world_size}, " - f"GRAD_ACCUM_STEPS={h.grad_accum_steps}, seq_len={seq_len}" - ) - local_batch_seqs = local_batch_tokens // seq_len - total_seqs = (val_data.val_tokens.numel() - 1) // seq_len - seq_start = (total_seqs * h.rank) // h.world_size - seq_end = (total_seqs * (h.rank + 1)) // h.world_size - val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) - val_token_count = torch.zeros((), device=device, dtype=torch.float64) - val_byte_count = torch.zeros((), device=device, dtype=torch.float64) - - model.eval() - with torch.inference_mode(): - for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): - batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) - raw_start = batch_seq_start * seq_len - raw_end = batch_seq_end * seq_len + 1 - local = val_data.val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) - x = local[:-1].reshape(-1, seq_len) - y = local[1:].reshape(-1, seq_len) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - batch_loss = model(x, y).detach() - batch_token_count = float(y.numel()) - val_loss_sum += batch_loss.to(torch.float64) * batch_token_count - val_token_count += batch_token_count - prev_ids = x.reshape(-1) - tgt_ids = y.reshape(-1) - token_bytes = val_data.base_bytes_lut[tgt_ids].to(dtype=torch.int16) - token_bytes += (val_data.has_leading_space_lut[tgt_ids] & ~val_data.is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) - val_byte_count += token_bytes.to(torch.float64).sum() - - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) - - model.train() - return _loss_bpb(val_loss_sum, val_token_count, val_byte_count) - - -def eval_val_sliding( - h: Hyperparameters, - device: torch.device, - val_data: ValidationData, - base_model: nn.Module, - batch_seqs: int = 32 -) -> tuple[float, float]: - """Sliding window evaluation: each token scored with maximum context.""" - base_model.eval() - logits_fn = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) - - seq_len = h.eval_seq_len - context_size = seq_len - h.eval_stride - total_tokens = val_data.val_tokens.numel() - 1 - - window_starts = [ws for ws in range(0, total_tokens, h.eval_stride) - if ws + context_size < total_tokens] - - total_windows = len(window_starts) - my_s = (total_windows * h.rank) // h.world_size - my_e = (total_windows * (h.rank + 1)) // h.world_size - my_windows = window_starts[my_s:my_e] - - loss_sum = torch.zeros((), device=device, dtype=torch.float64) - token_count = torch.zeros((), device=device, dtype=torch.float64) - byte_count = torch.zeros((), device=device, dtype=torch.float64) - - with torch.inference_mode(): - for bi in range(0, len(my_windows), batch_seqs): - batch_ws = my_windows[bi:bi + batch_seqs] - bsz = len(batch_ws) - - x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) - y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) - wlens: list[int] = [] - - for i, ws in enumerate(batch_ws): - we = min(ws + seq_len, total_tokens) - wlen = we - ws - wlens.append(wlen) - chunk = val_data.val_tokens[ws:we + 1].to(dtype=torch.int64, device=device) - x_batch[i, :wlen] = chunk[:-1] - y_batch[i, :wlen] = chunk[1:] - - with torch.autocast(device_type="cuda", dtype=torch.bfloat16): - logits = logits_fn(x_batch) - - nll = F.cross_entropy( - logits.reshape(-1, logits.size(-1)).float(), - y_batch.reshape(-1), - reduction="none", - ).reshape(bsz, seq_len) - - for i, ws in enumerate(batch_ws): - wlen = wlens[i] - s = 0 if ws == 0 else context_size - scored_nll = nll[i, s:wlen].to(torch.float64) - loss_sum += scored_nll.sum() - token_count += float(wlen - s) - tgt = y_batch[i, s:wlen] - prev = x_batch[i, s:wlen] - tb = val_data.base_bytes_lut[tgt].to(torch.float64) - tb += (val_data.has_leading_space_lut[tgt] & ~val_data.is_boundary_token_lut[prev]).to(torch.float64) - byte_count += tb.sum() - - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) - - base_model.train() - return _loss_bpb(loss_sum, token_count, byte_count) - - -# ---------------------------------------- -# TTT (Test-Time Training) - Legal Score-First -# ---------------------------------------- - -def eval_val_ttt( - h: Hyperparameters, - base_model: nn.Module, - device: torch.device, - val_data: ValidationData, - log_fn=None, -) -> tuple[float, float]: - """Legal score-first TTT: score each chunk with sliding windows, - then train on it. Every token scored BEFORE any update that could use it.""" - seq_len = h.eval_seq_len - stride = h.eval_stride - total_tokens = val_data.val_tokens.numel() - 1 - ttt_chunk = h.ttt_chunk_tokens - rank = h.rank - world_size = h.world_size - if log_fn is None: - log_fn = lambda msg: None - - window_starts = [ws for ws in range(0, total_tokens, stride) - if min(ws + seq_len, total_tokens) - ws >= stride or ws == 0] - - num_chunks = (total_tokens + ttt_chunk - 1) // ttt_chunk - chunk_windows: list[list[int]] = [[] for _ in range(num_chunks)] - for ws in window_starts: - end = min(ws + seq_len, total_tokens) - wlen = end - ws - s = 0 if ws == 0 else max(wlen - stride, 0) - scored_start = ws + s - ci = min(scored_start // ttt_chunk, num_chunks - 1) - chunk_windows[ci].append(ws) - - log_fn(f"ttt_sliding:start chunks={num_chunks} chunk_tokens={ttt_chunk} " - f"total_windows={len(window_starts)} stride={stride} " - f"ttt_lr={h.ttt_lr} ttt_epochs={h.ttt_epochs} " - f"freeze_blocks={h.ttt_freeze_blocks}") - - loss_sum = torch.zeros((), device=device, dtype=torch.float64) - token_count = torch.zeros((), device=device, dtype=torch.float64) - byte_count = torch.zeros((), device=device, dtype=torch.float64) - - frozen_block_ids = set(range(min(h.ttt_freeze_blocks, len(base_model.blocks)))) - ttt_params = [] - for name, p in base_model.named_parameters(): - freeze = False - for bi in frozen_block_ids: - if f"blocks.{bi}." in name: - freeze = True - break - if freeze: - p.requires_grad_(False) - else: - p.requires_grad_(True) - ttt_params.append(p) - - log_fn(f"ttt_sliding:params unfrozen={sum(p.numel() for p in ttt_params)} " - f"frozen={sum(p.numel() for p in base_model.parameters() if not p.requires_grad)}") - - optimizer = torch.optim.SGD(ttt_params, lr=h.ttt_lr, momentum=h.ttt_momentum) - batch_seqs = h.ttt_batch_seqs - t0 = time.perf_counter() - - for ci in range(num_chunks): - windows = chunk_windows[ci] - if not windows: - continue - chunk_start = ci * ttt_chunk - chunk_end = min((ci + 1) * ttt_chunk, total_tokens) - - # --- Phase 1: SCORE this chunk's windows (no_grad for TTT compat) --- - my_s = (len(windows) * rank) // world_size - my_e = (len(windows) * (rank + 1)) // world_size - my_windows = windows[my_s:my_e] - - base_model.eval() - with torch.no_grad(): - for bi in range(0, len(my_windows), batch_seqs): - batch_ws = my_windows[bi:bi + batch_seqs] - bsz = len(batch_ws) - x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) - y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) - wlens: list[int] = [] - for i, ws in enumerate(batch_ws): - end = min(ws + seq_len, total_tokens) - wlen = end - ws - wlens.append(wlen) - chunk_tok = val_data.val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) - x_batch[i, :wlen] = chunk_tok[:-1] - y_batch[i, :wlen] = chunk_tok[1:] - with torch.autocast(device_type="cuda", dtype=torch.bfloat16): - logits = base_model.forward_logits(x_batch) - nll = F.cross_entropy( - logits.reshape(-1, logits.size(-1)).float(), - y_batch.reshape(-1), reduction="none", - ).reshape(bsz, seq_len) - for i, ws in enumerate(batch_ws): - wlen = wlens[i] - s = 0 if ws == 0 else max(wlen - stride, 0) - scored_nll = nll[i, s:wlen].to(torch.float64) - loss_sum += scored_nll.sum() - token_count += float(wlen - s) - tgt, prev = y_batch[i, s:wlen], x_batch[i, s:wlen] - tb = val_data.base_bytes_lut[tgt].to(torch.float64) - tb += (val_data.has_leading_space_lut[tgt] & ~val_data.is_boundary_token_lut[prev]).to(torch.float64) - byte_count += tb.sum() - - # --- Phase 2: TRAIN on this chunk (already scored = legal) --- - is_last_chunk = (ci == num_chunks - 1) - if not is_last_chunk and h.ttt_epochs > 0: - base_model.train() - chunk_seqs = (chunk_end - chunk_start) // seq_len - if chunk_seqs > 0: - cos_lr = h.ttt_lr * 0.5 * (1.0 + math.cos(math.pi * ci / max(num_chunks - 1, 1))) - for pg in optimizer.param_groups: - pg['lr'] = cos_lr - my_seq_s = (chunk_seqs * rank) // world_size - my_seq_e = (chunk_seqs * (rank + 1)) // world_size - my_chunk_seqs = my_seq_e - my_seq_s - for _ep in range(h.ttt_epochs): - for bs in range(0, my_chunk_seqs, batch_seqs): - be = min(bs + batch_seqs, my_chunk_seqs) - actual_bs = my_seq_s + bs - start_tok = chunk_start + actual_bs * seq_len - end_tok = chunk_start + (my_seq_s + be) * seq_len + 1 - if end_tok > val_data.val_tokens.numel(): - continue - local = val_data.val_tokens[start_tok:end_tok].to(device=device, dtype=torch.int64) - x = local[:-1].reshape(-1, seq_len) - y = local[1:].reshape(-1, seq_len) - optimizer.zero_grad(set_to_none=True) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16): - loss = base_model(x, y) - loss.backward() - if world_size > 1: - for p in ttt_params: - if p.grad is not None: - dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) - torch.nn.utils.clip_grad_norm_(ttt_params, h.ttt_grad_clip) - optimizer.step() - - if rank == 0 and (ci % 10 == 0 or ci == num_chunks - 1): - elapsed = time.perf_counter() - t0 - rl = loss_sum.item() / max(token_count.item(), 1) - rbpb = rl / math.log(2.0) * (token_count.item() / max(byte_count.item(), 1)) if token_count.item() > 0 else 0.0 - log_fn(f" ttt_chunk [{ci+1}/{num_chunks}] bpb={rbpb:.6f} time={elapsed:.1f}s") - - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) - - val_loss = (loss_sum / token_count).item() - val_bpb = val_loss / math.log(2.0) * (token_count.item() / byte_count.item()) - - for p in base_model.parameters(): - p.requires_grad_(True) - base_model.eval() - - log_fn(f"ttt_sliding:done val_loss={val_loss:.6f} val_bpb={val_bpb:.6f} " - f"elapsed={time.perf_counter() - t0:.1f}s") - return val_loss, val_bpb - - -# ---------------------------------------- -# Eval orchestration -# ---------------------------------------- - -def timed_eval(label: str, fn, *args, **kwargs) -> tuple[float, float]: - torch.cuda.synchronize() - t0 = time.perf_counter() - val_loss, val_bpb = fn(*args, **kwargs) - torch.cuda.synchronize() - elapsed_ms = 1000.0 * (time.perf_counter() - t0) - log(f"{label} val_loss:{val_loss:.8f} val_bpb:{val_bpb:.8f} eval_time:{elapsed_ms:.0f}ms") - return val_loss, val_bpb - - -def run_evals( - h: Hyperparameters, - device: torch.device, - val_data: ValidationData, - eval_model: torch.nn.Module -): - compiled_model = torch.compile(eval_model, dynamic=False, fullgraph=True) - timed_eval("final_int6_roundtrip", eval_val, h, device, val_data, compiled_model) - if h.sliding_window_enabled: - timed_eval("final_int6_sliding_window", eval_val_sliding, h, device, val_data, eval_model) - if h.ttt_enabled: - timed_eval("final_int6_ttt", eval_val_ttt, h, eval_model, device, val_data, log_fn=log) - -# ----------------------------- -# Training -# ----------------------------- - -def train_model(h: Hyperparameters, device: torch.device, val_data: ValidationData) -> None: - # Set up model - base_model = GPT(h).to(device).bfloat16() - restore_fp32_params(base_model) - compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) - if h.distributed: - model = DDP(compiled_model, device_ids=[h.local_rank], broadcast_buffers=False) - else: - model = compiled_model - log(f"model_params:{sum(p.numel() for p in base_model.parameters())}") - - # Set up optimizer and load train data - optimizers = Optimizers(h, base_model) - train_loader = DistributedTokenLoader( h.train_files, h.rank, h.world_size, device) - - # Helper functions for training - max_wallclock_ms = 1000.0 * h.max_wallclock_seconds if h.max_wallclock_seconds > 0 else None - if h.gptq_enabled and max_wallclock_ms is not None: - max_wallclock_ms -= h.gptq_reserve_seconds * 1000.0 - log(f"gptq:reserving {h.gptq_reserve_seconds:.0f}s, effective={max_wallclock_ms:.0f}ms") - - def training_frac(step: int, elapsed_ms: float) -> float: - """Fraction of training completed (0 to 1), using step or wallclock.""" - if max_wallclock_ms is None: - return step / max(h.iterations, 1) - return elapsed_ms / max(max_wallclock_ms, 1e-9) - - def lr_mul(frac: float) -> float: - if h.warmdown_frac <= 0: - return 1.0 - if frac >= 1.0 - h.warmdown_frac: - return max((1.0 - frac) / h.warmdown_frac, h.min_lr) - return 1.0 - - def step_fn(step, lr_scale): - optimizers.zero_grad_all() - train_loss = torch.zeros((), device=device) - for micro_step in range(h.grad_accum_steps): - if h.distributed: - model.require_backward_grad_sync = micro_step == h.grad_accum_steps - 1 - x, y = train_loader.next_batch(h.train_batch_tokens, h.train_seq_len, h.grad_accum_steps) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - loss = model(x, y) - train_loss += loss.detach() - (loss / h.grad_accum_steps).backward() - train_loss /= h.grad_accum_steps - - frac = min(step / h.muon_momentum_warmup_steps, 1.0) if h.muon_momentum_warmup_steps > 0 else 1.0 - muon_momentum = (1 - frac) * h.muon_momentum_warmup_start + frac * h.muon_momentum - for group in optimizers.optimizer_muon.param_groups: - group["momentum"] = muon_momentum - - for opt in optimizers: - for group in opt.param_groups: - group["lr"] = group["base_lr"] * lr_scale - - if h.grad_clip_norm > 0: - torch.nn.utils.clip_grad_norm_(base_model.parameters(), h.grad_clip_norm) - - optimizers.step() - return train_loss - - # Model warmup - if h.warmup_steps > 0: - initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} - initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] - model.train() - for warmup_step in range(h.warmup_steps): - step_fn(warmup_step, 1.0) - if warmup_step <= 5 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == h.warmup_steps: - log(f"warmup_step: {warmup_step + 1}/{h.warmup_steps}") - base_model.load_state_dict(initial_model_state, strict=True) - for opt, state in zip(optimizers, initial_optimizer_states, strict=True): - opt.load_state_dict(state) - optimizers.zero_grad_all() - if h.distributed: - model.require_backward_grad_sync = True - train_loader = DistributedTokenLoader( - h.train_files, h.rank, h.world_size, device) - - # Training loop - ema_state = {name: t.detach().float().clone() for name, t in base_model.state_dict().items()} - ema_decay = h.ema_decay - - training_time_ms = 0.0 - stop_after_step: int | None = None - torch.cuda.synchronize() - t0 = time.perf_counter() - - step = 0 - while True: - last_step = step == h.iterations or (stop_after_step is not None and step >= stop_after_step) - - # Modification 2: activate recurrence at recur_start_step - if step == h.recur_start_step and not base_model._recurrence_active: - base_model.set_recurrence_active(True) - log(f"recurrence:activated at step {step}, virtual_layers={base_model._get_virtual_layers()}") - - should_validate = last_step or (h.val_loss_every > 0 and step % h.val_loss_every == 0) - if should_validate: - torch.cuda.synchronize() - training_time_ms += 1000.0 * (time.perf_counter() - t0) - val_loss, val_bpb = eval_val(h, device, val_data, model) - log(f"{step}/{h.iterations} val_loss: {val_loss:.4f} val_bpb: {val_bpb:.4f}") - torch.cuda.synchronize() - t0 = time.perf_counter() - - if last_step: - if stop_after_step is not None and step < h.iterations: - log( - f"stopping_early: wallclock_cap train_time: {training_time_ms:.0f}ms " - f"step: {step}/{h.iterations}" - ) - break - - elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - frac = training_frac(step, elapsed_ms) - scale = lr_mul(frac) - train_loss = step_fn(step, scale) - - with torch.no_grad(): - for name, t in base_model.state_dict().items(): - ema_state[name].mul_(ema_decay).add_(t.detach().float(), alpha=1.0 - ema_decay) - - step += 1 - approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - - should_log_train = ( - h.train_log_every > 0 - and (step <= 5 or step % h.train_log_every == 0 or stop_after_step is not None) - ) - if should_log_train: - tok_per_sec = step * h.train_batch_tokens / (approx_training_time_ms / 1000.0) - log( - f"{step}/{h.iterations} train_loss: {train_loss.item():.4f} " - f"train_time: {approx_training_time_ms / 60000:.1f}m tok/s: {tok_per_sec:.0f}" - ) - - reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms - if h.distributed and max_wallclock_ms is not None: - reached_cap_tensor = torch.tensor(int(reached_cap), device=device) - dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) - reached_cap = bool(reached_cap_tensor.item()) - if stop_after_step is None and reached_cap: - stop_after_step = step - - log( - f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " - f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" - ) - - # Weight averaging - log("ema:applying EMA weights") - current_state = base_model.state_dict() - avg_state = {name: t.to(dtype=current_state[name].dtype) for name, t in ema_state.items()} - base_model.load_state_dict(avg_state, strict=True) - - return base_model, compiled_model - - -def train_and_eval(h: Hyperparameters, device: torch.device) -> None: - random.seed(h.seed) - np.random.seed(h.seed) - torch.manual_seed(h.seed) - torch.cuda.manual_seed_all(h.seed) - - val_data = ValidationData(h, device) - log(f"train_shards: {len(list(Path(h.datasets_dir).resolve().glob('fineweb_train_*.bin')))}") - log(f"val_tokens: {val_data.val_tokens.numel() - 1}") - - base_model, compiled_model = train_model(h, device, val_data) - timed_eval("pre-quantization post-ema", eval_val, h, device, val_data, compiled_model) - - serialize(h, base_model, Path(__file__).read_text(encoding="utf-8")) - if h.distributed: - dist.barrier() - - eval_model = deserialize(h, device) - # Activate recurrence on eval model for consistent evaluation - eval_model.set_recurrence_active(base_model._recurrence_active) - - run_evals(h, device, val_data, eval_model) - - -def main(): - # Modification 2: increase dynamo cache size for recurrence - torch._dynamo.config.cache_size_limit = 32 - - world_size = int(os.environ.get("WORLD_SIZE", "1")) - local_rank = int(os.environ.get("LOCAL_RANK", "0")) - distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ - - if not torch.cuda.is_available(): - raise RuntimeError("CUDA is required") - if world_size <= 0: - raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") - if 8 % world_size != 0: - raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") - - device = torch.device("cuda", local_rank) - torch.cuda.set_device(device) - if distributed: - dist.init_process_group(backend="nccl", device_id=device) - dist.barrier() - - torch.backends.cuda.matmul.allow_tf32 = True - torch.backends.cudnn.allow_tf32 = True - torch.set_float32_matmul_precision("high") - from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp - - enable_cudnn_sdp(False) - enable_flash_sdp(True) - enable_mem_efficient_sdp(False) - enable_math_sdp(False) - torch._dynamo.config.optimize_ddp = False - - h = Hyperparameters() - set_logging_hparams(h) - if h.is_main_process: - os.makedirs("logs", exist_ok=True) - log(100 * "=", console=False) - log("Hyperparameters:", console=True) - for k, v in sorted(vars(type(h)).items()): - if not k.startswith("_"): - log(f" {k}: {v}", console=True) - log(Path(__file__).read_text(encoding="utf-8"), console=False) - log("=" * 100, console=False) - log(f"Running Python {sys.version}", console=False) - log(f"Running PyTorch {torch.__version__}", console=False) - log( - subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, - console=False, - ) - log("=" * 100, console=False) - - train_and_eval(h, device) - - if distributed: - dist.destroy_process_group() - - -if __name__ == "__main__": - main() - -==================================================================================================== -Running Python 3.12.3 (main, Nov 6 2025, 13:44:16) [GCC 13.3.0] -Running PyTorch 2.9.1+cu128 -Fri Apr 3 11:06:30 2026 -+-----------------------------------------------------------------------------------------+ -| NVIDIA-SMI 580.126.09 Driver Version: 580.126.09 CUDA Version: 13.0 | -+-----------------------------------------+------------------------+----------------------+ -| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | -| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | -| | | MIG M. | -|=========================================+========================+======================| -| 0 NVIDIA H100 80GB HBM3 On | 00000000:19:00.0 Off | 0 | -| N/A 36C P0 120W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 1 NVIDIA H100 80GB HBM3 On | 00000000:3B:00.0 Off | 0 | -| N/A 31C P0 115W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 2 NVIDIA H100 80GB HBM3 On | 00000000:4C:00.0 Off | 0 | -| N/A 30C P0 117W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 3 NVIDIA H100 80GB HBM3 On | 00000000:5D:00.0 Off | 0 | -| N/A 34C P0 120W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 4 NVIDIA H100 80GB HBM3 On | 00000000:9B:00.0 Off | 0 | -| N/A 36C P0 121W / 700W | 1521MiB / 81559MiB | 6% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 5 NVIDIA H100 80GB HBM3 On | 00000000:BB:00.0 Off | 0 | -| N/A 32C P0 120W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 6 NVIDIA H100 80GB HBM3 On | 00000000:CB:00.0 Off | 0 | -| N/A 34C P0 119W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ -| 7 NVIDIA H100 80GB HBM3 On | 00000000:DB:00.0 Off | 0 | -| N/A 30C P0 118W / 700W | 1521MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+------------------------+----------------------+ - -+-----------------------------------------------------------------------------------------+ -| Processes: | -| GPU GI CI PID Type Process name GPU Memory | -| ID ID Usage | -|=========================================================================================| -| No running processes found | -+-----------------------------------------------------------------------------------------+ - -==================================================================================================== train_shards: 80 val_tokens: 45508608 model_params:34401371 @@ -2059,36 +97,36 @@ warmup_step: 6/20 warmup_step: 10/20 warmup_step: 20/20 0/20000 val_loss: 8.3152 val_bpb: 3.6137 -1/20000 train_loss: 8.3175 train_time: 0.0m tok/s: 8454987 -2/20000 train_loss: 12.3306 train_time: 0.0m tok/s: 8334503 -3/20000 train_loss: 10.8414 train_time: 0.0m tok/s: 8219626 -4/20000 train_loss: 8.9815 train_time: 0.0m tok/s: 8164354 -5/20000 train_loss: 7.7899 train_time: 0.0m tok/s: 8130397 -500/20000 train_loss: 2.9043 train_time: 0.8m tok/s: 7895416 -1000/20000 train_loss: 2.8890 train_time: 1.7m tok/s: 7879347 -1500/20000 train_loss: 2.9171 train_time: 2.5m tok/s: 7874664 -2000/20000 train_loss: 2.6567 train_time: 3.3m tok/s: 7869658 -2500/20000 train_loss: 2.7134 train_time: 4.2m tok/s: 7868619 -3000/20000 train_loss: 2.7648 train_time: 5.0m tok/s: 7867899 +1/20000 train_loss: 8.3175 train_time: 0.0m tok/s: 8427288 +2/20000 train_loss: 12.3306 train_time: 0.0m tok/s: 8355799 +3/20000 train_loss: 10.8414 train_time: 0.0m tok/s: 8254132 +4/20000 train_loss: 8.9815 train_time: 0.0m tok/s: 8207037 +5/20000 train_loss: 7.7899 train_time: 0.0m tok/s: 8175906 +500/20000 train_loss: 2.9026 train_time: 0.8m tok/s: 7928345 +1000/20000 train_loss: 2.8868 train_time: 1.7m tok/s: 7897402 +1500/20000 train_loss: 2.9194 train_time: 2.5m tok/s: 7887462 +2000/20000 train_loss: 2.6598 train_time: 3.3m tok/s: 7882463 +2500/20000 train_loss: 2.7139 train_time: 4.2m tok/s: 7880860 +3000/20000 train_loss: 2.7634 train_time: 5.0m tok/s: 7882252 recurrence:activated at step 3000, virtual_layers=[0, 1, 2, 3, 4, 5, 4, 5, 6, 7, 8, 9, 10] -3500/20000 train_loss: 2.6864 train_time: 6.1m tok/s: 7465573 -4000/20000 train_loss: 2.6244 train_time: 7.1m tok/s: 7377189 -4000/20000 val_loss: 2.6459 val_bpb: 1.1499 -4500/20000 train_loss: 2.5756 train_time: 8.1m tok/s: 7310795 -5000/20000 train_loss: 2.5210 train_time: 9.0m tok/s: 7257843 -5418/20000 val_loss: 2.5333 val_bpb: 1.1009 -stopping_early: wallclock_cap train_time: 590023ms step: 5418/20000 +3500/20000 train_loss: 2.6898 train_time: 6.1m tok/s: 7479088 +4000/20000 train_loss: 2.6233 train_time: 7.1m tok/s: 7390618 +4000/20000 val_loss: 2.6466 val_bpb: 1.1502 +4500/20000 train_loss: 2.5741 train_time: 8.1m tok/s: 7322551 +5000/20000 train_loss: 2.5200 train_time: 9.0m tok/s: 7269901 +5427/20000 val_loss: 2.5333 val_bpb: 1.1010 +stopping_early: wallclock_cap train_time: 590048ms step: 5427/20000 peak memory allocated: 30119 MiB reserved: 30156 MiB ema:applying EMA weights -pre-quantization post-ema val_loss:2.53080694 val_bpb:1.09985954 eval_time:2011ms +pre-quantization post-ema val_loss:2.53084430 val_bpb:1.09987578 eval_time:2003ms Serialized model: 132405827 bytes -Code size: 80967 bytes +Code size: 23948 bytes GPTQ:collecting Hessians from calibration data... GPTQ:collected 66 Hessians in 9.7s GPTQ quantization: 66 layers with full GPTQ, 0 fallback to clip-search -selective_prune: unpruned=16.03MB target=16.0MB -selective_prune: pruning 269368/9382344 lowest-error ±1 values (excess=33671B) -Serialized model int6+brotli: 15877688 bytes -Total submission size int6+brotli: 15958655 bytes -final_int6_roundtrip val_loss:2.55991644 val_bpb:1.11251020 eval_time:7641ms -final_int6_sliding_window val_loss:2.51788435 val_bpb:1.09424353 eval_time:75624ms +selective_prune: unpruned=15.98MB target=16.0MB +selective_prune: already fits, no pruning needed +Serialized model int6+brotli: 15953548 bytes +Total submission size int6+brotli: 15977496 bytes +final_int6_roundtrip val_loss:2.55674680 val_bpb:1.11113270 eval_time:7580ms +final_int6_sliding_window val_loss:2.51447678 val_bpb:1.09276264 eval_time:75764ms From 02340cdb7699a4b295a6beb9abb158849560b135 Mon Sep 17 00:00:00 2001 From: Aryan Bhosale Date: Fri, 3 Apr 2026 19:46:30 +0530 Subject: [PATCH 3/4] =?UTF-8?q?Update:=20add=20parallel=20residuals=20?= =?UTF-8?q?=E2=80=94=20val=5Fbpb=201.0904=20(3-seed=20mean)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Added parallel residuals from layer 7+ (separate attn/MLP lanes). 3-seed mean improves from 1.0926 to 1.0904. --- .../README.md | 43 ++-- .../submission.json | 32 +-- .../train_gpt.py | 3 +- .../train_seed314.log | 73 +++--- .../train_seed42.log | 219 +++++++++++++++--- .../train_seed999.log | 73 +++--- 6 files changed, 302 insertions(+), 141 deletions(-) diff --git a/records/track_10min_16mb/2026-04-03_SP4096_DepthRecurrence_MuonEqR_GPTQ/README.md b/records/track_10min_16mb/2026-04-03_SP4096_DepthRecurrence_MuonEqR_GPTQ/README.md index 06cf5009b0..b340021aa8 100644 --- a/records/track_10min_16mb/2026-04-03_SP4096_DepthRecurrence_MuonEqR_GPTQ/README.md +++ b/records/track_10min_16mb/2026-04-03_SP4096_DepthRecurrence_MuonEqR_GPTQ/README.md @@ -1,28 +1,40 @@ -# Record: SP4096 + Depth Recurrence + MuonEq-R + Full GPTQ — val_bpb 1.0926 (3-seed mean) +# Record: SP4096 + Depth Recurrence + Parallel Residuals + MuonEq-R + Full GPTQ — val_bpb 1.0904 (3-seed mean) -**val_bpb = 1.0926** (3-seed mean, std 0.0009) | **~15.98 MB** | 8xH100 SXM +**val_bpb = 1.0904** (3-seed mean, std 0.0016) | **~15.98 MB** | 8xH100 SXM ## 3-Seed Results (8xH100 80GB SXM, PyTorch 2.9.1+cu128) | Seed | Steps | **Sliding BPB** | Artifact | |------|-------|-----------------|----------| -| 42 | 5,415 | **1.0935** | 15,999,165 | -| 314 | 5,415 | **1.0917** | 15,963,773 | -| 999 | 5,420 | **1.0928** | 15,977,496 | -| **Mean** | | **1.0926** | | +| 42 | 5,279 | **1.0923** | 15,965,928 | +| 314 | 5,279 | **1.0894** | 15,997,318 | +| 999 | 5,279 | **1.0896** | 15,990,607 | +| **Mean** | | **1.0904** | | -Merged SOTA (PR #1019): **1.1147 BPB**. Delta: **-0.0221 BPB**. +Merged SOTA (PR #1019): **1.1147 BPB**. Delta: **-0.0243 BPB**. ## Changes from Merged SOTA -1. **4096-Vocab + MLP 4x + WD 0.090** — sp4096 tokenizer, wider MLP, higher WD for better quantization compression. Source: PR #1218 @clarkkev, PR #1285 @dexhunter. -2. **Depth Recurrence (layers 4,5)** — Virtual 13-layer network from 11 physical layers, zero extra params. Activates step 3000. Source: PR #1204 @msisovic, PR #1260 @dexhunter. -3. **MuonEq-R** — Row-normalized Muon (arXiv:2603.28254). Source: PR #1260 @dexhunter. -4. **Full GPTQ int6 + Brotli + Compressed Wrapper** — All 66 layers at int6, brotli-11 byte-shuffle, LZMA-compressed self-extracting code wrapper (~24KB vs ~81KB uncompressed). +Five orthogonal improvements: + +### 1. 4096-Vocab + MLP 4x + WD 0.090 +sp4096 tokenizer, wider MLP (4x vs 3x), higher weight decay for better quantization compression. Source: PR #1218 @clarkkev, PR #1285 @dexhunter. + +### 2. Depth Recurrence (layers 4,5) +Virtual 13-layer network from 11 physical layers, zero extra params. Activates step 3000. Source: PR #1204 @msisovic, PR #1260 @dexhunter. + +### 3. Parallel Residuals (from layer 7) +From layer 7 onward, attention and MLP operate on separate residual lanes. Attention reads from lane 0, MLP reads from lane 1. A learned `lane_merge` scalar blends the lanes after the final layer. Source: PR #1204 @msisovic, PR #1289 @MatoTeziTanka. + +### 4. MuonEq-R +Row-normalized Muon optimizer (arXiv:2603.28254). Source: PR #1260 @dexhunter. + +### 5. Full GPTQ int6 + Brotli + Compressed Wrapper +All 66 layers at int6, brotli-11 byte-shuffle, LZMA self-extracting code wrapper (~25KB). Source: PR #1019 @abaybektursun, PR #1218 @clarkkev. ## Architecture -11L/512d/8H/4KV, MLP 4x LeakyReLU(0.5)^2, XSA all, QK-Gain 4.0, Partial RoPE 16d, LN Scale, VE128 (9-10), sigmoid-gated U-Net skips, EMA(0.997), MuonEq-R (lr=0.02, WD=0.090), depth recurrence layers 4,5, full GPTQ int6 + brotli-11. +11L/512d/8H/4KV, MLP 4x LeakyReLU(0.5)^2, XSA all, QK-Gain 4.0, Partial RoPE 16d, LN Scale, VE128 (9-10), sigmoid-gated U-Net skips, EMA(0.997), MuonEq-R (lr=0.02, WD=0.090), depth recurrence layers 4,5, parallel residuals from layer 7, full GPTQ int6 + brotli-11. ## Compliance @@ -33,15 +45,18 @@ Merged SOTA (PR #1019): **1.1147 BPB**. Delta: **-0.0221 BPB**. ## Reproduction ```bash +pip install brotli MATCHED_FINEWEB_REPO_ID=kevclark/parameter-golf python3 data/cached_challenge_fineweb.py --variant sp4096 --skip-manifest -SEED=42 RECUR_LAYERS=4,5 RECUR_START_STEP=3000 torchrun --standalone --nproc_per_node=8 train_gpt.py +SEED=42 RECUR_LAYERS=4,5 RECUR_START_STEP=3000 PARALLEL_START_LAYER=7 \ +torchrun --standalone --nproc_per_node=8 train_gpt.py ``` ## Credits - PR #1218 @clarkkev (4096-vocab + MLP 4x + brotli) - PR #1285 @dexhunter (WD 0.090 + all-int6) -- PR #1204 @msisovic (depth recurrence concept) +- PR #1204 @msisovic (parallel residuals + depth recurrence) +- PR #1289 @MatoTeziTanka (parallel residuals integration) - PR #1260 @dexhunter (MuonEq-R + depth recurrence impl) - PR #1019 @abaybektursun (GPTQ + XSA-all) - PR #1287 @dentity007 (base code) diff --git a/records/track_10min_16mb/2026-04-03_SP4096_DepthRecurrence_MuonEqR_GPTQ/submission.json b/records/track_10min_16mb/2026-04-03_SP4096_DepthRecurrence_MuonEqR_GPTQ/submission.json index b323460d9c..f05ecd01c7 100644 --- a/records/track_10min_16mb/2026-04-03_SP4096_DepthRecurrence_MuonEqR_GPTQ/submission.json +++ b/records/track_10min_16mb/2026-04-03_SP4096_DepthRecurrence_MuonEqR_GPTQ/submission.json @@ -1,34 +1,34 @@ { "author": "aryanbhosale", "github_id": "aryanbhosale", - "name": "SP4096 + Depth Recurrence + MuonEq-R + Full GPTQ", - "blurb": "4096-vocab (sp4096) + MLP 4x + WD 0.090 + depth recurrence (layers 4,5) + MuonEq-R + full GPTQ int6 + brotli + selective pruning. 3-seed mean: 1.09265 BPB, beating merged SOTA (PR #1019, 1.11474 BPB) by 0.02209 BPB (Welch t=-33.5).", + "name": "SP4096 + Depth Recurrence + Parallel Residuals + MuonEq-R + Full GPTQ", + "blurb": "4096-vocab + MLP 4x + WD 0.090 + depth recurrence (layers 4,5) + parallel residuals (from layer 7) + MuonEq-R + full GPTQ int6 + brotli. 3-seed mean: 1.09042 BPB, beating merged SOTA (PR #1019, 1.11474 BPB) by 0.02432 BPB.", "date": "2026-04-03", "track": "10min_16mb", - "val_bpb": 1.09264648, - "val_bpb_std": 0.00093453, + "val_bpb": 1.09042220, + "val_bpb_std": 0.00163473, "seeds": [42, 314, 999], "seed_results": { "42": { - "val_bpb": 1.09351750, - "artifact_bytes": 15999165, - "steps": 5415 + "val_bpb": 1.09230771, + "artifact_bytes": 15965928, + "steps": 5279 }, "314": { - "val_bpb": 1.09165930, - "artifact_bytes": 15963773, - "steps": 5415 + "val_bpb": 1.08940214, + "artifact_bytes": 15997318, + "steps": 5279 }, "999": { - "val_bpb": 1.09276264, - "artifact_bytes": 15977496, - "steps": 5420 + "val_bpb": 1.08955674, + "artifact_bytes": 15990607, + "steps": 5279 } }, "comparison_baseline_pr": 1019, - "delta_vs_pr1019_bpb": -0.02208861, - "artifact_bytes_max": 15999165, + "delta_vs_pr1019_bpb": -0.02431289, + "artifact_bytes_max": 15997318, "hardware": "8xH100 80GB SXM", "pytorch_version": "2.9.1+cu128", - "technique_summary": "SP4096 + MLP 4x + WD 0.090 + Depth Recurrence (layers 4,5) + MuonEq-R + Full GPTQ int6 + Brotli + Compressed Wrapper" + "technique_summary": "SP4096 + MLP 4x + WD 0.090 + Depth Recurrence (layers 4,5) + Parallel Residuals (layer 7+) + MuonEq-R + Full GPTQ int6 + Brotli + Compressed Wrapper" } diff --git a/records/track_10min_16mb/2026-04-03_SP4096_DepthRecurrence_MuonEqR_GPTQ/train_gpt.py b/records/track_10min_16mb/2026-04-03_SP4096_DepthRecurrence_MuonEqR_GPTQ/train_gpt.py index ba30f345c0..ea4e233c48 100644 --- a/records/track_10min_16mb/2026-04-03_SP4096_DepthRecurrence_MuonEqR_GPTQ/train_gpt.py +++ b/records/track_10min_16mb/2026-04-03_SP4096_DepthRecurrence_MuonEqR_GPTQ/train_gpt.py @@ -1,3 +1,2 @@ import lzma as L,base64 as B -__wrapper_size__=23447 -exec(L.decompress(B.b85decode("{Wp48S^xk9=GL@E0stWa8~^|S5YJf5;XFo2+g$)On@VT6Qap3bu0*kgCR~YUqB0W9R)iarr*QtEZpesGY3>~CZRiK|6Dwut$nH#N""!RYqQnA}G^`ZsFO;ar92)Xt#3E3Ki5S1}OfSx<=$c<4=h|J{kt$27^CQ01M+lVgZ0tGgX0&I*V@{U&JgYc0U!(4F-btCy*+qzv6D""p~UW!y~6{U*}y$E@2-R}vd?t*s#fnDO{!j>OImt34A(d+9n>hnnvzmd((_D1Cghg~(bQ$Yj)>!Y%{*o9ex8FWa#U)!OI!!5Prl^?bnBX2V(=(Bvc+CvGo!S{LhLn7pSsR!}@""U=OBW0)h6IYneQ1{|$<&k9TS^qGQpb-;#vEPAl%11UF)?6mtC8c04XzR$+h2=j84E2|i`pOEt$uyM`lGs*ejIF-}^SvSRZK$ePh1""`gt+?%1r#=OVy3pW`{ofc)6PhQQRP|_h56zl+sQ(le1eJ^&&qZxdGb15""aOb^-R1ouqi-H1_w|H;g(()bKz_!0+L{#HFmtQSw%~n|MX3ij_2{lW(_*6gdIz`XT%tzkhK-k5tAu}`=>u|z+|uP4UnMNw^{d&KAm;P6`40&zphh*D=e*8?KGZuo~y*`y#Wg}r(PV}?J$Of6s2V+h!p=48x=kjDY_aX`;C#9(p^)jAnKY3g{5DbCkKssjYrRZXs2-`q~ph{?^weRV5FIKs2lBmbcg&RzLuOlaj@{R`Bfpp)eb""D}Ew7C`>;2>B3jf`(s3Yq*>}BxRhe{CvdtS;G|JdNs#+jk4ep3dje2Aqp%-8na-AvoSenabJ|$$wl8x*@Ue0+ZLIPX!Z`X`fDGAy""*mN{vNnb*|);N0fh-Z`3Aau{>lWape5IR~oJnXs<;zeX^21})nE$qd8r&8ti65NPa0a~or2;Suxwo%f(tuSS@rF*DUP|;0zJT4F7""RlJu9oEJotpZlOYK^$6&@1D6Ybg$yb(#g*e~5V+DpFWW@xtT}P3;OWc#i`!T6_OgC#jrG@Z-u}yx8""D+(63t^Z!rAvVq^;Hlx=NeuDV(m5K`Lloy+!tf4a-gljma9pI-QkvA4QxjGark_SRsKw%OOGOK?F6!GAG}Y3eXW8Andz%""B`_|L#`TG!i_Sn38W@mXk8(m}NGMRyY!nG7#(7q2kx=|C9XxNBD8;s;!clxeRlYzvgugFsQ?+B#pA|G9oM#}UvB}vG?(T|bb!%<)eY{jK?T}o!!L;8MX""aR9n&d~6*Y^qQ+_l?~MMX+}0>=iu&u=R#57w@o9&)|V4Ug=h=-{88=2!gonRZ6P5zIsz|R)AX)sGS^%ca4&c`)xE88mbC%}O7HV%""CNLM;Td~lRF~xmpFQB%OIB~VgMRMqx`Fd@oX$o""jx0VVKTM%u$S1Je>ak!%r(}nlB@GQhHEW<7r_>z&RVBXF0~M6WaUEqT??3}H!vCwVrT-+aQ9}8eKZ(4ODp$WCE?qE;D{_IK|WO`#~pSTdixgJa`d3bV_y=L#oq?u(cX|U?!hP~-ioP)ZUHOJ8d*};0VgaVl!VMXCyKhG`+6wQ8j1>ea#G7dHNownYeKS|^XNy|f9HQOs-OM*>n;ie""nwoR%RYcjeZpqUUq|Q{@52h^IN145J&1a3U#~jPnK0aK&d3i~qHdLbi{H@k;?us5Y*zocf*@n^aB6E|!!fP4Q%Hx3nzYW1QQZ42y^ru2=G=TOOqCbw|xjTIna~z""duyA6$gl>oNX`nJ$UUlcQd&5t?n(mSO0z~;b2Qh9VwLd9e)S+5v8AUW?kV}GE*zZk{M460oKod9-ISQI""0gV^pJyb9~6iut%7p{klC#xEK#h0iDCAiLwnXUt{`dVc0%c*DFNthN{dz%#kkYZtc{vo~8QH!u(v)*o~Ty%+M$x9ShSK^=ObXclGly&&lCQ@K*MUVnM(*R5A1(z=MLI(2Q48|7$6E4nd%%NF{(r^Fv)1$j?gL6a""x2I8rJ4sL*y~@jG_{Y7!xZhtFF&*vl3X8r7=L6aAEs+{?PvR3L?U_I(8d2wrP`dD`xPr(9C{e;ZL1qbK#W""#js48ukMqA%N*qtQ&82XzfZ4p0oj*zF~)i|L-yQ@T!x8)8XyhR!+us3b?VZyH%?uQ+uqf&UUf;kQ&SfWg=E|;-8sbXd%4-YRUs@V""?&@A-0US|{R)r^Eg@<3ab)O^=a$|P@2-ZAptnZYae`;5q>$>vScXK6fpvUImA5X}jC|SL9Fv}!2)I$VjH?Q<*o7q2W2jBa_Nik^<7Kl$|iA<{3*mO-USuQCZun@~zOpH4J3=Ly%|4%OJ7""Dt%mRZX%&&HCVG|%ncCT8O#m}FzZA}gc`F>vTmqT*7it-7-dlWLE5>Y@AH|;Q#RvQmSvA+;j)-5)cy?A5UU|sFy@E;VbG59""vYiAY1@ywh%lQIVBJg|#q~wLHJen`9O9|}SSXdE3>2mGh5-Y4uJdhpWJJ%=8lGAf-yS""uvhvwlr?87A|u-G*0eJAX+W{aZ}509z~Hu_R5(uD;%T;oYEA+)MN2kd`(W3q>N#~8F4(Nz9+_=y!wSQat|PyNt4H*j48Rx_unaWX""#Ixl{=cbX-d1*Z1U+zMJ@88$@b3gu7isTsv8he{cdlOtzOkWU0b#Y6F9ZR?|GB6nZW$bW4Ll0~Ai|P>Ke}P&-$sWsP9v)-@|Jo;|6_A6ys?AYPM8(a|Tb;%>|bj%41&Unu@eC>UCY4jFsaHB0b;9BBl4guf*PB1LO~pMGC(@5e)M{XJZYhos;`c`6ZK{""4Zn|?!UdMd}TFsV(8kep--vCvs&!A-n!!`-;dK`UdU^mrW)2y_=s5z<6ZN`%RnY$""VYsJMF|Th|RxFK-!OQ41VB9|AJ|FmyCge_f>(sg51RZ^mz13KU9F`+iEc4aDlJ);a_SQtn&@sMm$""^cCs#9PqK6K3Qq)h++{Z7N)~3ZJ?USgf3MRM!b3x|DGgBM{3+ndwH-L#sF`fKQ`f3^CGrD)y_z^*pVEGFg)km_e20hzTw9SnIo?1zW&*B!7tL)LvSFJZnsJqK+0^m#GJ)}@(G*Kfg}=!Gg<<3>H(17{CX*@L5|LI3clm!hr-""o9fmX0Mg=W&CpwY*3?d5X>pGB^sN?c?LW;&=eFE?xaVdgKXw)=(d5zeC{ILJpfc7go$LJyQFyRfN@@TcUxQWXKMuNgy9VT&CSuxJf4r^WO|Y>a(v{o#9u+-cOZm;*CRzFE""FrJ$k40dzzv>*?;eIoOIy}}edA!*$t_!T@r;C7%a-LhY4+nV$yxLET-t!CWM+;8ZoF}p>vls6Glza{|D!@ism^#HE~l~bTsX#-vH""L`B$xZRMgr!^o)G18l""&pWTb6!qFAG9=SkXyCz;Y~Dt$gLZXo6RUOyO`Kv<-z|H0WT3ErkE5l>ee_*f3dGt5m`k~RoV>T;WUzd91)ELVco*S~oN&UFg@F4B""0kSGPnjohR6H8ftuE2vMwpU+NADl@Qmm28LDJOzXnn4)^o5ZTsY8PJ}gU@lh|S2T~3>a6;C}k?5)`l""I4Sq>?6Sk@A6e!)sG>HT_|b2ModK^NtFyt&Inr7m(~ITt78<4gnP}Wyf!D4C>U;r^B?xnn=`u8PQ^;$2rd)*qHvhUrQBW_)~zpR+U^OTTm#H#y!r>oM=EJ}I1J9({swy*c(b8EFXsiB#_T+wcZ$i5Ew8""de0CXD&_agYzTcwM+XKog&iM%-&4c*m(hc8?yLUmCB%e@e`1r1+~&E$_5?jQ!90eDf)JW?1T*ZGKa!xKq=Urb*@jWwj""N$fUxuUv(G=ZwBpJFkf8DcUM7vv5W3Pd5KtIfXlarmSj#rs*j*f%JNhiuhUnU8ylDa$d;3>#~U)RWMGSh^QLuR_u)r$&~7DGW>Fj=""@gEUnotT+kFFGIvd5mlb=ghr4jJM$=)yi_hCc>6BzFyx=QYlb7KbUR*F&HiHHLpL$1#n@~>V+0PTE)""f-$UA0}cAWo79P#g&lGNk9AYvmj&Ug_yw9s*IvPC>XmQ<%I8qO3GG-3vvgEdkwbC8zDU5r0ca}VWtiZsi;cu1q1?eY`_qDXGf)lw""@Fy~a7J=VUqLy=9cIxF>zq<""10YDIj}N#8YKAQY@U4#>C2EjAZK+J~^<#LmPCfUWntU7B#xOHKxgX&bUF4=614Hc~NOFm*6VrG#n;kv4Q<|0jz-nCA3P&6i@YOqr""0PeLhDL`gvdmTKz^1kns-t*@p|N0#&?`7pe4D>(#SOHw#HZ7aPG)L5KSnJ#KnB&^ZqwrwRdnITo!MHKvi9mG%1A~aa#3U0t9PRzd""S&2Qu%F@{Q2?0dzt8;tbnAC~X-hsHS-Q@4l#@zlr1@M#SA{$7$B0@{Nw>|Pe?pJ#tR>i2s)Trojropq`gb%+8y4jh!4~l1Pu^=Z6``i3IE>rr!S;T""8d~dS^3NhYMto;g;_#!n+M-H$y>FW(BKs>K+Bde8bycXj-Vqk*^D7Ljw`w8H3Mm#g$GUPtJTvT876@%u|fz~y1T@zoZ+1L#!n*d""F;zD3?dU0uTwqRC#;4Hrz_O2j5p8b&avrf1pzv8e-tNLbj!p^E&rUr)IaYYn?Clkz{OG_eiaA~j8pAZ0Wq=TXsXV5&Vg8W8crPE(""%fd*8_S+Xm6oU{|h&ugSL4bW%PnBRu^$!uXNykih*dMJM=Y+w}84%m?tI`JHF!8p{YgJ&zF}Tvr2agKniajsHmer}>u^clZE`uxx""f6`@@pymi)QGQCWUacF-y^Bwa)@_~Z$^13VD!X$`yCBX1Ry~=00t;+?hge$We)4!Humk0Gi3hsem&qGg6w8=<{5>&=y5RpOzGmTB""sTkAbeW5CwR@zmoHd_qyCCjNbCEL>`tP&@xC`Y@Zkui<}<>etKKs;@f#}iSmO&RUSLI$4^`?}v=6>^!nLs+)=w!E~DC{(=}&im8@""eMTc#)GhJ*e3^;@1$eZa~%aXgT~ESZDfSu}^=J?F^3YF8&CA`nc6afs*d+BjWpGBJ|T-j_pbvS0OpG""2;}ucP|*>0Xj~GUA@5q&IzbN6txVJlSZBK98V=^Qz+(|vBT}X#3KORC1vU;d3#p8|qZW(jJq^azP0FrM7>$47ealNHhpFFSdHPaaD{7y56^40OXP7C%Q1SN""Tzqb#f%VTbZjN;nBAyZ5VIcYL|{JY`0lux!PkyZ;WA4v%W^{1""$fnG+ExNXS#5$""6*?6Q#1ZxQKBKQdfkqVJz!Yv7Cd?zL8&cGwx?XXFq+`Pde-oZ(CB?CTRg{<(W#6_Og5tR-;n#7BaB9-CBUqEf8|60j;C6!v2!q|y""I@#8rUs$H|s`Wez;CSwn*AQ&)uJ`6ZJ~Bbe4z(YcD4uvJk`(E1Ixrm5V?-yWEDXE;iAi!E|9xC(U4?8Vz&lAdz-^G`J3j8|7m%VY""lvBVktpAttE@~G*s605}dKp5e?&yk_!UR(cq2u#-K_9f4h6f`#D7K4{ia7PZBN5(i)3iw6Bw0R4%w_IB6|ZviPVJWfG|OrPX#auR""Hp+$A*4>Z*0D<~|f9R^Rd1bPdn4=6@qWMfR$C*`vp%ne@mwict+b@_1T^<~2P>Eq(4n9biVz5Qg9ral$nT7vLal(PQ!PQ9Mpv#5E""5nnynf9qaufF5v4vBTNMxq&!Ui&wT)IXa{-!eHg-*t=1_cFiqhXzS-WrN99H~G6Sj}T&m""l^<~oG=?p{XGi;-FuM?_XBsDemRO!RN{-a@Qq^vNIpx(edVFO3UBp_h0x@c%=z_8;-Hzxd|MPbt3w8-^082;6tylol1ipjm*@mWP""CDPh?6PH(!y!#1`1Stk8-A1*nVyvDc>xS^$0MfWuzEe_+YrSEEfB(DHC=(^R{8ML^d2u0~TAv{4Vr9Bs*_I$qz(IBFHYh)@L6@4$""Gq|z25&7eqhEP2r_+N&wAw}8UxR*+}qI>`&`e_Ai""is>~@l0;uj{Ga|_1puPg1){E{Pj%h!3{K84{_+T()o=+-d{ai!D#7W-UhvZvsOkSQ^d3B43q?J%!0Kt&mSdN1p=)!8tOP>)1)7r-""rvL6e5MUtm0amt25j4{Gv|e78jK%7*R@?3$aq-7K^4<6Bp|t{3sR8-+IKp3w@w(T8CbqWWiY?Zxm_|ShVeNqZvznY4*&>P+++2WR""&_9ZSZ2Dc}fx>zn+}p1#=)YHJ|K~hzDz~vYjRss7Beh}3c_d05M--doRCHSwME+@|{DyIhcMsj)Va50(SI4uFRjI09a)RIMpQJ7e""({Akyn|Kov)>VA+16on49gFXYSbpRBo&1Jt{8~*2>QlWwL`8L{T7KZrn9fm9uDMDwCvj-+)(5%m8WT2fZcpV2xd7h|&_l1wKpZfT%g?w?X|l(+?2k9g3`4orgK}IVm)L$N)aoYU`s{SRwWS&AT3@Uu%Qu$@nVDLUkqk)G{3iI`HD}""(e?l#mwWtzNosyxHRzTbEM;fgf2Z_~eozfer{BwpzO{(dvP}>Tu6uNCGGmn-R~ibjV%d85(e)D?jg2JS$?u1)K(Utj%j-Q*J⁣""xjmY{Po@TxxSxTSw8Qa~I94%oLGXz0tyejhuHEg6;pXp#i?4k1T@~mKTc)sKh(cp1Vt-J1$3j)(oB_8-zV;~MZct2zsWPxTJrV?^""M}W)SIJE3}9GUN*8DOxG9E1D!F8$+4+c_ExBcp2i5_C84UgG*6vnKTHqDyKb8L--qq~>Jezp@hvWq?R`kLLI)rWXhD7Ji^O3&dtG""(RA0?g3ubu2Epb2qrF$28H{MJAuv$?bfQUq!{ElUBJaPaVgdT?0JivdjY8L&HLY|Psc><+LRA~nY;{17lGa6rFv4Zr>>Ks3!Ofmh""m1qnYktJG)D8<8HzIX>_$7NUOROEO}*4gHguxol<@{<&~#~<3=X}4#*W~OPabWetKM6Qq%=|ve3lGp@nYHMnOb#mlz""k!0=&Zx^OMPm?E-lhZ?v;MDB!KZzQQqc*gkFd0QN+M5*Bl|pPvrd0h~(*nyB+DQ%wl?~^E!tZ>SekL9z;4fGCw^x)88rGhV-NhI=""+c4$VEjl^ENH16yRTbIle2I81O2k>uMgd%*h)x0Bg4dTiH4dr>q5m4KDgb;YAcT*Kl`H|=x5%EuIsV-gDUfmiV_8^>Ojyj&XM;=y""09Hzi&`SnvmCCq>Tk~!C-+zm+-A-@4QD~19qDCMWMrP_IU^<>6)Zj-pw+xly8!(ukoESDA9~bJzMuNVA;+EfX4JlPQ*=QI{LG+-G""uP||j;mTOmmMp==2KH*YVdv-CT(h!KETO0X!cYg!$HW(`dN}SxS0D?uJ^%zZ_!z3}n4XC#L{eO$cH2Xw6aO_!a=FlpjjB>v`)QULT*up^x>l^m0mGQKsos8*F)!hdLR&>*~MCI$fF{o)tq+A(DX-L@^DmmxtBK)}*bF%Sg`U?fGAdk7PXS8zNcW;H>XI""BUAlJ(E~(~9OPp=XABT=c9d183MX3@E=yZ#{D$_+C?enNxvt31M_yNt^YGHyk>NR@BUq16!bAHe2{mO|Coa$FDdH`UK#y&q-BS""hm98l#YO5yYLApAtwj?upIM9<^=3davWpUdjSY>DesfDNG0?q8C4_Ka#H07dlPz$YI(@ax-D~&;GSwKd*o@>87KpG~9VzUQ^4l?#""s4Oc9JzGzINAq5-C9mw*ILwAcH1(b;;tH18z|!XkD0l4DWU>}n?X(|i4mQ?lqeD_te3exZOTHma+9Q!w|K{BzIJ}d<92PQ3puB&{tB>i`QrnW6dQKbN<>|}Ry""wJ2>cN%WqtnC}228|#>qis_GvPTb7b45elvqa>I1LvZlLS;A"">t)7_|D<$jeZ+&$@nl5thh#gfc0_ORnOeo~Gk?073gFI$(PDbSt(fARj_TaKN%+Yf-r""AgNq)ehb&C?zv;f22d+W0BN!LW%I!Zdj;D8OGma1XE)dqgzP^ru#P4=zYL8OQ;>OX(6s&*qqM*yCG<&m1#}h4t!nh&Ez4+Y$o;jDzFcme0dLZeN""LxQxvnAcFXZA~K#wHyT05O6!JJWlAkyy_jOU;HwpWG3ZRr2RP_8Yo5k`gItcS3mfr`4K{pYF7qbt6f@AOjL_#2_IggA{O6&PcscP""_I@sqdnGNFJo!4H^na~KDIqD;F^vDAAL;}efLMBkj&CpawXxWlC4r1WouSYaPivVDHC7M)PeS~5DNu7xSLyIp5`8C|8|-Vb)SzbI""v^14}|I&@>trWz}$^;t{$iW5K()XInKZ$~d&(1@&iGOXXNHlz=aT{Gv$;0lDtJ10SS3_RfI$yC0?g|$cS>jPv$kMGlk`X%S?jeclev#q5%QJm=#s=H_T(wLBEsnMS|y7""hvStmR2_N2MfcW%HU)c7cwYGB><)m=+8gsxj!v3ve^^--7xe1ZRPV5U1&O8hqcuz|a(jDp!R0to$YE0Id`g%q2Pix-#M1*?bTO^S7;{bKK3^1fUWI~5a`==oV|!{manu=V#YM*>*#LjyNT5)ThA;G0$}hLtxDMn!(a6NL*aTFA`7Hmp8W@""5$h07L=yXt$L()&2QVr-jx#;-N+4I`9~B@;stk36xS""7rvKJOnE}Fhy5!qn_FgsTI6)^Ke""0|wOg1?vz0a~Xj6YR$dGI2i^qzHobOnzREvdFex1{Mu>_O)ThEnC!mDA>>5m9n4&E`ux$Wz(8V9K&u%2e9Aa_0K;#~7w)>o*U&DAf${Yg`~*j*#YQ?j&;|bKs7A""liTgnLb|4rN?gyYJBsqMT!R~UnojBJc#%5Ry;r>p2kCZKSI;0X8&{Z_0X)4sy}*lHSo+eD97)>8s)ZOR#*id_UH_r@AUj`iJ8TcvA1(I%U`Ok$zv""r`eFH*d*>S`JSnev`_fVN9h^0*S05Xi8_WpC|uD@=ozyTEEvoFL2Ni?rzy;pHdM@sy3LNZVpKT9FxDx`>mTz+ANC{>wLBy5fMjlY?rK%vAMq5KD6zbpdZF;zrwDB}~+9%t@#;M)(R?Ml6~"">%A89thn}(vv2E7NrR)$pQ1niiIIKft3oq_CZ0jER$h+j8@qlw&VF!mmY+V7er+N+?{r*XrntXgc5znz{#ZGzy+lFN58mWF7hnII""L=r%sDzdYvaEFzGX|MutZ@oGMY?0~=J;)_ry$=wF2515~Yf5^vx""baeGPNNN)9ZF+y`UKct5KxzN`Gry#-%>oQK&YHvC;JYwK))qPffbVac6=iey5y_36Q=)b~@$dzwP5Bnk@<0$)6Q4ED%cQLMA23QH"";+NJnzP1(MxM1?8Njqfs?l@q{w!%#c64#6H`TEx3kKP$~Y~EU~?A*K&?X2nFJ3sX;aVbe|9}Ud92)9!U>U2S@nBU(;Gv?^m@S""wL{oSTz3f2i3}&Kvu?{`Y<^b7ieYWwuCFeck~?mZknU7Ydn6}(#{GKwF+drBjl(6}Tn#sZ$oLmee}^7dm;~``|Kx!pl}YP+UI7WJ""efyM~+y`~+tujH9!rRoT!rG!YwPEJc%%3@_%mNR+x+dw`QA2ppgCym{L_UBi5f_TIWK7aip*9dJ)Y`6FGZQnqK(PjoZtJk$lp7*(?99b+QQ=|2(wskv%4O=""f8W%wW2KwK1!;|$9WC$qe7B5pcC!?SH^&zJjneE%i0RC0>""Z0p3*XbH-c1dq}h?(gZ~Fy-E>DwP)&?dtA)+2;G_l`_zS9%kG*4YV""uB`*dx<(EtB`x-YT-~+|lLa=8R*!G)gTN(X#aNsN%jc4Z#OLc8)FCiGg=Cy@rajmW*cH4PR)uz+-{nbT3w?nvVU*_`0++fJA+u0|(>WI3Y+=n9+mpt*oh`%L@zyYb0dRgsU4TK""uvsZks@n9kchPoIE791S`V?`MI@*YVllPO`f%za+)yEgv6yeBt96A8q@CthCF@s;k<>@ROr7(aHe4t9rI8E7V#H^)N%mvpejhYQw""mwRl3mj~x-ouNua<{e_On7FHG;_4yPsZ&HzmJv^xK2kE%_UH8yB^IA_6LIJL*i~(J1J&$IUD~V{1e)E8+Zhd+i^?_f={#1}F<97D""NZ5tjis#;a=3Jt^apAjuK9|E+q@*(9}&*(jK@xHH~""gH*)aw$mB;3L5GYO4m-&d-+4t@duN)loJXmM}a(=hAj()u79aGu%*;#PwS`gq89rI*BMsp`)MqwUv{{TMap0u?4(TFl397{l;kkt""tJ(LN6%{8inmnQB<(QY{L2h}zvcEYo&b+L>VUY{Ag^|ua1XLRUZAqg}K*rU^h?I&?poUYt1""=38ndiT_kbHX=5Nl1Ys~+O;`Seot^aNjK8@846WEQJ!xD3-OwUVI-$^rtPH+e_z@RwW8E9tx>wtQ$oBHc|%$1*_EEF_FvCuP5K8-""(r~l^+eqE=e><%I_cn9oV?v$A}}V2ZElJk6*TtiWmeKb+o|)B^$wi!M""Q_>R$lXAL=IdDtzik|n2qi9_WS!=)oh@NQj0jpTDE_02xvsvXH2f}zegkCnVR-fJJ4=BHkr6!kRtKoMC{pltB-_BuDh6wy`rTM(nVt{-LwscfvTf~OzedS%Hc7>3)Eel#sU)9-Dg0^|f8KR2iPf~Yw26@W#A&c1!#{D}VNwD6(wXGfAodfL`o^OsMxrSQJ`PKirZj#P#0(m@XN-)HB>_Fp|G7MI""H7~idf?plwJxi?r$$D0F1?LwNn84I*5UdI4xk<4=Y)1U$`|hQZnROMUOWg*$gSs4jyx-4~Gz!Su3seC`Lig2z_`_cLOl02Fs91@_NS1#e-4gmKdIIcC5C9pFeI6q0}v#KErKC|@~giz3DlR{L$8By3b@!MxG-7u""CbcPX;IVP$Z_OX>pX0~}Xp-W+gQQ)bsbnL#ZBxspUwIih4tv!|q_6jA35uxMB9i&C7V-i<9=nm(SAa#Ap|@wRN}@WC1x!ODJb8k@""{VwjuL3RTlH;`x}3OA+M@gOTI0{*wu7QaATlR#ggvJLjAdO-5@ni61GZG*bzGUQwpcMzfV3T$%&G%s{6ZI4CBDcRfxBd6B22oY75eVsHA$VxqO%E1@Jm0QxizI-qor5{#E{6wo1-t{^vu5>iFrN-*6u=REKs=A&nk;""S99>l%(0N6nV0l=v$a%3h8fXB6p&HlNCEb@h1L&HQfwGRg4ufgEi_QsF0nZ{6ljnE+J62#DoY)}{H$sFliz9zjeucTkB~wrCGZNw""b_a1lx{;6jt4V2{T1MtIr0{bP6S}pOw#_Q&eGx+t{&ivW^z3U^t>8bFCSTuct^=yDCby^1h=uBaZaY6+*iD4tJXNfu4R%sxG{-z6""tzc)L);UGx=EUN|;VidC|1AtBC>dd*$%N5h{>_!+St?6mJ0vZJ^V2Y""#T+aHjK);yS@yagNnbL7vp6!A=4Jq2J|$&}NUC)^=W5LRW8I|P-S7wimo0JceaV}Krp^2ccuP;lrip&461DmZ+SfItk>FX0$L>IQ""pWF9SvNB4*%NZtJ1^CD`mYH!Q)|CN!^ho<2G=ovfd({5l0ju;P2!L$;!6_;l@4U@D$khmBSnVR*6P-|e*FO(Asj&`Nc)8FMR+ck&""TOV&%t&c5^yMA~gV=85_E#uGEyTkr09=fv-X+`rWOz3D{w@zB=M_}ZAtKMS&TM01s@#tkdI$r5}RSRJ$D;MX}AKCIy@Oy+Aw>!hU1Iq@t(r""^ESJzE)bX)s0!TzKxiBhk3Cbee%6_E&U;!gkHcN2uDw7i+v7E#bU8ah)v>O2I~GH2$3aayjC(R1=e9)k2-H63DJ+PF58vQz$tulW""E?PAK0thS8Vte*D7iM!#5)~p2$79>;PTJy)GJ2DrP-2*vn&X&;6Ryi0k2V{+&V}@S59&<&({c5gf*f^DwmsXQ^))VS%{)!5s5=urVH`)G0QIlXfijW3f=PdZqq%2fe7wpU4B&ZYRQre*n_^mtlH~8Kw0Z!&=i6Lk0SUYJb%i9""b{r%ZrZI6$RHQDEc%`s!*f78z%6^Oic~cI#a7Q4>N+Kwc<7jjpFC^)))if`&-FptV85u>Evw(tFF)nBFq2wK(GS8A~9t>ah{4Yk=""NR9(!OHIBr({tMiW`!psuu+47JjnF=dd3t_h-=uC4>~a)H57HzS%Jq}YA4Z~x|ZkTrgKK|6?eNtMCNWf>yoXF<&@7@z;?#g(=XkP""Jb+Z4%Vuc%;dMuqD_sT{{1A8fTG}$8W9|o@Q{RrC^2~QJO5;OUVIVs4V_K_N8za{7Lr56oeHni^m1Rwh7Jx8~qimxKZwB}xQ_eXF""y&D9L(J&rYY+3AoJQH#VC;m{SI+$w?f>>`-Eq4uI`mZK}p|9*-|Hi(pX$a}I7TcluXjT?Ad6W(h`gjHkNeKgJ""@&?p6B+#hf9mu>C!5&_5p0=g4;-5128x<6Q;PnM)GdyqW*Mj^z?-6r<)>hg`h0C=B>8@_`iExoYUo4d@m|}4ZQv9Ru|#q96iGxN>K&w?=!D&GgAKx9S;FaMaa8GTov~P5br5>${?`PVm099t!4myz""!rvCdi<>;!qE!MdE#mA!Lp_k8u(T-eP^jR^P7Wy*c51My>N3)7I5PxVh~q`^%|aLPpEY5|$~fd>mIFsZVlxh@wClZ)cdhWr8fzfEXb1cr2NqCMw5-45;)UR)H{@G7dp`8c8g*!&UxRQXUugpL""juDw=kR60VLF@7W0goL2N*1u#t&*JL+<_@_h*xhx%0jt(eA)~nj80N&IHLn6xrcIK&Vp^6JX2R1+EijuxKxZ&`_-pZ&RS*Hr_i2U""=&S@jYUMuY)(Eq5q6A1A&_Rx@CH2_;cUenUtP|I?qMV-=S;sO9n$%3=!`xfZ!KHPW1K97$AWPT""n|9uc?U%ew2p>DQikRxuP88g<&5n7x6;hBXo%Z8W?&OhsfH>~q4IjyKhV7)1*=W``cu;}y3r0DT&Ojb1$wnd6Vb9yj8a&NQW5t7%""B0qwSL#dj88V(}9o_f{Hg1E1O_z^!p0kpzH5UIXE3L_~fg*bZ{TIdQ8I1V(}qs?hX7(zl;%!2PgHVSPlIs^FT9oCO8Br6)x_*F_Z""tGWA?0xA66h1-U3E^z#{1l8Kox2Xz!O6dyJ(I3%{;)ECjTluN#;w-D!3=`KBd5k6%Y=N+PRn_%jON8*7^3A9&uMd9aSm{d7T+*_F""giN_0?JOmdk(S;y@jh}LZ?&!v~_!1N>;m+3Wuv|J$ndx""I1Kce?fB?IijjIQwKRw^JIp;7yA|X_M2DyqQG@-yG1>XEuT;%51%m=MA}n{FGCel>`R|W$lDkV|(Z=wmaB*w@>?1W&""2C>dUtFklNrOix)>}v7~lBvjjAuuax6~72^jh?kw4-D?{q%cTVI2wFEvXt>DsJzs)&U?C5#@C^FTw0{^*G#vpMdFxwBNp3tpT$8r?5oMTXR7fTlu_QVPhB""1D)o_ju_dwUD*j?lk0gEI$aOTS@T%^6iOZ8+CTg49vE$mLMw&kh6GB-(sXgzIT%g+jdUwu>of_SO5@zoQ{P!#GMMT#qO{sYSo$l?""sqegt!svC&Fd|1SSr$wjIy2BN<%oP@@i@OBfeW4J!x{3Ntm=56XT4!;!<3X?l;K?V#Kz%;3X5K8S7&ETG#qkpxO#SCG+2`hb`d$4""n*e2g8oE3(t>$8l)nElnp*Y%~t9KF0EK10Lb+H2ws2X8mcH$su{6sbdDA-RN^o1u!idB!xXpx4K2Mc8KUz""6i)PagVT>mfAb2I%MIJ=y|h(mD_T+C3H6BRQ5sC5G&K7?59!wL(kfjEPSNbf^K=ohJ8JW#XfwA>a?Y168PC>{UeT<3IdCc%=C0lE""N|Fw*^v>}36}39*Wz&fePChjHWPD5m#(2Hy&RX}u4RSttddLzChny7yT@jyK2M54^!sl0L{|x-QoBGzundoziQcQd|I}*i)lL=GMr+tqgJ3qkn0?{UMCDSMB"">};Hv_!}EH&yWShU0mx|A^WR^`V!VBGXv%i_Dh2~)17$3X1jTmrI2cMS~Bk54q*8sgyPU@1ezC*+z!z9RSyvE4Y%(PVgG&V(qM`E""`xUomx2cUI8f3JfZ`V2QdwZ>>4D^PQOA~MTNGH94@k!ngC)QeA)5R$oJ;_d+f@p9c3f9ov0aNs#CPTC-qICP?*9r!EM|BZL9y`Q^"";FX3?Uk+Oc#xGKg-RN;ts+$j&9qF?>CAG%=vZLS+g7NiXRVzLZumAu6_HsiSU%NN?00H-r0mt|R(7W)^vBYQl0ssI200dcD"))) \ No newline at end of file +exec(L.decompress(B.b85decode("{Wp48S^xk9=GL@E0stWa8~^|S5YJf5;YFWHtz7^#n@VT6Qap3bu0*kgCR~YUqB0W9R)iarr*QtEZpesGY3>~CZRiK|6Dwut$nH#N""!RYqQnA}G^`ZsFO;ar92)Xt#3E3Ki5S1}OfSx<=$c<4=h|J{kt$27^CQ01M+lVgZ0tGgX0&I*V@{U&JgYc0U!(4F-btCy*+qzv6D""p~UW!y~6{U*}y$E@2-R}vd?t*s#fnDO{!j>OImt34A(d+9n>hnnvzmd((_D1Cghg~(bQ$Yj)>!Y%{*o9ex8FWa#U)!OI!!5Prl^?bnBX2V(=(Bvc+CvGo!S{LhLn7pSsR!}@""U=OBW0)h6IYneQ1{|$<&k9TS^qGQpb-;#vEPAl%11UF)?6mtC8c04XzR$+h2=j84E2|i`pOEt$uyM`lGs*ejIF-}^SvSRZK$ePh1""`gt+?%1r#=OVy3pW`{ofc)6PhQQRP|_h56zl+sQ(le1eJ^&&qZxdGb15""aOb^-R1ouqi-H1_w|H;g(()bKz_!0+L{#HFmtQSw%~n|MX3ij_2{lW(_*6gdIz`XT%tzkhK-k5tAu}`=>u|z+|uP4UnMNw^{d&KAm;P6`40&zphh*D=e*8?KGZuo~y*`y#Wg}r(PV}?J$Of6s2V+h!p=48x=kjDY_aX`;C#9(p^)jAnKY3g{5DbCkKssjYrRZXs2-`q~ph{?^weRV5FIKs2lBmbcg&RzLuOlaj@{R`Bfpp)eb""D}Ew7C`>;2>B3jf`(s3Yq*>}BxRhe{CvdtS;G|JdNs#+jk4ep3dje2Aqp%-8na-AvoSenabJ|$$wl8x*@Ue0+ZLIPX!Z`X`fDGAy""*mN{vNnb*|);N0fh-Z`3Aau{>lWape5IR~oJnXs<;zeX^21})nE$qd8r&8ti65NPa0a~or2;Suxwo%f(tuSS@rF*DUP|;0zJT4F7""RlJu9oEB!E%i}B&x5U4}gtIp-6SMZ{+dkQZxc4r^bW|Az""23B02v@$$zZ%Q|e3t}6^@m@%ylth+=fP-&da%J`u${*%cAl8_70iLZEWXUfUjlL&RR^jHS9Tq95&vuGS2b5hx9w73Pd7u!;`;|Sv4%Kl~^O{I5Sl``;c77Gm&9srDuwH$#63K-RNdMg%p;!1ot7Jv0GK3ctSXc9RL;""g&^SBVjM<_jiV|I=%-4fq>|RpMKBQw;b#%_BIGzZ#XxSg@>j8}%*~8RJ6Pg4!7!5a9>N*+AlwvQ(uv4jICUoy4BG((kdo(}s}`Vv""1h!lBD{Fe`&jW-nqAF+l)dTQ)miJ+XWf+Bx5B39&Y&_c7nY6x^oh_zcHJOko`G7%$G7USc_F>xQc9n@jCYhVoUh?$}#BRZ^4uq_2""7s@t04Q9?!>yQDglhT^N2`YZ}W&g%U*+!q1Q8742Y|l(6#?^AU5Z5gV%-a0iERpi?v=HHVfjN}^6}a?}`q(n5@+Q`-%DW;bNf;Ef""(^-HggnS+gU$2LF$rp#h0Wy0{eG0thz}nPW`LO2Vo!k#dtZ=zHNxR4(2G-CZt}-<{jWZDRbsO@SnxD(""b8WvO+#Yoi*I?Ig#z9Q!8bUOMgidxfL3jJsds0HNh<{P}z+Z9^<%S9bL-*`;H7?ij0DLvEv<2>_I^ljL0F4;6hFjrBuZMbyksHh7pF|%iHA@V?2rWapvolwjHFu@DnNgX?""@ndz>DS(WZ6SyuBUx9aogimYJ-6OnW3G;q_w@lPU@zPb4!(=~XaPSAWHM@u`A=*pWT;kfDKmRzGpuc?UT1{*G{W%3""dY+IfM^;OE^&VtQW(NzSc0q{x2ls>V@3_w@lhBX?abxhiaNL8V5lwSXSyvKV(_w5g_wvG#mCE75E74wPhQAzuw>d4XM9z^oih&_7""hU&JHC4k0#E|OMFmdq5}sJ`7<^GvlJ_5kK%gwZiV-ti?K1TaXvL+`{$RRkO$jAu!eq""ga%GNi+ODunTk?S9_-O$glv8oSflYG;ode^PGGk_wSH%kr;l5Zy*CYU8k#YuvUyX?zz1V_5-r-L$7|Esx}6vhGEQxjc&q=L62}rI""E}`($@GXxR#1sO-!EmCKp|l|;ft}Gk2s>$Lh={&QRFYO#x`f&72H=#IbhgIcE7n)0gon>9`z4h9p0Fy#<7i7#anxl#Z4C6zU0j|E""hw3P-5y7kkjiRoBe|H1&9-&@9+ey}!%9d|{7VSq5VV$Y`#CSgDNpe<%8l^tJ8WZRVa|^@GhM`(^8lOXVNm!^MK(STQsj9cH0DW1+ci7FR^{Hl*A1l-t<@84cLt=}eZc>*b-_Q?_s;fA`bkQdo%!bI9M6-!Z!61W!gz""obH{xgfKETjMNfCT+rfuz23HS98c*Jh;7lBjMf;Y?v;B?o2HX%QJG@&5Y*?d_9_;zGYvvH;PW{@i_qlTQJdc*FK3LI+C{""sE~lcPiIOG{{Ublk&=rC=u&eusEwBJ+rFZZnMXU9gx{7w)%nKVg{Y@H$oijVan|1Vp*FVOtRj81VhDi-+>uV4l(^YNHqCJ5Vs9sc""+IvZ)JJ(D!tkpXGU19WClu_8&=JZpniZ@=Mr-Bp6t6q@;2L`!(XRYmsksc@D67xrj}LIK"">@jGp(eYTA@glWm11U;5;AI991*670Lrxpt)TAwhBueqWtFNFd-q{`ap!+42K~EL|5g-G7gf;}Ae@N};sgR{bS+YIg;3eVA5oQY0=jE9RSy8aX@|m#iE!?RCDS*|$Jka+=rav`5(_6`@tjW%$KND+>j>m9bWBN;940MUK<&zs(PA_v*=pEy""#=iKj_vMjF6F&&~+|>%?<&XFI5H0e;TO&;4;9;exq773k(xegV3ZCCP&=j47lU&BUqH2n6+@EJ>ed&PLa6c?C$rT2$ySigJp4O@5""Lc+_fEN^A6-p=hCn|S8P42!#(=W)dbrI@}ezD|pnyg6QZqjCv}%=b(8f@igY_B~fIdDbCiP&S9sx9=(iC(W23{pXMph*Se=0pp;8z8Klv?NX=i*!MS)7+FB%m!{nvE!p_Q6SIud5x&=6JrV|J5zIp43_euOi0K""WD=(7*VVP1oZ=a>g#>_dij3vL_l!8f2G$oy5hHlWVl)=tRiIJH-+ym+s2;Z;uG^Z3?y#>!DIzPy!of=oRgYXZM}1m(mJZ#pM8+MQqnQ4FSHk?`NPYN;jM;9*@9eMD""mPT!=6V#g0abmJa78KIrWo}0v?vM=)4%gYb*znENMJdiy!mVc?{AcVvLW=0fk-PYP~*ubwD~$u^Yaw!zRJseh9S?gRV*!CO?i""j+(RCC|?htf#@1$FbzCO7+l}ugcWir*^U<98~wx{>RyW?JZ=L3C2|}BElWP$g|BHKdFyWYgTKFB5A2p~cR-AMt>&w^s3jjq*oz9B""Y-D->I&`a-$XxXPs%CVc`r}IZg)(kXqLCI)vSjJ}f`lnYWkaX&T@W@WtpM&Bc^N+?luK26B`7%uUP8|9+gjQX{&s%hm7_hl?r;IV""NFS9fHEfzgX5Q4#*^+OfOu1kLXLzf1NS4jQZ41#c()>NDX|v8gB7lf^t?`Av""h}5H0wAaxFd_BC-5MC6iqgmoO_M-{aukTC~=U6sBEypj0#6%R>Wp?m<>5Vg11maFY_;G`^~Eza^nMZ{JSovKRCy9>38Y0A`)Nic""=oqW#MMyNBnVdT@*GIQWz(L%otkIlZgH>X`@fH&4@Jm_g)?xLmBbZK21R=nEFu0vtdUU&Qxjh+nUP}&GkZllla0TL>-reIDJr<|TrKzn<_G={3i{=F|8kMHy{E{#$Vf_Vj;IVf>%eofKtK""oAa3oVy~p>%OLZX;D^)XKJlJ;gNuCM&G(T0)*1MDl{JF-c{3z*1202RO=gM))Z9P{jXz4^W%+3{9""i2loyx$b#@Uh^|5mwj5$NRHN8cSfUPSc*Jgc53>MIw6vOZcheQRY|=HNYTH*DUZq~(Lu~MqLt|Q%GPc3X(ZM_K8%Zr-Ac)+*BF-%""N{pdt!FBip@hugGl!6`0kU3Ju1D%LODFp^+$Is@8n8{m~$bqDMCge}nRcBEY?0r>w7^cqr5_B|T{""u%&?X0@C7oBVfKL71q2*u+9|i7Ko;hezQK+5$LLfy9tUg%27Iq?RUu&5wljmm~wS)A3yQs!=t_bOl2`_SZyzC5E&u5UmXf6u3VVTov3uXp+FeLp7mB3EJdDKhNofRAXyN;du^8G}rB?vbmTrEOua@eoN7<@gvI""A>P1GobI;Q__y-&$D-9f=uQkSXPp_#4H&xDo(fZDx6U!3jPfZL0Cnvyi#4a4A9#4-~!""x;P{#db3Ln*<&Hs&fx4lE_4p&BF+jAfTP<=(%8?5QDOYmpmpI;SM3wv+Tgr{Lw""&MUQDwRhMFgKxJ}o1KXLqE00p>dg}ER5IdP%%cLhy7dBhqD!vA*{C^LDI_rK9amZ^w;MJb3y!Ll>|#uuz3gncUmbYr+a3;JoPZ0B""fbsoszK&=OM%u(78P@ekd!;kT8U8Y!w3PZBhIO6t8""f!DgpCYrqR@QwEsa828~%0HbxI)o`$AHb=TnE9d>+7}tOBp3yeakb2YPZ4PY?*#GyElWk$s{?9sB""{RLh6#&a>=;g_F+)^AAHvvg42Eit3a75)PM9$gZN>@M=@!Ucr;oGR_M?7O;4DDafPQykq8Sb{j>p6lHoW4);J1-ZO*H~g$~ASR8q""m3DvA$VB^Dj{Yb}8Gcm_?52MJxsUbkS16&1@!#nNyB@P5Xn=^$r+=nQ#Wor-(OHO|T-Ts^aMn#m(Fr5`&*1F27hkU$JqZ)G@UoPg""TCrW@xK1PRSJ_%ryQXOd>!)bid@!_hezg?^~!{""Encs7kNQM%YJ|_gLXogC$kBH+A_C>etC82dbtT+nff3&$FgoY^llp@yK%<**skYKSwFMVF#o5l(fb^qj251+u$w%EbO}#I#`5#i!""i_dPVW&@K{-~1uD3)0d3byNCLlX9f0U2Y-AlqtA&g3=qhVmkp=dP(;5k?oDv7;$lge}o)1?=jnL3?j9yn03^c|L-?jKU#=Mjsaq~""ciAXkQ{>o+_YttQ2xyqUA78Vmkg0shDp6;Yq1Ku$x#fp&TCJLAjDgOo;y=Y|2xGxBaWetiODmKgIkDm+G-u~h@FD-pcIO(5j@Bev""w>EOt(SFXQ8I_4lt@%;oF&#D0+TUz}Nfb?q8QQ~Ur<;Ji^xLP*082}!Y7?5;T++L}wc-oGJMG$1u{-3-WEXDagir>=gDefit@GS""K1+5ozDX~oryl${x>Be0^&Vrs4fmO=aZF|""N4coCYP7sTRok_hcrN{}q$!_Tjkq^#Z__JN;Ju9$Gy+r6rFqO_=S45D_lUBY(_I+~(cCMx?Bbx3`eutdj12a-TiIv9P2XjCda1d9""ONdhX(zRe(T{{1nWSN=RN)E^i7Nc^<#EqOOK%o|_Lel5B6hY3OR(G`pUQZIOFNOoAmLa9jEX?3g$UGat^e}-A(vF3yn5`Kjhbye7""yJgLnmF;u(Z69Zl5r5#o2Mut?M%p_~N4wzgN$&O$8!B&9Uav-u62n88>Zpb|=?;oIzsZm&rL2N|DYB12tDjq);WiUu9Y2U>-sS6N""*87o=0j8VK!oUaB-6LoOW!$y!1{)&~|r@M+_S|R6tzIV&3hQ!REwLv1I__0oyvMfb3Fjk(m1a9)gb1CYrUcfu%rTfa7m5W|($o{UlNx=}hcc&O""!I8y?IY?OlVU2b6q=%AnC(dPoTWOR8g)zyc;+np8{L5-{sFNJ`b!Si4W5}k6w^?~U)<-?#P0z-8k}gR-$8-#""U>L0xS=JxSPsC{q&KkB0;=i{Eh!67Gdm2zvCw(ZgVxt22ry*DF)?wXSu8HQ&xp8I_GK9Z}>SM$rR_d!g80FT!H5pfGc%>pySL|ar""XPtVyAC@KlCKyZ_nGTGbv4h6WfFiUg6*uH7p0r6*W>YJCJHIoTVtO=brVWBX^W&<^m?C#}4{<@M-=-LxwcQw)lr(~?3*ZJL>j%E4""OaRcz8Tk<*w&H9HLiS~%VqkYVW<%W9cHmY=X#|M(!c?^^&auF7g;g$b=#8_CiJ9UBBkcG)v(j82}a={2#H`y4S}(8sPvy63P>Vp&UMV-FgeAfl1*L6vZgJp""@BV>PCthn|IweqpLkfn5rW-k|ia}gxglH@OJctg_x5jHI6cHgQ#)BKOu}AFhkeD7PN00Vta03CpQ`Nyq>0;^O-AD>MZsY^4eWysW""rv9&_=<56ln;o?j^z1IWQhq8$zno}7Fucp69()8h7e7d(Axe|*Hz?-V7%A!>SM6>ow3iJ=U-SlC;^mZf<3ffOuu(!J=a1fcx)n`K""JB|YT-Dor|J6wYD8h;u+V{nuB4|nM%v8lL3rL@iSln5pdUWoKT(oPK%S8=Fny`9#d+?LW{D0Xny9qndx=mjFCL%upkbt%I+|9NWe""-GlS!G;X|}(k?kvneqb73*-z1FSqs3E-ljfnr<#Vk1g@5US3gd7#II6(lbW7l40ECKb@rg($O>f&aN0Tq#FsymF)R_L33t%QLl#m""Op0?rd%s-cw5@;smu~HvVmX&%tvSGDVDgFY*`lLgjqz@W?JJb#mX3|i?r3(ATft785Y+z}pvt!pK1xQ9Y1B(>by`_RR(rzoiE9D1""d_g#^b*Y7LJ>Y~o#|wC%@UX}>1H2)O@*R9F(CurItl9+x}q^b;ftvw;fQ{M`{mQE}Jp?*sL*M04I-i4>$""tO$&Co9uHbLDaSOpkhd8kXT>Rufiz!{@B{>GKlu49j!xwkmcPl-Z|f8POZp{RFa!>NvLX~I-8A-BqX?G29P8%v1`!LcbRDMo9>BGuvKQIE+p4|k~2_dT7KD#o1zBg-{_ct+BEqze)ec}""^`s-e_}(Bi-aG^M#DGRCp5;TtMfu*qbFF7+{acH2YSk{(EVMle-WoFhS)c;B9D}7N%=c?+ER}!;xSdCrUq{iuOdn#=i?2`;#5KM^""8y=Zg;U0Z?d8?RqMphKh5GZHNL$(+p`OGY_Jge0a+m;hmPzVqtg~l~2)C=v#;l=APyFxMmu*JNOCdmpj_15_Mx@90Lj`)w;^|l>p""8RG_#(vyO%$#22p<%Giz^vz4F2Ks7TZ^r3%s5+A630AaRTD{#&cu)IHf|{u5{@)tClA{*39|NP$ATe4E5i)1zYA$nK4I""oWov!DG&l2CU=U!E""b!UeV&1qbq9l-D`e#&HDul(^Ja0kI5nV$20hbdJO=8K{v1|2(f(OqsEUH|aAk7s48Z0n4ENjNX0-y%2_JLHszP}@9bC9y=WA3RKL""V)qjAhj2^`;2&DAsqEWX$YxFd$gDOJ0PoPK0PC=V=wr4mcdI(nrwsJZ1^c%KPHh7zpt5~4{B|m-YYJmOCHUQHOCSU`>l9={VjvU6""-PwN(v0Z7nLc_}gv|6DX3`wOKa`9M;""`=C@3kmeRM^XaA^0?AE!fkikB5K90LcmgC;w!MTyfzp0*lL6DE>s*fo%I$D*1>$EBcdGR|gU4=kLGHE|OjqO7*;B1?NC+|S*D#3a""Z0H!*F0LG@b;s|ho@K~Tw7*rrA`0$&R^$1}T5XX)53PH@9P+l}fW+l}*y5V$)=XO-eXGsyU`-#WO-Wmm{C""sHPj=3O""FbY%RdyNKImi~b6k!T|i8E_13)C}ayW!^4Z6HZPhz5sq_>T{~{))6U>sf3KYjF*n5OHUi=-Xlpb?#R>CGi4l3j>j9QF6G94YoW{i9f0bVCRU@2CB2#`Yl37AOp*2uWoCH95;tQ-sCm8mSd&^O;IpWU#mSr`n^IYeGvPLG#6h7~4!Kyn+RsKDJ;_ZchyToH=JSVM+*QpG;%1udadY#B_QqwF&<|z!>{O_W?""0I}&=zxIzLrlZ6#7mi`H2IjaQb8^V#GT&3;_zOD@Uev4dKQ~owoT0s&=J(XLG`9ZDd?g)t&;4Y1YO7}B=K;P3NcTYU{|x7=={Yb|""pK_v;BDTH!rlz8@Q=v_a!_*k{^cU_+41D7IQ;7shVr#^IH(N}4U~{^%""(qLEOyjOkp>@2&?&OygQLq^Wz2af(lNw2`~$gE$X`1i`R%6)g4wb-w{-Y?ZW&d*ab4R%c=-tR{BKdPb?""9Q8tu8EvXv>JDx3xIPTe7c>LzEtgR#Stk!(e16|A$oYmZlvAh=81x=gAnbGx5Vb&g1>Iix%R^xV""pi&^W6-Uf~vwDd|e=Y`(opu@l#b~R7OK5C#(+JK8KN9nFf)tgoYGWA%nHju~0HxsV""XaC#4O1QNjHEA?*ys-n}bjh*G&2cfQPTgRIn5hSIwVu-CDs;NI4Pz5?+73$XA@bw1b#Yf4Tvx5_GRyy*Qp!=$6l4O6ha0dmU}5O^""RRBo~2t@`6;~PfT6=g3bt?{l>4LZqsRLSD!UH^7^aUdB8UYAVD0IV7sto+va{E#jo_6SPJ!}I;`B6>|HpwN@uwiWUtX3Q@CKW{O*""O)r}sqN!V?9~BMZztY5+HHFyCKMAx9%S)NTy{oj6zQI4T-7Qv;pgOT2%lwr9hk}o&zzuiy!|`sz50E%sGUdJPp|zlyrZpFM&&#Bv8)h~!cvE%A(lrXnU1^V>&>@5gBeKKjGHmc4y?8txzIR^b1v&f$""Zfa&9x<3IYZj>VAD}wSF_(3Vib`#~D+y+|fcM?D)e&Jju-OCp8nfh9ry0=ly@|(P3xV%6m1Lhxd+SA<;z|$r)^GPfD42`$mYO""^zfEK@Mi!*`u^mNO#rw^{$~7V0jh_M3`Wog;jhH-GP$EHfmz2SIR|x*BcO?J_uR|inUTXPJbM>QE*2vyZGg=ga8O`$evemCnEo#`""J8>6Bdo3DLK^VCnaPNi~q|Ig`%||Cu6A1X=GwhO1ZMs^@EAkMUz@O<-eVh1*ZGM|0==};cID`oc01acj7*3s^yUSNFfXv~da4dI`""=)=4jtP+c6zk><)oTS*k6t{1JVQrMGCk^jBCuft=hVt^342;r=q6bxagt!;!s+FL?uqT8ai5$(3x;H7r{l~J3%l|s4bGF({h86+w""*F0u06GlV>puIm-oEGvmOeh1>8)j(b1(o<>e!|s|I9?o3v0-Xq#I)xgaM^jKqBRdRr}IVCK!IkU#+g4SbKDF`R%}@Up2m%f4bi0i""77$4wvUlUWnDyVl}QfW>r`mFP9u`#W1%$zFnVN90|!gX|(4LeZ##4agXjF=tVFuj!$wSPlK>?>Di`;sD5)}}R~9CaqA""!X|5T6@+w+*QBEV5qT6yfrx|79%hBuB-#u%=$O09dcx}1r{%rs$tdKPVwa~hpaGWjl-){1^wI(TQz+Zm&1h~A?s^#;onn`9TqAD`""$2f*GeoY9SfOa7$t|+=R3YXiB?t>dk%#""YZt*J2=iugz7UX3mZAw`QI8vr$eUdc1I~Dt0?wW(t4f@(t0#+_zHZ815K#GDuM??NHe|I8$%RMS_L_Y8R$*Bz{TbrO&(>OiP&Mnl""S`8Sc4$cH%)_OlzQB%Ld&ofWEC7^Vi&xKtBQiGLLlmmVscL#OEv3V|AbG5In>PkP6-3K5D9@z;k4KM6T{OUQ;D#xFU%>ekfx%0@Un?m>WT_&fZa-3oGj`&<+aZeQ6?JFb$9e!090a)K1w;d-s2A%LxbzO`m(_|9G3!WNL*RiLPCq^ES^2ATU}lV?T{oWxF?zez=xsyi5Nu""A*l;(BD<6Rj2|@q-m7%wT1Z1eBq}8nW-#yDN)K(7rWvp78;XVTOr=-WAo77aA|I!ryi7)djRi^4t{M(P*3OX`hT;_zwdIM?e$=Ps""c9c+FY%d^@RzRR(wu4s&doh@8k%JpYDda7s#M*F${i8-}4Acf@n7yzGX58?nVLiCE0Sv?)I)WQ$-ml5~5bE$dmW{1kEv+5gIeRM@%dP?R(*^-b5n?8b0#uG3M`i@Cqb4i+XRhsEd2^3f*iuy0SJnUeK>8dyFw=Y1tL+PFf!M2E?MG3fSZGK8d6wNWdFgX0heb*Gdb&LwK2T3(A3>d(Ko@Eg@qqrXjY`!|w#c>S*>Q=62SHjf&&""7Ef4W!!uC&ab`$Idh83cVB}K*Z3qN5+>b|;e9-Lvj`9zT5{dI7S}b&AYL`o=X3?-SbD*`kVd(|<(IyKh7u5R=T4TdLo`LgCw7l(2""+7E0f3_teoKEq#*RxgEQm>t`?`_i5lP&Hr{u3r{T7rO|!bzXZNL<9hBj&q05rWR@e)!SJTa{wiZg24bsXIU{i!C""?pAvJr;LpsEk4eh*eWC-vrbNIxg?Oz_fE&Xr-g&C&@NAF$QF%9pqoY%agi>D{?58w@VQGRfGJ#UG2b1QvT@o@o2Ji)Taq;ARymyl""ZnM8wntUwcJkR`aTChVX)wgbiE^I@@+%JqR;-m0-ibprFW`3P*GWB*I%Aj_0CQpmTmGNiUVj&qYQ|pEVS#E$WTFc^ND2G4>yi&V?""PMJ3sGLKOa|7hVJ)*CzO@m**l+LX(yBlvTL*(^Z;6Qa37BVTW`T5P7TMIBc+mGsarM%UZ{1eXB{fyD7fV;|)j7+`c5qyU""BQ7nPJK_92*hs5)@!5D$Kb<9d3_U=Vr8>-%E)*i}BQZCC3=(Y86Zr*t3Zv#-Jx?nPSqCEEe+#BO0FNU?6m7RI8YpP6M|?bH5{*H*""eOO3YM|U0!_mF{MLF?IDo2gC+P=yB=&sHtMxgUq7JhBBQPMa3^52?X#@g~1pP3eae1QdV#m>KX(&YMXPTL^$UzM70|m7q0%`2Q?V""<{7hRXjHpV!3}bLGqV*D#g8_ls03Ebzn${9P4k!_B_+?Qu2@TkD=XtvvPCJ4zDrkRz+jjWTAt7p2_Hj{m&kGIPv|b>w*;JSMsC%T""67VsC$S(_9jSt;u^zlK)3_or+t0BB>Q+9p{*4~YVM*Jim|d4>Y76-aRH""HyK5VS|a--P@LR3DrBh%0v$ou*)-Gyp~2Oe>jDoV^Zez6K57@sswg~gG8waI)YdhLShK)C1wdOO4|ervo*@<&BC==9r<7A~wciDU""AQ3}of6&dh!XMV*X!J(%8kUUo7}9Dt+sX>TFd#2UUWC%Hy!ZJ@8y{Zk>UtBd4ieXNF`wvyQl;{?*NyY@ya=uuy{jurFG6#)=@(0q""-^^8+VQyfSt}U}Cx^X1UP1K_Dvu0I=&i_RS""&FGgywG%`^!Mm&;g|HVEwjOZsIz+hd1$9>B%gEFl-ijEp2dfi=#mVYabqd=s_-bHfBz2vR848h>NlbFbAu+wUzG&lIhZ?IdRL%jN""9myCj{r9E){rpWkwa(^9Zq}pOJW#uu@_7c)bVv%tM8g=k!zME+N1-dr7RROa&?TDU5Jd5r)pyn#6*$lT9GQ_>M48NOPDZp<{`c)fD)`N+7g9d9POITzsZW$^$>O6lda*a46=vjY3YE6Os_Ko_H4u""PB+d>!8~*!la|^565vA11=e+)JNN-$W>9B6VN?bnx%D&4$7f?+)QQ1l*|exfu~yUsFQZZPjsyOEnKxp)a1p@z3O4K=V$PSj*9y>_""0YvC#ohTKw!sNAf30wq_Phu%6jNU_Fnh?2v|1>hot3_QzsrbkD#Ax9;S+lk+OA)x7%5T09!PRjYx1)2@d<5+LTC-*B8hL!gCxXF{_?MzEa>0I?#oKjiHSEc0>8-E)xHltER?DRA4@i6oR@`g?-m6*-S{&$&#IK}wR*Q+jDSTX9{X_4GA`#ye""!paFtf-_Vm;FpB5R>%VxYjPi$mN!O@e=U5COw}""g2?!pV>-Ju+K%`q4cghVU602^47!fv8)t`CJn)UygZ5Y;-fNZ7<2C0fbXT`hVKBNMCNG5jLd==H8Y7q&#vV}kqPP$+aD~Zs*~Hqb""HuGa~blJLX$Prk|%-{-x+_dZ$XiO`OMhbP|RO!;Q;YQ>O$TGpl%mvw4Ts+*54q7AhNN0?@E3<#ZJ1c{@W%@K!`w}#T5oS""RN?AgM)WFF_q7Z5WZg6b_vcO@9gXwM>qRLK15QSw{oXE7Zmov>oi(v2GuO@Vg7L=#(FaFb$gdCj!1""*4ZnRno;NutmEqIZud9a;5OEObKp2@MCGOA#6=u%iivuJ^^FCgr?Tr_M{phavEtq{*z@z}SiJZ@wSk+~;~g8jBR-`9gydu0=iEVK""&XV$~Oz>VOu0_BoJKJbwv%Ei84~`o#pS{?P?F@UV{ePF-c^>i$jNT#F@}zJ0!xGYBKQs`<-Jmoe{|lSGC2t)psWkS9m2%alcBBbt""Fd#a7WRX&dDeUceTGpLR-Bkl{_2eJ<8QUcJe`VQOvZX_8RT)m$WWl0SZ-*nB_lc=EXFg>b`SQ;z7ICJ?^FZK~IM`9pq!0aCMZpLXvX""C8WayVD#vcpoZP--A2#r&MS#m(Z{UdW+ox8N2qgr1UHYXl9B_4Gmj&~bmV{>L)zvP`HB#""ynzRo&_&yKGVjhTp|B-6o!VMt#6LUxw#K$Z*QwR@4uio_Via$Rjd2*+I1g_W$BaO~s1%q-135wI`bOe&m$E6XlK2*Fs*0T=kF!bn6RGK(eM%J7Nid>J6vtA)GDu0dU&LS0YE#ozS!0TaD53V9g5tli""t*L^o)vb~l!oT;$O|u0)!qp?^-DEO)N@^Q%;ftKJ)p)8G3T)Mv%Qt&!}gP+1MOnRPy7am>{Gd}t_dT7B61>%UxFt-1W*Bs;;qLT9nm0!MUACC@W2m*""@p~#4i7S|0FzOIL6QCbUn5)S1yRQLrUwO?@3*)8#$hA+Yky5+YjF_+K""o5S4ln?wr=H&QDIv&p#90U~9k0V!lOIfV-YTYoMoa;_l@thWHC8K|h&y>3zx)nt""sOyx5dSlJ|2mYqJFW_8+)%HL_zCpNp;WvKv5EC74U2hf{6(*BLOt-*e3%__I3rs(L35WGqcGbpw<1VCck$GBRR<&V%cy1usfFpBl""FJZrF#_mDkxS-co(sLkNIA!*Jy?)wT+3dY+exDD1y^j_dVAktMo1$W6X+p9K2Jj`&-arE(8b_3KbB@n*pj*Yu;a?p94#6VB4Bm+zJ+W$UEX693oc%$yqQaHAEB8H2+Fv?UC6~lXlF>UYp""*c5yJo1b37CBFwejFsT^;){7shwY)iX$TTvGOkURJWQ~+R`$0MU5#lO41gGth=ps!a4W#;axKy@kgNDX43U4iSe`mo+m*_vF>-_f""@wDaeaZ+uRde*GtF$-tKD!kYB=W4R4P&|t{0z@{U98VdxmFvvS{aeJ`III|bZZ2>&Qc`xN#q=65hI`99T5RH~a;MCb64m%nk%LRB""iB|4b5ZcjUew_Q2K*?xeXM600TLI6Zq(_3`-MW6rwx7Okf7eG!r{mY9OS}Pc_7V7sheM{WTVI&}{;pV-(TI77jua?wy(>$)kT7m@EjFpGoine5lh^Dq)vFQTEPX""*!#@B5H}Ed`7@__6wqs0mTNDd8XJVOlBk{4O@=_&mJ)#vhC5>aJLa9_GvinXc3GoFo)DcZ*ug-L?M5+a{SiX58PY~Cgwe7TYLf1d""vgDi<98izug;|-)!k%bN8XFSnsS!pc^6JG;V%zj^qp$Yx_B*Y^7""oc9d9(7Q)D88m<$7VwCJ_zifks(neSBXL#M!e};US1tiRey#=pv4P;Amh8YO!!x!fCOdcpyre*JBV9@E=GEK@eB%=aZh}8Z#1yc)pz0r9bw%@}`VM!;6UbtwegSrsk%C62wPn+G=~9&(xP)*5INPTo#>6T8_cFj^OMnY)fekxGvk*f7_Ps""YZz)#G3B0Xnj)-{xx^cz3$Fw8sq`ED=XfE>-?Kz72NZxbc3>f|m@r0yPaLgp4?h=a4cNOGj=l5GNPB0@%mrepp2MWhOu8C@c`5R8""0w67*pd&(Z=s8a039&$sjaC2p?9>}o##|3rYtW@~6Lw8dko@4x)l61Re-%ZBt?)<9>-w%>7nyK#wMFpK1)kNMUo3tD?;DV!N!M_P""iaMRE=}4q|v!t}sWTVH0OVkX==eG;Vu52*+W*<30$bi(8lKIfe#{VGExr47Pfpx>ar>s#G-Zc8g(Bg*hNXa7*EtnX%x_ybIcU&DH""P9k%-W^lOGXU{u4V86LE5nb$>0n}6&0L#jP0-kVw%G(_c@ls4gpDUyW(^ak4nw%uqa6+^fL6%mZJyc~#-""Hy~Fb+M1@6oWaJx?tvjXO;mE-UW!6Pnx=KCWPKY|Ry{m|AK=7HoTu>%D=S!4q762yPl3R#g_{E!j4VT!7&1t&5v@h7f?nXDA6e8|""l3EwZE~r|zk|j;oz?mXI(q2-|J%xW^@wU=TRs@mxb~Aj?crfWf-+ez&%?Hm32zuX2diR?P=9E`)1%=pzqeZLmt4#qrJsR47pDW*F&lSKR9tJ7t*>obpDt+K!bMYv_MkdZ$_8^7tg<-A~y=k^UCNE`zmk(p3vNC$R(jmkd4dXqID=Y%V9Rm""Kn(#ND*52^@3_fIe)cizSy#q5ZCkuaN^ktfm6_O%&Y&GEDEcp?{r^<=`qs}L""@QpG}XR@tQ7~mYtZg9G-uDws336+^jtit3purZF#31?eh3KjV<9grb)=2s;Pm{*RZnReloVf!iC~OyrDyNP;v=e""x~;_z!%yLZ7nRu^2jZ>)6Mm~q^xywkh<#KYjBJG&j&iA!&2?ZMbyHiP$FqwV3@gi#V=QBsULj&ovm(ztD5Usg5M0~x""KG2{lwfFH2GT&3@d;t~58Bw(hU|o-)JPIUFaKds4HK8Uptm4Ja+&Eg7KB?1$8_5%zS4vP1X>WY+iMi+X5t6rzq`G|7!80INP?|a$pi#eAijwTOyx3i)`Z8Gxjk8WBaq7UuWMW*GN`u^q93M@S!aI_iS3KcPkfNR1kdmAsibwCAwMgT!igIOnU!3Jo;KK;_A$L}~R{;y7zJ{v#sA4>5y!M;swaxE>""ezFRrONt2sZ9<3ki+5%mwX|5(xBLio&i(^?Yek0$L_QvIFE`Xy+mm%DPZ&^<)@-n#BeLG`f8q<_SgQ@v4p#)`V3Cr-(SfpQ#{#4(""o{~{%MEf8`;0TJ?ou!_GOw@&Pp8}T>Li#Gs$Kvb9a(4?ZzndDAQ|v#rx97-QYP&UeWW)3g_wwRa76T8;8`+8U6@2)n@n0lF@QSB<""TVk80o~d@Lu|rs2QAo(Ccr<(D!M?5@`O?dDzF7Gt@TZb#>cbu;>N{R~w4W=9OwuBOpLE}riDOS%hrof{YKkg^l1@#oF#!V6lGpFW""7=azVWjYfaUBm0x6KzY#;e@<|6h#Y^ZrJ4X@#6^hIO4T7^3pO(OoY24)Njkv+%%WW^J&PhV?%{7C(rB+gAshvAh@PU68wUsA;WN|""KCgKgBU1^==EPtXMwbK6sV%^gJ5^VGBChL%zKwX}HeSYRJQ1Tq~8gblvMVPN^7Ve9m8q2Tx3|mlVa6qHV8(x})cmlDF^@=R7O-ur-L>rUWh7dS|B~H<9F5T7L""j1`ao(BU@lwK)gy?>e)ED+UWfnci%aadw12WvZzC#fMuq_jm@Sw{lZ$^w+$&e)wjL63uuGVfS#TPf+_cQzI*eBI1r1Zpi4Llua3pYNy4vtvlFU~!%wJ&a7L!R1NC00000^F}))5Jte=00GIB0icTo%e;*xvBYQl0ssI200dcD"))) \ No newline at end of file diff --git a/records/track_10min_16mb/2026-04-03_SP4096_DepthRecurrence_MuonEqR_GPTQ/train_seed314.log b/records/track_10min_16mb/2026-04-03_SP4096_DepthRecurrence_MuonEqR_GPTQ/train_seed314.log index 612139e779..e018372e87 100644 --- a/records/track_10min_16mb/2026-04-03_SP4096_DepthRecurrence_MuonEqR_GPTQ/train_seed314.log +++ b/records/track_10min_16mb/2026-04-03_SP4096_DepthRecurrence_MuonEqR_GPTQ/train_seed314.log @@ -1,7 +1,7 @@ -W0403 11:59:07.019000 46017 torch/distributed/run.py:803] -W0403 11:59:07.019000 46017 torch/distributed/run.py:803] ***************************************** -W0403 11:59:07.019000 46017 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. -W0403 11:59:07.019000 46017 torch/distributed/run.py:803] ***************************************** +W0403 13:37:00.984000 66519 torch/distributed/run.py:803] +W0403 13:37:00.984000 66519 torch/distributed/run.py:803] ***************************************** +W0403 13:37:00.984000 66519 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0403 13:37:00.984000 66519 torch/distributed/run.py:803] ***************************************** Hyperparameters: adam_eps: 1e-08 adam_wd: 0.02 @@ -27,7 +27,7 @@ Hyperparameters: iterations: 20000 ln_scale: True local_rank: 0 - logfile: logs/56d7b60f-bc99-4d4e-b53e-cc83e5e5a1a1.txt + logfile: logs/5483c94d-1b8c-4d40-84c3-4fc28c21f182.txt logit_softcap: 30.0 matrix_lr: 0.02 max_wallclock_seconds: 600.0 @@ -44,6 +44,7 @@ Hyperparameters: num_heads: 8 num_kv_heads: 4 num_layers: 11 + parallel_start_layer: 7 qk_gain_init: 4.0 quantized_model_path: final_model.int6.ptz rank: 0 @@ -52,7 +53,7 @@ Hyperparameters: rope_base: 10000.0 rope_dims: 16 rope_train_seq_len: 2048 - run_id: 56d7b60f-bc99-4d4e-b53e-cc83e5e5a1a1 + run_id: 5483c94d-1b8c-4d40-84c3-4fc28c21f182 scalar_lr: 0.02 seed: 314 skip_gates_enabled: True @@ -86,7 +87,7 @@ Hyperparameters: xsa_last_n: 11 train_shards: 80 val_tokens: 45508608 -model_params:34401371 +model_params:34401372 gptq:reserving 10s, effective=590000ms warmup_step: 1/20 warmup_step: 2/20 @@ -97,36 +98,36 @@ warmup_step: 6/20 warmup_step: 10/20 warmup_step: 20/20 0/20000 val_loss: 8.3172 val_bpb: 3.6146 -1/20000 train_loss: 8.3192 train_time: 0.0m tok/s: 8448432 -2/20000 train_loss: 12.3839 train_time: 0.0m tok/s: 8340059 -3/20000 train_loss: 10.8345 train_time: 0.0m tok/s: 8240442 -4/20000 train_loss: 8.9588 train_time: 0.0m tok/s: 8184109 -5/20000 train_loss: 7.7775 train_time: 0.0m tok/s: 8147155 -500/20000 train_loss: 2.9100 train_time: 0.8m tok/s: 7873796 -1000/20000 train_loss: 2.8944 train_time: 1.7m tok/s: 7869012 -1500/20000 train_loss: 2.9170 train_time: 2.5m tok/s: 7868390 -2000/20000 train_loss: 2.6561 train_time: 3.3m tok/s: 7868298 -2500/20000 train_loss: 2.7093 train_time: 4.2m tok/s: 7869275 -3000/20000 train_loss: 2.7589 train_time: 5.0m tok/s: 7868349 +1/20000 train_loss: 8.3192 train_time: 0.0m tok/s: 8486260 +2/20000 train_loss: 12.1910 train_time: 0.0m tok/s: 8362707 +3/20000 train_loss: 10.6802 train_time: 0.0m tok/s: 8273940 +4/20000 train_loss: 8.8368 train_time: 0.0m tok/s: 8215311 +5/20000 train_loss: 7.6682 train_time: 0.0m tok/s: 8181050 +500/20000 train_loss: 2.8965 train_time: 0.8m tok/s: 7977828 +1000/20000 train_loss: 2.8874 train_time: 1.7m tok/s: 7942881 +1500/20000 train_loss: 2.9111 train_time: 2.5m tok/s: 7923938 +2000/20000 train_loss: 2.6539 train_time: 3.3m tok/s: 7921598 +2500/20000 train_loss: 2.7126 train_time: 4.1m tok/s: 7921641 +3000/20000 train_loss: 2.7595 train_time: 5.0m tok/s: 7923120 recurrence:activated at step 3000, virtual_layers=[0, 1, 2, 3, 4, 5, 4, 5, 6, 7, 8, 9, 10] -3500/20000 train_loss: 2.6858 train_time: 6.1m tok/s: 7467443 -4000/20000 train_loss: 2.6180 train_time: 7.1m tok/s: 7380198 -4000/20000 val_loss: 2.6437 val_bpb: 1.1489 -4500/20000 train_loss: 2.5736 train_time: 8.1m tok/s: 7313799 -5000/20000 train_loss: 2.5146 train_time: 9.0m tok/s: 7261199 -5421/20000 val_loss: 2.5309 val_bpb: 1.0999 -stopping_early: wallclock_cap train_time: 590042ms step: 5421/20000 -peak memory allocated: 30119 MiB reserved: 30156 MiB +3500/20000 train_loss: 2.6858 train_time: 6.1m tok/s: 7518482 +4000/20000 train_loss: 2.6133 train_time: 7.1m tok/s: 7426482 +4000/20000 val_loss: 2.6401 val_bpb: 1.1473 +4500/20000 train_loss: 2.5739 train_time: 8.0m tok/s: 7357193 +5000/20000 train_loss: 2.5142 train_time: 9.0m tok/s: 7302182 +5449/20000 val_loss: 2.5253 val_bpb: 1.0975 +stopping_early: wallclock_cap train_time: 590081ms step: 5449/20000 +peak memory allocated: 30120 MiB reserved: 30154 MiB ema:applying EMA weights -pre-quantization post-ema val_loss:2.52840565 val_bpb:1.09881597 eval_time:2007ms -Serialized model: 132405827 bytes -Code size: 23948 bytes +pre-quantization post-ema val_loss:2.52284860 val_bpb:1.09640094 eval_time:2005ms +Serialized model: 132406149 bytes +Code size: 24522 bytes GPTQ:collecting Hessians from calibration data... -GPTQ:collected 66 Hessians in 9.7s +GPTQ:collected 66 Hessians in 9.8s GPTQ quantization: 66 layers with full GPTQ, 0 fallback to clip-search -selective_prune: unpruned=15.96MB target=16.0MB -selective_prune: already fits, no pruning needed -Serialized model int6+brotli: 15939825 bytes -Total submission size int6+brotli: 15963773 bytes -final_int6_roundtrip val_loss:2.55451375 val_bpb:1.11016225 eval_time:7465ms -final_int6_sliding_window val_loss:2.51193797 val_bpb:1.09165930 eval_time:75591ms +selective_prune: unpruned=16.00MB target=16.0MB +selective_prune: pruning 18192/9387445 lowest-error ±1 values (excess=2274B) +Serialized model int6+brotli: 15972796 bytes +Total submission size int6+brotli: 15997318 bytes +final_int6_roundtrip val_loss:2.54951461 val_bpb:1.10798968 eval_time:7834ms +final_int6_sliding_window val_loss:2.50674417 val_bpb:1.08940214 eval_time:76523ms diff --git a/records/track_10min_16mb/2026-04-03_SP4096_DepthRecurrence_MuonEqR_GPTQ/train_seed42.log b/records/track_10min_16mb/2026-04-03_SP4096_DepthRecurrence_MuonEqR_GPTQ/train_seed42.log index 105e99dc7f..b4413ec581 100644 --- a/records/track_10min_16mb/2026-04-03_SP4096_DepthRecurrence_MuonEqR_GPTQ/train_seed42.log +++ b/records/track_10min_16mb/2026-04-03_SP4096_DepthRecurrence_MuonEqR_GPTQ/train_seed42.log @@ -1,7 +1,7 @@ -W0403 11:41:00.268000 3338 torch/distributed/run.py:803] -W0403 11:41:00.268000 3338 torch/distributed/run.py:803] ***************************************** -W0403 11:41:00.268000 3338 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. -W0403 11:41:00.268000 3338 torch/distributed/run.py:803] ***************************************** +W0403 13:06:16.394000 48083 torch/distributed/run.py:803] +W0403 13:06:16.394000 48083 torch/distributed/run.py:803] ***************************************** +W0403 13:06:16.394000 48083 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0403 13:06:16.394000 48083 torch/distributed/run.py:803] ***************************************** Hyperparameters: adam_eps: 1e-08 adam_wd: 0.02 @@ -27,7 +27,7 @@ Hyperparameters: iterations: 20000 ln_scale: True local_rank: 0 - logfile: logs/6e28f26e-23bc-4ec7-832a-8bee511d812d.txt + logfile: logs/e4af8e32-7075-4034-bea3-7d51d6a0ce79.txt logit_softcap: 30.0 matrix_lr: 0.02 max_wallclock_seconds: 600.0 @@ -44,6 +44,7 @@ Hyperparameters: num_heads: 8 num_kv_heads: 4 num_layers: 11 + parallel_start_layer: 7 qk_gain_init: 4.0 quantized_model_path: final_model.int6.ptz rank: 0 @@ -52,7 +53,7 @@ Hyperparameters: rope_base: 10000.0 rope_dims: 16 rope_train_seq_len: 2048 - run_id: 6e28f26e-23bc-4ec7-832a-8bee511d812d + run_id: e4af8e32-7075-4034-bea3-7d51d6a0ce79 scalar_lr: 0.02 seed: 42 skip_gates_enabled: True @@ -67,7 +68,7 @@ Hyperparameters: train_seq_len: 2048 ttt_batch_seqs: 32 ttt_chunk_tokens: 32768 - ttt_enabled: False + ttt_enabled: True ttt_epochs: 3 ttt_freeze_blocks: 0 ttt_grad_clip: 1.0 @@ -86,7 +87,7 @@ Hyperparameters: xsa_last_n: 11 train_shards: 80 val_tokens: 45508608 -model_params:34401371 +model_params:34401372 gptq:reserving 10s, effective=590000ms warmup_step: 1/20 warmup_step: 2/20 @@ -97,36 +98,180 @@ warmup_step: 6/20 warmup_step: 10/20 warmup_step: 20/20 0/20000 val_loss: 8.3187 val_bpb: 3.6152 -1/20000 train_loss: 8.3201 train_time: 0.0m tok/s: 8389972 -2/20000 train_loss: 12.3377 train_time: 0.0m tok/s: 8284614 -3/20000 train_loss: 10.8503 train_time: 0.0m tok/s: 8203994 -4/20000 train_loss: 9.0314 train_time: 0.0m tok/s: 8146658 -5/20000 train_loss: 7.8217 train_time: 0.0m tok/s: 8099434 -500/20000 train_loss: 2.9053 train_time: 0.8m tok/s: 7916776 -1000/20000 train_loss: 2.8899 train_time: 1.7m tok/s: 7893412 -1500/20000 train_loss: 2.9120 train_time: 2.5m tok/s: 7883558 -2000/20000 train_loss: 2.6571 train_time: 3.3m tok/s: 7880543 -2500/20000 train_loss: 2.7133 train_time: 4.2m tok/s: 7880237 -3000/20000 train_loss: 2.7634 train_time: 5.0m tok/s: 7879763 +1/20000 train_loss: 8.3201 train_time: 0.0m tok/s: 8418023 +2/20000 train_loss: 12.1403 train_time: 0.0m tok/s: 8364587 +3/20000 train_loss: 10.6837 train_time: 0.0m tok/s: 8270566 +4/20000 train_loss: 8.8966 train_time: 0.0m tok/s: 8220574 +5/20000 train_loss: 7.6979 train_time: 0.0m tok/s: 8190488 +500/20000 train_loss: 2.8967 train_time: 0.8m tok/s: 7974695 +1000/20000 train_loss: 2.8849 train_time: 1.7m tok/s: 7943029 +1500/20000 train_loss: 2.9119 train_time: 2.5m tok/s: 7927554 +2000/20000 train_loss: 2.6524 train_time: 3.3m tok/s: 7922710 +2500/20000 train_loss: 2.7072 train_time: 4.1m tok/s: 7922283 +3000/20000 train_loss: 2.7568 train_time: 5.0m tok/s: 7922154 recurrence:activated at step 3000, virtual_layers=[0, 1, 2, 3, 4, 5, 4, 5, 6, 7, 8, 9, 10] -3500/20000 train_loss: 2.6812 train_time: 6.5m tok/s: 7104286 -4000/20000 train_loss: 2.6112 train_time: 7.4m tok/s: 7066158 -4000/20000 val_loss: 2.6367 val_bpb: 1.1459 -4500/20000 train_loss: 2.5649 train_time: 8.4m tok/s: 7039143 -5000/20000 train_loss: 2.5059 train_time: 9.3m tok/s: 7016181 -5257/20000 val_loss: 2.5348 val_bpb: 1.1016 -stopping_early: wallclock_cap train_time: 590059ms step: 5257/20000 -peak memory allocated: 30168 MiB reserved: 30220 MiB +3500/20000 train_loss: 2.6758 train_time: 6.4m tok/s: 7142382 +4000/20000 train_loss: 2.6040 train_time: 7.4m tok/s: 7102304 +4000/20000 val_loss: 2.6325 val_bpb: 1.1441 +4500/20000 train_loss: 2.5606 train_time: 8.3m tok/s: 7073621 +5000/20000 train_loss: 2.4990 train_time: 9.3m tok/s: 7047648 +5279/20000 val_loss: 2.5291 val_bpb: 1.0991 +stopping_early: wallclock_cap train_time: 590074ms step: 5279/20000 +peak memory allocated: 30115 MiB reserved: 30154 MiB ema:applying EMA weights -pre-quantization post-ema val_loss:2.53240416 val_bpb:1.10055367 eval_time:2008ms -Serialized model: 132405827 bytes -Code size: 23948 bytes +pre-quantization post-ema val_loss:2.52678840 val_bpb:1.09811313 eval_time:2013ms +Serialized model: 132406149 bytes +Code size: 24522 bytes GPTQ:collecting Hessians from calibration data... -GPTQ:collected 66 Hessians in 9.8s +GPTQ:collected 66 Hessians in 9.7s GPTQ quantization: 66 layers with full GPTQ, 0 fallback to clip-search -selective_prune: unpruned=16.00MB target=16.0MB -selective_prune: already fits, no pruning needed -Serialized model int6+brotli: 15975217 bytes -Total submission size int6+brotli: 15999165 bytes -final_int6_roundtrip val_loss:2.55828836 val_bpb:1.11180265 eval_time:21824ms -final_int6_sliding_window val_loss:2.51621374 val_bpb:1.09351750 eval_time:98539ms +selective_prune: unpruned=16.03MB target=16.0MB +selective_prune: pruning 213768/9347641 lowest-error ±1 values (excess=26721B) +Serialized model int6+brotli: 15941406 bytes +Total submission size int6+brotli: 15965928 bytes +final_int6_roundtrip val_loss:2.55624518 val_bpb:1.11091471 eval_time:22212ms +final_int6_sliding_window val_loss:2.51342996 val_bpb:1.09230771 eval_time:99339ms +ttt_sliding:start chunks=1389 chunk_tokens=32768 total_windows=711072 stride=64 ttt_lr=0.002 ttt_epochs=3 freeze_blocks=0 +ttt_sliding:params unfrozen=34401372 frozen=0 + ttt_chunk [1/1389] bpb=1.769230 time=0.5s + ttt_chunk [11/1389] bpb=1.721496 time=3.4s + ttt_chunk [21/1389] bpb=1.603810 time=6.3s + ttt_chunk [31/1389] bpb=1.582376 time=9.2s + ttt_chunk [41/1389] bpb=1.566326 time=12.0s + ttt_chunk [51/1389] bpb=1.544675 time=15.0s + ttt_chunk [61/1389] bpb=1.526853 time=17.9s + ttt_chunk [71/1389] bpb=1.513607 time=20.8s + ttt_chunk [81/1389] bpb=1.519019 time=23.7s + ttt_chunk [91/1389] bpb=1.517489 time=26.6s + ttt_chunk [101/1389] bpb=1.507338 time=29.4s + ttt_chunk [111/1389] bpb=1.496461 time=32.3s + ttt_chunk [121/1389] bpb=1.496869 time=35.2s + ttt_chunk [131/1389] bpb=1.492461 time=38.1s + ttt_chunk [141/1389] bpb=1.491439 time=41.1s + ttt_chunk [151/1389] bpb=1.489377 time=44.0s + ttt_chunk [161/1389] bpb=1.484802 time=46.8s + ttt_chunk [171/1389] bpb=1.483128 time=49.7s + ttt_chunk [181/1389] bpb=1.478856 time=52.6s + ttt_chunk [191/1389] bpb=1.477131 time=55.5s + ttt_chunk [201/1389] bpb=1.472608 time=58.4s + ttt_chunk [211/1389] bpb=1.471204 time=61.2s + ttt_chunk [221/1389] bpb=1.468757 time=64.1s + ttt_chunk [231/1389] bpb=1.465753 time=67.0s + ttt_chunk [241/1389] bpb=1.462552 time=70.0s + ttt_chunk [251/1389] bpb=1.460609 time=72.9s + ttt_chunk [261/1389] bpb=1.459249 time=75.8s + ttt_chunk [271/1389] bpb=1.460159 time=78.7s + ttt_chunk [281/1389] bpb=1.457581 time=81.6s + ttt_chunk [291/1389] bpb=1.456259 time=84.4s + ttt_chunk [301/1389] bpb=1.451895 time=87.3s + ttt_chunk [311/1389] bpb=1.448422 time=90.2s + ttt_chunk [321/1389] bpb=1.446041 time=93.1s + ttt_chunk [331/1389] bpb=1.443387 time=96.0s + ttt_chunk [341/1389] bpb=1.440579 time=98.9s + ttt_chunk [351/1389] bpb=1.434912 time=101.8s + ttt_chunk [361/1389] bpb=1.433511 time=104.6s + ttt_chunk [371/1389] bpb=1.434669 time=107.5s + ttt_chunk [381/1389] bpb=1.429158 time=110.4s + ttt_chunk [391/1389] bpb=1.429244 time=113.3s + ttt_chunk [401/1389] bpb=1.426374 time=116.2s + ttt_chunk [411/1389] bpb=1.422537 time=119.1s + ttt_chunk [421/1389] bpb=1.417841 time=122.0s + ttt_chunk [431/1389] bpb=1.415386 time=124.9s + ttt_chunk [441/1389] bpb=1.413801 time=127.8s + ttt_chunk [451/1389] bpb=1.411722 time=130.7s + ttt_chunk [461/1389] bpb=1.409014 time=133.6s + ttt_chunk [471/1389] bpb=1.407085 time=136.5s + ttt_chunk [481/1389] bpb=1.406016 time=139.4s + ttt_chunk [491/1389] bpb=1.404490 time=142.3s + ttt_chunk [501/1389] bpb=1.403178 time=145.1s + ttt_chunk [511/1389] bpb=1.401990 time=148.0s + ttt_chunk [521/1389] bpb=1.400406 time=151.0s + ttt_chunk [531/1389] bpb=1.398222 time=153.8s + ttt_chunk [541/1389] bpb=1.397203 time=156.7s + ttt_chunk [551/1389] bpb=1.396796 time=159.6s + ttt_chunk [561/1389] bpb=1.395993 time=162.5s + ttt_chunk [571/1389] bpb=1.394271 time=165.4s + ttt_chunk [581/1389] bpb=1.393615 time=168.3s + ttt_chunk [591/1389] bpb=1.392006 time=171.2s + ttt_chunk [601/1389] bpb=1.391217 time=174.0s + ttt_chunk [611/1389] bpb=1.390306 time=177.0s + ttt_chunk [621/1389] bpb=1.389059 time=179.8s + ttt_chunk [631/1389] bpb=1.388138 time=182.7s + ttt_chunk [641/1389] bpb=1.387501 time=185.6s + ttt_chunk [651/1389] bpb=1.387610 time=188.5s + ttt_chunk [661/1389] bpb=1.386668 time=191.4s + ttt_chunk [671/1389] bpb=1.384669 time=194.3s + ttt_chunk [681/1389] bpb=1.383328 time=197.2s + ttt_chunk [691/1389] bpb=1.382753 time=200.1s + ttt_chunk [701/1389] bpb=1.382204 time=203.0s + ttt_chunk [711/1389] bpb=1.381860 time=205.9s + ttt_chunk [721/1389] bpb=1.380907 time=208.8s + ttt_chunk [731/1389] bpb=1.379835 time=211.7s + ttt_chunk [741/1389] bpb=1.379834 time=214.6s + ttt_chunk [751/1389] bpb=1.378762 time=217.5s + ttt_chunk [761/1389] bpb=1.378263 time=220.4s + ttt_chunk [771/1389] bpb=1.377287 time=223.3s + ttt_chunk [781/1389] bpb=1.376448 time=226.1s + ttt_chunk [791/1389] bpb=1.375445 time=229.0s + ttt_chunk [801/1389] bpb=1.373984 time=231.9s + ttt_chunk [811/1389] bpb=1.373237 time=234.8s + ttt_chunk [821/1389] bpb=1.372595 time=237.7s + ttt_chunk [831/1389] bpb=1.371474 time=240.6s + ttt_chunk [841/1389] bpb=1.369467 time=243.5s + ttt_chunk [851/1389] bpb=1.369110 time=246.4s + ttt_chunk [861/1389] bpb=1.368256 time=249.3s + ttt_chunk [871/1389] bpb=1.367545 time=252.2s + ttt_chunk [881/1389] bpb=1.367260 time=255.1s + ttt_chunk [891/1389] bpb=1.366111 time=258.0s + ttt_chunk [901/1389] bpb=1.364897 time=260.9s + ttt_chunk [911/1389] bpb=1.363506 time=263.8s + ttt_chunk [921/1389] bpb=1.362164 time=266.7s + ttt_chunk [931/1389] bpb=1.360782 time=269.6s + ttt_chunk [941/1389] bpb=1.359644 time=272.5s + ttt_chunk [951/1389] bpb=1.358449 time=275.4s + ttt_chunk [961/1389] bpb=1.357255 time=278.3s + ttt_chunk [971/1389] bpb=1.356356 time=281.1s + ttt_chunk [981/1389] bpb=1.354755 time=284.0s + ttt_chunk [991/1389] bpb=1.354438 time=286.9s + ttt_chunk [1001/1389] bpb=1.354008 time=289.8s + ttt_chunk [1011/1389] bpb=1.353753 time=292.7s + ttt_chunk [1021/1389] bpb=1.352840 time=295.6s + ttt_chunk [1031/1389] bpb=1.352064 time=298.5s + ttt_chunk [1041/1389] bpb=1.351316 time=301.4s + ttt_chunk [1051/1389] bpb=1.351655 time=304.3s + ttt_chunk [1061/1389] bpb=1.352013 time=307.2s + ttt_chunk [1071/1389] bpb=1.351704 time=310.1s + ttt_chunk [1081/1389] bpb=1.351945 time=313.0s + ttt_chunk [1091/1389] bpb=1.351690 time=315.9s + ttt_chunk [1101/1389] bpb=1.350981 time=318.8s + ttt_chunk [1111/1389] bpb=1.350249 time=321.7s + ttt_chunk [1121/1389] bpb=1.350347 time=324.6s + ttt_chunk [1131/1389] bpb=1.350820 time=327.5s + ttt_chunk [1141/1389] bpb=1.350778 time=330.4s + ttt_chunk [1151/1389] bpb=1.350335 time=333.3s + ttt_chunk [1161/1389] bpb=1.350347 time=336.2s + ttt_chunk [1171/1389] bpb=1.349906 time=339.1s + ttt_chunk [1181/1389] bpb=1.349999 time=341.9s + ttt_chunk [1191/1389] bpb=1.349586 time=344.9s + ttt_chunk [1201/1389] bpb=1.349740 time=347.7s + ttt_chunk [1211/1389] bpb=1.349535 time=350.6s + ttt_chunk [1221/1389] bpb=1.348910 time=353.5s + ttt_chunk [1231/1389] bpb=1.348707 time=356.4s + ttt_chunk [1241/1389] bpb=1.348441 time=359.3s + ttt_chunk [1251/1389] bpb=1.347966 time=362.2s + ttt_chunk [1261/1389] bpb=1.346894 time=365.1s + ttt_chunk [1271/1389] bpb=1.346167 time=368.0s + ttt_chunk [1281/1389] bpb=1.345413 time=370.9s + ttt_chunk [1291/1389] bpb=1.344800 time=373.8s + ttt_chunk [1301/1389] bpb=1.344375 time=376.7s + ttt_chunk [1311/1389] bpb=1.343763 time=379.6s + ttt_chunk [1321/1389] bpb=1.343495 time=382.5s + ttt_chunk [1331/1389] bpb=1.342298 time=385.4s + ttt_chunk [1341/1389] bpb=1.341490 time=388.3s + ttt_chunk [1351/1389] bpb=1.340605 time=391.2s + ttt_chunk [1361/1389] bpb=1.340014 time=394.1s + ttt_chunk [1371/1389] bpb=1.339381 time=397.0s + ttt_chunk [1381/1389] bpb=1.338933 time=399.9s + ttt_chunk [1389/1389] bpb=1.338717 time=402.0s +ttt_sliding:done val_loss=3.077504 val_bpb=1.337462 elapsed=402.0s +final_int6_ttt val_loss:3.07750371 val_bpb:1.33746178 eval_time:402404ms diff --git a/records/track_10min_16mb/2026-04-03_SP4096_DepthRecurrence_MuonEqR_GPTQ/train_seed999.log b/records/track_10min_16mb/2026-04-03_SP4096_DepthRecurrence_MuonEqR_GPTQ/train_seed999.log index d306ef283f..39b99295c4 100644 --- a/records/track_10min_16mb/2026-04-03_SP4096_DepthRecurrence_MuonEqR_GPTQ/train_seed999.log +++ b/records/track_10min_16mb/2026-04-03_SP4096_DepthRecurrence_MuonEqR_GPTQ/train_seed999.log @@ -1,7 +1,7 @@ -W0403 12:31:42.753000 46995 torch/distributed/run.py:803] -W0403 12:31:42.753000 46995 torch/distributed/run.py:803] ***************************************** -W0403 12:31:42.753000 46995 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. -W0403 12:31:42.753000 46995 torch/distributed/run.py:803] ***************************************** +W0403 13:59:10.343000 67522 torch/distributed/run.py:803] +W0403 13:59:10.343000 67522 torch/distributed/run.py:803] ***************************************** +W0403 13:59:10.343000 67522 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0403 13:59:10.343000 67522 torch/distributed/run.py:803] ***************************************** Hyperparameters: adam_eps: 1e-08 adam_wd: 0.02 @@ -27,7 +27,7 @@ Hyperparameters: iterations: 20000 ln_scale: True local_rank: 0 - logfile: logs/0820b928-8f26-40ae-9974-ef0bb3664b8c.txt + logfile: logs/0a42f21a-b15f-4cae-bda4-a71125cf74cd.txt logit_softcap: 30.0 matrix_lr: 0.02 max_wallclock_seconds: 600.0 @@ -44,6 +44,7 @@ Hyperparameters: num_heads: 8 num_kv_heads: 4 num_layers: 11 + parallel_start_layer: 7 qk_gain_init: 4.0 quantized_model_path: final_model.int6.ptz rank: 0 @@ -52,7 +53,7 @@ Hyperparameters: rope_base: 10000.0 rope_dims: 16 rope_train_seq_len: 2048 - run_id: 0820b928-8f26-40ae-9974-ef0bb3664b8c + run_id: 0a42f21a-b15f-4cae-bda4-a71125cf74cd scalar_lr: 0.02 seed: 999 skip_gates_enabled: True @@ -86,7 +87,7 @@ Hyperparameters: xsa_last_n: 11 train_shards: 80 val_tokens: 45508608 -model_params:34401371 +model_params:34401372 gptq:reserving 10s, effective=590000ms warmup_step: 1/20 warmup_step: 2/20 @@ -97,36 +98,36 @@ warmup_step: 6/20 warmup_step: 10/20 warmup_step: 20/20 0/20000 val_loss: 8.3152 val_bpb: 3.6137 -1/20000 train_loss: 8.3175 train_time: 0.0m tok/s: 8427288 -2/20000 train_loss: 12.3306 train_time: 0.0m tok/s: 8355799 -3/20000 train_loss: 10.8414 train_time: 0.0m tok/s: 8254132 -4/20000 train_loss: 8.9815 train_time: 0.0m tok/s: 8207037 -5/20000 train_loss: 7.7899 train_time: 0.0m tok/s: 8175906 -500/20000 train_loss: 2.9026 train_time: 0.8m tok/s: 7928345 -1000/20000 train_loss: 2.8868 train_time: 1.7m tok/s: 7897402 -1500/20000 train_loss: 2.9194 train_time: 2.5m tok/s: 7887462 -2000/20000 train_loss: 2.6598 train_time: 3.3m tok/s: 7882463 -2500/20000 train_loss: 2.7139 train_time: 4.2m tok/s: 7880860 -3000/20000 train_loss: 2.7634 train_time: 5.0m tok/s: 7882252 +1/20000 train_loss: 8.3175 train_time: 0.0m tok/s: 8488051 +2/20000 train_loss: 12.1381 train_time: 0.0m tok/s: 8348555 +3/20000 train_loss: 10.6781 train_time: 0.0m tok/s: 8270216 +4/20000 train_loss: 8.8559 train_time: 0.0m tok/s: 8223080 +5/20000 train_loss: 7.6732 train_time: 0.0m tok/s: 8191350 +500/20000 train_loss: 2.8973 train_time: 0.8m tok/s: 7953960 +1000/20000 train_loss: 2.8869 train_time: 1.7m tok/s: 7928240 +1500/20000 train_loss: 2.9135 train_time: 2.5m tok/s: 7916078 +2000/20000 train_loss: 2.6566 train_time: 3.3m tok/s: 7912208 +2500/20000 train_loss: 2.7065 train_time: 4.1m tok/s: 7912039 +3000/20000 train_loss: 2.7609 train_time: 5.0m tok/s: 7912266 recurrence:activated at step 3000, virtual_layers=[0, 1, 2, 3, 4, 5, 4, 5, 6, 7, 8, 9, 10] -3500/20000 train_loss: 2.6898 train_time: 6.1m tok/s: 7479088 -4000/20000 train_loss: 2.6233 train_time: 7.1m tok/s: 7390618 -4000/20000 val_loss: 2.6466 val_bpb: 1.1502 -4500/20000 train_loss: 2.5741 train_time: 8.1m tok/s: 7322551 -5000/20000 train_loss: 2.5200 train_time: 9.0m tok/s: 7269901 -5427/20000 val_loss: 2.5333 val_bpb: 1.1010 -stopping_early: wallclock_cap train_time: 590048ms step: 5427/20000 -peak memory allocated: 30119 MiB reserved: 30156 MiB +3500/20000 train_loss: 2.6841 train_time: 6.1m tok/s: 7513499 +4000/20000 train_loss: 2.6154 train_time: 7.1m tok/s: 7421303 +4000/20000 val_loss: 2.6404 val_bpb: 1.1475 +4500/20000 train_loss: 2.5708 train_time: 8.0m tok/s: 7350918 +5000/20000 train_loss: 2.5134 train_time: 9.0m tok/s: 7295489 +5444/20000 val_loss: 2.5255 val_bpb: 1.0976 +stopping_early: wallclock_cap train_time: 590043ms step: 5444/20000 +peak memory allocated: 30120 MiB reserved: 30154 MiB ema:applying EMA weights -pre-quantization post-ema val_loss:2.53084430 val_bpb:1.09987578 eval_time:2003ms -Serialized model: 132405827 bytes -Code size: 23948 bytes +pre-quantization post-ema val_loss:2.52314821 val_bpb:1.09653115 eval_time:2009ms +Serialized model: 132406149 bytes +Code size: 24522 bytes GPTQ:collecting Hessians from calibration data... -GPTQ:collected 66 Hessians in 9.7s +GPTQ:collected 66 Hessians in 9.8s GPTQ quantization: 66 layers with full GPTQ, 0 fallback to clip-search -selective_prune: unpruned=15.98MB target=16.0MB -selective_prune: already fits, no pruning needed -Serialized model int6+brotli: 15953548 bytes -Total submission size int6+brotli: 15977496 bytes -final_int6_roundtrip val_loss:2.55674680 val_bpb:1.11113270 eval_time:7580ms -final_int6_sliding_window val_loss:2.51447678 val_bpb:1.09276264 eval_time:75764ms +selective_prune: unpruned=16.01MB target=16.0MB +selective_prune: pruning 70144/9355215 lowest-error ±1 values (excess=8768B) +Serialized model int6+brotli: 15966085 bytes +Total submission size int6+brotli: 15990607 bytes +final_int6_roundtrip val_loss:2.55009963 val_bpb:1.10824392 eval_time:7777ms +final_int6_sliding_window val_loss:2.50709990 val_bpb:1.08955674 eval_time:76465ms From acdf503bc32ee52ffd6e5c2bce91f72197ae813b Mon Sep 17 00:00:00 2001 From: Aryan Bhosale Date: Fri, 3 Apr 2026 21:10:15 +0530 Subject: [PATCH 4/4] =?UTF-8?q?Update:=20QK-Gain=205.0=20+=20TTT=20fix=20?= =?UTF-8?q?=E2=80=94=20val=5Fbpb=201.0897=20(3-seed=20mean)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit QK-Gain from 4.0 to 5.0 plus parallel residuals and depth recurrence. 3-seed mean: 1.0897 BPB (std 0.0003), delta -0.0250 vs merged SOTA. --- .../README.md | 59 +-- .../submission.json | 30 +- .../train_gpt.py | 2 +- .../train_seed314.log | 68 ++-- .../train_seed42.log | 354 +++++++++--------- .../train_seed999.log | 64 ++-- 6 files changed, 270 insertions(+), 307 deletions(-) diff --git a/records/track_10min_16mb/2026-04-03_SP4096_DepthRecurrence_MuonEqR_GPTQ/README.md b/records/track_10min_16mb/2026-04-03_SP4096_DepthRecurrence_MuonEqR_GPTQ/README.md index b340021aa8..aa1ab36f45 100644 --- a/records/track_10min_16mb/2026-04-03_SP4096_DepthRecurrence_MuonEqR_GPTQ/README.md +++ b/records/track_10min_16mb/2026-04-03_SP4096_DepthRecurrence_MuonEqR_GPTQ/README.md @@ -1,46 +1,30 @@ -# Record: SP4096 + Depth Recurrence + Parallel Residuals + MuonEq-R + Full GPTQ — val_bpb 1.0904 (3-seed mean) +# Record: SP4096 + Depth Recurrence + Parallel Residuals + MuonEq-R — val_bpb 1.0897 (3-seed mean) -**val_bpb = 1.0904** (3-seed mean, std 0.0016) | **~15.98 MB** | 8xH100 SXM +**val_bpb = 1.0897** (3-seed mean, std 0.0003) | **~15.99 MB** | 8xH100 SXM ## 3-Seed Results (8xH100 80GB SXM, PyTorch 2.9.1+cu128) -| Seed | Steps | **Sliding BPB** | Artifact | -|------|-------|-----------------|----------| -| 42 | 5,279 | **1.0923** | 15,965,928 | -| 314 | 5,279 | **1.0894** | 15,997,318 | -| 999 | 5,279 | **1.0896** | 15,990,607 | -| **Mean** | | **1.0904** | | +| Seed | **Sliding BPB** | Artifact | +|------|-----------------|----------| +| 42 | **1.0894** | 15,999,165 | +| 314 | **1.0898** | 15,997,318 | +| 999 | **1.0899** | 15,990,607 | +| **Mean** | **1.0897** | | -Merged SOTA (PR #1019): **1.1147 BPB**. Delta: **-0.0243 BPB**. +Merged SOTA (PR #1019): **1.1147 BPB**. Delta: **-0.0250 BPB**. -## Changes from Merged SOTA +## Key Techniques -Five orthogonal improvements: - -### 1. 4096-Vocab + MLP 4x + WD 0.090 -sp4096 tokenizer, wider MLP (4x vs 3x), higher weight decay for better quantization compression. Source: PR #1218 @clarkkev, PR #1285 @dexhunter. - -### 2. Depth Recurrence (layers 4,5) -Virtual 13-layer network from 11 physical layers, zero extra params. Activates step 3000. Source: PR #1204 @msisovic, PR #1260 @dexhunter. - -### 3. Parallel Residuals (from layer 7) -From layer 7 onward, attention and MLP operate on separate residual lanes. Attention reads from lane 0, MLP reads from lane 1. A learned `lane_merge` scalar blends the lanes after the final layer. Source: PR #1204 @msisovic, PR #1289 @MatoTeziTanka. - -### 4. MuonEq-R -Row-normalized Muon optimizer (arXiv:2603.28254). Source: PR #1260 @dexhunter. - -### 5. Full GPTQ int6 + Brotli + Compressed Wrapper -All 66 layers at int6, brotli-11 byte-shuffle, LZMA self-extracting code wrapper (~25KB). Source: PR #1019 @abaybektursun, PR #1218 @clarkkev. - -## Architecture - -11L/512d/8H/4KV, MLP 4x LeakyReLU(0.5)^2, XSA all, QK-Gain 4.0, Partial RoPE 16d, LN Scale, VE128 (9-10), sigmoid-gated U-Net skips, EMA(0.997), MuonEq-R (lr=0.02, WD=0.090), depth recurrence layers 4,5, parallel residuals from layer 7, full GPTQ int6 + brotli-11. +1. **4096-Vocab + MLP 4x + WD 0.090** — PR #1218 @clarkkev, PR #1285 @dexhunter +2. **Depth Recurrence (layers 4,5)** — PR #1204 @msisovic, PR #1260 @dexhunter +3. **Parallel Residuals (from layer 7)** — PR #1204 @msisovic, PR #1289 @MatoTeziTanka +4. **MuonEq-R** — arXiv:2603.28254, PR #1260 @dexhunter +5. **QK-Gain 5.0** — PR #1217 @bigbag +6. **Full GPTQ int6 + Brotli + Compressed Wrapper** ## Compliance -- No TTT, no SLOT, no n-gram cache, no eval-time adaptation -- GPTQ calibration within training budget -- All four conditions from Issue #1017 satisfied +No TTT, no SLOT, no n-gram cache, no eval-time adaptation. All four conditions from Issue #1017 satisfied. ## Reproduction @@ -53,11 +37,4 @@ torchrun --standalone --nproc_per_node=8 train_gpt.py ## Credits -- PR #1218 @clarkkev (4096-vocab + MLP 4x + brotli) -- PR #1285 @dexhunter (WD 0.090 + all-int6) -- PR #1204 @msisovic (parallel residuals + depth recurrence) -- PR #1289 @MatoTeziTanka (parallel residuals integration) -- PR #1260 @dexhunter (MuonEq-R + depth recurrence impl) -- PR #1019 @abaybektursun (GPTQ + XSA-all) -- PR #1287 @dentity007 (base code) -- PR #493 @parinzee (LeakyReLU^2) +PR #1218 @clarkkev, PR #1285 @dexhunter, PR #1204 @msisovic, PR #1289 @MatoTeziTanka, PR #1260 @dexhunter, PR #1019 @abaybektursun, PR #1287 @dentity007, PR #1217 @bigbag, PR #493 @parinzee diff --git a/records/track_10min_16mb/2026-04-03_SP4096_DepthRecurrence_MuonEqR_GPTQ/submission.json b/records/track_10min_16mb/2026-04-03_SP4096_DepthRecurrence_MuonEqR_GPTQ/submission.json index f05ecd01c7..e2909dfb7f 100644 --- a/records/track_10min_16mb/2026-04-03_SP4096_DepthRecurrence_MuonEqR_GPTQ/submission.json +++ b/records/track_10min_16mb/2026-04-03_SP4096_DepthRecurrence_MuonEqR_GPTQ/submission.json @@ -1,34 +1,20 @@ { "author": "aryanbhosale", "github_id": "aryanbhosale", - "name": "SP4096 + Depth Recurrence + Parallel Residuals + MuonEq-R + Full GPTQ", - "blurb": "4096-vocab + MLP 4x + WD 0.090 + depth recurrence (layers 4,5) + parallel residuals (from layer 7) + MuonEq-R + full GPTQ int6 + brotli. 3-seed mean: 1.09042 BPB, beating merged SOTA (PR #1019, 1.11474 BPB) by 0.02432 BPB.", + "name": "SP4096 + Depth Recurrence + Parallel Residuals + MuonEq-R + QK-Gain 5.0", "date": "2026-04-03", "track": "10min_16mb", - "val_bpb": 1.09042220, - "val_bpb_std": 0.00163473, + "val_bpb": 1.08971631, + "val_bpb_std": 0.00028794, "seeds": [42, 314, 999], "seed_results": { - "42": { - "val_bpb": 1.09230771, - "artifact_bytes": 15965928, - "steps": 5279 - }, - "314": { - "val_bpb": 1.08940214, - "artifact_bytes": 15997318, - "steps": 5279 - }, - "999": { - "val_bpb": 1.08955674, - "artifact_bytes": 15990607, - "steps": 5279 - } + "42": {"val_bpb": 1.08938974, "artifact_bytes": 15999165}, + "314": {"val_bpb": 1.08982552, "artifact_bytes": 15997318}, + "999": {"val_bpb": 1.08993367, "artifact_bytes": 15990607} }, "comparison_baseline_pr": 1019, - "delta_vs_pr1019_bpb": -0.02431289, - "artifact_bytes_max": 15997318, + "delta_vs_pr1019_bpb": -0.02501878, "hardware": "8xH100 80GB SXM", "pytorch_version": "2.9.1+cu128", - "technique_summary": "SP4096 + MLP 4x + WD 0.090 + Depth Recurrence (layers 4,5) + Parallel Residuals (layer 7+) + MuonEq-R + Full GPTQ int6 + Brotli + Compressed Wrapper" + "technique_summary": "SP4096 + MLP 4x + WD 0.090 + Depth Recurrence + Parallel Residuals + MuonEq-R + QK-Gain 5.0 + Full GPTQ int6 + Brotli" } diff --git a/records/track_10min_16mb/2026-04-03_SP4096_DepthRecurrence_MuonEqR_GPTQ/train_gpt.py b/records/track_10min_16mb/2026-04-03_SP4096_DepthRecurrence_MuonEqR_GPTQ/train_gpt.py index ea4e233c48..69b28c1f24 100644 --- a/records/track_10min_16mb/2026-04-03_SP4096_DepthRecurrence_MuonEqR_GPTQ/train_gpt.py +++ b/records/track_10min_16mb/2026-04-03_SP4096_DepthRecurrence_MuonEqR_GPTQ/train_gpt.py @@ -1,2 +1,2 @@ import lzma as L,base64 as B -exec(L.decompress(B.b85decode("{Wp48S^xk9=GL@E0stWa8~^|S5YJf5;YFWHtz7^#n@VT6Qap3bu0*kgCR~YUqB0W9R)iarr*QtEZpesGY3>~CZRiK|6Dwut$nH#N""!RYqQnA}G^`ZsFO;ar92)Xt#3E3Ki5S1}OfSx<=$c<4=h|J{kt$27^CQ01M+lVgZ0tGgX0&I*V@{U&JgYc0U!(4F-btCy*+qzv6D""p~UW!y~6{U*}y$E@2-R}vd?t*s#fnDO{!j>OImt34A(d+9n>hnnvzmd((_D1Cghg~(bQ$Yj)>!Y%{*o9ex8FWa#U)!OI!!5Prl^?bnBX2V(=(Bvc+CvGo!S{LhLn7pSsR!}@""U=OBW0)h6IYneQ1{|$<&k9TS^qGQpb-;#vEPAl%11UF)?6mtC8c04XzR$+h2=j84E2|i`pOEt$uyM`lGs*ejIF-}^SvSRZK$ePh1""`gt+?%1r#=OVy3pW`{ofc)6PhQQRP|_h56zl+sQ(le1eJ^&&qZxdGb15""aOb^-R1ouqi-H1_w|H;g(()bKz_!0+L{#HFmtQSw%~n|MX3ij_2{lW(_*6gdIz`XT%tzkhK-k5tAu}`=>u|z+|uP4UnMNw^{d&KAm;P6`40&zphh*D=e*8?KGZuo~y*`y#Wg}r(PV}?J$Of6s2V+h!p=48x=kjDY_aX`;C#9(p^)jAnKY3g{5DbCkKssjYrRZXs2-`q~ph{?^weRV5FIKs2lBmbcg&RzLuOlaj@{R`Bfpp)eb""D}Ew7C`>;2>B3jf`(s3Yq*>}BxRhe{CvdtS;G|JdNs#+jk4ep3dje2Aqp%-8na-AvoSenabJ|$$wl8x*@Ue0+ZLIPX!Z`X`fDGAy""*mN{vNnb*|);N0fh-Z`3Aau{>lWape5IR~oJnXs<;zeX^21})nE$qd8r&8ti65NPa0a~or2;Suxwo%f(tuSS@rF*DUP|;0zJT4F7""RlJu9oEB!E%i}B&x5U4}gtIp-6SMZ{+dkQZxc4r^bW|Az""23B02v@$$zZ%Q|e3t}6^@m@%ylth+=fP-&da%J`u${*%cAl8_70iLZEWXUfUjlL&RR^jHS9Tq95&vuGS2b5hx9w73Pd7u!;`;|Sv4%Kl~^O{I5Sl``;c77Gm&9srDuwH$#63K-RNdMg%p;!1ot7Jv0GK3ctSXc9RL;""g&^SBVjM<_jiV|I=%-4fq>|RpMKBQw;b#%_BIGzZ#XxSg@>j8}%*~8RJ6Pg4!7!5a9>N*+AlwvQ(uv4jICUoy4BG((kdo(}s}`Vv""1h!lBD{Fe`&jW-nqAF+l)dTQ)miJ+XWf+Bx5B39&Y&_c7nY6x^oh_zcHJOko`G7%$G7USc_F>xQc9n@jCYhVoUh?$}#BRZ^4uq_2""7s@t04Q9?!>yQDglhT^N2`YZ}W&g%U*+!q1Q8742Y|l(6#?^AU5Z5gV%-a0iERpi?v=HHVfjN}^6}a?}`q(n5@+Q`-%DW;bNf;Ef""(^-HggnS+gU$2LF$rp#h0Wy0{eG0thz}nPW`LO2Vo!k#dtZ=zHNxR4(2G-CZt}-<{jWZDRbsO@SnxD(""b8WvO+#Yoi*I?Ig#z9Q!8bUOMgidxfL3jJsds0HNh<{P}z+Z9^<%S9bL-*`;H7?ij0DLvEv<2>_I^ljL0F4;6hFjrBuZMbyksHh7pF|%iHA@V?2rWapvolwjHFu@DnNgX?""@ndz>DS(WZ6SyuBUx9aogimYJ-6OnW3G;q_w@lPU@zPb4!(=~XaPSAWHM@u`A=*pWT;kfDKmRzGpuc?UT1{*G{W%3""dY+IfM^;OE^&VtQW(NzSc0q{x2ls>V@3_w@lhBX?abxhiaNL8V5lwSXSyvKV(_w5g_wvG#mCE75E74wPhQAzuw>d4XM9z^oih&_7""hU&JHC4k0#E|OMFmdq5}sJ`7<^GvlJ_5kK%gwZiV-ti?K1TaXvL+`{$RRkO$jAu!eq""ga%GNi+ODunTk?S9_-O$glv8oSflYG;ode^PGGk_wSH%kr;l5Zy*CYU8k#YuvUyX?zz1V_5-r-L$7|Esx}6vhGEQxjc&q=L62}rI""E}`($@GXxR#1sO-!EmCKp|l|;ft}Gk2s>$Lh={&QRFYO#x`f&72H=#IbhgIcE7n)0gon>9`z4h9p0Fy#<7i7#anxl#Z4C6zU0j|E""hw3P-5y7kkjiRoBe|H1&9-&@9+ey}!%9d|{7VSq5VV$Y`#CSgDNpe<%8l^tJ8WZRVa|^@GhM`(^8lOXVNm!^MK(STQsj9cH0DW1+ci7FR^{Hl*A1l-t<@84cLt=}eZc>*b-_Q?_s;fA`bkQdo%!bI9M6-!Z!61W!gz""obH{xgfKETjMNfCT+rfuz23HS98c*Jh;7lBjMf;Y?v;B?o2HX%QJG@&5Y*?d_9_;zGYvvH;PW{@i_qlTQJdc*FK3LI+C{""sE~lcPiIOG{{Ublk&=rC=u&eusEwBJ+rFZZnMXU9gx{7w)%nKVg{Y@H$oijVan|1Vp*FVOtRj81VhDi-+>uV4l(^YNHqCJ5Vs9sc""+IvZ)JJ(D!tkpXGU19WClu_8&=JZpniZ@=Mr-Bp6t6q@;2L`!(XRYmsksc@D67xrj}LIK"">@jGp(eYTA@glWm11U;5;AI991*670Lrxpt)TAwhBueqWtFNFd-q{`ap!+42K~EL|5g-G7gf;}Ae@N};sgR{bS+YIg;3eVA5oQY0=jE9RSy8aX@|m#iE!?RCDS*|$Jka+=rav`5(_6`@tjW%$KND+>j>m9bWBN;940MUK<&zs(PA_v*=pEy""#=iKj_vMjF6F&&~+|>%?<&XFI5H0e;TO&;4;9;exq773k(xegV3ZCCP&=j47lU&BUqH2n6+@EJ>ed&PLa6c?C$rT2$ySigJp4O@5""Lc+_fEN^A6-p=hCn|S8P42!#(=W)dbrI@}ezD|pnyg6QZqjCv}%=b(8f@igY_B~fIdDbCiP&S9sx9=(iC(W23{pXMph*Se=0pp;8z8Klv?NX=i*!MS)7+FB%m!{nvE!p_Q6SIud5x&=6JrV|J5zIp43_euOi0K""WD=(7*VVP1oZ=a>g#>_dij3vL_l!8f2G$oy5hHlWVl)=tRiIJH-+ym+s2;Z;uG^Z3?y#>!DIzPy!of=oRgYXZM}1m(mJZ#pM8+MQqnQ4FSHk?`NPYN;jM;9*@9eMD""mPT!=6V#g0abmJa78KIrWo}0v?vM=)4%gYb*znENMJdiy!mVc?{AcVvLW=0fk-PYP~*ubwD~$u^Yaw!zRJseh9S?gRV*!CO?i""j+(RCC|?htf#@1$FbzCO7+l}ugcWir*^U<98~wx{>RyW?JZ=L3C2|}BElWP$g|BHKdFyWYgTKFB5A2p~cR-AMt>&w^s3jjq*oz9B""Y-D->I&`a-$XxXPs%CVc`r}IZg)(kXqLCI)vSjJ}f`lnYWkaX&T@W@WtpM&Bc^N+?luK26B`7%uUP8|9+gjQX{&s%hm7_hl?r;IV""NFS9fHEfzgX5Q4#*^+OfOu1kLXLzf1NS4jQZ41#c()>NDX|v8gB7lf^t?`Av""h}5H0wAaxFd_BC-5MC6iqgmoO_M-{aukTC~=U6sBEypj0#6%R>Wp?m<>5Vg11maFY_;G`^~Eza^nMZ{JSovKRCy9>38Y0A`)Nic""=oqW#MMyNBnVdT@*GIQWz(L%otkIlZgH>X`@fH&4@Jm_g)?xLmBbZK21R=nEFu0vtdUU&Qxjh+nUP}&GkZllla0TL>-reIDJr<|TrKzn<_G={3i{=F|8kMHy{E{#$Vf_Vj;IVf>%eofKtK""oAa3oVy~p>%OLZX;D^)XKJlJ;gNuCM&G(T0)*1MDl{JF-c{3z*1202RO=gM))Z9P{jXz4^W%+3{9""i2loyx$b#@Uh^|5mwj5$NRHN8cSfUPSc*Jgc53>MIw6vOZcheQRY|=HNYTH*DUZq~(Lu~MqLt|Q%GPc3X(ZM_K8%Zr-Ac)+*BF-%""N{pdt!FBip@hugGl!6`0kU3Ju1D%LODFp^+$Is@8n8{m~$bqDMCge}nRcBEY?0r>w7^cqr5_B|T{""u%&?X0@C7oBVfKL71q2*u+9|i7Ko;hezQK+5$LLfy9tUg%27Iq?RUu&5wljmm~wS)A3yQs!=t_bOl2`_SZyzC5E&u5UmXf6u3VVTov3uXp+FeLp7mB3EJdDKhNofRAXyN;du^8G}rB?vbmTrEOua@eoN7<@gvI""A>P1GobI;Q__y-&$D-9f=uQkSXPp_#4H&xDo(fZDx6U!3jPfZL0Cnvyi#4a4A9#4-~!""x;P{#db3Ln*<&Hs&fx4lE_4p&BF+jAfTP<=(%8?5QDOYmpmpI;SM3wv+Tgr{Lw""&MUQDwRhMFgKxJ}o1KXLqE00p>dg}ER5IdP%%cLhy7dBhqD!vA*{C^LDI_rK9amZ^w;MJb3y!Ll>|#uuz3gncUmbYr+a3;JoPZ0B""fbsoszK&=OM%u(78P@ekd!;kT8U8Y!w3PZBhIO6t8""f!DgpCYrqR@QwEsa828~%0HbxI)o`$AHb=TnE9d>+7}tOBp3yeakb2YPZ4PY?*#GyElWk$s{?9sB""{RLh6#&a>=;g_F+)^AAHvvg42Eit3a75)PM9$gZN>@M=@!Ucr;oGR_M?7O;4DDafPQykq8Sb{j>p6lHoW4);J1-ZO*H~g$~ASR8q""m3DvA$VB^Dj{Yb}8Gcm_?52MJxsUbkS16&1@!#nNyB@P5Xn=^$r+=nQ#Wor-(OHO|T-Ts^aMn#m(Fr5`&*1F27hkU$JqZ)G@UoPg""TCrW@xK1PRSJ_%ryQXOd>!)bid@!_hezg?^~!{""Encs7kNQM%YJ|_gLXogC$kBH+A_C>etC82dbtT+nff3&$FgoY^llp@yK%<**skYKSwFMVF#o5l(fb^qj251+u$w%EbO}#I#`5#i!""i_dPVW&@K{-~1uD3)0d3byNCLlX9f0U2Y-AlqtA&g3=qhVmkp=dP(;5k?oDv7;$lge}o)1?=jnL3?j9yn03^c|L-?jKU#=Mjsaq~""ciAXkQ{>o+_YttQ2xyqUA78Vmkg0shDp6;Yq1Ku$x#fp&TCJLAjDgOo;y=Y|2xGxBaWetiODmKgIkDm+G-u~h@FD-pcIO(5j@Bev""w>EOt(SFXQ8I_4lt@%;oF&#D0+TUz}Nfb?q8QQ~Ur<;Ji^xLP*082}!Y7?5;T++L}wc-oGJMG$1u{-3-WEXDagir>=gDefit@GS""K1+5ozDX~oryl${x>Be0^&Vrs4fmO=aZF|""N4coCYP7sTRok_hcrN{}q$!_Tjkq^#Z__JN;Ju9$Gy+r6rFqO_=S45D_lUBY(_I+~(cCMx?Bbx3`eutdj12a-TiIv9P2XjCda1d9""ONdhX(zRe(T{{1nWSN=RN)E^i7Nc^<#EqOOK%o|_Lel5B6hY3OR(G`pUQZIOFNOoAmLa9jEX?3g$UGat^e}-A(vF3yn5`Kjhbye7""yJgLnmF;u(Z69Zl5r5#o2Mut?M%p_~N4wzgN$&O$8!B&9Uav-u62n88>Zpb|=?;oIzsZm&rL2N|DYB12tDjq);WiUu9Y2U>-sS6N""*87o=0j8VK!oUaB-6LoOW!$y!1{)&~|r@M+_S|R6tzIV&3hQ!REwLv1I__0oyvMfb3Fjk(m1a9)gb1CYrUcfu%rTfa7m5W|($o{UlNx=}hcc&O""!I8y?IY?OlVU2b6q=%AnC(dPoTWOR8g)zyc;+np8{L5-{sFNJ`b!Si4W5}k6w^?~U)<-?#P0z-8k}gR-$8-#""U>L0xS=JxSPsC{q&KkB0;=i{Eh!67Gdm2zvCw(ZgVxt22ry*DF)?wXSu8HQ&xp8I_GK9Z}>SM$rR_d!g80FT!H5pfGc%>pySL|ar""XPtVyAC@KlCKyZ_nGTGbv4h6WfFiUg6*uH7p0r6*W>YJCJHIoTVtO=brVWBX^W&<^m?C#}4{<@M-=-LxwcQw)lr(~?3*ZJL>j%E4""OaRcz8Tk<*w&H9HLiS~%VqkYVW<%W9cHmY=X#|M(!c?^^&auF7g;g$b=#8_CiJ9UBBkcG)v(j82}a={2#H`y4S}(8sPvy63P>Vp&UMV-FgeAfl1*L6vZgJp""@BV>PCthn|IweqpLkfn5rW-k|ia}gxglH@OJctg_x5jHI6cHgQ#)BKOu}AFhkeD7PN00Vta03CpQ`Nyq>0;^O-AD>MZsY^4eWysW""rv9&_=<56ln;o?j^z1IWQhq8$zno}7Fucp69()8h7e7d(Axe|*Hz?-V7%A!>SM6>ow3iJ=U-SlC;^mZf<3ffOuu(!J=a1fcx)n`K""JB|YT-Dor|J6wYD8h;u+V{nuB4|nM%v8lL3rL@iSln5pdUWoKT(oPK%S8=Fny`9#d+?LW{D0Xny9qndx=mjFCL%upkbt%I+|9NWe""-GlS!G;X|}(k?kvneqb73*-z1FSqs3E-ljfnr<#Vk1g@5US3gd7#II6(lbW7l40ECKb@rg($O>f&aN0Tq#FsymF)R_L33t%QLl#m""Op0?rd%s-cw5@;smu~HvVmX&%tvSGDVDgFY*`lLgjqz@W?JJb#mX3|i?r3(ATft785Y+z}pvt!pK1xQ9Y1B(>by`_RR(rzoiE9D1""d_g#^b*Y7LJ>Y~o#|wC%@UX}>1H2)O@*R9F(CurItl9+x}q^b;ftvw;fQ{M`{mQE}Jp?*sL*M04I-i4>$""tO$&Co9uHbLDaSOpkhd8kXT>Rufiz!{@B{>GKlu49j!xwkmcPl-Z|f8POZp{RFa!>NvLX~I-8A-BqX?G29P8%v1`!LcbRDMo9>BGuvKQIE+p4|k~2_dT7KD#o1zBg-{_ct+BEqze)ec}""^`s-e_}(Bi-aG^M#DGRCp5;TtMfu*qbFF7+{acH2YSk{(EVMle-WoFhS)c;B9D}7N%=c?+ER}!;xSdCrUq{iuOdn#=i?2`;#5KM^""8y=Zg;U0Z?d8?RqMphKh5GZHNL$(+p`OGY_Jge0a+m;hmPzVqtg~l~2)C=v#;l=APyFxMmu*JNOCdmpj_15_Mx@90Lj`)w;^|l>p""8RG_#(vyO%$#22p<%Giz^vz4F2Ks7TZ^r3%s5+A630AaRTD{#&cu)IHf|{u5{@)tClA{*39|NP$ATe4E5i)1zYA$nK4I""oWov!DG&l2CU=U!E""b!UeV&1qbq9l-D`e#&HDul(^Ja0kI5nV$20hbdJO=8K{v1|2(f(OqsEUH|aAk7s48Z0n4ENjNX0-y%2_JLHszP}@9bC9y=WA3RKL""V)qjAhj2^`;2&DAsqEWX$YxFd$gDOJ0PoPK0PC=V=wr4mcdI(nrwsJZ1^c%KPHh7zpt5~4{B|m-YYJmOCHUQHOCSU`>l9={VjvU6""-PwN(v0Z7nLc_}gv|6DX3`wOKa`9M;""`=C@3kmeRM^XaA^0?AE!fkikB5K90LcmgC;w!MTyfzp0*lL6DE>s*fo%I$D*1>$EBcdGR|gU4=kLGHE|OjqO7*;B1?NC+|S*D#3a""Z0H!*F0LG@b;s|ho@K~Tw7*rrA`0$&R^$1}T5XX)53PH@9P+l}fW+l}*y5V$)=XO-eXGsyU`-#WO-Wmm{C""sHPj=3O""FbY%RdyNKImi~b6k!T|i8E_13)C}ayW!^4Z6HZPhz5sq_>T{~{))6U>sf3KYjF*n5OHUi=-Xlpb?#R>CGi4l3j>j9QF6G94YoW{i9f0bVCRU@2CB2#`Yl37AOp*2uWoCH95;tQ-sCm8mSd&^O;IpWU#mSr`n^IYeGvPLG#6h7~4!Kyn+RsKDJ;_ZchyToH=JSVM+*QpG;%1udadY#B_QqwF&<|z!>{O_W?""0I}&=zxIzLrlZ6#7mi`H2IjaQb8^V#GT&3;_zOD@Uev4dKQ~owoT0s&=J(XLG`9ZDd?g)t&;4Y1YO7}B=K;P3NcTYU{|x7=={Yb|""pK_v;BDTH!rlz8@Q=v_a!_*k{^cU_+41D7IQ;7shVr#^IH(N}4U~{^%""(qLEOyjOkp>@2&?&OygQLq^Wz2af(lNw2`~$gE$X`1i`R%6)g4wb-w{-Y?ZW&d*ab4R%c=-tR{BKdPb?""9Q8tu8EvXv>JDx3xIPTe7c>LzEtgR#Stk!(e16|A$oYmZlvAh=81x=gAnbGx5Vb&g1>Iix%R^xV""pi&^W6-Uf~vwDd|e=Y`(opu@l#b~R7OK5C#(+JK8KN9nFf)tgoYGWA%nHju~0HxsV""XaC#4O1QNjHEA?*ys-n}bjh*G&2cfQPTgRIn5hSIwVu-CDs;NI4Pz5?+73$XA@bw1b#Yf4Tvx5_GRyy*Qp!=$6l4O6ha0dmU}5O^""RRBo~2t@`6;~PfT6=g3bt?{l>4LZqsRLSD!UH^7^aUdB8UYAVD0IV7sto+va{E#jo_6SPJ!}I;`B6>|HpwN@uwiWUtX3Q@CKW{O*""O)r}sqN!V?9~BMZztY5+HHFyCKMAx9%S)NTy{oj6zQI4T-7Qv;pgOT2%lwr9hk}o&zzuiy!|`sz50E%sGUdJPp|zlyrZpFM&&#Bv8)h~!cvE%A(lrXnU1^V>&>@5gBeKKjGHmc4y?8txzIR^b1v&f$""Zfa&9x<3IYZj>VAD}wSF_(3Vib`#~D+y+|fcM?D)e&Jju-OCp8nfh9ry0=ly@|(P3xV%6m1Lhxd+SA<;z|$r)^GPfD42`$mYO""^zfEK@Mi!*`u^mNO#rw^{$~7V0jh_M3`Wog;jhH-GP$EHfmz2SIR|x*BcO?J_uR|inUTXPJbM>QE*2vyZGg=ga8O`$evemCnEo#`""J8>6Bdo3DLK^VCnaPNi~q|Ig`%||Cu6A1X=GwhO1ZMs^@EAkMUz@O<-eVh1*ZGM|0==};cID`oc01acj7*3s^yUSNFfXv~da4dI`""=)=4jtP+c6zk><)oTS*k6t{1JVQrMGCk^jBCuft=hVt^342;r=q6bxagt!;!s+FL?uqT8ai5$(3x;H7r{l~J3%l|s4bGF({h86+w""*F0u06GlV>puIm-oEGvmOeh1>8)j(b1(o<>e!|s|I9?o3v0-Xq#I)xgaM^jKqBRdRr}IVCK!IkU#+g4SbKDF`R%}@Up2m%f4bi0i""77$4wvUlUWnDyVl}QfW>r`mFP9u`#W1%$zFnVN90|!gX|(4LeZ##4agXjF=tVFuj!$wSPlK>?>Di`;sD5)}}R~9CaqA""!X|5T6@+w+*QBEV5qT6yfrx|79%hBuB-#u%=$O09dcx}1r{%rs$tdKPVwa~hpaGWjl-){1^wI(TQz+Zm&1h~A?s^#;onn`9TqAD`""$2f*GeoY9SfOa7$t|+=R3YXiB?t>dk%#""YZt*J2=iugz7UX3mZAw`QI8vr$eUdc1I~Dt0?wW(t4f@(t0#+_zHZ815K#GDuM??NHe|I8$%RMS_L_Y8R$*Bz{TbrO&(>OiP&Mnl""S`8Sc4$cH%)_OlzQB%Ld&ofWEC7^Vi&xKtBQiGLLlmmVscL#OEv3V|AbG5In>PkP6-3K5D9@z;k4KM6T{OUQ;D#xFU%>ekfx%0@Un?m>WT_&fZa-3oGj`&<+aZeQ6?JFb$9e!090a)K1w;d-s2A%LxbzO`m(_|9G3!WNL*RiLPCq^ES^2ATU}lV?T{oWxF?zez=xsyi5Nu""A*l;(BD<6Rj2|@q-m7%wT1Z1eBq}8nW-#yDN)K(7rWvp78;XVTOr=-WAo77aA|I!ryi7)djRi^4t{M(P*3OX`hT;_zwdIM?e$=Ps""c9c+FY%d^@RzRR(wu4s&doh@8k%JpYDda7s#M*F${i8-}4Acf@n7yzGX58?nVLiCE0Sv?)I)WQ$-ml5~5bE$dmW{1kEv+5gIeRM@%dP?R(*^-b5n?8b0#uG3M`i@Cqb4i+XRhsEd2^3f*iuy0SJnUeK>8dyFw=Y1tL+PFf!M2E?MG3fSZGK8d6wNWdFgX0heb*Gdb&LwK2T3(A3>d(Ko@Eg@qqrXjY`!|w#c>S*>Q=62SHjf&&""7Ef4W!!uC&ab`$Idh83cVB}K*Z3qN5+>b|;e9-Lvj`9zT5{dI7S}b&AYL`o=X3?-SbD*`kVd(|<(IyKh7u5R=T4TdLo`LgCw7l(2""+7E0f3_teoKEq#*RxgEQm>t`?`_i5lP&Hr{u3r{T7rO|!bzXZNL<9hBj&q05rWR@e)!SJTa{wiZg24bsXIU{i!C""?pAvJr;LpsEk4eh*eWC-vrbNIxg?Oz_fE&Xr-g&C&@NAF$QF%9pqoY%agi>D{?58w@VQGRfGJ#UG2b1QvT@o@o2Ji)Taq;ARymyl""ZnM8wntUwcJkR`aTChVX)wgbiE^I@@+%JqR;-m0-ibprFW`3P*GWB*I%Aj_0CQpmTmGNiUVj&qYQ|pEVS#E$WTFc^ND2G4>yi&V?""PMJ3sGLKOa|7hVJ)*CzO@m**l+LX(yBlvTL*(^Z;6Qa37BVTW`T5P7TMIBc+mGsarM%UZ{1eXB{fyD7fV;|)j7+`c5qyU""BQ7nPJK_92*hs5)@!5D$Kb<9d3_U=Vr8>-%E)*i}BQZCC3=(Y86Zr*t3Zv#-Jx?nPSqCEEe+#BO0FNU?6m7RI8YpP6M|?bH5{*H*""eOO3YM|U0!_mF{MLF?IDo2gC+P=yB=&sHtMxgUq7JhBBQPMa3^52?X#@g~1pP3eae1QdV#m>KX(&YMXPTL^$UzM70|m7q0%`2Q?V""<{7hRXjHpV!3}bLGqV*D#g8_ls03Ebzn${9P4k!_B_+?Qu2@TkD=XtvvPCJ4zDrkRz+jjWTAt7p2_Hj{m&kGIPv|b>w*;JSMsC%T""67VsC$S(_9jSt;u^zlK)3_or+t0BB>Q+9p{*4~YVM*Jim|d4>Y76-aRH""HyK5VS|a--P@LR3DrBh%0v$ou*)-Gyp~2Oe>jDoV^Zez6K57@sswg~gG8waI)YdhLShK)C1wdOO4|ervo*@<&BC==9r<7A~wciDU""AQ3}of6&dh!XMV*X!J(%8kUUo7}9Dt+sX>TFd#2UUWC%Hy!ZJ@8y{Zk>UtBd4ieXNF`wvyQl;{?*NyY@ya=uuy{jurFG6#)=@(0q""-^^8+VQyfSt}U}Cx^X1UP1K_Dvu0I=&i_RS""&FGgywG%`^!Mm&;g|HVEwjOZsIz+hd1$9>B%gEFl-ijEp2dfi=#mVYabqd=s_-bHfBz2vR848h>NlbFbAu+wUzG&lIhZ?IdRL%jN""9myCj{r9E){rpWkwa(^9Zq}pOJW#uu@_7c)bVv%tM8g=k!zME+N1-dr7RROa&?TDU5Jd5r)pyn#6*$lT9GQ_>M48NOPDZp<{`c)fD)`N+7g9d9POITzsZW$^$>O6lda*a46=vjY3YE6Os_Ko_H4u""PB+d>!8~*!la|^565vA11=e+)JNN-$W>9B6VN?bnx%D&4$7f?+)QQ1l*|exfu~yUsFQZZPjsyOEnKxp)a1p@z3O4K=V$PSj*9y>_""0YvC#ohTKw!sNAf30wq_Phu%6jNU_Fnh?2v|1>hot3_QzsrbkD#Ax9;S+lk+OA)x7%5T09!PRjYx1)2@d<5+LTC-*B8hL!gCxXF{_?MzEa>0I?#oKjiHSEc0>8-E)xHltER?DRA4@i6oR@`g?-m6*-S{&$&#IK}wR*Q+jDSTX9{X_4GA`#ye""!paFtf-_Vm;FpB5R>%VxYjPi$mN!O@e=U5COw}""g2?!pV>-Ju+K%`q4cghVU602^47!fv8)t`CJn)UygZ5Y;-fNZ7<2C0fbXT`hVKBNMCNG5jLd==H8Y7q&#vV}kqPP$+aD~Zs*~Hqb""HuGa~blJLX$Prk|%-{-x+_dZ$XiO`OMhbP|RO!;Q;YQ>O$TGpl%mvw4Ts+*54q7AhNN0?@E3<#ZJ1c{@W%@K!`w}#T5oS""RN?AgM)WFF_q7Z5WZg6b_vcO@9gXwM>qRLK15QSw{oXE7Zmov>oi(v2GuO@Vg7L=#(FaFb$gdCj!1""*4ZnRno;NutmEqIZud9a;5OEObKp2@MCGOA#6=u%iivuJ^^FCgr?Tr_M{phavEtq{*z@z}SiJZ@wSk+~;~g8jBR-`9gydu0=iEVK""&XV$~Oz>VOu0_BoJKJbwv%Ei84~`o#pS{?P?F@UV{ePF-c^>i$jNT#F@}zJ0!xGYBKQs`<-Jmoe{|lSGC2t)psWkS9m2%alcBBbt""Fd#a7WRX&dDeUceTGpLR-Bkl{_2eJ<8QUcJe`VQOvZX_8RT)m$WWl0SZ-*nB_lc=EXFg>b`SQ;z7ICJ?^FZK~IM`9pq!0aCMZpLXvX""C8WayVD#vcpoZP--A2#r&MS#m(Z{UdW+ox8N2qgr1UHYXl9B_4Gmj&~bmV{>L)zvP`HB#""ynzRo&_&yKGVjhTp|B-6o!VMt#6LUxw#K$Z*QwR@4uio_Via$Rjd2*+I1g_W$BaO~s1%q-135wI`bOe&m$E6XlK2*Fs*0T=kF!bn6RGK(eM%J7Nid>J6vtA)GDu0dU&LS0YE#ozS!0TaD53V9g5tli""t*L^o)vb~l!oT;$O|u0)!qp?^-DEO)N@^Q%;ftKJ)p)8G3T)Mv%Qt&!}gP+1MOnRPy7am>{Gd}t_dT7B61>%UxFt-1W*Bs;;qLT9nm0!MUACC@W2m*""@p~#4i7S|0FzOIL6QCbUn5)S1yRQLrUwO?@3*)8#$hA+Yky5+YjF_+K""o5S4ln?wr=H&QDIv&p#90U~9k0V!lOIfV-YTYoMoa;_l@thWHC8K|h&y>3zx)nt""sOyx5dSlJ|2mYqJFW_8+)%HL_zCpNp;WvKv5EC74U2hf{6(*BLOt-*e3%__I3rs(L35WGqcGbpw<1VCck$GBRR<&V%cy1usfFpBl""FJZrF#_mDkxS-co(sLkNIA!*Jy?)wT+3dY+exDD1y^j_dVAktMo1$W6X+p9K2Jj`&-arE(8b_3KbB@n*pj*Yu;a?p94#6VB4Bm+zJ+W$UEX693oc%$yqQaHAEB8H2+Fv?UC6~lXlF>UYp""*c5yJo1b37CBFwejFsT^;){7shwY)iX$TTvGOkURJWQ~+R`$0MU5#lO41gGth=ps!a4W#;axKy@kgNDX43U4iSe`mo+m*_vF>-_f""@wDaeaZ+uRde*GtF$-tKD!kYB=W4R4P&|t{0z@{U98VdxmFvvS{aeJ`III|bZZ2>&Qc`xN#q=65hI`99T5RH~a;MCb64m%nk%LRB""iB|4b5ZcjUew_Q2K*?xeXM600TLI6Zq(_3`-MW6rwx7Okf7eG!r{mY9OS}Pc_7V7sheM{WTVI&}{;pV-(TI77jua?wy(>$)kT7m@EjFpGoine5lh^Dq)vFQTEPX""*!#@B5H}Ed`7@__6wqs0mTNDd8XJVOlBk{4O@=_&mJ)#vhC5>aJLa9_GvinXc3GoFo)DcZ*ug-L?M5+a{SiX58PY~Cgwe7TYLf1d""vgDi<98izug;|-)!k%bN8XFSnsS!pc^6JG;V%zj^qp$Yx_B*Y^7""oc9d9(7Q)D88m<$7VwCJ_zifks(neSBXL#M!e};US1tiRey#=pv4P;Amh8YO!!x!fCOdcpyre*JBV9@E=GEK@eB%=aZh}8Z#1yc)pz0r9bw%@}`VM!;6UbtwegSrsk%C62wPn+G=~9&(xP)*5INPTo#>6T8_cFj^OMnY)fekxGvk*f7_Ps""YZz)#G3B0Xnj)-{xx^cz3$Fw8sq`ED=XfE>-?Kz72NZxbc3>f|m@r0yPaLgp4?h=a4cNOGj=l5GNPB0@%mrepp2MWhOu8C@c`5R8""0w67*pd&(Z=s8a039&$sjaC2p?9>}o##|3rYtW@~6Lw8dko@4x)l61Re-%ZBt?)<9>-w%>7nyK#wMFpK1)kNMUo3tD?;DV!N!M_P""iaMRE=}4q|v!t}sWTVH0OVkX==eG;Vu52*+W*<30$bi(8lKIfe#{VGExr47Pfpx>ar>s#G-Zc8g(Bg*hNXa7*EtnX%x_ybIcU&DH""P9k%-W^lOGXU{u4V86LE5nb$>0n}6&0L#jP0-kVw%G(_c@ls4gpDUyW(^ak4nw%uqa6+^fL6%mZJyc~#-""Hy~Fb+M1@6oWaJx?tvjXO;mE-UW!6Pnx=KCWPKY|Ry{m|AK=7HoTu>%D=S!4q762yPl3R#g_{E!j4VT!7&1t&5v@h7f?nXDA6e8|""l3EwZE~r|zk|j;oz?mXI(q2-|J%xW^@wU=TRs@mxb~Aj?crfWf-+ez&%?Hm32zuX2diR?P=9E`)1%=pzqeZLmt4#qrJsR47pDW*F&lSKR9tJ7t*>obpDt+K!bMYv_MkdZ$_8^7tg<-A~y=k^UCNE`zmk(p3vNC$R(jmkd4dXqID=Y%V9Rm""Kn(#ND*52^@3_fIe)cizSy#q5ZCkuaN^ktfm6_O%&Y&GEDEcp?{r^<=`qs}L""@QpG}XR@tQ7~mYtZg9G-uDws336+^jtit3purZF#31?eh3KjV<9grb)=2s;Pm{*RZnReloVf!iC~OyrDyNP;v=e""x~;_z!%yLZ7nRu^2jZ>)6Mm~q^xywkh<#KYjBJG&j&iA!&2?ZMbyHiP$FqwV3@gi#V=QBsULj&ovm(ztD5Usg5M0~x""KG2{lwfFH2GT&3@d;t~58Bw(hU|o-)JPIUFaKds4HK8Uptm4Ja+&Eg7KB?1$8_5%zS4vP1X>WY+iMi+X5t6rzq`G|7!80INP?|a$pi#eAijwTOyx3i)`Z8Gxjk8WBaq7UuWMW*GN`u^q93M@S!aI_iS3KcPkfNR1kdmAsibwCAwMgT!igIOnU!3Jo;KK;_A$L}~R{;y7zJ{v#sA4>5y!M;swaxE>""ezFRrONt2sZ9<3ki+5%mwX|5(xBLio&i(^?Yek0$L_QvIFE`Xy+mm%DPZ&^<)@-n#BeLG`f8q<_SgQ@v4p#)`V3Cr-(SfpQ#{#4(""o{~{%MEf8`;0TJ?ou!_GOw@&Pp8}T>Li#Gs$Kvb9a(4?ZzndDAQ|v#rx97-QYP&UeWW)3g_wwRa76T8;8`+8U6@2)n@n0lF@QSB<""TVk80o~d@Lu|rs2QAo(Ccr<(D!M?5@`O?dDzF7Gt@TZb#>cbu;>N{R~w4W=9OwuBOpLE}riDOS%hrof{YKkg^l1@#oF#!V6lGpFW""7=azVWjYfaUBm0x6KzY#;e@<|6h#Y^ZrJ4X@#6^hIO4T7^3pO(OoY24)Njkv+%%WW^J&PhV?%{7C(rB+gAshvAh@PU68wUsA;WN|""KCgKgBU1^==EPtXMwbK6sV%^gJ5^VGBChL%zKwX}HeSYRJQ1Tq~8gblvMVPN^7Ve9m8q2Tx3|mlVa6qHV8(x})cmlDF^@=R7O-ur-L>rUWh7dS|B~H<9F5T7L""j1`ao(BU@lwK)gy?>e)ED+UWfnci%aadw12WvZzC#fMuq_jm@Sw{lZ$^w+$&e)wjL63uuGVfS#TPf+_cQzI*eBI1r1Zpi4Llua3pYNy4vtvlFU~!%wJ&a7L!R1NC00000^F}))5Jte=00GIB0icTo%e;*xvBYQl0ssI200dcD"))) \ No newline at end of file +exec(L.decompress(B.b85decode("{Wp48S^xk9=GL@E0stWa8~^|S5YJf5;YMsq-dzARn@VT6Qap3bu0*kgCR~YUqB0W9R)iarr*QtEZpesGY3>~CZRiK|6Dwut$nH#N""!RYqQnA}G^`ZsFO;ar92)Xt#3E3Ki5S1}OfSx<=$c<4=h|J{kt$27^CQ01M+lVgZ0tGgX0&I*V@{U&JgYc0U!(4F-btCy*+qzv6D""p~UW!y~6{U*}y$E@2-R}vd?t*s#fnDO{!j>OImt34A(d+9n>hnnvzmd((_D1Cghg~(bQ$Yj)>!Y%{*o9ex8FWa#U)!OI!!5Prl^?bnBX2V(=(Bvc+CvGo!S{LhLn7pSsR!}@""U=OBW0)h6IYneQ1{|$<&k9TS^qGQpb-;#vEPAl%11UF)?6mtC8c04XzR$+h2=j84E2|i`pOEt$uyM`lGs*ejIF-}^SvSRZK$ePh1""`gt+?%1r#=OVy3pW`{ofc)6PhQQRP|_h56zl+sQ(le1eJ^&&qZxdGb15""aOb^-R1ouqi-H1_w|H;g(()bKz_!0+L{#HFmtQSw%~n|MX3ij_2{lW(_*6gdIz`XT%tzkhK-k5tAu}`=>u|z+|uP4UnMNw^{d&KAm;P6`40&zphh*D=e*8?KGZuo~y*`y#Wg}r(PV}?J$Oae&vv%2(eb|cSn}oTUc1&B%M^B@xlvn10$Ol5{Rc(||g*`0AO8zk14;L7H$""%dzpg>)#q=gM_#uqc-I?vb;GXT^_H=R{GSK>=VuvqSe%#z|2c2YlT&*kjZ2;Kl!jcm$Q-fVo|ict&Ja4ywb`i8dvCk+h?3OCDUXX""&zMsrr>Pa>PMm^gxGuU~bBkga*GhQ$y8PeK(vV5>_veh8WQRi`fT)Py)OBaHdkXS7+2I5uiFuTjpn{|fPzQ@LVpRws+Pag!JmkiczrmnPtnMu9kKEb_KkhHbOR{Wt""8%3T$xbVm9!-A4Y!!lg(FQ%bN;xc;)J}M}+SWEaAW4y-yw9R0Q^AK@NSl^zI0|UPwqq?6`rga!kw{>S5i#M*_l%=PQ;>8$L|2FxThROI+2n1yh%V%z*SB!)&{G^cNY-1sEn*QxI>3(Mq#-79*&HQx{l(5_G&smxlM^P9n^Rw+snn;2J5K""+nT$`x$au4QlQ;UK9VKSJ^ZsCG*f{fZeEUIiN6m{@t+`ew5)LOdk-VMqWlfcBU7?nxLFgfE?x""b6Az{EuCO+S7#4-pxaPAl""(DONH5{t3k?rrWsL@mpCUHPbMPya8454xhhOJwf3(m|sY7eZT;5;fOW8%FFIEqW8&2?)-71^h*Ga9=nxiacg5nW%jY56jIq4Am!#Af%Ad7b;fSu8t""$W0?^M^3JvwEwjL7wvBuvv-sQt;=;pA19#M1g=W3%g!++Ta&Pj)}6%(rmy(w1Ya{^5TZ6je7V3lo*%vz3G$IqGSc1Yk<@ueD-^lP""3>l=~HATuJsnUd7lSv1hzR!o!Cqg*vpqo?H!H#F#qeqDew8JLvu8?Ch?xbP*m(b{S0`8GH-3tOgQT7;SaPNV8fm#zUwNRTcE@G@W""+jk9^I`mCeVsjz}lIoTv50U$zCf#Ui4uX%VDo5bs-vLZo_x0r`2|tIB*U}cLyXQ7=78niZCTHmt1$P#TUdM`P1r%XC!Mfom@NI{l{DZy1;89(d%L$BtDZFH1NaTdSZcWfMclca35;Qg22-674m;""pFNb#C7XsQB*kC4dE$?&KNO4Qs%Q~HXIRE%U)*?lT{Xj~p!?ka)_Bno=GJBn5M!Zfb>?Eqj}7jL?ct;;1tV1%i%E>in)Cx!WeUF(""<(m7t3ADHh@DgeP#pc}Eeu2?bCz3w^A+AI1(xU2Hqbbe!&Fxs|v14@I;SI5er>_S++RxO|ur%K6xYT(6JaB}s)QlHbv""Z)6Bw+AX&rZptjNW62D6QMGF>PTDD-=uVIFXKGEg`>~}zmCKBwRaC>?s@~W?#0buJ@U`4@l&pe61;H%?;6Q6%<%})hiqO2AWA>0vO6s%^XlaGwMC@uO5^K8dY|}XfV|*YTjvMf""^@u$1$x6v0XJ|vHxw#$4ip?Awj`#mhj}NwB`myCgJoy;q`7y#}mkqeC7SdtKKRAU3ZU7;wp5hW0zyiDIOE*VW{9UXqA$mN)lVbH6dhCF*>HaJ>tNE)-@TssUii}Y*CD{I4_G%`eqw#b0%fyt|;*i@e0cz=;@""5>?^i@R4Ck?NQyTO+R4Xm)XAqsI4O(u8t}56&!qffK?uZhzjPSzgY^w!+3jjeYG>Jrnj2Oq6ns@F|=M$2v^JTkd7g@C""M!d$)Fs^r4e|;n)DLzd}`BG1LKaonOUt%k9>o8xDr?#e`$lIbNs==FD>W{W(n$jL{;Xfkmp6DWf$U8&ND9y^j5|rhhi>Y%r&i(GMQU4%}Sol)2_maQ+b*(nZ9)@$bk0G0y@(GL?Zhk%>MHXV*)(OXx#s!*Q-aja3QbuK9f4p-F93>F4xs+ZTRG_@*Ox""(Tw(zGwQzlTvONhN*6L?P;rgmivbLx%mE>!68)lZxUPaWAdP6(JWyEqUi7PAab6M3Q1k^ryv#jSKMXM4ES(rB9@NC*BPnz|UVtBH""OT1m7zs3BNQReVFG!KarK*nL)BE<-R-gTg&c>4Q9)JpEMVLJi_B29>mSHPm3Poc3o!gPjs3fQD~Eu5S|t__-z`AU~AE1e51)*Qd+""9NIi=^&Y;!iv7d4C|(QPjI*`7UwTO9^AF%{n&lD1e?|U2LGb3iPE#Yx$6Ag=B%rFXeL7+Q@+q~fQ*Y)`$xl60Lj=1(7?+AC=hJcZ""@RGyvUOdLQthGpLaepQv3s9dRu6T?qdth=vex;#!oY$8*TdkE?Uu7t2#vd;^mymC%9J>0dE5o$UC64)*Ei`a1133I+!RbtUOzGxC""qbWeT_V({}2o-ybQK&}WdVq#kS~^8%y{}E37av^OP$FkBD6!ce;Z$An-pLWFte`bj<>ht_#=44GDGkpjgBFq""sw{3b8S$6%m%t^D{{!U)vVB`@I_-LD*e?UZSDSY-=Lc`F_;H*<{abP;#Xbxn9_ZU|I6xD3!BL@;&z+;R{f~Tj{s79jIya_@!9lK)""E_L91IdRHk*q~95#av3k#D+u)w`Faav0hNP49GqaB9sF>UhHHD79l}_UhQ_`m`kW;FFhwT0Ehc?BwG;A=mU41M1aQdfo3M7hGh+`""5}rM%v6k%dO)|?EAx{2m`=>6e^^0IhV@DwRO?tlUX}Dg{%03Y91kKm5xz7pN$~u+Gh%WHERhhZh8bs*U(@8YGNY(X7TK$p}1M}7M""nnEHAjmc5&SPb|Y{H=M;i%j#$%E1KFN3)VpK}@TO0CG=M|ESDG@F$#dZ}(T3fL1KW*%faXoRz|B!QY8>I3Ig_piE8tde@!NznZmt}mI6&dgwM|imnVB}CvISopw7}Tay-)hPY+i^vD$Z#QxqqdN&E$)IDNDktDpZvIYdz#n)+IdI5~qDE}*k3""@AI;u*gs7t~w8u#|;;lux4I3gUbjsk-eP1izoM%ZT-{xD$Zs}hsSXP+!vQQNmb""I(>|JMux5bIBc&3Qb3@Nde7+2+`qI8(v>G2cD?e1;y8!OC&jm?8tc{d*2#T%ssy0>L@FPMBV4Y>5$Z)G;jia1o2mh+oLG9ic^jG(""RhGR8xl`bNxG|C;B}Xu}?IN4ZyU{tP)1+W;V22hyDQB2ew|jj4JVR=$N0u7-Ft1x6swd~Ix*kUnOi0N6`}FZ_#<2^Mb&!g+Cx@5m""J|)on`J6Bl9H-}Ai?pzaKyjl*DJZol(#;UxvxsQZd2wS{PV12yOh4!oxTvaCVs~Lu0N;)q*)*K4MB9Hwgkljs`2WJ9bu6@@#Ru7L""lLoMYhOlrxk8!{?H_1{(J$`oA4g3_dZ{8x%L5oeEXP1j-vYM6J5p0qJ_3b;y@LqbYhr&cNb&V(&(^-V;g2BjwzZuhR+YIi@I}j{|""{{6m(ZV}zhA!6f_?mneg?f3B)3cFy<#G8t`kCWFSuTKAn*6PBKJ7xc8)#mc6?(&pjH^EA+K>acmtW>L+C1~jUQl=|Ir-1%Uh""Wx+>RA&D;+_+I!CH~Z=7+Lc0EQ`FB*JJpW9f)V51xQe>qN>aQ1GQi8lMkWEW7oo0r_lKghcZRT{U;$pW>Uz-$(rsk}ECNY^xTf_@&Xarggx>Z-chgo){>Q^hR90&A7n&=IOuYGsk^OC%v+zt92><5%sfO50YIvDz5iz+nlVES5UJRHh3""izM1`V|P{46rvhf6ulkm%$TD*rG&IV2q{m%ANj%4XX+cn?mI)>%gnhUDdsHCM_{dLW|bxYjDz6~0EnV<{(mSalL)%Ut;IzJZ3WY}""w)rlF__5w=IhUbNEF7EzTa6T%?KCGJue2FSq#U~C0dx*oo<$}Ey9*Ql>#@p~*x7N}nc0aHnQ&4xvK^kfiNiP(5E`}x@Cf@KEWjak""SXg^3>C1DtGI0O5OL_k}dtz4rA*4""$dTDxPbI!7@A8uW5p>;(7z>`n%hu03VIr9&yjLDcXH;HQBVp*uY8FG!BhOUJI3FdvYo9CT?Stx~0i)fKW%_qEud}0Z-sIl_fvjS@rd&g?NdbfLcExMOrtfOyIuIF$T^mj8w`DhZvFc&LaX-5`Kw+y|bqAy*YnemRGNjZi@s$FybRly?pa9;Aeu`3LIEzMDwYT>ZFe""9tjiJ!x;qmDy=J2*%4aKBnbo7`bJAFHDwWx!}#}%QYS{LuO{n`_=""ZD?h63?8uAiW;9H*)hW_+Ilm^suj&ZwnUI(!)^Dt^EmAxI&Y-yWcvTNncLLY219$TDY!q$B9YOp*-*mX0Y}8I5?^ji*GS6iuN0xl""J>8z3t)fC1%#GXf{>8;U`1^}Ld2>ex?>c8o(P%_4)ZZ(mygu^9Mbn0zo(5ddn2|0HP!>pPm*Fjc9I!tkWr8?Kdq0J<2Qi_gF^-a(""zJO|PMrVdmA&oVjBcX}(-qgxsDKOHp!nj<6GloG$vy@6I%TwC|i}J3?u6*""4%j<=4GOS!oc&(;BLv#{V(wIZqv0Q`Ws768VsMm_#;6Moe9*jj?VhNa-Cz#q_A*)$ME_LEl`Kz1E%&okK>SsZ""27@=zQQ05i>y!@us0q&~CHeDX`0{BqF=1YYYhZv&yKePMM8)YUBL&)^A>*&8fqTjkIX&Aplo=0?_)m83zk>I}7&laW$o^i3B8Qa=""^NW${ct&LBg(ZMb(MUOt6;d9(O5vGamk8hGj$(_;Igj<50hL!Ym8_Y4fMhOyn1N@GKuZc7Q*C}$rE$D^2p+bx8`3^e>N0W%WF45u0#VA>77V(Kg5S%Ux+oOC7Zw*x""$|}OU2i=(6;1CUKOXc^+v^=w1$|UqoVBzhFtH_9|P}L85)bKd>bBF>Y&4W~Bs(pn#""Jvo0MY|kfwr4Et8JqQnsMPmlM9>?*wnu$QLSI+IevQ&^ZOQNp>P=EiPiQvUOMNTSdwgT*7`t^{9<9__R8@DSxK>(itZ^AxCMUEaR""#36)Wi>?{Hh;pUWt5LmVwzZgS$;^9F!BL!>-54h9*f!+M%s~bCfEFpN}m`87&Jk@aV""NZ6=WeM+6{qq3KnhTu4hlk0e(F{|1In`S$I2ubp?*F65_WiyHCsa@w^X#%wx37ZozheVMwA;S(WnaWcfsdK$(%CLS~-)35~e8Hi$""QqTbOnTzfKs`V37jL(adsZ&6IWxzGfEp|lwY_8UqalWuF2KpV#E96)_X(cDT9znz-y<-@YF{ZrA-pjK1^sP=Bhth$OAVghd;D{5?""qWo6k+-Mv`@FNOwK{%vHU!Emw>{D""?`X$4N9%=TsP3_?@7MQO1t7et4W;G{<7*Aiocs}*oXdAHeN39mL#@u>(AxVB;tJC+OvXDX9qp7vd#Xq7r;VdQDHWG4o6i`""fQyx^Btr@#n9FU|0vYTFr~fWB;{wEpX5nAD*put3h7Oq&Q9UEmO17S&YRsb7(>a0`}-JLHnW>z(b=#hV{Syy~G7o-_mO)KsqeA_phT|3zbW*S57)iG^(93T3mUrGlJzSSaw+SK1|Bh""CPaU&z*|WoT;3yR$6P@#rfqWHT4yRxFVXt|>f`G|u)weN*iy6m_^pdVS2NR)yu1pu==lCteGQoi*aLA5ovmrPp;4|uMW;~=83(SU""^Bp;8qopNfuON$iuUxLV@Acd3EHs)iv)^9&3>d!&h(Clqt^#_Q&v?$C>_>^5C0Yncg@1R^xN!fmNULYZ(TP""E){|m6vd?6?~UUJMz^#qdhUCtOH)B%#4m~8X4JUkwo@`hL$y&Z@)w85)`7uGnr)VM0NunjxzHr>F%402zb_(-C1+4_GMHQp2080r""<2hy3Kvgq{2KynY*Q=v0Yw5#^OGim^Ye@-&wb^7aTtJ5vlhH4i{{cz%b`_;>5p5VjD;Qr-c>zoHdzTfGf0g6<2VgSX-ggZ;F~<`c""T(MNO*eol|7;XeHfqf%ggpY)3l#6uGx5<@$d);;E5j_CWNl;YzFScqZ9?QTfES|;X!tKIL;;z-^eIOs$AS*{a1wX9nseXj)#3_8t""4No_$F>-?omyw5%*GKFzt6h_5)hoky>Yx}3==Z-Ch|OzxG4!UWz;Jr$kX8|1O&ArM?BN5dwDNtHO3jk~{pxy@1+G|rqrbrl=|s6`""nO9R&e%Ts?S_m1lV2J}%a`lr3Yo9}zmu1QL~<#bi@%=Ea=bC!7?KXF&ALQ&g(RIimsICP@{HzVgL0b&wl}BjO9-2qih<""xYDM`3nFkj&(RLHx6VfMEL^va(dA3b3@o$_kWN%x6}R>p>UlV0(3%oiuD)I|e-aq8Bj5PgS*i)+lfuq#4LDyXt+{W;N>4=KE+*5i""H#`x%hXK_V2mOP(to3!g<>8omNH_utG4`w_tWS?w65f~l%%I_wxXX|vB6p~2lW*hmeaxIj2m78Ha0Elq##6HbE1Mg}6nDx!2`m0X~9Wy(23KUa*uCf{qezNJ=""Yp%ZI+mXOOuT!X)Q%gAD-UCZIZ;7cqcErv780eq8R_Y}Bzx03Xay^|-9d$G)BQ4ZWAB&V5Y=y;;5!uGqFluNNaTW-MVE&GuTKe>z""8nLuv1T^^7|$=GcI74jiHhYz~i3Gb@49LoNUz+L!V^?SyX&iEf0kuN6lSuv6`Wx$#eBc&H1`AC3xkF#}(D@""0~=VMt69UuF)cp?3DOxy02!(B-iVfeh&{D}HOLSqQxAfdnj?<8`aT7Kk7WRY5&x4_J+~O)92eY*7nz--;Y*LDQt`o)e!R""=sT(hcw}otijhfwXf~;aXT%IQxd6wf$TW%Tg07{lwXv-I#hLHSCp8gq;Y>on2kxUfT>N!(o+$rAYGVfsE>#$3CTfqHwtBIzt~xb2""ha28gLBqJb0N$CKo>^2Gx9j4b4c>n4u&E@m3D!F@|KN2o@NTB){Kfx7KkHsm6E-@tVOzQpu<~co_0{Txa>pFkn""7h>eu{#W#-?{lSoE^)V0!>+eq5LQd?a-E|Ixg*ruYkCLMH%7>cT5""UYv+j0Ly7l>LnV1dAL?;0nIH=>d|+&0D3U`u4d<5t6TF#vVK{7cgzai86Epj_4^HDq2(;v0-JRK4w@{K3WPQ!MnR9|N)e""Qs-DIPVBP+0vEfkrT~d_rgEjqSyu6vwXs8C_$I`JJIV4+4#&&4P2cNQhPnm&=+d^Do3Uzz2z2dyW>mq6GM""dm!rIi5kl#8C|*w0UJSK8ba+-F+W&{_CQZAuZ;S*Nw@RWb+%}qvyFz=Ux~9{HIQFD(~<;?RfnOTYmKDiY2nFciN5};x6q={1p;m>""tg!I73ONDrt=O99fw<{b!WkqecLH0D&5bGIk6""iaV9;ALR>Yf_3g}5|W{pFuh=9tu4BFcx@tR6QSCo8lnpzi1EtBkJKPH0u^+&uiv)py}y^9_Fj_8R830LL7HAto4noWc-^_""q7CV9Qq?M-M)h1|DM64T3)EgGrf-*VZ&P~+JehS4DV%}_=>W-$2^Tbujg8J@""Z(v>iQ^+&Llz~c~%wy7QK|wp-lY*MNcW|DFX=7>q@aUmX9wOyxo@XO`wfWG9Jr-vbi=h9y(3vD`3fD!cp&|NO2SKUr@W{^-`RGvc""2jjC7ZuSs!5LtyJ4>D4lXitBpTtiC7QBa9R5C(RmE0r`zy1i}Xb6KSH)Pct?&|oBI!$n!DlwWyK0v!#IyG|Wm2vliE9g;`Y(4~e9""3kdOgIqTb9nM@Kbbnl(AksL-O*=dTU<-jJG#G8E+o^1t*16|H{e<#Kx3<)-*`u9e0ep#W-Vl1sqe)>Urr+|^{|3xY>_~}zRf2H{G""h14Px`m7bb?gZWt92OzcoFI^3j`LQu2(|)EMI#^!gUUqWIY(7{{;xS(_$3D`clFD(m@8ya""nt(&%*&wu%+injC>pO$Vm{U0BAdiR3z(+(BUL-~KRmN_&oPM4x6(|O45qqj~N=|ypHGNs*O8SU*u%TPif@12W%!@6ZWMv%d(g`Gq""(6ASRD?BW0Lgl?SgnRY}Syi(t`K$?eK3Kl2aAl8NBWtS_bWNh;sQ31_G^cu;iX;Ize#b4!VlNWxktkPyCfixZ>&~@xCQ3-eTu@4*cByo769mO*J-Ri+_h=;fSRxdd!!x!wx7vA--i3;(5`ONPDPn1uGR{m^#QL;j$~nKD2FXlVf%S|V;n!D?9JiuyU5Wgw""PXr^;OUQ9$faYYR;^!{89QfMnUrl1`^(-&uZBjNMhaJ8z7{r?h&1aPk*%eZy=DR8YWgsp(^@t+69aV6x_jFqT0aLtbcK=@Q(zB@Z""balM_UU)0OtE)7@?lYVV2IU4^(NO}AsZyfq!2&XD""2%@y)D4$o7Kw`L70=%|>V|QGxQ@ggw)^LZGlt#*xG{ZT8LoZJ-Lw|}^h%WJ8+hFhKPdtK|7Sh3cW>&RKVi}{w7Hj*F9U??jL0y}7""cjT7r1sb+`%TI-HCn@vL13QPl8G?#%X009FGjq7Z)*aa3T|q`fQP;t{`_SqbL8{-UmmAs6!ApYS7$!vKX>%Kgdv02|4nI~Ul|b{vgO3*!wM=%2MokWnHIFocd>R;w}c2e""2f!^ilxcytxYttCYfd{+tR{bZ=Dr^WJPKC)@ClS$ZDqI$V|2&b|h=;cp@Us~b*^""EmhW;?~SR9Lwb@cYXxDjVbi#WWtHe)8A-f@%r=MJ=1*R827%4#g}4?<>W7oquoa>-jK2ppAvtlC-ASq}&`>7Y>aw~?OaoM=jCVhP""9tV!We?q2%L`$I66oOtqNs`sDOp(s5C^9S6#OQM8VeKGN?G7>}@|)wq+HW%vKOR8L%h8Z0&6iyP}@>ww*FuUcVsk4hR`bn~8oX9vUxW6bWS2v>hZe$<#|{tAK1*""vuSMc*Zm(fBGU`)0Rkx@Sa?dzdi>=jc}@ys7P;O*X4WiRq#>yYgBfD2nxkBq1VMX9sz-t7z<@TbQ}a2)JWATQYe|xo9d^x@_C~7M""YqLG6l7jaJh|r6!r1naj&FuQX32crFoIn-L<4@D)h;))Zkyp2yHyH}eO@Lvnve$;OB5&KL&{u;+()W0<#o(kvf$-loeHzX%%NaCM""o+o4HK5q}z(?RTk#clDL+QlqE_tL(YOn4p3;2Fb*$OG8SB7U#WEUIue4wt8iW*mOglwo~9TRxT9D+D%#+z+GU{xyNDc-OGUtNH`3""j)CmhKTtzz*Pww7Ec""!*7Lgq!XqxZ7=(N{U@^@+bvTQ7Xlx*#j{%%{xFqk$6Lm3>F~66-Y4d{IR(VzU(y@8^Skxk6|150Np6du6NjD(*""A)m?DfZoa{4%9yOy{HT2?W>jS5>c*u=BPPUXhpLZ<5S!J7n2MC)moesTd9*)dqcOy|`9cL%pc""=Svz=D@v9p(&1dCYu!rFD+qJw{rm`BD7~4hHb=?N""q*53S>NENM1j}elZvpqy0M;czwQlDZgi9wt7}LF&fmQQ!zb7Q51MVG>y(G<*$F$2NwUoSf283l_0wM<;)4#hg}9DRT`~3Ug^c?Wa+lwEHvifrUoACxwAlKG(E24""k8Ce((|jHC$rSn%KEtb27vPes-NsaC^1#Q@m_Wo_;U5)AEchK5#UrOz8QEkyAo_o^;$;$z=}Lpl`q^Jvl?#uFi{6JXiq}OV$Dax5s>I}K3{A5N9#~z@|-mHX~5?H!D+K~ckuM@0&J-d`+4uHR;Y2kiT6BXiR""A0D0e)`p0=0n6kTRYM#dlRz1ku!sd_rNWs;Z)U3uFK83spQ8=f+$~jl&9_j{;uuQI2bS$(%kBg|8cg!~28ND-cy`Z1V8{u(Ug$d^""dgmXHOYXSKg&r^u5N`Xu=-ISn;ls{$gn3rZ!EK@XDKy9y!P(|nuB6s|Ll1B|(#-1E_x}2!cq4e|``N@;;^ln=SwQ@jCg!`i(Xw&m""xr7OW*>U9YN3q%Wk3z<>7&nM7UN;UJ`Av>$M;v7hLiOmik(T9KO;(@R+Q5t$iUz`j`bg%ssJk""^HiaTbNiJ(E&RuqvU2f{F09klX<<&vMI9Iw@1ykvT5kILHI76)@!@)C~EeM#s""0odg+Gomw_W8m5_%QGRr*QsS^!Iq7fst*?#LgRbPJx{O;Xj6=JTKL;t7}El)^Gdr^`Q4;l!DA&6C#D|e_`c7U(>X^4Q9a`>Mu^K_PvlV@?I0~ut_5T=T8R1B(Z7P$?I)Cc0!L7h^E>$8sbXgq}%1y9q!qnJ)+r#_iS""FUVwbb)L9LGb-rDdtPB{lGn?@DpCnH*J)s8H`ZMuv14i$vol}M9`YbV=)hxuVNRUYY)$n|vk=5iLH$-}d+8%^r_(a=iM-uLppI~r""6|!$F7U%YQ!a7M|Rqv_xWa@)kqO6K})+aq$(&G02)F_bXE98Fwy$*?~0=C4!yqrlKLChh<#bNu%TmW3f^-P9wXe+tC9q=%PENRQ<""6U^U{R(z#{QT6AhW)G~+EqSTh>_MNkqM6DE4&j~6GnLey;~Bs?QAKDyQDpX#4|$U&KIRDOy!VQX|*G3Z#fgPBJnc(6O@lNcQ5-(*Y5Ot)Sd^5g2e3(1SyOhs)B_}#CJbsrj`""&fc9Gf|_8;Ve@+guF#D^jrpWE2El7#9Y!~kFe}|rrh(3OOQ{wlTlWswX0ao0>En>FGUaa7Vv+*vio6G4`Z3G~*T>LXD*OMam5}4D""wqu{BrOw-s$e+@NbwNQ90z#{rR7oX2Wf5{kz2LsSLiFS3aZ!Fxxc4QTg+y*@(n9HAr5aYSJO-qQiO=nq2E#5${FE<=`2_muOX3OF""Fg9gyq8UINs04fQ9&~s&Eu8h}iQ}#0&4E?NXBh;xrt*3HL$4b)vozmUoky47jCg%|?!q;E$Osf4c!7OGKDoG6*v181&%R+%S~wH>""fxg{f`b%FHN7_C0is>n?a?6Ke8R*JO=FvxnZW!pRjVlsaAfzLEj8(;lY%NC@cfO_mONuN1L9xz1?CrRehT*?S$O3h1NmVwbR>@H*""ZM`|GOV6jh5mBf!B&|Wurmde~-)IIhngiR)J`cRlVl}A0^LvYY>&Cjhq-quJ3=~9@|9HDLHuB&AV$uQ^vsrn2cQhYu}s~fsvmU""31Wv8C&RU=HW1FA41D9vEGyK9Sfn)hl?LtaweFn~jd9DKetHlU;{EUTd_m>sVXn~%%|e(_^scqoyWbX(*~9116iRL%T=iv`afM(9""B*q-L3!+Ihk5H{>O%I!QBc^4ii-5!plh7@h05)T4q`w>$|Jlhz2)xMwHhBiknGcH=U1=i#jYA*AI18Ui6|lw)RDZ)2T=IpHd-1cZ""tP8;q-D2DzJK*YL6B;QzO@jE03@D6DW$u<75b$;T?GiV3fox!sClHrE9O->#rIs4iPg*)^E~RQN`n6r>*3MKH1o$TtM`F5rboP<="">|LL#LzrvLW?&c4$P{5kcHIeg?dQ{1hF$B3m=(Ttrz0+0R=$i=R;gcVYp5{1j2|I`0u@qnmsc8}a6;B~Njw""{uXJ#39?&}TQI&e-726xd4bd>dVahE7yUO{2Q(0N#Mns8)aWMP*4X%T1;GV$fjY0WQ@dZjkK@OZ{R&>^te~flHsu4iU_A>pB@Eiw7ri8;1$D6ID>}%$*AWCteY{Bi5Z7#Gl%ColYI@kI_et4?{H7X)Hqdr`;w_?K4U}1$9jwY""ZI=(k9BYO5AJqIdor;z@*0nyrzQkW>ZqcNDAvEdY4(7Wq9B~7+Q#FybVrehEAG$t*X>YRuk@B{n$IJ+7qhr~SLvAR$lrI0vEFkO?""yYmE;DN-oRzQ8olwbbkvbAkzGuY+skaB3R&MCOm5Pb>LsYIKirpD03k3s;Khe&imDLRqMtF^OzkSyA=B)uBk2`stR8)$du-HDRKy"">GKrWtK8Sd>v7TTuyGiS2Fh5EXuz>9RFN7MQL$+^9QaVCBiNdD;V-Qs=*Epq}GB==EO1#e>dM0IKaat5@)X?&hGO3>QFRZ2)`""?A~LBjP;L~!xN1J3<00QF~vFO=Tr%?n{w(%&xI)8%Ye5gxF&7OZ`yEyOrTz!_2BP#UtTzbmWZkAifJYz74$a*;Ho;gVj$r+JzoSsO+DZz""Zg6D^_qiEK0|v5LLo0l)-;I~es~s#hhFxi+nklM^Z@sEB^|q`y_xk)Wr+NB5$73}y;BzQ<~f>G(yUu=0{4JrBJd(WvjTdG$<~X@thE*30S>~yj;=Mvn~(IU?O}qgc*tejPrXjn#;own7_QAKd)t$e""AQNlI5a6r0u8QXP<`570jU!bVsXCXH*Y)?eks^-^dvkBWfj`KP)C}WmZywgc?)C14c-AG9bsUHB01d_~j<1O%{jKHH$H<~ht%%a?""zg##95a6i3s^tLN`<*p6N}Q-2pzbEDWYtlaunUXG&?1TH9og&IL2PdiE_QY&T-Y!e-y@j>9h7NK@;kk>)W-(t&1t5kjbncx#WLBj""Mx7>xZqeB!vcGsL7743tjUVU@X$R^8Z(}RxVIhz3cX)#bIbG!1@hwb""&TtQQON2HEjjyPXwM6r6_qNxS`rmeC`rwwQi3E{YRM4a=%;9=C7EG$8h)PXMLKXR2zq$1_s*TiBFHDkgds4jUSpU8k(XnDtWrsK4""F(X`Z=zljd6nKAHj*F)v0V!LgugvS!AW^zDPQeJU""*)|V;cmRoiW^%(tS=95uPvY)=-t8+b0bK4nMl+7MAd^lEIIHmnDYrCg+{$(n1Son9ug72Hf>j4X^Y@~$G8zgB!UX~kjgZq56{E<_""YFjAo(5fcjbK)kRETXHwIaV5&Z;0RzADHkZ1PcYXofF{Yj*??Q>""JZQawy}s<{CE4w4>>x^FbDMp3>X&|wq9XWO-)%e4~)@t6?qNc4&_;;u91<>xl+|HIP=n%`WeR>)pBWW""L&cq+o$xIb%MT{5dr_!CH*;2AT7O7S_D=>Q1r-lU8ArorZ8gOe4Xo=Qr-L)MLcTtwgw!h}+rWzvBUTA9{p~uOr-KU{q$`)|=JO5^""W;XyvK|U`vm!Y$=-8#o$HZa|3OaPClw&8KSrN(3%)%uVL*QY`I0AThSLP3QM9hMT1X6S+OaIR7?xb~nfh1Qj+@*ry^b~t0leQ_YG""FU=dl>F`q6AA3`HFcrZw=7*Qlu=$PyG9LkNp%o!E!5}k!_O|h*}}*$#F=hw235AZC5VJ-d%UGc7`}Rsyj6L-YKg;XQ0$fa34lL""iC;6&g($sh0C7*3oH1ULT7Y_9vDx5x|E;rsz89cxE~M6;Kqc@gLQCX5k%g0~dLky5HfsVYH3MxRs5""`IY0t+6tT>sj7F`0ZlE(a$K0F4HBcCirR{17gh}ZuuSsDn""G;W7PeHcoiq)ntbx}?5?Z3B+>a+y>|1r-ORsnCZ2Y4;`0FPr!)Y?){P19Mt7K5QWt%g5PzgR3plj>#s#*A&KG1i#;KxBX5YSW7x6"")47V`c=I2R%p+R6bCY|!F#A2Xo4w^sBkWA<9RMsB@7y7w{brbdAXw)TV~>yZe*fd^TL((eqj6BMCYpTSmY6_^D87&B`rB64!ueN3""6Ac){1D}wk9FKNb2e1R9qN;Z9y0vEh59m^%n@j*l+A3r}lN(SR5Z+Blvq}T-7i!oe#LpHulu`JdRPJ@$5C>BPYT&lH8ZI%8Le^%|""h=Tp7{qPq;QR6*E(+L<$V#uEvw>^<(n)R_pXyCVhS<;rzvOlCz373a|AvartXEkV%U*x@(L`rg2*?5`?SuZWNP3c76C=l""^?$Uu`R+l3dhY{z2x|)jCy=XT(NJq>=yG@fox(H?*J4OYc3W}HsNuaa@%L;yegcLNT+0HcnR}dBNsL^kz0bG1{1;-bbaiqlg7S7B3qUzjt2hvUTf??3Cg2EfVaw5EX!G""3)U$SqlA9t?ZD6gty=JCPd5TgW6n0j3^kN?ZVt_9iFIGMppOcwf6U*2KJA4QdmQ&2R(q}%87kZ#ZO@OyjgY=+?`t+r5UvdW1;@Wz""aNsg!R-ANyagfrK7GsQsB#zM2+GVWD+XM51Lo0QBw-KY*ANL6bn|DY>5l4Y&dEw6RneDohSeHQyYB`wp%}M3z3V}d;Vk;bxpGO7<""x;matTTqNY>CgRSmd$i?=UU1;vAghGp4pV#BwxjzLv7!7UPBW?2icX(Wh=aFiKMomp@;G}u^@3p*unCcd2|e;gHSCRC-TE0=7$YZ""?fylSm?;d^&pR=&jrt-IDg>i8kC^G)#nx&T%1#_bMyr{{eNC_lW>CN6(JYqKt;tWEES5DR6{j&UE)DVgj+i;6Pn;8}_yDXget<0%""nal0*)jcMAh8Rc7u?M!apU~U_&D1j&(!}Qrx-x?=;)kEbpC_iwo>63I^2Fz-Q_&m{D51@V{Fu;MlEZ#i3SFay2!pw*xZD$b