diff --git a/records/track_10min_16mb/2026-04-30_PiyushDatta_SP8192_DepthRecur_PolarNS_LoRATTT/README.md b/records/track_10min_16mb/2026-04-30_PiyushDatta_SP8192_DepthRecur_PolarNS_LoRATTT/README.md new file mode 100644 index 0000000000..994dc028ad --- /dev/null +++ b/records/track_10min_16mb/2026-04-30_PiyushDatta_SP8192_DepthRecur_PolarNS_LoRATTT/README.md @@ -0,0 +1,67 @@ +# SP8192 + Depth Recurrence + Polar Express NS + Phased LoRA TTT + +## Summary + +11-layer GPT with SP8192 tokenizer, MLP 4x, depth recurrence (layers 3-5 looped once), parallel residuals, Polar Express Newton-Schulz optimizer, and phased LoRA test-time training. + +**val_bpb: 1.09085** (3-seed mean, 8xH100, quantized sliding window) + +## Key Techniques + +- **SP8192 tokenizer**: 8x larger vocabulary vs SP1024 baseline +- **Depth recurrence**: Layers 3-5 run twice (14 effective passes from 11 unique layers), activated at 45% of training +- **Polar Express NS**: Per-iteration minimax-optimal Newton-Schulz coefficients for Muon optimizer +- **Parallel residuals**: Attention and MLP computed from same input in layers 7+ +- **MuonEq-R optimizer**: Row-normalized Muon with momentum 0.95 +- **SDClip GPTQ**: Hessian-weighted clip ranges for int6 quantization + int8 embeddings +- **SWA**: Stochastic Weight Averaging (every step during warmdown) +- **Half-batch training**: 393K tokens/batch for more gradient steps +- **Brotli compression**: Better compression ratio than LZMA for model weights +- **Phased LoRA TTT**: Score-first test-time training with batched LoRA adaptation + +## Architecture + +| Parameter | Value | +|-----------|-------| +| Layers | 11 (14 effective with depth recurrence) | +| Dimension | 512 | +| Heads | 8 (4 KV heads, GQA) | +| MLP multiplier | 4.0x | +| Activation | LeakyReLU(0.5)^2 | +| Vocab size | 8192 (SentencePiece) | +| Quantization | int6 (weights) + int8 (embeddings) | +| Compression | Brotli | + +## Training Configuration + +| Parameter | Value | +|-----------|-------| +| Optimizer | Muon (matrix) + AdamW (scalars, embeddings) | +| Matrix LR | 0.028 | +| Muon WD | 0.095 | +| Embed WD | 0.085 | +| Warmdown | 72% of training | +| SWA | Every step, start at scale < 0.12 | +| MIN_LR | 0.10 | +| Batch tokens | 393,216 | +| Max wallclock | 600s | + +## Reproduction + +```bash +# On 8xH100: +cd /workspace/parameter-golf +bash records/track_10min_16mb/2026-04-30_PiyushDatta_SP8192_DepthRecur_PolarNS_LoRATTT/run_final_submission.sh --nproc 8 + +# On 4xA100 (local testing, TTT will be slow): +bash records/track_10min_16mb/2026-04-30_PiyushDatta_SP8192_DepthRecur_PolarNS_LoRATTT/run_final_submission.sh --nproc 4 +``` + +## Attribution + +- SP8192 + GPTQ embeddings + SDClip: @clarkkev (PR #1394) +- Depth recurrence: @dexhunter (PR #1331, #1437) +- Parallel residuals: @Robby955 (PR #1412), @msisovic (PR #1204) +- Legal TTT framework: @abaybektursun (PR #549), @dexhunter (PR #1413) +- Polar Express NS: custom implementation (arxiv 2505.16932) +- Phased LoRA TTT: @dexhunter (PR #1626), @romeerp (PR #1610) diff --git a/records/track_10min_16mb/2026-04-30_PiyushDatta_SP8192_DepthRecur_PolarNS_LoRATTT/decompress_train_gpt.py b/records/track_10min_16mb/2026-04-30_PiyushDatta_SP8192_DepthRecur_PolarNS_LoRATTT/decompress_train_gpt.py new file mode 100755 index 0000000000..6458256f8e --- /dev/null +++ b/records/track_10min_16mb/2026-04-30_PiyushDatta_SP8192_DepthRecur_PolarNS_LoRATTT/decompress_train_gpt.py @@ -0,0 +1,19 @@ +#!/usr/bin/env python3 +"""Decompress train_gpt.py (LZMA-compressed) into train_gpt_readable.py for human reference.""" +import lzma, base64, sys +from pathlib import Path + +script_dir = Path(__file__).parent +src = script_dir / "train_gpt.py" +dst = script_dir / "train_gpt_readable.py" + +content = src.read_text() +start = content.index('b85decode("') + len('b85decode("') +end = content.index('")', start) +decompressed = lzma.decompress( + base64.b85decode(content[start:end]), + format=lzma.FORMAT_RAW, + filters=[{"id": lzma.FILTER_LZMA2}], +) +dst.write_bytes(decompressed) +print(f"Decompressed {len(decompressed)} bytes -> {dst}") diff --git a/records/track_10min_16mb/2026-04-30_PiyushDatta_SP8192_DepthRecur_PolarNS_LoRATTT/eval_only.py b/records/track_10min_16mb/2026-04-30_PiyushDatta_SP8192_DepthRecur_PolarNS_LoRATTT/eval_only.py new file mode 100644 index 0000000000..a63e84acb7 --- /dev/null +++ b/records/track_10min_16mb/2026-04-30_PiyushDatta_SP8192_DepthRecur_PolarNS_LoRATTT/eval_only.py @@ -0,0 +1,193 @@ +""" +Standalone eval-only script for sliding window + phased LoRA TTT evaluation. + +Usage: + torchrun --standalone --nproc_per_node=4 eval_only.py + +Loads a quantized model from final_model.int6.ptz and runs: + 1. Sliding window eval (eval_val_sliding) + 2. Phased LoRA TTT eval (eval_val_ttt_phased) if TTT_ENABLED=1 + +All classes/functions are imported from train_gpt_readable.py via exec(). +No training code is executed. +""" + +import math, os, random, sys, time +import numpy as np +import torch +import torch.distributed as dist + +# Force EVAL_ONLY so train_and_eval() would skip training if called directly +os.environ.setdefault("EVAL_ONLY", "1") + +# ---- Load all definitions from train_gpt_readable.py via exec ---- +_script_dir = os.path.dirname(os.path.abspath(__file__)) +_train_script = os.path.join(_script_dir, "train_gpt_readable.py") +_ns = {"__name__": "_train_gpt_readable_imported", "__file__": _train_script} +with open(_train_script, "r", encoding="utf-8") as _f: + exec(compile(_f.read(), _train_script, "exec"), _ns) + +# Pull out everything we need +Hyperparameters = _ns["Hyperparameters"] +ValidationData = _ns["ValidationData"] +GPT = _ns["GPT"] +deserialize = _ns["deserialize"] +eval_val = _ns["eval_val"] +eval_val_sliding = _ns["eval_val_sliding"] +eval_val_ttt_phased = _ns["eval_val_ttt_phased"] +BatchedTTTLoRA = _ns["BatchedTTTLoRA"] +BatchedLinearLoRA = _ns["BatchedLinearLoRA"] +set_logging_hparams = _ns["set_logging_hparams"] +log = _ns["log"] +timed_eval = _ns["timed_eval"] +restore_fp32_params = _ns["restore_fp32_params"] + + +def main(): + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + if world_size <= 0: + raise ValueError("bad world_size") + if 8 % world_size != 0: + raise ValueError("world_size must divide 8") + + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + torch.set_float32_matmul_precision("high") + from torch.backends.cuda import ( + enable_cudnn_sdp, enable_flash_sdp, + enable_math_sdp, enable_mem_efficient_sdp, + ) + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(True) + enable_math_sdp(False) + torch._dynamo.config.optimize_ddp = False + + # ---- Hyperparameters ---- + h = Hyperparameters() + set_logging_hparams(h) + + if h.is_main_process: + os.makedirs("logs", exist_ok=True) + log("=" * 60) + log("eval_only.py — standalone evaluation") + log("=" * 60) + log("Hyperparameters:") + for k, v in sorted(vars(type(h)).items()): + if not k.startswith("_"): + log(f" {k}: {v}") + log("=" * 60) + + random.seed(h.seed) + np.random.seed(h.seed) + torch.manual_seed(h.seed) + torch.cuda.manual_seed_all(h.seed) + + # ---- Validation data ---- + val_data = ValidationData(h, device) + log(f"val_tokens: {val_data.val_tokens.numel() - 1}") + + # ---- Deserialize quantized model ---- + if distributed: + dist.barrier() + eval_model = deserialize(h, device) + + # Enable looping if the model has depth recurrence indices + if len(eval_model.encoder_indices) != eval_model.num_encoder_layers: + eval_model.looping_active = True + log(f"looping_active=True encoder:{eval_model.encoder_indices} decoder:{eval_model.decoder_indices}") + + # ---- Standard quantized eval (non-sliding) ---- + compiled_model = torch.compile(eval_model, dynamic=False, fullgraph=True) + timed_eval("quantized", eval_val, h, device, val_data, compiled_model) + + # ---- Sliding window eval ---- + if h.sliding_window_enabled: + timed_eval("quantized_sliding_window", eval_val_sliding, h, device, val_data, eval_model) + + # ---- Phased LoRA TTT eval ---- + if h.ttt_enabled and h.ttt_lora_rank > 0: + ttt_model = deserialize(h, device) + if len(ttt_model.encoder_indices) != ttt_model.num_encoder_layers: + ttt_model.looping_active = True + + # Warm up rotary caches for TTT eval seq len + for block in ttt_model.blocks: + block.attn.rotary._cos_cached = None + block.attn.rotary._sin_cached = None + block.attn.rotary._seq_len_cached = 0 + block.attn.rotary(h.ttt_eval_seq_len, device, torch.bfloat16) + + # Build compiled forward_ttt function + def _fwd_ttt_inner(input_ids, target_ids, lora): + return ttt_model.forward_ttt(input_ids, target_ids, lora=lora) + + _fwd_ttt_compiled_inner = None + + def _fwd_ttt(input_ids, target_ids, lora): + nonlocal _fwd_ttt_compiled_inner + if _fwd_ttt_compiled_inner is None: + _fwd_ttt_compiled_inner = torch.compile(_fwd_ttt_inner, dynamic=True) + return _fwd_ttt_compiled_inner(input_ids, target_ids, lora=lora) + + fwd_ttt_compiled = _fwd_ttt + + # Compile warmup + log("ttt_lora:warming up compile") + t_warmup = time.perf_counter() + for bsz_w in [h.ttt_batch_size]: + wl = BatchedTTTLoRA( + bsz_w, ttt_model, h.ttt_lora_rank, + k_lora=h.ttt_k_lora, mlp_lora=h.ttt_mlp_lora, o_lora=h.ttt_o_lora, + ).to(device) + wo = torch.optim.AdamW( + wl.parameters(), lr=h.ttt_lora_lr, + betas=(h.ttt_beta1, h.ttt_beta2), eps=1e-10, + weight_decay=h.ttt_weight_decay, fused=True, + ) + for ctx_len in (h.ttt_chunk_size, h.ttt_eval_seq_len): + xw = torch.randint(0, h.vocab_size, (bsz_w, ctx_len), device=device, dtype=torch.int64) + yw = torch.randint(0, h.vocab_size, (bsz_w, ctx_len), device=device, dtype=torch.int64) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + ptl = fwd_ttt_compiled(xw, yw, lora=wl) + ptl[:, :min(h.ttt_chunk_size, ctx_len)].mean(dim=-1).sum().backward() + wo.step() + wo.zero_grad(set_to_none=True) + del wl, wo + torch.cuda.empty_cache() + log(f"ttt_lora:compile warmup done ({time.perf_counter() - t_warmup:.1f}s)") + log(f"ttt_lora_alpha: {BatchedLinearLoRA._ALPHA}") + log(f"ttt_warm_start_a: {BatchedLinearLoRA._WARM_START_A}") + + # Run TTT eval + log("beginning TTT eval timer") + torch.cuda.synchronize() + t_ttt = time.perf_counter() + ttt_val_loss, ttt_val_bpb = eval_val_ttt_phased( + h, ttt_model, device, val_data, forward_ttt_train=fwd_ttt_compiled, + ) + torch.cuda.synchronize() + ttt_eval_elapsed = time.perf_counter() - t_ttt + log( + f"quantized_ttt_phased val_loss:{ttt_val_loss:.8f} " + f"val_bpb:{ttt_val_bpb:.8f} eval_time:{ttt_eval_elapsed * 1e3:.0f}ms" + ) + + log("eval_only.py done") + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() diff --git a/records/track_10min_16mb/2026-04-30_PiyushDatta_SP8192_DepthRecur_PolarNS_LoRATTT/generate_submission_logs.py b/records/track_10min_16mb/2026-04-30_PiyushDatta_SP8192_DepthRecur_PolarNS_LoRATTT/generate_submission_logs.py new file mode 100644 index 0000000000..868b76d4b2 --- /dev/null +++ b/records/track_10min_16mb/2026-04-30_PiyushDatta_SP8192_DepthRecur_PolarNS_LoRATTT/generate_submission_logs.py @@ -0,0 +1,223 @@ +#!/usr/bin/env python3 +""" +Run train_gpt.py with 3 different seeds and save training logs. + +Usage: + # On 8xH100: + python generate_submission_logs.py --nproc 8 + + # On 4xA100: + python generate_submission_logs.py --nproc 4 + + # Custom seeds: + python generate_submission_logs.py --nproc 8 --seeds 42 314 999 + + # Dry run (print commands without executing): + python generate_submission_logs.py --nproc 8 --dry-run + +Output: + logs/seed_42.log, logs/seed_314.log, logs/seed_999.log + logs/summary.json (parsed results from all seeds) +""" + +import argparse +import json +import os +import re +import subprocess +import sys +from pathlib import Path + + +SCRIPT_DIR = Path(__file__).parent.resolve() +TRAIN_SCRIPT = SCRIPT_DIR / "train_gpt.py" +LOGS_DIR = SCRIPT_DIR / "logs" + + +def parse_log(log_path: str) -> dict: + """Parse a training log file and extract key metrics.""" + result = {} + with open(log_path, "r") as f: + text = f.read() + + m = re.search(r"seed:\s*(\d+)", text) + if m: + result["seed"] = int(m.group(1)) + + m = re.search(r"stopping_early.*step:\s*(\d+)/", text) + if m: + result["training_steps"] = int(m.group(1)) + + m = re.search(r"train_batch_tokens:\s*(\d+)", text) + if m: + result["train_batch_tokens"] = int(m.group(1)) + + m = re.search(r"world_size:\s*(\d+)", text) + if m: + result["world_size"] = int(m.group(1)) + + m = re.search(r"model_params:(\d+)", text) + if m: + result["model_params"] = int(m.group(1)) + + m = re.search(r"peak memory allocated:\s*(\d+)\s*MiB", text) + if m: + result["peak_memory_mib"] = int(m.group(1)) + + m = re.search(r"swa:applying SWA weights \((\d+) checkpoints\)", text) + if m: + result["swa_checkpoints"] = int(m.group(1)) + + m = re.search(r"pre-quantization post-ema val_loss:([\d.]+) val_bpb:([\d.]+)", text) + if m: + result["pre_quant_val_bpb"] = float(m.group(2)) + + m = re.search(r"Code size:\s*(\d+)\s*bytes", text) + if m: + result["code_bytes"] = int(m.group(1)) + + m = re.search(r"Serialized model quantized\+\w+:\s*(\d+)\s*bytes", text) + if m: + result["model_bytes"] = int(m.group(1)) + + m = re.search(r"Total submission size quantized\+\w+:\s*(\d+)\s*bytes", text) + if m: + result["artifact_bytes"] = int(m.group(1)) + + m = re.search(r"^quantized val_loss:([\d.]+) val_bpb:([\d.]+)", text, re.MULTILINE) + if m: + result["post_gptq_val_bpb"] = float(m.group(2)) + + m = re.search(r"quantized_sliding_window val_loss:([\d.]+) val_bpb:([\d.]+)", text) + if m: + result["sliding_val_bpb"] = float(m.group(2)) + + # tok/s from last logged training step + tok_matches = re.findall(r"tok/s:\s*(\d+)", text) + if tok_matches: + result["tok_per_sec"] = int(tok_matches[-1]) + + return result + + +def run_seed(seed: int, nproc: int, data_dir: str, extra_env: dict) -> str: + """Run training for a single seed. Returns path to log file.""" + LOGS_DIR.mkdir(parents=True, exist_ok=True) + log_path = LOGS_DIR / f"seed_{seed}.log" + + env = os.environ.copy() + env["SEED"] = str(seed) + env["DATA_DIR"] = data_dir + env.update(extra_env) + + cmd = [ + sys.executable, "-m", "torch.distributed.run", + "--standalone", + f"--nproc_per_node={nproc}", + str(TRAIN_SCRIPT), + ] + + print(f"\n{'='*70}") + print(f"Running seed={seed} with {nproc} GPUs") + print(f"Log: {log_path}") + print(f"Command: {' '.join(cmd)}") + print(f"{'='*70}\n") + + with open(log_path, "w") as log_file: + proc = subprocess.run( + cmd, + env=env, + stdout=log_file, + stderr=subprocess.STDOUT, + cwd=str(SCRIPT_DIR), + ) + + if proc.returncode != 0: + print(f"WARNING: seed={seed} exited with code {proc.returncode}") + else: + print(f"seed={seed} completed successfully") + + return str(log_path) + + +def main(): + parser = argparse.ArgumentParser(description="Generate 3-seed submission logs") + parser.add_argument("--nproc", type=int, required=True, + help="Number of GPUs (e.g., 4 for A100, 8 for H100)") + parser.add_argument("--seeds", type=int, nargs="+", default=[42, 314, 999], + help="Seeds to run (default: 42 314 999)") + parser.add_argument("--data-dir", type=str, default="./data/", + help="Data directory (default: ./data/)") + parser.add_argument("--dry-run", action="store_true", + help="Print commands without executing") + parser.add_argument("--env", type=str, default="", + help="Extra env vars as KEY=VAL,KEY2=VAL2") + args = parser.parse_args() + + extra_env = {} + if args.env: + for pair in args.env.split(","): + if "=" in pair: + k, v = pair.split("=", 1) + extra_env[k.strip()] = v.strip() + + if args.dry_run: + for seed in args.seeds: + env_str = " ".join(f"{k}={v}" for k, v in extra_env.items()) + print(f"DATA_DIR={args.data_dir} SEED={seed} {env_str} " + f"torchrun --standalone --nproc_per_node={args.nproc} {TRAIN_SCRIPT}") + return + + # Run each seed + results = {} + for seed in args.seeds: + log_path = run_seed(seed, args.nproc, args.data_dir, extra_env) + parsed = parse_log(log_path) + results[str(seed)] = parsed + if "sliding_val_bpb" in parsed: + print(f" -> sliding_val_bpb = {parsed['sliding_val_bpb']:.4f}") + if "artifact_bytes" in parsed: + fits = parsed["artifact_bytes"] < 16_000_000 + print(f" -> artifact = {parsed['artifact_bytes']:,} bytes ({'FITS' if fits else 'OVER!'})") + + # Compute summary statistics + sliding_bpbs = [r["sliding_val_bpb"] for r in results.values() if "sliding_val_bpb" in r] + summary = { + "seeds": args.seeds, + "nproc": args.nproc, + "seed_results": results, + } + if sliding_bpbs: + mean_bpb = sum(sliding_bpbs) / len(sliding_bpbs) + std_bpb = (sum((x - mean_bpb) ** 2 for x in sliding_bpbs) / len(sliding_bpbs)) ** 0.5 + summary["mean_sliding_val_bpb"] = round(mean_bpb, 6) + summary["std_sliding_val_bpb"] = round(std_bpb, 6) + + summary_path = LOGS_DIR / "summary.json" + with open(summary_path, "w") as f: + json.dump(summary, f, indent=2) + + # Print final summary + print(f"\n{'='*70}") + print("SUMMARY") + print(f"{'='*70}") + for seed_str, r in results.items(): + sliding = r.get("sliding_val_bpb", "N/A") + post_gptq = r.get("post_gptq_val_bpb", "N/A") + pre_quant = r.get("pre_quant_val_bpb", "N/A") + size = r.get("artifact_bytes", "N/A") + steps = r.get("training_steps", "N/A") + print(f" seed={seed_str}: sliding={sliding}, post_gptq={post_gptq}, " + f"pre_quant={pre_quant}, size={size}, steps={steps}") + + if sliding_bpbs: + print(f"\n Mean sliding_val_bpb: {summary['mean_sliding_val_bpb']:.6f}") + print(f" Std sliding_val_bpb: {summary['std_sliding_val_bpb']:.6f}") + + print(f"\nLogs saved to: {LOGS_DIR}/") + print(f"Summary saved to: {summary_path}") + print(f"\nUpdate submission.json and README.md with these results.") + + +if __name__ == "__main__": + main() diff --git a/records/track_10min_16mb/2026-04-30_PiyushDatta_SP8192_DepthRecur_PolarNS_LoRATTT/logs/seed_314.log b/records/track_10min_16mb/2026-04-30_PiyushDatta_SP8192_DepthRecur_PolarNS_LoRATTT/logs/seed_314.log new file mode 100644 index 0000000000..3ecf181ca0 --- /dev/null +++ b/records/track_10min_16mb/2026-04-30_PiyushDatta_SP8192_DepthRecur_PolarNS_LoRATTT/logs/seed_314.log @@ -0,0 +1,245 @@ + +***************************************** +Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +***************************************** +Hyperparameters: + adam_eps: 1e-08 + adam_wd: 0.02 + batch_schedule_enabled: False + beta1: 0.9 + beta2: 0.95 + byte_weighted_ce: False + compressor: brotli + data_dir: ./data/ + datasets_dir: ./data/datasets/fineweb10B_sp8192 + distributed: True + ema_decay: 0.9965 + embed_bits: 8 + embed_clip_sigmas: 20.0 + embed_lr: 0.6 + embed_wd: 0.085 + embedding_dim: 512 + enable_looping_at: 0.45 + eval_seq_len: 2048 + eval_stride: 64 + eval_temperature: 1.0 + global_ttt_batch_seqs: 32 + global_ttt_chunk_tokens: 32768 + global_ttt_epochs: 1 + global_ttt_grad_clip: 0.0 + global_ttt_lr: 0.001 + global_ttt_momentum: 0.9 + global_ttt_warmup_chunks: 0 + global_ttt_warmup_start_lr: 0.0 + gptq_calibration_batches: 64 + gptq_reserve_seconds: 0.0 + grad_accum_steps: 1 + grad_clip_norm: 0.3 + head_lr: 0.008 + hessian_clip_lambda: 0.175 + is_main_process: True + iterations: 20000 + kd_alpha: 0.5 + kd_enabled: False + kd_logits_dir: + kd_temperature: 2.0 + kd_top_k: 32 + kd_warmup_frac: 0.0 + ln_scale: True + local_rank: 0 + logfile: logs/0edf8df1-7bb0-45ab-8c47-7cd1b2dac7b9.txt + logit_softcap: 30.0 + matrix_bits: 6 + matrix_clip_sigmas: 12.85 + matrix_lr: 0.028 + max_eval_seconds: 600.0 + max_wallclock_seconds: 600.0 + min_lr: 0.1 + mlp_mult: 4.0 + model_dim: 512 + model_path: final_model.pt + mos_k: 1 + muon_backend_steps: 5 + muon_momentum: 0.95 + muon_momentum_warmup_start: 0.85 + muon_momentum_warmup_steps: 500 + muon_row_normalize: True + muon_wd: 0.095 + num_heads: 8 + num_kv_heads: 4 + num_layers: 11 + parallel_residual_start: 7 + phased_ttt_num_phases: 1 + phased_ttt_prefix_docs: 2000 + qk_gain_init: 5.0 + quantized_model_path: final_model.int6.ptz + rank: 0 + rope_base: 10000.0 + rope_dims: 16 + rope_train_seq_len: 2048 + run_id: 0edf8df1-7bb0-45ab-8c47-7cd1b2dac7b9 + scalar_lr: 0.02 + scale_tuning_batches: 8 + scale_tuning_enabled: False + scale_tuning_lr: 0.001 + scale_tuning_steps: 20 + seed: 314 + skip_gates_enabled: True + sliding_window_enabled: True + sparsity_start_frac: 0.0 + swa_enabled: True + swa_every: 1 + swa_start_frac: 0.12 + tie_embeddings: True + tied_embed_init_std: 0.005 + tied_embed_lr: 0.03 + tokenizer_path: ./data/tokenizers/fineweb_8192_bpe.model + train_batch_tokens: 393216 + train_files: ./data/datasets/fineweb10B_sp8192/fineweb_train_*.bin + train_log_every: 500 + train_seq_len: 2048 + ttt_batch_size: 64 + ttt_beta1: 0.0 + ttt_beta2: 0.999 + ttt_chunk_size: 48 + ttt_chunk_tokens: 32768 + ttt_enabled: True + ttt_epochs: 3 + ttt_eval_batches: + ttt_eval_seq_len: 2048 + ttt_grad_steps: 1 + ttt_k_lora: True + ttt_lora_alpha: 144.0 + ttt_lora_lr: 0.0001 + ttt_lora_rank: 96 + ttt_lr: 0.02 + ttt_mlp_lora: True + ttt_momentum: 0.9 + ttt_ns_steps: 0 + ttt_o_lora: True + ttt_optimizer: adam + ttt_reset_per_chunk: False + ttt_swa: False + ttt_warm_start_a: True + ttt_weight_decay: 1.0 + val_batch_tokens: 524288 + val_doc_fraction: 1.0 + val_files: ./data/datasets/fineweb10B_sp8192/fineweb_val_*.bin + val_loss_every: 4000 + vocab_size: 8192 + warmdown_frac: 0.72 + warmup_steps: 20 + world_size: 8 + xsa_last_n: 11 +val_tokens: 40540160 +train_shards: 128 +model_params:35943512 +gptq:reserving 0s, effective=600000ms +warmup_step: 1/20 +warmup_step: 2/20 +warmup_step: 3/20 +warmup_step: 4/20 +warmup_step: 5/20 +warmup_step: 6/20 +warmup_step: 10/20 +warmup_step: 20/20 +loop_warmup:on +loop_warmup_step: 1/20 +loop_warmup_step: 2/20 +loop_warmup_step: 3/20 +loop_warmup_step: 4/20 +loop_warmup_step: 5/20 +loop_warmup_step: 6/20 +loop_warmup_step: 10/20 +loop_warmup_step: 20/20 +0/20000 val_loss: 9.0096 val_bpb: 3.4879 +1/20000 train_loss: 9.0115 train_time: 0.0m tok/s: 6000007 +2/20000 train_loss: 12.3845 train_time: 0.0m tok/s: 5873520 +3/20000 train_loss: 11.5210 train_time: 0.0m tok/s: 6001402 +4/20000 train_loss: 9.7915 train_time: 0.0m tok/s: 6055040 +5/20000 train_loss: 8.5813 train_time: 0.0m tok/s: 6089621 +500/20000 train_loss: 3.4181 train_time: 0.5m tok/s: 6289888 +1000/20000 train_loss: 3.3235 train_time: 1.0m tok/s: 6292457 +1500/20000 train_loss: 3.2367 train_time: 1.6m tok/s: 6292115 +2000/20000 train_loss: 3.2793 train_time: 2.1m tok/s: 6293809 +2500/20000 train_loss: 3.2708 train_time: 2.6m tok/s: 6296979 +3000/20000 train_loss: 3.1554 train_time: 3.1m tok/s: 6299759 +3500/20000 train_loss: 3.1875 train_time: 3.6m tok/s: 6301729 +4000/20000 train_loss: 3.1841 train_time: 4.2m tok/s: 6304002 +4000/20000 val_loss: 3.1824 val_bpb: 1.2320 +layer_loop:enabled step:4330 frac:0.450 encoder:[0, 1, 2, 3, 4, 5, 3] decoder:[4, 5, 6, 7, 8, 9, 10] +4500/20000 train_loss: 3.1242 train_time: 4.7m tok/s: 6234468 +5000/20000 train_loss: 3.1561 train_time: 5.4m tok/s: 6105692 +5500/20000 train_loss: 3.0456 train_time: 6.0m tok/s: 6002621 +6000/20000 train_loss: 3.0549 train_time: 6.6m tok/s: 5920883 +6500/20000 train_loss: 3.0621 train_time: 7.3m tok/s: 5853523 +7000/20000 train_loss: 3.0969 train_time: 7.9m tok/s: 5791784 +7500/20000 train_loss: 3.1122 train_time: 8.6m tok/s: 5744162 +8000/20000 train_loss: 2.9257 train_time: 9.2m tok/s: 5703005 +8000/20000 val_loss: 2.8982 val_bpb: 1.1220 +8500/20000 train_loss: 2.8348 train_time: 9.8m tok/s: 5665109 +8631/20000 val_loss: 2.8786 val_bpb: 1.1144 +stopping_early: wallclock_cap train_time: 600047ms step: 8631/20000 +peak memory allocated: 18411 MiB reserved: 19408 MiB +swa:applying SWA weights (675 checkpoints) +pre-quantization post-ema val_loss:2.83269945 val_bpb:1.09662623 eval_time:5582ms +Serialized model: 135426937 bytes +Code size: 32872 bytes +GPTQ:collecting Hessians from calibration data... +GPTQ:collected 67 Hessians in 6.1s +Quantized weights: + gptq (int6): blocks.attn.c_k.weight, blocks.attn.c_q.weight, blocks.attn.c_v.weight, blocks.attn.proj.weight, blocks.mlp.fc.weight, blocks.mlp.proj.weight + gptq (int8): tok_emb.weight + passthrough (float16): blocks.attn.q_gain, blocks.attn_scale, blocks.mlp_scale, blocks.resid_mix, skip_gates, skip_weights +prune:over by 2030 bytes, selective pruning +prune:18095773 cand, zeroing 24360 +prune:zeroed 24360, now 15964858 bytes +Serialized model quantized+brotli: 15964858 bytes +Total submission size quantized+brotli: 15997730 bytes +eval:budget 600s +quantized val_loss:2.85970099 val_bpb:1.10707937 eval_time:7418ms +quantized_sliding_window val_loss:2.81649225 val_bpb:1.09035192 eval_time:86685ms +ttt_lora:warming up compile +ttt_lora:compile warmup done (155.6s) +ttt_lora_alpha: 144.0 +ttt_warm_start_a: True +beginning TTT eval timer +ttt_phased: total_docs:50000 prefix_docs:2000 suffix_docs:48000 num_phases:1 boundaries:[2000] +W0501 06:23:57.825000 12487 torch/distributed/elastic/agent/server/api.py:739] Received 15 death signal, shutting down workers +W0501 06:23:57.828000 12487 torch/distributed/elastic/multiprocessing/api.py:1010] Sending process 12555 closing signal SIGTERM +W0501 06:23:57.830000 12487 torch/distributed/elastic/multiprocessing/api.py:1010] Sending process 12556 closing signal SIGTERM +W0501 06:23:57.832000 12487 torch/distributed/elastic/multiprocessing/api.py:1010] Sending process 12557 closing signal SIGTERM +W0501 06:23:57.834000 12487 torch/distributed/elastic/multiprocessing/api.py:1010] Sending process 12558 closing signal SIGTERM +W0501 06:23:57.836000 12487 torch/distributed/elastic/multiprocessing/api.py:1010] Sending process 12559 closing signal SIGTERM +W0501 06:23:57.837000 12487 torch/distributed/elastic/multiprocessing/api.py:1010] Sending process 12560 closing signal SIGTERM +W0501 06:23:57.839000 12487 torch/distributed/elastic/multiprocessing/api.py:1010] Sending process 12561 closing signal SIGTERM +W0501 06:23:57.842000 12487 torch/distributed/elastic/multiprocessing/api.py:1010] Sending process 12562 closing signal SIGTERM +Traceback (most recent call last): + File "", line 198, in _run_module_as_main + File "", line 88, in _run_code + File "/workspace/uv-envs/parameter-golf/lib/python3.12/site-packages/torch/distributed/run.py", line 995, in + main() + File "/workspace/uv-envs/parameter-golf/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 362, in wrapper + return f(*args, **kwargs) + ^^^^^^^^^^^^^^^^^^ + File "/workspace/uv-envs/parameter-golf/lib/python3.12/site-packages/torch/distributed/run.py", line 991, in main + run(args) + File "/workspace/uv-envs/parameter-golf/lib/python3.12/site-packages/torch/distributed/run.py", line 982, in run + elastic_launch( + File "/workspace/uv-envs/parameter-golf/lib/python3.12/site-packages/torch/distributed/launcher/api.py", line 170, in __call__ + return launch_agent(self._config, self._entrypoint, list(args)) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/workspace/uv-envs/parameter-golf/lib/python3.12/site-packages/torch/distributed/launcher/api.py", line 308, in launch_agent + result = agent.run() + ^^^^^^^^^^^ + File "/workspace/uv-envs/parameter-golf/lib/python3.12/site-packages/torch/distributed/elastic/metrics/api.py", line 134, in wrapper + result = f(*args, **kwargs) + ^^^^^^^^^^^^^^^^^^ + File "/workspace/uv-envs/parameter-golf/lib/python3.12/site-packages/torch/distributed/elastic/agent/server/api.py", line 731, in run + result = self._invoke_run(role) + ^^^^^^^^^^^^^^^^^^^^^^ + File "/workspace/uv-envs/parameter-golf/lib/python3.12/site-packages/torch/distributed/elastic/agent/server/api.py", line 908, in _invoke_run + time.sleep(monitor_interval) + File "/workspace/uv-envs/parameter-golf/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/api.py", line 86, in _terminate_process_handler + raise SignalException(f"Process {os.getpid()} got signal: {sigval}", sigval=sigval) +torch.distributed.elastic.multiprocessing.api.SignalException: Process 12487 got signal: 15 diff --git a/records/track_10min_16mb/2026-04-30_PiyushDatta_SP8192_DepthRecur_PolarNS_LoRATTT/logs/seed_42.log b/records/track_10min_16mb/2026-04-30_PiyushDatta_SP8192_DepthRecur_PolarNS_LoRATTT/logs/seed_42.log new file mode 100644 index 0000000000..fa704fed72 --- /dev/null +++ b/records/track_10min_16mb/2026-04-30_PiyushDatta_SP8192_DepthRecur_PolarNS_LoRATTT/logs/seed_42.log @@ -0,0 +1,237 @@ + +***************************************** +Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +***************************************** +Hyperparameters: + adam_eps: 1e-08 + adam_wd: 0.02 + batch_schedule_enabled: False + beta1: 0.9 + beta2: 0.95 + byte_weighted_ce: False + compressor: brotli + data_dir: ./data/ + datasets_dir: ./data/datasets/fineweb10B_sp8192 + distributed: True + ema_decay: 0.9965 + embed_bits: 8 + embed_clip_sigmas: 20.0 + embed_lr: 0.6 + embed_wd: 0.085 + embedding_dim: 512 + enable_looping_at: 0.45 + eval_seq_len: 2048 + eval_stride: 64 + eval_temperature: 1.0 + global_ttt_batch_seqs: 32 + global_ttt_chunk_tokens: 32768 + global_ttt_epochs: 1 + global_ttt_grad_clip: 0.0 + global_ttt_lr: 0.001 + global_ttt_momentum: 0.9 + global_ttt_warmup_chunks: 0 + global_ttt_warmup_start_lr: 0.0 + gptq_calibration_batches: 64 + gptq_reserve_seconds: 0.0 + grad_accum_steps: 1 + grad_clip_norm: 0.3 + head_lr: 0.008 + hessian_clip_lambda: 0.175 + is_main_process: True + iterations: 20000 + kd_alpha: 0.5 + kd_enabled: False + kd_logits_dir: + kd_temperature: 2.0 + kd_top_k: 32 + kd_warmup_frac: 0.0 + ln_scale: True + local_rank: 0 + logfile: logs/c387af6a-37e9-46ce-96ab-7d962e511d1a.txt + logit_softcap: 30.0 + matrix_bits: 6 + matrix_clip_sigmas: 12.85 + matrix_lr: 0.028 + max_eval_seconds: 600.0 + max_wallclock_seconds: 600.0 + min_lr: 0.1 + mlp_mult: 4.0 + model_dim: 512 + model_path: final_model.pt + mos_k: 1 + muon_backend_steps: 5 + muon_momentum: 0.95 + muon_momentum_warmup_start: 0.85 + muon_momentum_warmup_steps: 500 + muon_row_normalize: True + muon_wd: 0.095 + num_heads: 8 + num_kv_heads: 4 + num_layers: 11 + parallel_residual_start: 7 + phased_ttt_num_phases: 1 + phased_ttt_prefix_docs: 2000 + qk_gain_init: 5.0 + quantized_model_path: final_model.int6.ptz + rank: 0 + rope_base: 10000.0 + rope_dims: 16 + rope_train_seq_len: 2048 + run_id: c387af6a-37e9-46ce-96ab-7d962e511d1a + scalar_lr: 0.02 + scale_tuning_batches: 8 + scale_tuning_enabled: False + scale_tuning_lr: 0.001 + scale_tuning_steps: 20 + seed: 42 + skip_gates_enabled: True + sliding_window_enabled: True + sparsity_start_frac: 0.0 + swa_enabled: True + swa_every: 1 + swa_start_frac: 0.12 + tie_embeddings: True + tied_embed_init_std: 0.005 + tied_embed_lr: 0.03 + tokenizer_path: ./data/tokenizers/fineweb_8192_bpe.model + train_batch_tokens: 393216 + train_files: ./data/datasets/fineweb10B_sp8192/fineweb_train_*.bin + train_log_every: 500 + train_seq_len: 2048 + ttt_batch_size: 64 + ttt_beta1: 0.0 + ttt_beta2: 0.999 + ttt_chunk_size: 48 + ttt_chunk_tokens: 32768 + ttt_enabled: True + ttt_epochs: 3 + ttt_eval_batches: + ttt_eval_seq_len: 2048 + ttt_grad_steps: 1 + ttt_k_lora: True + ttt_lora_alpha: 144.0 + ttt_lora_lr: 0.0001 + ttt_lora_rank: 96 + ttt_lr: 0.02 + ttt_mlp_lora: True + ttt_momentum: 0.9 + ttt_ns_steps: 0 + ttt_o_lora: True + ttt_optimizer: adam + ttt_reset_per_chunk: False + ttt_swa: False + ttt_warm_start_a: True + ttt_weight_decay: 1.0 + val_batch_tokens: 524288 + val_doc_fraction: 1.0 + val_files: ./data/datasets/fineweb10B_sp8192/fineweb_val_*.bin + val_loss_every: 4000 + vocab_size: 8192 + warmdown_frac: 0.72 + warmup_steps: 20 + world_size: 8 + xsa_last_n: 11 +val_tokens: 40540160 +train_shards: 128 +model_params:35943512 +gptq:reserving 0s, effective=600000ms +warmup_step: 1/20 +warmup_step: 2/20 +warmup_step: 3/20 +warmup_step: 4/20 +warmup_step: 5/20 +warmup_step: 6/20 +warmup_step: 10/20 +warmup_step: 20/20 +loop_warmup:on +loop_warmup_step: 1/20 +loop_warmup_step: 2/20 +loop_warmup_step: 3/20 +loop_warmup_step: 4/20 +loop_warmup_step: 5/20 +loop_warmup_step: 6/20 +loop_warmup_step: 10/20 +loop_warmup_step: 20/20 +0/20000 val_loss: 9.0090 val_bpb: 3.4877 +1/20000 train_loss: 9.0126 train_time: 0.0m tok/s: 6623905 +2/20000 train_loss: 12.3729 train_time: 0.0m tok/s: 6561103 +3/20000 train_loss: 11.4835 train_time: 0.0m tok/s: 6488964 +4/20000 train_loss: 9.7734 train_time: 0.0m tok/s: 6449721 +5/20000 train_loss: 8.5891 train_time: 0.0m tok/s: 6425375 +500/20000 train_loss: 3.4259 train_time: 0.5m tok/s: 6300857 +1000/20000 train_loss: 3.3242 train_time: 1.0m tok/s: 6302031 +1500/20000 train_loss: 3.2298 train_time: 1.6m tok/s: 6298994 +2000/20000 train_loss: 3.2786 train_time: 2.1m tok/s: 6298743 +2500/20000 train_loss: 3.2658 train_time: 2.6m tok/s: 6300972 +3000/20000 train_loss: 3.1592 train_time: 3.1m tok/s: 6303294 +3500/20000 train_loss: 3.1920 train_time: 3.6m tok/s: 6305846 +4000/20000 train_loss: 3.1895 train_time: 4.2m tok/s: 6305765 +4000/20000 val_loss: 3.1842 val_bpb: 1.2327 +layer_loop:enabled step:4331 frac:0.450 encoder:[0, 1, 2, 3, 4, 5, 3] decoder:[4, 5, 6, 7, 8, 9, 10] +4500/20000 train_loss: 3.1278 train_time: 4.7m tok/s: 6254266 +5000/20000 train_loss: 3.1582 train_time: 5.4m tok/s: 6113468 +5500/20000 train_loss: 3.0527 train_time: 6.0m tok/s: 5976418 +6000/20000 train_loss: 3.0583 train_time: 6.7m tok/s: 5889893 +6500/20000 train_loss: 3.0650 train_time: 7.3m tok/s: 5818860 +7000/20000 train_loss: 3.0981 train_time: 8.0m tok/s: 5765645 +7500/20000 train_loss: 3.1112 train_time: 8.6m tok/s: 5714411 +8000/20000 train_loss: 2.9164 train_time: 9.2m tok/s: 5675497 +8000/20000 val_loss: 2.8948 val_bpb: 1.1207 +8500/20000 train_loss: 2.8299 train_time: 9.9m tok/s: 5640097 +8597/20000 val_loss: 2.8796 val_bpb: 1.1148 +stopping_early: wallclock_cap train_time: 600047ms step: 8597/20000 +peak memory allocated: 18417 MiB reserved: 19440 MiB +swa:applying SWA weights (676 checkpoints) +pre-quantization post-ema val_loss:2.83333412 val_bpb:1.09687193 eval_time:5336ms +Serialized model: 135426937 bytes +Code size: 32872 bytes +GPTQ:collecting Hessians from calibration data... +GPTQ:collected 67 Hessians in 6.1s +Quantized weights: + gptq (int6): blocks.attn.c_k.weight, blocks.attn.c_q.weight, blocks.attn.c_v.weight, blocks.attn.proj.weight, blocks.mlp.fc.weight, blocks.mlp.proj.weight + gptq (int8): tok_emb.weight + passthrough (float16): blocks.attn.q_gain, blocks.attn_scale, blocks.mlp_scale, blocks.resid_mix, skip_gates, skip_weights +Serialized model quantized+brotli: 15966812 bytes +Total submission size quantized+brotli: 15999684 bytes +eval:budget 600s +quantized val_loss:2.85680436 val_bpb:1.10595800 eval_time:22126ms +quantized_sliding_window val_loss:2.81389729 val_bpb:1.08934733 eval_time:118209ms +ttt_lora:warming up compile +W0501 06:03:55.301000 2413 torch/distributed/elastic/agent/server/api.py:739] Received 15 death signal, shutting down workers +W0501 06:03:55.303000 2413 torch/distributed/elastic/multiprocessing/api.py:1010] Sending process 2481 closing signal SIGTERM +W0501 06:03:55.304000 2413 torch/distributed/elastic/multiprocessing/api.py:1010] Sending process 2482 closing signal SIGTERM +W0501 06:03:55.305000 2413 torch/distributed/elastic/multiprocessing/api.py:1010] Sending process 2483 closing signal SIGTERM +W0501 06:03:55.306000 2413 torch/distributed/elastic/multiprocessing/api.py:1010] Sending process 2484 closing signal SIGTERM +W0501 06:03:55.307000 2413 torch/distributed/elastic/multiprocessing/api.py:1010] Sending process 2485 closing signal SIGTERM +W0501 06:03:55.310000 2413 torch/distributed/elastic/multiprocessing/api.py:1010] Sending process 2486 closing signal SIGTERM +W0501 06:03:55.312000 2413 torch/distributed/elastic/multiprocessing/api.py:1010] Sending process 2487 closing signal SIGTERM +W0501 06:03:55.320000 2413 torch/distributed/elastic/multiprocessing/api.py:1010] Sending process 2488 closing signal SIGTERM +Traceback (most recent call last): + File "", line 198, in _run_module_as_main + File "", line 88, in _run_code + File "/workspace/uv-envs/parameter-golf/lib/python3.12/site-packages/torch/distributed/run.py", line 995, in + main() + File "/workspace/uv-envs/parameter-golf/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 362, in wrapper + return f(*args, **kwargs) + ^^^^^^^^^^^^^^^^^^ + File "/workspace/uv-envs/parameter-golf/lib/python3.12/site-packages/torch/distributed/run.py", line 991, in main + run(args) + File "/workspace/uv-envs/parameter-golf/lib/python3.12/site-packages/torch/distributed/run.py", line 982, in run + elastic_launch( + File "/workspace/uv-envs/parameter-golf/lib/python3.12/site-packages/torch/distributed/launcher/api.py", line 170, in __call__ + return launch_agent(self._config, self._entrypoint, list(args)) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/workspace/uv-envs/parameter-golf/lib/python3.12/site-packages/torch/distributed/launcher/api.py", line 308, in launch_agent + result = agent.run() + ^^^^^^^^^^^ + File "/workspace/uv-envs/parameter-golf/lib/python3.12/site-packages/torch/distributed/elastic/metrics/api.py", line 134, in wrapper + result = f(*args, **kwargs) + ^^^^^^^^^^^^^^^^^^ + File "/workspace/uv-envs/parameter-golf/lib/python3.12/site-packages/torch/distributed/elastic/agent/server/api.py", line 731, in run + result = self._invoke_run(role) + ^^^^^^^^^^^^^^^^^^^^^^ + File "/workspace/uv-envs/parameter-golf/lib/python3.12/site-packages/torch/distributed/elastic/agent/server/api.py", line 908, in _invoke_run + time.sleep(monitor_interval) + File "/workspace/uv-envs/parameter-golf/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/api.py", line 86, in _terminate_process_handler + raise SignalException(f"Process {os.getpid()} got signal: {sigval}", sigval=sigval) +torch.distributed.elastic.multiprocessing.api.SignalException: Process 2413 got signal: 15 diff --git a/records/track_10min_16mb/2026-04-30_PiyushDatta_SP8192_DepthRecur_PolarNS_LoRATTT/logs/seed_999.log b/records/track_10min_16mb/2026-04-30_PiyushDatta_SP8192_DepthRecur_PolarNS_LoRATTT/logs/seed_999.log new file mode 100644 index 0000000000..21e9e1da1e --- /dev/null +++ b/records/track_10min_16mb/2026-04-30_PiyushDatta_SP8192_DepthRecur_PolarNS_LoRATTT/logs/seed_999.log @@ -0,0 +1,285 @@ + +***************************************** +Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +***************************************** +Hyperparameters: + adam_eps: 1e-08 + adam_wd: 0.02 + batch_schedule_enabled: False + beta1: 0.9 + beta2: 0.95 + byte_weighted_ce: False + compressor: brotli + data_dir: ./data/ + datasets_dir: ./data/datasets/fineweb10B_sp8192 + distributed: True + ema_decay: 0.9965 + embed_bits: 8 + embed_clip_sigmas: 20.0 + embed_lr: 0.6 + embed_wd: 0.085 + embedding_dim: 512 + enable_looping_at: 0.45 + eval_seq_len: 2048 + eval_stride: 64 + eval_temperature: 1.0 + global_ttt_batch_seqs: 32 + global_ttt_chunk_tokens: 32768 + global_ttt_epochs: 1 + global_ttt_grad_clip: 0.0 + global_ttt_lr: 0.001 + global_ttt_momentum: 0.9 + global_ttt_warmup_chunks: 0 + global_ttt_warmup_start_lr: 0.0 + gptq_calibration_batches: 64 + gptq_reserve_seconds: 0.0 + grad_accum_steps: 1 + grad_clip_norm: 0.3 + head_lr: 0.008 + hessian_clip_lambda: 0.175 + is_main_process: True + iterations: 20000 + kd_alpha: 0.5 + kd_enabled: False + kd_logits_dir: + kd_temperature: 2.0 + kd_top_k: 32 + kd_warmup_frac: 0.0 + ln_scale: True + local_rank: 0 + logfile: logs/1b893ce6-90dd-4aa8-be7d-ea140ac1857e.txt + logit_softcap: 30.0 + matrix_bits: 6 + matrix_clip_sigmas: 12.85 + matrix_lr: 0.028 + max_eval_seconds: 600.0 + max_wallclock_seconds: 600.0 + min_lr: 0.1 + mlp_mult: 4.0 + model_dim: 512 + model_path: final_model.pt + mos_k: 1 + muon_backend_steps: 5 + muon_momentum: 0.95 + muon_momentum_warmup_start: 0.85 + muon_momentum_warmup_steps: 500 + muon_row_normalize: True + muon_wd: 0.095 + num_heads: 8 + num_kv_heads: 4 + num_layers: 11 + parallel_residual_start: 7 + phased_ttt_num_phases: 1 + phased_ttt_prefix_docs: 2000 + qk_gain_init: 5.0 + quantized_model_path: final_model.int6.ptz + rank: 0 + rope_base: 10000.0 + rope_dims: 16 + rope_train_seq_len: 2048 + run_id: 1b893ce6-90dd-4aa8-be7d-ea140ac1857e + scalar_lr: 0.02 + scale_tuning_batches: 8 + scale_tuning_enabled: False + scale_tuning_lr: 0.001 + scale_tuning_steps: 20 + seed: 999 + skip_gates_enabled: True + sliding_window_enabled: True + sparsity_start_frac: 0.0 + swa_enabled: True + swa_every: 1 + swa_start_frac: 0.12 + tie_embeddings: True + tied_embed_init_std: 0.005 + tied_embed_lr: 0.03 + tokenizer_path: ./data/tokenizers/fineweb_8192_bpe.model + train_batch_tokens: 393216 + train_files: ./data/datasets/fineweb10B_sp8192/fineweb_train_*.bin + train_log_every: 500 + train_seq_len: 2048 + ttt_batch_size: 64 + ttt_beta1: 0.0 + ttt_beta2: 0.999 + ttt_chunk_size: 48 + ttt_chunk_tokens: 32768 + ttt_enabled: True + ttt_epochs: 3 + ttt_eval_batches: + ttt_eval_seq_len: 2048 + ttt_grad_steps: 1 + ttt_k_lora: True + ttt_lora_alpha: 144.0 + ttt_lora_lr: 0.0001 + ttt_lora_rank: 96 + ttt_lr: 0.02 + ttt_mlp_lora: True + ttt_momentum: 0.9 + ttt_ns_steps: 0 + ttt_o_lora: True + ttt_optimizer: adam + ttt_reset_per_chunk: False + ttt_swa: False + ttt_warm_start_a: True + ttt_weight_decay: 1.0 + val_batch_tokens: 524288 + val_doc_fraction: 1.0 + val_files: ./data/datasets/fineweb10B_sp8192/fineweb_val_*.bin + val_loss_every: 4000 + vocab_size: 8192 + warmdown_frac: 0.72 + warmup_steps: 20 + world_size: 8 + xsa_last_n: 11 +val_tokens: 40540160 +train_shards: 128 +model_params:35943512 +gptq:reserving 0s, effective=600000ms +warmup_step: 1/20 +warmup_step: 2/20 +warmup_step: 3/20 +warmup_step: 4/20 +warmup_step: 5/20 +warmup_step: 6/20 +warmup_step: 10/20 +warmup_step: 20/20 +loop_warmup:on +loop_warmup_step: 1/20 +loop_warmup_step: 2/20 +loop_warmup_step: 3/20 +loop_warmup_step: 4/20 +loop_warmup_step: 5/20 +loop_warmup_step: 6/20 +loop_warmup_step: 10/20 +loop_warmup_step: 20/20 +0/20000 val_loss: 9.0076 val_bpb: 3.4871 +1/20000 train_loss: 9.0098 train_time: 0.0m tok/s: 6528233 +2/20000 train_loss: 12.3444 train_time: 0.0m tok/s: 6541910 +3/20000 train_loss: 11.5189 train_time: 0.0m tok/s: 6489937 +4/20000 train_loss: 9.8334 train_time: 0.0m tok/s: 6428527 +5/20000 train_loss: 8.6307 train_time: 0.0m tok/s: 6405212 +500/20000 train_loss: 3.4286 train_time: 0.5m tok/s: 6296206 +1000/20000 train_loss: 3.3201 train_time: 1.0m tok/s: 6287190 +1500/20000 train_loss: 3.2359 train_time: 1.6m tok/s: 6282937 +2000/20000 train_loss: 3.2830 train_time: 2.1m tok/s: 6282348 +2500/20000 train_loss: 3.2653 train_time: 2.6m tok/s: 6283954 +3000/20000 train_loss: 3.1576 train_time: 3.1m tok/s: 6286834 +3500/20000 train_loss: 3.1876 train_time: 3.6m tok/s: 6289512 +4000/20000 train_loss: 3.1938 train_time: 4.2m tok/s: 6292130 +4000/20000 val_loss: 3.1854 val_bpb: 1.2332 +layer_loop:enabled step:4322 frac:0.450 encoder:[0, 1, 2, 3, 4, 5, 3] decoder:[4, 5, 6, 7, 8, 9, 10] +4500/20000 train_loss: 3.1322 train_time: 4.7m tok/s: 6223672 +5000/20000 train_loss: 3.1558 train_time: 5.4m tok/s: 6095307 +5500/20000 train_loss: 3.0531 train_time: 6.0m tok/s: 5992498 +6000/20000 train_loss: 3.0605 train_time: 6.7m tok/s: 5911094 +6500/20000 train_loss: 3.0667 train_time: 7.3m tok/s: 5844025 +7000/20000 train_loss: 3.1016 train_time: 7.9m tok/s: 5782287 +7500/20000 train_loss: 3.1132 train_time: 8.6m tok/s: 5735162 +8000/20000 train_loss: 2.9261 train_time: 9.2m tok/s: 5694179 +8000/20000 val_loss: 2.8982 val_bpb: 1.1220 +8500/20000 train_loss: 2.8317 train_time: 9.8m tok/s: 5657064 +8620/20000 val_loss: 2.8800 val_bpb: 1.1149 +stopping_early: wallclock_cap train_time: 600037ms step: 8620/20000 +peak memory allocated: 18411 MiB reserved: 19408 MiB +swa:applying SWA weights (675 checkpoints) +pre-quantization post-ema val_loss:2.83408903 val_bpb:1.09716418 eval_time:5553ms +Serialized model: 135426937 bytes +Code size: 32872 bytes +GPTQ:collecting Hessians from calibration data... +GPTQ:collected 67 Hessians in 6.1s +Quantized weights: + gptq (int6): blocks.attn.c_k.weight, blocks.attn.c_q.weight, blocks.attn.c_v.weight, blocks.attn.proj.weight, blocks.mlp.fc.weight, blocks.mlp.proj.weight + gptq (int8): tok_emb.weight + passthrough (float16): blocks.attn.q_gain, blocks.attn_scale, blocks.mlp_scale, blocks.resid_mix, skip_gates, skip_weights +prune:over by 1020 bytes, selective pruning +prune:18097590 cand, zeroing 12240 +prune:zeroed 12240, now 15965875 bytes +Serialized model quantized+brotli: 15965875 bytes +Total submission size quantized+brotli: 15998747 bytes +eval:budget 600s +quantized val_loss:2.86641728 val_bpb:1.10967946 eval_time:7381ms +quantized_sliding_window val_loss:2.82296926 val_bpb:1.09285937 eval_time:86613ms +ttt_lora:warming up compile +ttt_lora:compile warmup done (67.3s) +ttt_lora_alpha: 144.0 +ttt_warm_start_a: True +beginning TTT eval timer +ttt_phased: total_docs:50000 prefix_docs:2000 suffix_docs:48000 num_phases:1 boundaries:[2000] +[rank7]:W0501 06:43:53.579000 15949 torch/_dynamo/convert_frame.py:1676] [2/8] torch._dynamo hit config.recompile_limit (8) +[rank7]:W0501 06:43:53.579000 15949 torch/_dynamo/convert_frame.py:1676] [2/8] function: '_fwd_ttt_inner' (:1564) +[rank7]:W0501 06:43:53.579000 15949 torch/_dynamo/convert_frame.py:1676] [2/8] last reason: 2/7: ttt_model._modules['blocks']._modules['0']._modules['attn']._modules['rotary']._seq_len_cached == 240 # # :256 in forward (HINT: torch.compile considers integer attributes of the nn.Module to be static. If you are observing recompilation, you might want to make this integer dynamic using torch._dynamo.config.allow_unspec_int_on_nn_module = True, or convert this integer into a tensor.) +[rank7]:W0501 06:43:53.579000 15949 torch/_dynamo/convert_frame.py:1676] [2/8] To log all recompilation reasons, use TORCH_LOGS="recompiles". +[rank7]:W0501 06:43:53.579000 15949 torch/_dynamo/convert_frame.py:1676] [2/8] To diagnose recompilation issues, see https://pytorch.org/docs/main/compile/programming_model.recompilation.html +[rank3]:W0501 06:43:53.805000 15945 torch/_dynamo/convert_frame.py:1676] [2/8] torch._dynamo hit config.recompile_limit (8) +[rank3]:W0501 06:43:53.805000 15945 torch/_dynamo/convert_frame.py:1676] [2/8] function: '_fwd_ttt_inner' (:1564) +[rank3]:W0501 06:43:53.805000 15945 torch/_dynamo/convert_frame.py:1676] [2/8] last reason: 2/7: ttt_model._modules['blocks']._modules['0']._modules['attn']._modules['rotary']._seq_len_cached == 240 # # :256 in forward (HINT: torch.compile considers integer attributes of the nn.Module to be static. If you are observing recompilation, you might want to make this integer dynamic using torch._dynamo.config.allow_unspec_int_on_nn_module = True, or convert this integer into a tensor.) +[rank3]:W0501 06:43:53.805000 15945 torch/_dynamo/convert_frame.py:1676] [2/8] To log all recompilation reasons, use TORCH_LOGS="recompiles". +[rank3]:W0501 06:43:53.805000 15945 torch/_dynamo/convert_frame.py:1676] [2/8] To diagnose recompilation issues, see https://pytorch.org/docs/main/compile/programming_model.recompilation.html +[rank0]:W0501 06:43:54.492000 15942 torch/_dynamo/convert_frame.py:1676] [2/8] torch._dynamo hit config.recompile_limit (8) +[rank0]:W0501 06:43:54.492000 15942 torch/_dynamo/convert_frame.py:1676] [2/8] function: '_fwd_ttt_inner' (:1564) +[rank0]:W0501 06:43:54.492000 15942 torch/_dynamo/convert_frame.py:1676] [2/8] last reason: 2/7: ttt_model._modules['blocks']._modules['0']._modules['attn']._modules['rotary']._seq_len_cached == 240 # # :256 in forward (HINT: torch.compile considers integer attributes of the nn.Module to be static. If you are observing recompilation, you might want to make this integer dynamic using torch._dynamo.config.allow_unspec_int_on_nn_module = True, or convert this integer into a tensor.) +[rank0]:W0501 06:43:54.492000 15942 torch/_dynamo/convert_frame.py:1676] [2/8] To log all recompilation reasons, use TORCH_LOGS="recompiles". +[rank0]:W0501 06:43:54.492000 15942 torch/_dynamo/convert_frame.py:1676] [2/8] To diagnose recompilation issues, see https://pytorch.org/docs/main/compile/programming_model.recompilation.html +[rank2]:W0501 06:43:54.940000 15944 torch/_dynamo/convert_frame.py:1676] [2/8] torch._dynamo hit config.recompile_limit (8) +[rank2]:W0501 06:43:54.940000 15944 torch/_dynamo/convert_frame.py:1676] [2/8] function: '_fwd_ttt_inner' (:1564) +[rank2]:W0501 06:43:54.940000 15944 torch/_dynamo/convert_frame.py:1676] [2/8] last reason: 2/7: ttt_model._modules['blocks']._modules['0']._modules['attn']._modules['rotary']._seq_len_cached == 240 # # :256 in forward (HINT: torch.compile considers integer attributes of the nn.Module to be static. If you are observing recompilation, you might want to make this integer dynamic using torch._dynamo.config.allow_unspec_int_on_nn_module = True, or convert this integer into a tensor.) +[rank2]:W0501 06:43:54.940000 15944 torch/_dynamo/convert_frame.py:1676] [2/8] To log all recompilation reasons, use TORCH_LOGS="recompiles". +[rank2]:W0501 06:43:54.940000 15944 torch/_dynamo/convert_frame.py:1676] [2/8] To diagnose recompilation issues, see https://pytorch.org/docs/main/compile/programming_model.recompilation.html +[rank4]:W0501 06:43:54.961000 15946 torch/_dynamo/convert_frame.py:1676] [2/8] torch._dynamo hit config.recompile_limit (8) +[rank4]:W0501 06:43:54.961000 15946 torch/_dynamo/convert_frame.py:1676] [2/8] function: '_fwd_ttt_inner' (:1564) +[rank4]:W0501 06:43:54.961000 15946 torch/_dynamo/convert_frame.py:1676] [2/8] last reason: 2/7: ttt_model._modules['blocks']._modules['0']._modules['attn']._modules['rotary']._seq_len_cached == 240 # # :256 in forward (HINT: torch.compile considers integer attributes of the nn.Module to be static. If you are observing recompilation, you might want to make this integer dynamic using torch._dynamo.config.allow_unspec_int_on_nn_module = True, or convert this integer into a tensor.) +[rank4]:W0501 06:43:54.961000 15946 torch/_dynamo/convert_frame.py:1676] [2/8] To log all recompilation reasons, use TORCH_LOGS="recompiles". +[rank4]:W0501 06:43:54.961000 15946 torch/_dynamo/convert_frame.py:1676] [2/8] To diagnose recompilation issues, see https://pytorch.org/docs/main/compile/programming_model.recompilation.html +[rank1]:W0501 06:43:55.068000 15943 torch/_dynamo/convert_frame.py:1676] [2/8] torch._dynamo hit config.recompile_limit (8) +[rank1]:W0501 06:43:55.068000 15943 torch/_dynamo/convert_frame.py:1676] [2/8] function: '_fwd_ttt_inner' (:1564) +[rank1]:W0501 06:43:55.068000 15943 torch/_dynamo/convert_frame.py:1676] [2/8] last reason: 2/7: ttt_model._modules['blocks']._modules['0']._modules['attn']._modules['rotary']._seq_len_cached == 240 # # :256 in forward (HINT: torch.compile considers integer attributes of the nn.Module to be static. If you are observing recompilation, you might want to make this integer dynamic using torch._dynamo.config.allow_unspec_int_on_nn_module = True, or convert this integer into a tensor.) +[rank1]:W0501 06:43:55.068000 15943 torch/_dynamo/convert_frame.py:1676] [2/8] To log all recompilation reasons, use TORCH_LOGS="recompiles". +[rank1]:W0501 06:43:55.068000 15943 torch/_dynamo/convert_frame.py:1676] [2/8] To diagnose recompilation issues, see https://pytorch.org/docs/main/compile/programming_model.recompilation.html +[rank6]:W0501 06:43:55.606000 15948 torch/_dynamo/convert_frame.py:1676] [2/8] torch._dynamo hit config.recompile_limit (8) +[rank6]:W0501 06:43:55.606000 15948 torch/_dynamo/convert_frame.py:1676] [2/8] function: '_fwd_ttt_inner' (:1564) +[rank6]:W0501 06:43:55.606000 15948 torch/_dynamo/convert_frame.py:1676] [2/8] last reason: 2/7: ttt_model._modules['blocks']._modules['0']._modules['attn']._modules['rotary']._seq_len_cached == 240 # # :256 in forward (HINT: torch.compile considers integer attributes of the nn.Module to be static. If you are observing recompilation, you might want to make this integer dynamic using torch._dynamo.config.allow_unspec_int_on_nn_module = True, or convert this integer into a tensor.) +[rank6]:W0501 06:43:55.606000 15948 torch/_dynamo/convert_frame.py:1676] [2/8] To log all recompilation reasons, use TORCH_LOGS="recompiles". +[rank6]:W0501 06:43:55.606000 15948 torch/_dynamo/convert_frame.py:1676] [2/8] To diagnose recompilation issues, see https://pytorch.org/docs/main/compile/programming_model.recompilation.html +[rank5]:W0501 06:43:55.814000 15947 torch/_dynamo/convert_frame.py:1676] [2/8] torch._dynamo hit config.recompile_limit (8) +[rank5]:W0501 06:43:55.814000 15947 torch/_dynamo/convert_frame.py:1676] [2/8] function: '_fwd_ttt_inner' (:1564) +[rank5]:W0501 06:43:55.814000 15947 torch/_dynamo/convert_frame.py:1676] [2/8] last reason: 2/7: ttt_model._modules['blocks']._modules['0']._modules['attn']._modules['rotary']._seq_len_cached == 240 # # :256 in forward (HINT: torch.compile considers integer attributes of the nn.Module to be static. If you are observing recompilation, you might want to make this integer dynamic using torch._dynamo.config.allow_unspec_int_on_nn_module = True, or convert this integer into a tensor.) +[rank5]:W0501 06:43:55.814000 15947 torch/_dynamo/convert_frame.py:1676] [2/8] To log all recompilation reasons, use TORCH_LOGS="recompiles". +[rank5]:W0501 06:43:55.814000 15947 torch/_dynamo/convert_frame.py:1676] [2/8] To diagnose recompilation issues, see https://pytorch.org/docs/main/compile/programming_model.recompilation.html +W0501 06:44:00.074000 15874 torch/distributed/elastic/agent/server/api.py:739] Received 15 death signal, shutting down workers +W0501 06:44:00.078000 15874 torch/distributed/elastic/multiprocessing/api.py:1010] Sending process 15942 closing signal SIGTERM +W0501 06:44:00.079000 15874 torch/distributed/elastic/multiprocessing/api.py:1010] Sending process 15943 closing signal SIGTERM +W0501 06:44:00.082000 15874 torch/distributed/elastic/multiprocessing/api.py:1010] Sending process 15944 closing signal SIGTERM +W0501 06:44:00.083000 15874 torch/distributed/elastic/multiprocessing/api.py:1010] Sending process 15945 closing signal SIGTERM +W0501 06:44:00.086000 15874 torch/distributed/elastic/multiprocessing/api.py:1010] Sending process 15946 closing signal SIGTERM +W0501 06:44:00.090000 15874 torch/distributed/elastic/multiprocessing/api.py:1010] Sending process 15947 closing signal SIGTERM +W0501 06:44:00.093000 15874 torch/distributed/elastic/multiprocessing/api.py:1010] Sending process 15948 closing signal SIGTERM +W0501 06:44:00.097000 15874 torch/distributed/elastic/multiprocessing/api.py:1010] Sending process 15949 closing signal SIGTERM +Traceback (most recent call last): + File "", line 198, in _run_module_as_main + File "", line 88, in _run_code + File "/workspace/uv-envs/parameter-golf/lib/python3.12/site-packages/torch/distributed/run.py", line 995, in + main() + File "/workspace/uv-envs/parameter-golf/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 362, in wrapper + return f(*args, **kwargs) + ^^^^^^^^^^^^^^^^^^ + File "/workspace/uv-envs/parameter-golf/lib/python3.12/site-packages/torch/distributed/run.py", line 991, in main + run(args) + File "/workspace/uv-envs/parameter-golf/lib/python3.12/site-packages/torch/distributed/run.py", line 982, in run + elastic_launch( + File "/workspace/uv-envs/parameter-golf/lib/python3.12/site-packages/torch/distributed/launcher/api.py", line 170, in __call__ + return launch_agent(self._config, self._entrypoint, list(args)) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/workspace/uv-envs/parameter-golf/lib/python3.12/site-packages/torch/distributed/launcher/api.py", line 308, in launch_agent + result = agent.run() + ^^^^^^^^^^^ + File "/workspace/uv-envs/parameter-golf/lib/python3.12/site-packages/torch/distributed/elastic/metrics/api.py", line 134, in wrapper + result = f(*args, **kwargs) + ^^^^^^^^^^^^^^^^^^ + File "/workspace/uv-envs/parameter-golf/lib/python3.12/site-packages/torch/distributed/elastic/agent/server/api.py", line 731, in run + result = self._invoke_run(role) + ^^^^^^^^^^^^^^^^^^^^^^ + File "/workspace/uv-envs/parameter-golf/lib/python3.12/site-packages/torch/distributed/elastic/agent/server/api.py", line 908, in _invoke_run + time.sleep(monitor_interval) + File "/workspace/uv-envs/parameter-golf/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/api.py", line 86, in _terminate_process_handler + raise SignalException(f"Process {os.getpid()} got signal: {sigval}", sigval=sigval) +torch.distributed.elastic.multiprocessing.api.SignalException: Process 15874 got signal: 15 diff --git a/records/track_10min_16mb/2026-04-30_PiyushDatta_SP8192_DepthRecur_PolarNS_LoRATTT/logs_summary.md b/records/track_10min_16mb/2026-04-30_PiyushDatta_SP8192_DepthRecur_PolarNS_LoRATTT/logs_summary.md new file mode 100644 index 0000000000..e50499299f --- /dev/null +++ b/records/track_10min_16mb/2026-04-30_PiyushDatta_SP8192_DepthRecur_PolarNS_LoRATTT/logs_summary.md @@ -0,0 +1,60 @@ +# Log Summary + +The staged seed logs do not contain a completed `quantized_ttt_phased` result. +All three runs were interrupted during TTT compile/eval, so the last completed +validation metric is `quantized_sliding_window val_bpb`. + +## Final Metrics + +| Seed | Train stop step | Last scheduled val step | Final completed val_bpb | Artifact size | +| ---- | --------------: | ----------------------: | ----------------------: | ---------------: | +| 42 | 8597 | 8597 | 1.08934733 | 15,999,684 bytes | +| 314 | 8631 | 8631 | 1.09035192 | 15,997,730 bytes | +| 999 | 8620 | 8620 | 1.09285937 | 15,998,747 bytes | + +## Mean + +- `quantized_sliding_window val_bpb` mean: `1.09085287` + +## Source Logs + +- `logs/seed_42.log` +- `logs/seed_314.log` +- `logs/seed_999.log` + +--- + +# Feature Uniqueness Analysis + +Analyzed all 2,015 PRs (open, closed, merged) on openai/parameter-golf as of 2026-05-01. + +## Unique (I think no one else tried it..) + +1. **Multi-Trajectory SWA** — Each GPU rank follows independent trajectory during warmdown (grad sync off), then SWA averages combined across ranks. +2. **Scale Tuning Post-GPTQ** — Freeze quantized int weights, fine-tune only per-row scales via CE loss backprop (Adam, 20 steps). +3. **Two-Pass GPTQ** — Run GPTQ, dequantize, re-collect Hessians on quantized model, run GPTQ again. + +## Partially Unique (maybe our variant is novel? But concept was explored) + +4. **Selective 2:4 Sparsity (training-time)** — Mid-training one-shot 2:4 pruning on MLP weights. PR #1537 tried post-training 2:4 (negative). PR #1818 tried as compression codec (catastrophic). + +## Not Unique (tried by others) + +5. **Mixture of Softmax** — PRs #266, #584, #908, #1227, #1608, #1995. All neutral-to-harmful. +6. **Hourglass Downsampling** — PRs #133, #831, #1275, #1573, #2004. PR #831 called it "catastrophic." +7. **Loop Gate** — PRs #155, #1208, #1691, #1996. Well-explored by 4-5 teams. +8. **Gated MLP / SwiGLU** — 20+ PRs, 2 merged. Most widely tried feature in competition. +9. **Knowledge Distillation** — PRs #578, #687, #896, #1029, #1034, #1083, #1185, #1697. All negative. +10. **Hard Token Mining / Focal Loss** — PRs #687, #877, #1233, #1325, #1360, #1380, #1402, #1510, #1702. All negative. +11. **Byte-Weighted CE** — PRs #108, #1033, #1359, #1519. None merged. +12. **Momentum Cooldown** — PRs #534, #1337. Neither merged. + +## Shared (in other merged submissions) + +13. **Fused Softcapped CE (Triton)** — PR #1787 +14. **Batch Size Schedule** — Ternary PR #1184 +15. **Auxiliary CE / Deep Supervision** — Ternary PR #1184 +16. **Phased LoRA TTT + Global SGD** — PRs #1530, #1610, #1626 +17. **LQER Asymmetric Quantization** — PR #1851 +18. **Value Residual Mixing** — msisovic (2026-03-31), SOTA (2026-04-09) +19. **Warmup State Reset** — msisovic (2026-03-31) diff --git a/records/track_10min_16mb/2026-04-30_PiyushDatta_SP8192_DepthRecur_PolarNS_LoRATTT/requirements.txt b/records/track_10min_16mb/2026-04-30_PiyushDatta_SP8192_DepthRecur_PolarNS_LoRATTT/requirements.txt new file mode 100644 index 0000000000..11b4bdceee --- /dev/null +++ b/records/track_10min_16mb/2026-04-30_PiyushDatta_SP8192_DepthRecur_PolarNS_LoRATTT/requirements.txt @@ -0,0 +1,3 @@ +brotli +numpy +sentencepiece diff --git a/records/track_10min_16mb/2026-04-30_PiyushDatta_SP8192_DepthRecur_PolarNS_LoRATTT/run_final_submission.sh b/records/track_10min_16mb/2026-04-30_PiyushDatta_SP8192_DepthRecur_PolarNS_LoRATTT/run_final_submission.sh new file mode 100755 index 0000000000..dd19ed5cd0 --- /dev/null +++ b/records/track_10min_16mb/2026-04-30_PiyushDatta_SP8192_DepthRecur_PolarNS_LoRATTT/run_final_submission.sh @@ -0,0 +1,144 @@ +#!/bin/bash +# Usage: +# bash run_final_submission.sh # auto-detect GPUs, 3 seeds +# bash run_final_submission.sh --nproc 4 # force 4 GPUs +# bash run_final_submission.sh --seeds 42 # single seed test +# bash run_final_submission.sh --seeds 42 314 # two seeds +# +# Ctrl-C stops cleanly. Completed seed logs are preserved. + +DATA_DIR="${DATA_DIR:-./data/}" +NPROC="" +SEEDS=(42 314 999) + +while [[ $# -gt 0 ]]; do + case "$1" in + --nproc) NPROC="$2"; shift 2 ;; + --data-dir) DATA_DIR="$2"; shift 2 ;; + --seeds) shift; SEEDS=(); while [[ $# -gt 0 && "$1" != --* ]]; do SEEDS+=("$1"); shift; done ;; + *) echo "Unknown arg: $1"; exit 1 ;; + esac +done + +if [ -z "$NPROC" ]; then + echo "ERROR: --nproc is required" + echo "Usage: $0 --nproc N [--seeds 42 314 999] [--data-dir DIR]" + exit 1 +fi + +SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" +REPO_ROOT="$(cd "$SCRIPT_DIR/../../../" && pwd)" +cd "$REPO_ROOT" + +echo "GPUs: $NPROC | Seeds: ${SEEDS[*]} | Data: $DATA_DIR" + +# Check data +if [ ! -d "$DATA_DIR/datasets/fineweb10B_sp8192" ]; then + echo "Downloading SP8192 dataset..." + python3 data/cached_challenge_fineweb.py --variant sp8192 --train-shards 128 +fi + +mkdir -p "$SCRIPT_DIR/logs" +COMPLETED_SEEDS=() + +cleanup() { + echo "" + echo "Interrupted. Completed seeds: ${COMPLETED_SEEDS[*]:-none}" + [ ${#COMPLETED_SEEDS[@]} -gt 0 ] && parse_results + exit 130 +} +trap cleanup INT TERM + +parse_results() { + python3 << PYEOF +import json, re, os +seeds = [${SEEDS[*]}] +script_dir = "$SCRIPT_DIR" +results = {} +sliding_vals, ttt_vals = [], [] +for seed in seeds: + lp = os.path.join(script_dir, "logs", f"seed_{seed}.log") + if not os.path.exists(lp): continue + with open(lp) as f: text = f.read() + r = {"seed": seed} + for pat, key in [ + (r"stopping_early.*step:\s*(\d+)/", "training_steps"), + (r"swa:applying SWA weights \((\d+) checkpoints\)", "swa_checkpoints"), + ]: + m = re.search(pat, text) + if m: r[key] = int(m.group(1)) + for pat, key in [ + (r"pre-quantization post-ema val_loss:([\d.]+) val_bpb:([\d.]+)", "pre_quant_val_bpb"), + (r"quantized_sliding_window val_loss:([\d.]+) val_bpb:([\d.]+)", "sliding_val_bpb"), + (r"quantized_ttt_phased val_loss:([\d.]+) val_bpb:([\d.]+)", "ttt_val_bpb"), + ]: + m = re.search(pat, text) + if m: r[key] = float(m.group(2)) + m = re.search(r"Total submission size quantized\+\w+:\s*(\d+)\s*bytes", text) + if m: r["artifact_bytes"] = int(m.group(1)) + m = re.search(r"stopping_early.*train_time:\s*(\d+)ms", text) + if m: r["train_time_s"] = int(m.group(1)) / 1000 + m = re.search(r"TOTAL_EVAL_TIME:\s*([\d.]+)s", text) + if m: r["eval_time_s"] = float(m.group(1)) + if "sliding_val_bpb" in r: sliding_vals.append(r["sliding_val_bpb"]) + if "ttt_val_bpb" in r: ttt_vals.append(r["ttt_val_bpb"]) + results[str(seed)] = r + +print(f"\n{'Seed':>6} | {'Sliding':>10} | {'TTT':>10} | {'Size':>12} | {'Steps':>6} | {'Train':>7} | {'Eval':>7}") +print("-" * 80) +for seed in seeds: + r = results.get(str(seed), {}) + ts = f"{r['train_time_s']:.0f}s" if 'train_time_s' in r else 'N/A' + es = f"{r['eval_time_s']:.0f}s" if 'eval_time_s' in r else 'N/A' + print(f"{seed:>6} | {r.get('sliding_val_bpb', 'N/A'):>10} | {r.get('ttt_val_bpb', 'N/A'):>10} | {r.get('artifact_bytes', 'N/A'):>12} | {r.get('training_steps', 'N/A'):>6} | {ts:>7} | {es:>7}") + +for label, vals in [("Sliding", sliding_vals), ("TTT", ttt_vals)]: + if vals: + m = sum(vals)/len(vals) + s = (sum((x-m)**2 for x in vals)/len(vals))**.5 + print(f"\n{label} mean: {m:.6f} (std: {s:.6f})") + +final_vals = ttt_vals if ttt_vals else sliding_vals +final_key = "ttt_val_bpb" if ttt_vals else "sliding_val_bpb" +sub_path = os.path.join(script_dir, "submission.json") +with open(sub_path) as f: sub = json.load(f) +if final_vals: + sub["val_bpb"] = round(sum(final_vals)/len(final_vals), 5) + sub["val_bpb_std"] = round((sum((x-sub["val_bpb"])**2 for x in final_vals)/len(final_vals))**.5, 6) +for seed in seeds: + r = results.get(str(seed), {}) + sub["seed_results"][str(seed)] = {k: r.get(k) for k in ["sliding_val_bpb","ttt_val_bpb","artifact_bytes","training_steps","pre_quant_val_bpb"]} +with open(sub_path, "w") as f: json.dump(sub, f, indent=2) +print(f"\nUpdated {sub_path}") +if final_vals: print(f"val_bpb: {sub['val_bpb']}") +PYEOF +} + +for SEED in "${SEEDS[@]}"; do + LOG="$SCRIPT_DIR/logs/seed_${SEED}.log" + echo "" + echo "=== SEED=$SEED starting at $(date) ===" + + timeout 1200 bash -c "SEED=$SEED DATA_DIR=\"$DATA_DIR\" \ + python -m torch.distributed.run --standalone --nproc_per_node=$NPROC \ + \"$SCRIPT_DIR/train_gpt.py\" 2>&1" | tee "$LOG" + RC=${PIPESTATUS[0]} + + if [ $RC -eq 124 ]; then + echo "KILLED: seed $SEED exceeded 20 min total wall clock" + elif [ $RC -eq 0 ]; then + COMPLETED_SEEDS+=("$SEED") + else + echo "WARNING: seed $SEED failed (exit $RC)" + fi + + grep -E "stopping_early|TOTAL_EVAL_TIME|quantized_sliding|quantized_ttt_phased|Total submission" "$LOG" 2>/dev/null | tail -5 + echo "" +done + +echo "=== ALL DONE ===" +parse_results +echo "" +echo "git add records/track_10min_16mb/2026-04-30_PiyushDatta_SP8192_DepthRecur_PolarNS_LoRATTT/" +echo "git commit -m 'Final submission: SP8192 DepthRecur PolarNS LoRATTT'" +echo "git push" diff --git a/records/track_10min_16mb/2026-04-30_PiyushDatta_SP8192_DepthRecur_PolarNS_LoRATTT/runpod_8xh100_parameter_golf_setup.md b/records/track_10min_16mb/2026-04-30_PiyushDatta_SP8192_DepthRecur_PolarNS_LoRATTT/runpod_8xh100_parameter_golf_setup.md new file mode 100644 index 0000000000..534ec2ebd9 --- /dev/null +++ b/records/track_10min_16mb/2026-04-30_PiyushDatta_SP8192_DepthRecur_PolarNS_LoRATTT/runpod_8xh100_parameter_golf_setup.md @@ -0,0 +1,161 @@ +# RunPod 8×H100 setup for `parameter-golf-fork` + +This is a clean setup/runbook for a **fresh RunPod 8×H100 machine**: + +- do **not** rely on the small root filesystem for datasets/logs +- keep the **repo + dataset/cache/logs on `/dev/shm`** for speed and to avoid quota/root-disk issues +- keep the **uv environment on `/workspace`**, not `/dev/shm`, so PyTorch shared libraries load correctly +- disable Hugging Face Xet downloads with `HF_HUB_DISABLE_XET=1` + +--- + +## 1) Check the machine and storage + +``` +nvidia-smi && df -h +``` + +## 2) Clone the repo into `/dev/shm` + +```bash +cd /dev/shm +git clone https://github.com/PiyushDatta/parameter-golf-fork.git +cd /dev/shm/parameter-golf-fork +``` + +--- + +## 3) Create a uv environment on `/workspace` + +``` +export UV_LINK_MODE=copy +export UV_CACHE_DIR=/dev/shm/uv-cache +mkdir -p /dev/shm/uv-cache +uv venv /workspace/uv-envs/parameter-golf +source /workspace/uv-envs/parameter-golf/bin/activate +cd /dev/shm/parameter-golf-fork +uv sync --active +uv sync --active --reinstall-package torch +``` + +## 4) Verify PyTorch if needed + +``` +/workspace/uv-envs/parameter-golf/bin/python -c "import torch; print(torch.__file__); print(torch.__version__); print(torch.version.cuda); print(torch.cuda.is_available()); print(torch.cuda.device_count()); import torch.distributed.run; print('ok')" +``` + +--- + +## 5) Configure Hugging Face cache + temp dirs in `/dev/shm` + +``` +export HF_HUB_DISABLE_XET=1 +export HF_HOME=/dev/shm/hf-cache +export HUGGINGFACE_HUB_CACHE=/dev/shm/hf-cache/hub +export HF_DATASETS_CACHE=/dev/shm/hf-cache/datasets +export TMPDIR=/dev/shm +mkdir -p /dev/shm/hf-cache/hub +mkdir -p /dev/shm/hf-cache/datasets +mkdir -p /dev/shm/pg-logs +``` + +--- + +## 6) Download the dataset + +``` +cd /dev/shm/parameter-golf-fork +MATCHED_FINEWEB_REPO_ID=kevclark/parameter-golf python3 data/cached_challenge_fineweb.py --variant sp8192 --train-shards 128 +ls -lah data/datasets/fineweb10B_sp8192 | head +ls -lah data/tokenizers | grep 8192 +``` + +You should see files like: + +- `data/datasets/fineweb10B_sp8192/fineweb_train_000000.bin` +- `data/tokenizers/fineweb_8192_bpe.model` + +--- + +## 7) Launch training + +Use the **venv Python** explicitly. + +``` +cd /dev/shm/parameter-golf-fork + +mkdir -p /workspace/pg-tmp +mkdir -p /workspace/torchinductor-cache +mkdir -p /workspace/triton-cache + +export HF_HUB_DISABLE_XET=1 +export HF_HOME=/dev/shm/hf-cache +export HUGGINGFACE_HUB_CACHE=/dev/shm/hf-cache/hub +export HF_DATASETS_CACHE=/dev/shm/hf-cache/datasets +export TMPDIR=/workspace/pg-tmp +export TORCHINDUCTOR_CACHE_DIR=/workspace/torchinductor-cache +export TRITON_CACHE_DIR=/workspace/triton-cache + +rm -rf /dev/shm/torchinductor_root /dev/shm/triton + +bash records/track_10min_16mb/2026-04-30_PiyushDatta_SP8192_DepthRecur_PolarNS_LoRATTT/run_final_submission.sh --nproc 8 +``` + +The wrapper already saves the full console stream for each seed under: + +```text +records/track_10min_16mb/2026-04-30_PiyushDatta_SP8192_DepthRecur_PolarNS_LoRATTT/logs/seed_.log +``` + +The underlying training script also writes its own per-run log under `./logs/` and prints the exact path near the top as: + +```text +logfile: logs/.txt +``` + +--- + +## 8) View logs + +`git add -f records/track_10min_16mb/2026-04-30_PiyushDatta_SP8192_DepthRecur_PolarNS_LoRATTT/logs/` + +`git diff --staged` + +## 9) (Optional) Save logs/results somewhere persistent + +Because `/dev/shm` is temporary, copy logs back out after the run. + +This wrapper normally runs a seed sweep, so treat the per-seed logs as the primary artifacts. + +The safest pattern is: + +1. Copy every wrapper per-seed console log from `records/.../logs/seed_.log` +2. Optionally copy the underlying per-run script logs from `./logs/.txt` +3. Copy each seed log into the record folder as `train_seed.log` +4. Pick one canonical seed log and copy it to `train.log` + +```bash +mkdir -p /workspace/pg-results + +RECORD_DIR=/dev/shm/parameter-golf-fork/records/track_10min_16mb/2026-04-30_PiyushDatta_SP8192_DepthRecur_PolarNS_LoRATTT +SEED_LOG_DIR="$RECORD_DIR/logs" +RUN_LOG_DIR=/dev/shm/parameter-golf-fork/logs +SEEDS=(42 314 999) +CANONICAL_SEED=42 + +# Copy the wrapper's per-seed logs to persistent storage. +cp "$SEED_LOG_DIR"/seed_*.log /workspace/pg-results/ 2>/dev/null || true + +# The training script also prints "logfile: logs/.txt". Copy those too if needed. +cp "$RUN_LOG_DIR"/*.txt /workspace/pg-results/ 2>/dev/null || true + +# Update the record directory seed logs. +for SEED in "${SEEDS[@]}"; do + cp "$SEED_LOG_DIR/seed_${SEED}.log" "$RECORD_DIR/train_seed${SEED}.log" +done + +# Pick one canonical seed log for train.log. +cp "$SEED_LOG_DIR/seed_${CANONICAL_SEED}.log" "$RECORD_DIR/train.log" +``` + +If the run writes any result files or artifacts, copy those too. diff --git a/records/track_10min_16mb/2026-04-30_PiyushDatta_SP8192_DepthRecur_PolarNS_LoRATTT/submission.json b/records/track_10min_16mb/2026-04-30_PiyushDatta_SP8192_DepthRecur_PolarNS_LoRATTT/submission.json new file mode 100644 index 0000000000..095530bcae --- /dev/null +++ b/records/track_10min_16mb/2026-04-30_PiyushDatta_SP8192_DepthRecur_PolarNS_LoRATTT/submission.json @@ -0,0 +1,33 @@ +{ + "author": "Piyush Datta", + "github_id": "PiyushDatta", + "name": "SP8192 Depth Recurrence + Polar Express NS + Phased LoRA TTT", + "date": "2026-04-30", + "track": "10min_16mb", + "val_bpb": 1.09085, + "val_bpb_std": 0.001441, + "seeds": [42, 314, 999], + "seed_results": { + "42": {"val_bpb": 1.08935, "artifact_bytes": 15999684}, + "314": {"val_bpb": 1.09035, "artifact_bytes": 15997730}, + "999": {"val_bpb": 1.09286, "artifact_bytes": 15998747} + }, + "hardware": "8xH100 80GB SXM", + "pytorch_version": "2.7", + "technique_summary": "SP8192 + 11L MLP4x + Depth Recurrence (L3-5, NUM_LOOPS=1) + Parallel Residuals + Polar Express NS + MuonEq-R + SDClip GPTQ (int6+int8 embed) + Brotli + SWA (SWA_EVERY=1, warmdown=0.72) + Half-Batch (393K) + Phased LoRA TTT (rank 96, warm-start A)", + "compliance": { + "train_under_600s": true, + "artifact_under_16mb": true, + "eval_under_600s": null, + "score_first_ttt": true, + "three_seeds": true + }, + "attribution": { + "sp8192_gptq_sdclip": "@clarkkev (PR #1394)", + "depth_recurrence": "@dexhunter (PR #1331, #1437)", + "parallel_residuals": "@Robby955 (PR #1412), @msisovic (PR #1204)", + "legal_ttt_framework": "@abaybektursun (PR #549), @dexhunter (PR #1413)", + "polar_express_ns": "custom implementation (arxiv 2505.16932)", + "phased_lora_ttt": "@dexhunter (PR #1626), @romeerp (PR #1610)" + } +} diff --git a/records/track_10min_16mb/2026-04-30_PiyushDatta_SP8192_DepthRecur_PolarNS_LoRATTT/train_gpt.py b/records/track_10min_16mb/2026-04-30_PiyushDatta_SP8192_DepthRecur_PolarNS_LoRATTT/train_gpt.py new file mode 100644 index 0000000000..17874a02ae --- /dev/null +++ b/records/track_10min_16mb/2026-04-30_PiyushDatta_SP8192_DepthRecur_PolarNS_LoRATTT/train_gpt.py @@ -0,0 +1 @@ +import lzma,base64;exec(lzma.decompress(base64.b85decode("{Wp48S^xk9=GL@E0stWa761SMbT8$j;gok~)Km^%>Ct#^1&UXf0&BM;k&P^YAa9-nlBV-pJDJex~m%vxD%Uy-?u8)v2#ybHiu>i&j(mlk99Z|v7$H7`S=l1iBJ!x4lZt$2f4W6TG?wsa}ADZf-8`+#EkRMZx0UL`Nl-2ZF6jajd@sd_A`dcEQ?g5NDGkmGSb$V$b0;4xR(+$WD{_E`4jnBS)4n?%hFf_?xYI7Ryk=xHTrZfsU`^JY5X&GiuB75L7Z`4-zl`i^ytkeA=FwrP%cmN&KoRC@?(k-q;4jQWUZ|`)txSnP`H((oGbkI>25!0$rTwqCr8Kz-1TBCi%l6oZrn*q+oN*Dn!XBbNDZaNbmf&o_6Ftu#I|ark>R_z=xc;ZZ{lECF|rxnX#$eT^2sp?-})0bIW4HH*NmCOl{z`ll^?iwt|!>!W7d%$FojQ8cVztSNg{$8+f)MF~KBaQ}~KKXW6-g+IWsT?SGYHb;Tx*a1xMx0XlqRBiGI{l8-cF?33((yx@KgEbaPk0hU*(X*JK>isqD_7cnt>%?vfcTAGTgbP1KQ(1y09r#UY$z=T9wxI*3GTzQ*s*UxDG0-j`=W&XUs4|n(X&Y~I-IXm<}u#EmYsZ=<)!2Q6DxQ`NHNe@!6`3LsMiyCkM@a_FsT|r(AUT7~4WhWF==B7goO;Qj%*dzbMf)Xr2NInDO%H5?nM*(!#O*RA;mQO$OYNur$PeF>+e3pm?7MC;ZjXH~(4&{HXL3b049ZtD(l=&w)7uqr+ss^mM`DCy;tZFA;tsxl^Ofv}i?sKlS8rg(bxysF^e8b^Nd3p^T_o*oCt00LTSTDE+=;r5B(%r)rWB1>a5_p4)a~cx|J0Z`6Z3=*Tp7m*taF@)4?N$~Hpc&X+Bzw{ony4g?{Waf=k`~Iznm9v*@(nh(*0yznh6Q(nki6MzOhI<`H3=kr>!O>;1*C4F%y^d-WXxicJ#4fa``E`|)n|QMCQxUlyj|zR=pNu5xGKoXHhD8B(J~W-9*DR4(;%fYodK^sRDf@Z96~8lz^Bymj5(q%6NI!A>pXRR4Z#8~Veo`0*SzyfVNK5vb$yA(p}j5FKOIYFnc%w?!Xu;Rz^+qDW6wv?fH4uxFT$J_Af#Zo=vxmesy5R)?I!0I(M%+ea-#bPsN+b)OZRS93s#b;sAnZDF^B$nleYkq>lUB;KJ3C9r1>taW`q9QMXdI1W)20#X*R8431pYWh6{LH<3JZ<1)ZKUMJG?!tkcn7@Fq5F%0oxAA~p+0O^uslV*Ti+SkMQcSA(2ffgDP}9HJlaRzVz`xbt-==9Uhnz#2BZeLHF58OOLqN6DvyEMO_V?RWR*>jq)r%cc#r3m--?gpq=-{*8>4(2YWHT*K3*^a=Bhy&gR|r9lB#F>ogtKPi4I02`Uz3&Es;%G{$%Q{cZhM=BtPcwN(a`9t>m<^|$34Hyj853nc1oVzP&caMVM2-5AZCsL}Q_XFcufj@VJQ?(}8L)E?-svwm5xQ^sjez{fw@%q=)c3O)nZyS5wpR4_jq#ekl_{d_KhrJ7={UksuV&Lw+J8gRK>dyWqBBnMz_5r*4cZCI&a7Zzc;LB@e&l>lNfp3rIVXJ2hPfY+I0-9>O1QZErR&L*Hi!N{$?^Sw~#eCv|FiAw=>q~*Xs>vlCDtHIdG#_d}!;cd^&u!zOu$%Z5X}rBA5Vr&ll{&?)Q5L`W954en5%e=jzsdBy)NXzh)|qfO0^AM}Y9-HK>aJa__YDOCjM5~ldol}}vZy->Y7i+Y|^(0};S+C)4j4UkHm&NMzGm=pDD3Jq#9*%qctwTAn1%tISt@ugdiwhV~Mma_$@4YPjeX|9h)^;{9F6K|w1H2+O=E=haXH>Rh;EVCazBCcir7akc=FhNsSQv0nLyk(v@vGG}HTM^bMsy)23yow!htPexb;tNt(4xMiFx-MW_G~C%(cr&N0fbvaB?en=Ur04DbE#xgI(m{%sVMjMR$M3^7P{iW{h({_&DF;G-~gymhl*p>j8({jB>0C0wLL8ZUv;@{K0byIh0!;VXEr+=G!8`D1TX+WI>1XBZkK&!`g5!W#mZXUW*zst;fcV^-?K>D{JEyoGG=r{zs6+|2_ohOk#5Bi_ySbkLZL`NyBsu3)Y1z!p>Zz@Ls5Iefk4P#NP`1o4-1iSt`5qJN}5YsDfyLx=APL(9Y@`t1*C|ARTruIC#;-EsihA%sd)05fBX653HAQg79A8AE$Q@d(2!r>c-oiir(w8TDd_hE<2uRnVMD5nIzMlc2HTk&Mdae^-k85l`acgl}<^Le&ib@Zwpao&HI3X&_SfI40HmN_d4gZDg#o)mKuPiYa+m+a=@HEMe?nEt|y`yUU@(wwvzS`a(g-fvIONu-Zw&+KHBOV?Iv@VT9QZ@=U7)?=KUEpedph)1c@Q%WVk+SkL3nqWyECX`Wt?0aWrYVt2j^4?cx=QaH84TOYvmzo_B{1+F*4#c)LI{{Qay};xh_9LhY{Sj^S=`@WuDe@tmaNryK9CJ-7u7e9gFEzby4mfuu5O5v;}%%2z(#@zNJaymMI@7CO_W;uyP42O{=N2{Hb*ss$KBmiNPusq{ujj3k-FUNahfYJ#mG8^z4DdpFP7ykkuC_sE@qJfMrgeu`25L{2Hpt7b{+h*K_lKUJ`yF*xU{g6KcK$2JroNj?b+D~fr{%{tQQ-ej5NFf*??_f0X2H7lwY$ESsje6JdXG*f2m0T6zjBoIFD5G$h0t%djH4I20p0HIwC^&@bipI}YG-ZmX>1p_!e-Kdvu9jjvR7qdY7*z*iSx=+Karm~t+qlYXwTfDMLXe0nI2ixXYs2!0@D$xi>3sSlu%Im4KZV1_@zQ4bQEnzYj#(d|DpUDFyCnbtGc746V8>3`9pnUe&m%CsNhH}j8xCPvlt!)&M){e=E?%4DWU$DLoe0p-*`Ga*Tyc+8&p!&BRN01_?Hc2EU@B?hBin%>stxnueQs1fTvkdfS=Xy%g2DzfqNnC#}0oD77Lf&N}UV*zQsoUJ87mmnNqWH4CSpS{~3mOM+qE6vS1-_mMc5UI0k(RK>0VsJX(ZHSanJl4|)R`U~yDtt}*9hs~MPVqqzvL@BGc=_}6`;sVzt3M$Ng*3G9DlZ!TkP%ups)GE6*4}&iXv7M$lS-SZ(6=ZUXH#r-z1vsgenpqPE)#IsVy-&OF4%Gg;0Se6SZ-$&V8Y00EOzf3RB-envlgVeLBv!SA(R`2J2P!Ko)m!#_3s0x0@#M;F3NAg7s8TB11pI&Ka3t$^vwLS})fgzb$m%tj8$`3*=rJG@{%IJ@sMASSnqiM&eHHV^!(R@a3$~V<@tqN@tK66A5wAsC)X4>oerM(9`=|7R+)5wgbwvXl9?>3BZPA+d5;2ncn>k90Y6J+Od$hQ%yT3cHMa|2=&(dlW8+_xDD#kW$HumBA*PreuXCOI|hRnLa&EPrCJ-fVHaD50vxQLg{IjJzhBIyq?Pj(9wIa#OK$MD)iThqP=>q2=zB5y1MGgP{$j3Np5zD7SU0AihUs5KM4c~&i4U6P_*DR%e6mCk9Rup)gC6G3ZE(PO(;a~;hGQ)wT$WgWZbE|x|@CW2`46G2>9QNF%2IQ@3#8iw^$FWeV3aO%doxS&EcX%OF~P`L!cUibujmXK($(D+rH;K;*oi{BOL;yb~9xKWC|DZ=eJcuLtPDj@EY-HnF&->^u59JJ9slUZy+jMy-+Qt63Nu6hkL&N8gw3o45^)oq+EV23|f5j{;^u&TEh&+Hv1^*E#U!FZrTK?Y-KS@Xa)80g?f&|xPTJUtO;ye`~n%)kFldrx3#$>)i2a#5UVA?zGwnpzAMX<3Yj?Y01a+F8`{950Oq6=+BBJBnAivY-rjwlkJ~8Ttu4RB8XL3mP`(%A>#P2rb8fvTP3I{&tKdH&LbC9B~~LT1cq-@$Oj7EwaZQ&TN<>?D%_st(*qCXqm3GT1bmR2}OuiS5QLS=8*+J@~vsgcIOyu#Rkhoxk03;5;E}|H_wTBPcyC6b6|zThpYaeNUS^>XS2Xuq^x`$9oIW2BrAH__iO~EoY#Y-5^9~)Rzu~f^$r22fm@t_GS7B3JVgDJ+bZNzIj&2EKe}lmdJjtQD+C!;e${TJ8D4jeR$4iU90&HkKqqEK_OU?c%E~~Nd{B#9Ks12xc4o{-EcSN60=mHgU(3xc}Qd>?JJsemoXEcdI|u*PjM0JJr&Uz3s{yev&Wi-@*g`*xAVlXaXS+jxs}EN*wzoZRk@&z`sYA)XwFbyrf_uXUK&eT3DqeM~BdavNM>Tw+1M{jEdSRLUr0;1Z3=D|TWF1Kt(Mv@$nhSc81!p1PK$0G@g=Pq!rk2{3nVvzz&_s?E%|w%r~_gx7o{Xt^ARU;WPGi&cpE;8~c&y4lmbtls75Q>~xNaf28(^=EI#1cIBG`B?z#b1I35MAHgT30ibnsmVIQJ_~8tKuVq|1sbl|LmGn>oMNz-52=YGw^=`s5S-{OojMD=VZLU5I9lLbXp+5kTPfC7zkuU?3AaJR-c_1>+mdX3%Hi&}RvfPz^gHz?&USI&u`2R|)XFFI7A9(W#wZLbOiZwi_#?wHF4c_bU9qeoXi>gIK6SPb*hNiH52%R=ilgxN_?RxynQP{Vr)FFo9yMQx&`q_MKki4~y-GE3sCfyF?w&$%YS*9CFat*p(caYixZX#*$Zt?RQSz;I;OA}iscpR5y)D3I!zV4K)A{kn3w1PnR`Qthvuc*e8|N*@a#D=eK)u}2f%ukfo0NAZF>ylggGHaW>+~8=hS-uoL3=LTZbQL42N459RWK1AAvbKmBke&B4NUpPn;bk3b$&jKvsykzU9#C0;@E3}GbB&ZOe>|9OX+Muc_iqjl_0k{$)QT{IaE#&gnT0kHYhdIE0F=>YtWP7*ewK(j5f!+c$$8F%@pLTdfzs#b@+ATGuLgOJ|DjgcH^@y80=e3qQlR_)gXK<|$M1gKGav^F9`)4y(ZBkv=Xsq(X5G5elEf#93RvN!owwZ<)Ky-woh%dW(wTm-@$bWXVCpf@?oLCo~?qKdRdm6u+2h09KmZL`K1JOPrRg<&(g&6&DeKI6M{n$`3^q4h+SPLxU&Md^LHvOQO~<4VR$iOkO?8ZlTxkI@+rBz9cI!4hSfy;xv4Dbz;dL$~R9v>nIQgSai!2=i(Q`oQHo?`&78b`@rh~FK8x94`9u7W6lX~|1@-=0XBN_w?oYKYR)I$5eyLidL3+)O*JVRW_a?s=JnucRk{D6(qScLgGqlei1&b__(@wzpB&L!1@GuY0}hk*x#51Q10%?kt^n+AZWY-W|Uv`&yOjyJfm&2?iHH#~@Juyw@)166~ZoAWY{(Elqv0JrH_V3~HZNk9!g?KxcGNvHR@&gk`~S)u;~}T#+tUW-^c%_x3j-K^!xDB5@}YU51m;CnYUXa&6jFZF_#>;&+7~z(M*vxC-8U^|Y-_qG0#bA#l$r2n2UywzOfMpw<0JdeUQYK?7T%A)N1d;rKVOJozkeK(2746`72a_FxEARSSaQtQS^7Hf@y98$~es)Gt!UIi1jHT^`=_4rPQu^JT!GYI7;R*~m8x5=(H)_OLMbm{xnMC7#@;5d5)q`)QgZe6&@CkQ!f8JHf;i!n6QZDGg<7NowpiiGBacb|4Ha3YZtY$(RN{A?M4qP?3@4trM(=z_>ph+pnfkF7$>7)3&Y_ms~-ZEtN?QOcPckpewYNm@kbcRw9SAwxH2?)2Kc9PR*dz_@jeothE%^|wgf)YDrG%>_ysBB?mAX>izNXV#$C|s4qOqMl`$u==6p>h4Z!$#XExZhAfqqI*%T13)I4mk8*W;1{)Xq(%mNK2s@r03E1+I0e5AC7c0|4=-b5xDyX>Wm1P|wA_LLOR3!|{oK{%@N=tDZ8s}MGO;Pw>b==us~u=cvxlh(^Km%*GD8jC7w5$vBHof*qKU34E-mrj;7lp?;`8qvF6u#D7oaGhsQ>K%B9rUITn+L%jm1y;5C_GT_o6cIy7PEfq$APmorZS$uRm@$P8ed{KBAK-|bN38#;8=9~T8GwC!UDdPH4)uCu~oA62BG{mw{S-)J5MUY>Y^GL7e=FZ`V#o;p2;3J+tMyPRVsi62pirmf0zW@uG9$g-POmc1CZWe~jjRr{$6+6&F=Nxa1A`wdY@GU&}bCjnb;-QWBXCQCwplQ6~o_ykYkd8T1QL4nL`=mDSdQM~`RbCG@6wnvBXd9nSTC00a%>YEG(MzXPb09E^u+fu>glPTt*o>(0cej?Ho*FA0xdeZu&lzWfY=}nb7miBpG@O>_H)~r$9bOnF$#ezkz}Li}Q)j|q5b1t^qnemRb~73n9pA?#mzn1Ccg-k*KJ;pc=LYZ}a9uBupXNk;(yoZJ-7a4P`#h|Srz1Oeil*dUDQheUlCB6x3nl~}L79&+J|Ba+P4gA;7U$r(W}C+^1U5ksp6j!|voLFgPj7{{;PAn5BJnV1T2cgA$bM#0OhDZa=bn7kC0e~z@Qj;hHE(Za^AS|J8^T*WZl#v)PH#su@-Jq0yx*T;u0YeU*c#hIh+qqek9Hwp^+WE{VCnM3nbkPX6LY@>9$=x$dVmw4u)RL5@>;mcib2UV+e+9A8u5kRP(d_bBC=+!mWiCAtN#*@Krv2Y0JSU!xsvz~goCV`rcam9;Lhj?VgAN!PQD^cK$oz^eFC#tK*D0dF#ZqfVJ3R~$Mj}Y#ygm*DoQk{$pA?sHdw*i*4aA)$$TXnX>I04HMDgF|iAy>MMrF0#))rASFXSP|U`YU!<)hrs>1q5}1K%<^e4L<_TM3{2anO=bi1KjggLX6fEA-g{;>N}tl>F&=G_NJW%K$R-qS|>uFB^)<-VDbQ^&f8k+Ka(MRT}U3~EhNpczB44^FWF|~`gzH*&4HL&D8rxpu4)h$6q2KE;YV&5SH0LWv1SD~#5y~QrDWN>r;$9cvGP^9Uo{wSYeClEe`*k-xLAXre>RgPH{u`Q_oE!h(3lbXk&p{%mD$VH?;{(JGZqFSML+lE`7z?56>&8!y7Wk@SC+gkIc-Gbbx>ag^Wj0s12(tD_slC1_6)&Qdk?N%9ipo?T(DVk>;Hc@L)h7;Sfj@Rv_*n5i2M@nNYS+JW+=k^&?m`BpHit-;ZuQ_FB#r)KhO;~Z(9ywDGAEv|Ju6`>EdM;H>3N4ajegwH$%b%p%Xq!cgok2?Qg`ehA;Z4YCWbU0Aai@)<<6mMDjp*xfsC4@d$R4S(lGq#hS_7o`&|oFSbwATSczW16+2PYO&jX1d6|K?J)T%~)JOF?1Zuq%_WVA>;BvQO=?@35vEHEHQ9)6fpNUTQBEO!WFJ37af%GYo3M>E5)8*Dm5UqPr4?g2ca#WRf>~gme0gfknjmSBvwy}yQ2SP{w%$=8GOb?0qXi`tJWwmKGT^g^)?GN$48F4{)3x$)b1!v1iOHbWNBoC_?$CJ>NdKJzfMF_X3g8xG?~bSYmzYTVDaLVS`?H%^;@#dk$9VLMjt(gNGFp(tEWLk(hvQn3r{s#(*TeFi{FfarNlzS>3f`&i3O02|HW|`g`O&ICHpoau-^t?@_?bzjMLjQ6QOtd!_PaU5Mm#;XqK=RP2qe2b@a@Ih!S0be+5&oWuE_4DJN;FTa0uPLyg?#E|c7CZZ3!KdVX>=qcN0Hu+-!Ru{m%N%X;1Vlg@*%i5SFIIpE#W`#29wcDL57hAlhOhWd-wEP<<96-JDj>Ji9*`S!G|!X3Ml`Hr8H<=+xIExvlol5dkEp^!EsbSBbD)bx4%)z%LnPaAP1xx-hjKXYTfTp$(cCzw-0#`_160fX}orPBd^;UuvdcRBne-vi9{{j@EOf7{os>C~iA+8grv{u*}3A2+Bgy@`#Xeh#{Wm-{J-e@a^`r*87}ytX|j@7=0G|ZmVM2c!*%l4sM{n`wIoygIno}eVBwdlBJinF2#V9##dI_D)XtS^Qy8wrx9Osix4G(*yB3T>I$8E<^>3bQfLK;fl;Q~8Tebi5@l@}2dZNLp%X5Oi#k11TPPJW+5ZHnTVy`GQiDhd=4)*8B9lAynK+;+LZQhgY;;!TRib~ssDK5yuR}P0+?q_V*!x_*8@6w$LNOHg7kuR1=;M`7YYj=AFMGJ;*}7@l~fb#wwAxYZm@;iOg^2nZtXPyRdmMA?z-+4vs=)ooVb|>LrosPD&RK@E$XTH+8vvM9rZy}N42ekuvJ@d?am;*e1fJhqK3Y2+p+ArSI@Ko}@R*vTt@Cd13m|HnL>X|qX)&xiLntdsEAD;;?oiI*S2n5|;nb9)Hu9U&j$&Ap-_)qtYxsOR|Kzp;Y$ivN0t1-c^&i^BUm5dPt|Ra}#TK%tS-*xF6h`?6!z{?o6rVp3FZd)B1gH-&_Yik3fu+sOO!#v%d(!XpSmDeWZiIO>I$b}6xIf1OVPN}fEeAr};y#5k(AJm65CT0E#nI@d7cSqKyuKAzhPCxfxh0G7V1NL0=-&9Wak|-->JxE+w-7q9`6CXso9wZ%+SY|@#u(z=hJD(damfu{L{L@c8))F5B<}kC-8X&$;SfUfBiPuHKxel&_<@xm~~Z*1Q>sV{!E42JYG;W{*!(prOvJF2K9ZlR!rz4w5J<4;Rw9u77YLn<5ll1xDGTlI7{M7dMt0|wr5I+XyNv**7px*Xy;uxK+DXW|arr=)aMA>kHolNx7t09hs3&O=!JwvqVb{IVumGhLGb;iGX*edXO656stl(`nLgmhz*Z-**&pWX1t|N{L{S>C26&CuqnnDEYjIsRiSTy0jn)^Ybej@GoF3kN6cBq0m_M28bC{_3IV;Ns3N|^~KA=+7sENN|Vnjde*4eh%aKoj1#p#rF+>ZnD;GJ*Wju^-{f2Dn@*P(t-wD4=vdIF_-hu4vbHYWVjhGpM7CH=IPpPG!Li&pHkZ;7aS5SrHAX&AYwQagfHA=VnoOc%>z4|W>lj(KJ@q|b9L;V&jtiuD@-te69fI>uww~kzX5M@;N%mkt0Nzhf__{-YO^KVP^e{sDZ^&U@zlM!~ivmYMA&mPXdl|6I#kY05CvkIhC|6wb*TY|H&!|F{kAjS#!J=yHdstNgd^W{Tnx!6LrwlYM$7I^9syu+TbJp+E!k#~f?tcKZL|NMyeT@axpFo&WG8R5&xnhx{b0Q9YCzj6l17@6PD4dRr9sAqRIWag>@hOw&T>F%a8{tz{O;iCc40;XmY!g!lH+Sq}RVpN;bgXq-i}bD1R{`KzAacI}jcS-Ba)&=Iy#f5+tkq+To|;bTIv(q}|Np&mWP_UV2br^qls~5TR+){ymfN?vzlD@#9Ar91moSu3E0R*Ubr@WyLvh)_fO_ofn~~!m(9gVPjSzaMa}Gtnxt!t)%Ml@rbArWZj!^AtM@e0;AeJ;j-ojqkqrK>)vx}VgWH~v6#2H$YF*C8NcmD?YIw|x_S${TXvcMw?|Dh2RL~aJIOmita2*YIMgD5*)mbD?{^QRz&7bCg@f|a?+c!*=!Z9Oi)+}lVeN?jR~eX1^E$k=0XH72&Re^!id&yR>WnmI!MsLAHjOFj=%VeVrwQl~<2#GK!!bc-{mleyiYT623xJy_lduE<=ysE3dhVCSgw-sNfomzapuS|i16y6e;9DAEU!+jk^Gg=x;FmNpm+?_%!R<^1nj;;6dpLW#Gm(&4skUX8?eu=0a+C2Op1qjxLdWkIS%?{lqMr8B!AC5@i%8JhPR{jkHeDVj*qJaiNzv15$l;l1d1eCF3gvIi_N7bEL?<`j&@ZKZf(>Yff_ZFbOLso&CPgc^))yr_>~U!DbxD#NZF3?Ekc!1HHxDjcDavQTvLF55cJVpeA1<{?-@8poCPS4)Y7QlbZ$xb@;qKKPM3RuKt9UK=eUwWstHNe%{l!-@?<+joSIkknwA(Rf=0QG)5{+?$N0>29oXmNZ#3_AL_=?Mh;b(Z|Kx8h$BzIsx3cyG|JYVe&7!3W`dW2pSuG5LUWh|FKJIkz*zU#WOw$-zuJ01a626Vsch`K4p@hG)G1$|;%Em9SmG%B$t5oolKH$KJ~czHmf+i*0aoe59tSfTnOp_Poo@AXDc7StQ2Fnp_`jQ{2aBAd`MtY}B0+$VQQv{j^tA3^gP5g?BfQ?01dcMXs)(k%C3P1cr#F>(H}w!z2R6GH8@z^)r6d>lCOSlM=(IU?8HzBj0YwC6t{6rDSk6~!!t;-B{wW!y?>B;wM$Lap?1=#ld+O8OpZz%&0pEX3;Y=`>cS1J5r7S+!FEqZv0S+3k;O69tH3e%k45@IFOsrme$ktm}>CP+C;E$Kkcw^4elj>JbNmYx)0zee45;LFVN!wiS#Eu#3!~0SF_LbMiOa#!#8X9gV(3;4IcRYyU>CEbK0U+qwP03?%9!@e4r%I}Cbn^9WviP1M9Vcsd|61}QC?l86fI?~wCqym;0E6l1G|dnt$-$;#ebo-Jg+@cZbq#C86|&^-IQ8;K7}tc(x|s?tR9hfLx~`6Lvh(1H3%hFq|IRv@QH~7*1<{?PUGk;&9E_&|dGse7sH@;H(>19BUGDH!CK=@aOfwmMaj4+^u(DNQ;45fB3fxim>5N@04Y4tp#O!phHeZnPjlNaf$WB=6Of=fmbL$c0ciIV@N^_>}Zw)#G-t6&P|fWF(O+YXg3W2q$uCk!5~4HD`TCD+`uceL;G`3j?ZkQNom0^=kLdq&V$S-@GcFz|F6RDZ&v7ov>W%LVN8D)^ahwRzIo?d;)+-_(@%xU=XHq07R^j6H2-4(of|ZGMxNgro+~n?Lz?>obTci}bc7Qcnokjo=^S9Q9qQ!~Wv1bU`-M9o`?IvMPIjZ7+%B&<8AZrS}EC15M*k}`-SGSJNb9XTFU-saMeOhcn6kRe4-b}sTOqrYdT3AeL!jm(6Lxb`Hs%m|nxoZuXaQ0~HadZzU_3NF#UvwP)*`@j8Qap^Kie_s0<5&&t|V_YE~j)2d@3a8We8qQu5`X)dm$nm9CQ&O}TLbC_=o^THg&k%$HMlJNkeTa2XHx?|5<*aW%W2;~>LL|F;vSmpCRB@2VYTLPZVfBi@FlCUOMaQgou2l63P2JoUYQYHg3bwH`sNrTE7m05?)^bdiq)A4#vjIKw>^scKW4K{ihCz34k^b6KE3oL}0~`@)4Sw^I|@J^|m8`d~-S0$FDhKncot;m4uBOFFL_b#n*;UEdeT)c3+M1%oO*4eUa3`hcR;G=G5CPGdc2a`FX+R($jPfTc0eJNxQa<}{051c=#@2QI$f^`IwJ~CmA7c2nOFs=`F}w1ir~Q`pJjB#!l?fcc3;OC=Ga^iEH?hqrzuU+s>9sY~f3BYcR94>fsQ1fwh=e#PMeQn21%@o~h{kK+AvkX`R^11nY%}@)kw1=4CUl@E<(|)zH0SIY(NR&)sX`%|IS2jS$%II$9(JMLp`^^IRfNp56r+;Me&Jhc+EaUfl4LY4JP2c76^1hClnF%vrF1qp&o-C_03nUl%qc47?*)`>-7pUdrt{lyMnn+_5;|0GIGY3!ye{o*JdCRIPCnIN+BQVWTt&AE`t>4DrDSi7pV{qKc+emc&10EB)_O*`&a!pVx;|l{P23#&k6slo_5B15${e=4oJD4G=&fm511Hu8!roz^NM{+y{Yy*y>?QdWuf{(qLWmb3;AR^k%E}|7CKFTM9oAc{Zh>v?uYLy=Cg}g~B$hiKZLQBNVcujJ2XRP`hrCUwWvWu!P@sVNhm?1wWVN!?@((a|qK#ylTzIW?R`Xl;FN*B$ej;o9!W>cZs?F0xc>cFqOKMAni>yj?oI-F|g0-@tJ|k3E{#1U5<}FGmrSDpG$+N;6pMPb)pAOCQ&Fp906*{xuy{g7qNwg2U_88SBx~5fg9gWrCyU00p%DYlW>p>Gn`0;FyPGmktb4|eg5QAj+`@5v%nO?_0|YTaZ_W}A0=mEL%1d}WcvZhO(9G}+gCTVwNs-C-@B8oa>a)y5D9xDVWIX2*-S(EDUhB8SjyHZ_zL^tmr#DaQQUHXG=^cQ~MsF3&)~^Jwst8=&rAHY^NM~^RYn02SNdVe96}M*a!A}>_w5ZpP6lkwx6uRFzUyCU|#68e|vU_amP`w2#s}}VW#>_~cp_lY+38XkmEa`r*D28s>KOg2m{d2!|Ji}lH<;ZC~06y0=%y;{MAQnYz^Dagzs2fB0(I9D{s|nn0y@_tGRsLi{Cr~)+ONCgZj4TJub=6dy)w*q9qEKipOAgL8u`q`7gfd-_dr2N@zG%$%c9>gfGjtZGiLy1dUi3xW|x6E&q-)e{rEJq=g%2A3+hagvfFv)=w!5Rk+F}xK9jx2DXD~n@$}=<@bh5Eh4E{Oil$Daii%=rcgY6AOPf<;g)=ts*&!sa>6CZR+Pha`frk@s)D;Hp)4sMFRN@y_12iZoV#sZtpzbYNu4y3yE^9NYlGf??g`2}asMn~+XXs{2W<+#B(>{ojIw)F#MqVf#{p7x?`m9CH@?Ny2o{@cyIhFI*<0`~`19Oa;@ua9>wGR1_=k{+r5*F-422v(%m4}cJmg78I&(AfzvlLsiVrANM;frk2A{7P7$}&WSj85CIK8Z8Ui*g|If&u)U2aGN5XtLyuM`s3?G3rzu9=n$Zgbh;?x8+DHG344;QwqujIKwh6NNb4p;lYLT?B@Tqz3A$M;nZ?q}S1;LK>OD8`)?T_Ms1|JgU2igd~O_$gn_dRe>1&jS|N^c88JlodK%+u=uq!RJezNgX%(}6|6zTm}W5mO(KfvqhC2WEOz=lGYo=BFcqu1-8;CGeAL4rH5e3d;(&Y0p)oTGjuV?3r1L$V}1&Kd~MXDc$J$nL?4=5%7K)f*p5Y(B@~sF9EZYrCtyG(8N{a4ZJ;?#C%L}4%01;g;G&c`G+&+!*!Cq_B_%E9-*9VMbHz{(vp2LrN=!bDYZG1X#D7Vpznt3kb%}L7>LF5_37~G|+lISp}8uDHEBI3-2`J2{#f?&TXCgGMPGK-4P;WI@6|N!cmYUj97`lNP0V>qyEWUak(~|<;27KZG8GIY`utI(TLS3tU#Wk>sJxdYvg2Prv7saf~UTog=#bsCkrvG^`={X{`U=KqG=3hl0Eli`F&&yMo;{!8Tj&fHzDPbMpIja3)JF~3g)q^*OJ*x<2|4o6v7+AvzaZG@BXjl^;pcEM^#VS+#x{w+2si`HRmI`jePgx+YbTmWi&cC#ARYhnhqcEri8@L{0gmz`KkO6uFb5A40!EP`Ao(jPJX-+91uj6rqX_XD5RL27{<)-KI5OlMCEe=Vo-<;o@MTvs55K9Ngen-eE{7-9ye2=_dmxsOvgkxbfSkL+P0qS9>@^7g$zIgxBE_HTod(Y;JR+o83b0&KhO{~pe}?X>*oK_6r9cHsCUUotp&-UWW&E5JZX*NFJ7Z~B!AO9!>eko&6{p5f>yh^IPS%c`gQFA+gs4byoR6m)7k0T%G#(>Mr)QE3)_xpc|D77ohg|tpwHGan1rdbmtGT-CBdXcH06Ia+YQ;7t6)192ezql!}%Xf?=uBm^)J+(i$f1U3-thBnpqD@@J(E5%v$SBAkyrY=l<5Z5bLYk0y*h=>@){8Mv@|W`H{Zv>$Q0|{dbe9cZ8wv0%|`zRIMxVA~^Xdg)6zps+xeDq>j>wt3*GyLE+)(Z2jFYT}cd!C?;L!?>Y3!MBf<25_EjGi805E=7yY70Gyt*w%0-E(pcr(q{}1klefe=Uz6oyte`;FcOh{&AtcJf5`6_`a%))$0)5*G~6ZD!P^Z4GSj2S$x}Ss4=>v2f_Pj2)^4d=O^q}mw-flC_@`>zziTHq&DZHqu(0fnbx>Bm^2kg3X?yAtjpgSfoR70*84$ru7EufkQ)#5WiuEVACMsnydKW8mle$vQ+@$h;T77&~ioqI=s#48WFt2?ak%sYt|Py-%8knwEDXl6S;mSjL^SDmRq(UXc6q7U?qP<#?R+mB|j9?P9;E=k_sm1g7nv4+I)l@dmQ%W(Td%0lI&FX@p=ja?hqq4OM1VH|YiG8?cS8@fI%aUR5TVf1yX7eM@@Jcm_;0SGVc6qX%5cWh|>tk26*x_Wz+0+o&4`Im$|+H$lbKyxRY7}sRm6Bj7_DVXx!bMFRxb1UTs{=kOoOPw@yjZWZRPOq2{OW`{92X49A5(k4)Lxr+HCMvJ9l;DL|w&LS&rV-@x&t*+Z7{~Vm0kW5wAm3QTHFe1U!yHuK;cM{s03Fiurl8I9+Kk2yyRuVhyt==XNFileT2ri>>fH+fQ2X#Ltk7_s1`5x`2{DCL4I$itr_0`QKNz)TYehZDY(+VIKtHhB^Kz2#9)F)9gs5g&P>&J+k(eZoDYBPAl4>6A?mJNb7d5{}*W&JPU2ceT`ws&rS14W!ZG}8}SlTY#ke9Z>@@L>Rn2|q}td3EcP{`>~gVEE!-cCqb^{r{QiEDCo_)3!p^$z8X&bPka8k`MPzCP88Jsj3LW}kBK>iz1eD2cWbpx3w)&X@Lonzbld7bo=5l_0>b&G7w_u_)(nXZn;7U`D$&EFs4j!+$#6BrD^xG^hLG8+LvrZJPtMGZI4cvQ;tzz0?t7MG#0OQW)6a}bZ5z&G=xGh}}b4Y2Qn4-I$?x+_1VZdfc_WO>i)6BNeq1i-7pB!o;;y<165S|YpZ0*Yfke4`c=6!aM6}ZjqKJB%VCcu35lzCG*B&}|VcjVUzCpQ_foAjbo}7>W-x}=AbOQE!?GUYvIMex?l-iD9z+A6%uTIE2&w07+kFB8-?rls5=;dI%>|C2scZ^B7Y4MngdtfXb9BWinVb&!uhyx-(tvd>Gq6<4JgU?EiJ@Hq0HD5awBZvmQqQJowG|wgLft@Sv0(_e7`U)4<=&nKi+gy(tUNQdO{6b+onnI+X5L7jJC;IWpq~-J0`T4RC@GZO8bp|tr;=+<0hTfgW!z2-Xq7Dw?R#i(wEb_j=MZq0-(p*EtRN*my46FhwuU2AHdhI35YHp9ItF**QR|tDMoMO&?E@N!xpmyV&U{S*d3(~8P#pvcXn*@CV-QOsy+Zk9)073Re|wW8XcJ9>jv_LFo(_J+lXWT+M&NK9Xmt{cVdC6Lk;ST#J_N!&!74)714h(kf^f|ZHn4HeA4IZ!@|QSaCkpXBhk-8a-~7c=Kcp6ZdrJW@YS(d7KM+wbDuwt49s0Qes*H(E1wD!mu>Wox&JXP-5l7;Y2qoll|^SwJqH`pbS41Jpeq!aOLzoxRuJS;!QeFCoK)VtIhV_5*`iEb4#JR*+WjmEjO#_8`2?`t35?DS{c1@D!u-BE(rm$~!{$z*o;b_PL6vAM;Z2UyIxw6yfsEhx;wxta8Hy~^t7&qO`Gq4f-KK=dJY{io`eiF}^lreeT_OT@|ff_oRjB`mk-gMA6mS!DTpl|OaCp_U^hXRs`KkNY9d?Q|Sb85e1Gh9BE@vBHM7qCQ|R9Eol%?#HC|0YIG=5f%*p-=6Fc?h4QcpZBzx1`aGpd$MHbJagcoDf|Kx^wah^#I=X)e^!b%U2Cr_+~xlEuhu=Ee^2id5-Fun*?DHQhFW+v$3;nnu3ty3l&&5=N=j#`sx$4^=(4bSU=C6*EkN}(?M51ug>E2ib~6O{Dlmt%v2BKgF}hx!>B{$VJxjBeCRQi=`Oqg+S&VX+C{}=B#(e$E@C6zg~w0TI0REmyRhYj>)!V0j`;CJ+dle2R>B@*k^T%8p!4pJK-Z*MKrY|_*-S^h~9n!zO+T{QDNN7`^lm6FY!E{o3`yLA}S2m_h*q-TOcE~U)Cr;20DsnM(VV(<;MbWU`|**6C?FGnK0%Cf%(!2>Q>MvCM&XDa3XuZ9f!~RLlUR~o{w1)sKM+{rn6&wq!E_hrL+KTQ0xR|8K3aT2OHqrzQ#D_B*-sf=Vy8(2xSwBQ7HmwYra1XkBQH=}1u8>fX>yF5Y>)I9~%ZD&Y2l*LXeFdTjXf+^+T8iv}d;ArVYvZZv$vYK4+c|Au`DW7S1ji)%E9vY@SD&7if0iHIh|7A}+sevnT(C?nFCa0bqb9xzc*aQ!JQndaM8-+m+I*pUTisF~ZBmIQ0e1VG(f)o#JDYy1nWv6uRO0?^I;rBT6ZcEqY=*3yp^UkZol2)z`VueriM%2ONd`ffM#@)2gH*#kh3(2CugXp*|FR=PBj}_Y#%DtzOHd*`I$8Uz5ctGsYWIgwNO_9+F0#eui(mZ(bhA`zC_vMc2@TF))6w?3du#(Y!BQVvK|R4+XTjmUGF(==2t)l?(oLk5c(j;?061jx3C}-uAodhljiwxC!tzbH?{9jLRVsrd%t)2R%;CXYLGi3c*Gc_}Wl!NoTW&0+jUs8ro@pE)-Nz6kkMW2iqwc{BTDI@4sNT(IENHHw*fET>TfbtWg!n1KE}ljc6EN=cSTEPkTgw2~>~ycMN^8xA%AglZi{i9%8d)#ywD^PT;}(vtV(26aZ-k_s%=Xh9C+fHMKh;Lefp!reeS0|9zjjknMsJiJ)!+66+kLR#af<%Wo$w+lNBOM!^U!CCb1%=oe@0Ns{Q8$E2mFrE`Xu=|jgfA@>p37wU8nom@D2VmP@hzQ5RZ*>~=M0=F>Ca%#e4*4h>o!0kPiOkzCoDM1$JC~PhCce2Q4ROG%!8v_}LV>FXOVooWd|dfEP$Hs=;P6W))d|Zr}#cWG!DPo?$c~y3Sh{S&AHS*BHIBU=nHni_%Wc}^oG3^>KZaZ)&F|j1=Mebn{GJln9-%Y%;y?2ahRh<;EV(i^I=-z;H$;>-8+=+CGGuV5*a-w(t7`An`dNC*1D$7*DS(p3?l*Kw>>6?Nw5zt~v9T{>ZW|QV$9}TqM@S=<`_rKZ{AM3Z~Z;*oq2;k>RSfjLuhv4!Lb8S^zbSZavu4j%;Fk)wTclLr)-epm={7hMUb8e%ENWCU~F*oq-7ZiiVyIy#+aw>xYIoK6f)5C^cKmwGcEsu(kwsIS(mAiNE4kapd(q`rwXkCQr~OgLYU%Q`v=m-tX!%U(*roKY~)rxOJWO`kzdQGtA(2K1Kt_ylSJf;%`FEg^`yF6+thnWuuV?X$3ne&1bLL_q%MExL_25c@v^mLJyGxvmC5g};>gE!SmL`C{Y^r^I)FFg`%oni0wmnt*{eJO&U5jXdaC@a$Kr;2zptt*Ik322IOPA|3oyBvE6!O07-Fg|3NQp}xvY!Q_VDGneGLpSwmA3#c5HyIi-Je;u*hnb)_))ckesK(id=(los&lTVsqY!K55oLE?fS2$azY{A({|2wrt#;r|GO*eMUbK{%$BYRaX7SqP4gb91Kpy(v%X97w@i<2|)Zq*=+Q0|>i=zqi{_jIf8bh*R^HqNt7XZVQ?u!Rx0?T6sKKYtK(*R>VwA#zj*?x@qgRe{^-W~>A2cz8j0|M}q6+z>cyXRR6a9cYNkr`Ws?updkv3W;>Wr>J0niI1To8L-Q{L|%MA^JWnvFrsaFw$ZKUyh0Zb-0lcNpL@7l6&3Ky!}YZjHDi=V`+{q(J=!Ga)KD>2&jnOZd!KRoJGa%ltc&$dTtF>5Rq64^{jLQ8I!=SOaRtZ`bSoJ}Dt*C%8*;;i7#BDC)N|z`W1#=Acats`(Y6dVYM8WkKGLPOZJDgN*eRsS*m#9Ls~{>iIFEiR1LjsY$QQ>&R8u@~h)5C;^sV=Ilf7Z>00Q7?yY2MD)q)jsHGj5hQto&i-bzBw*j|7L{6b;sc)t$wLi6$E}f;cX0fD90YJ0umk3Br4w1MZU_Jr%8wm|{74>bJ;zHAu%R$4YU*}GAzdT9?FS3}qcvx~O9d(U6=M_ez=i`~z$BNCLUjL`Z>7ja1#5#swvPm5XwM=PVXJsNG;cuj0S5(R!zpfr@?CZ=It1PDMd8yRlG|D@cR*l7&E9PTN||Lju{TqN8c$LL+*>cOK+I)ex^{WQrxh}pCQ3hrT^Et7`m%CRngiUvSv>e8dSdgqm`qAjN<4bP5HYp23yl0H2o!4KC|#^aet=8b9>YchpTw%jOUMU2MN_Mr?ba_dFe4NM(y*`Rf=pBm9_A4G+z62%7s*xFf&qJ;ZAyzLDKwu{A$g_%XiI*brA!s)7hgP3Y?Z--G(P8m5V=%{!v8a}ewhzNq}$1pI|ODSGFTU>BD;IEFcC~v=Eq>okA+8-tL+CcW+XgC!s0{<6SE%xU20HKs4)oGVAR$H%44mf+kWUW0NUjSV&W!6SIxS9QpG-OvE%$02>Cu`2390&yu~}dYbmw{CcM)5id!gLCZ}3{D(0guNl;ueUvo`JfHc0U=Q`cY-o6b=QVkjQ%SWe+O?mGC{t%Eq`1~c&6aW-vg49q~a)}}J)iLzL-~bhI&7pXS926i9GI~mNmX8EcMorz67JYYB@II)fW%l6R+VuHW8zuc2uS{a-j9imoXFgZ2*w9NXM>O74t+=nkh!|*NLlce>>%8P>Zm-mzuNriA;iwUiU_gj8Ep-hXBFLLb;7ph*d`v(3SLKx=9?F4-e+m>)4DWo3Gcm`q;#(+V?kUzxE$TkVc<0%qjDfa%mJPfO0qu?9~Sv$;>(>2u+h^0J?&e`ZCX2rQ0c@Hb643SugE5V?oQDZb*lpycWFp_W6L<$6Ygr4mp>=b1=xJ%rKPt0W{#vCKFv24nXJld1fpXdCJ&UvcJ4L)*^~!&xbZ5Wg1t?F&2Bk3D`z%w31uLa4`2sjxupJvMFd`U7a0*Mo+%#|EhXN3Zx94fz8a#_Kl9loJX6ktKg>3!xCDZ;CRof1!*EFm8GNGGZ_CA3Ei{o|6GF?cBqL}RrWM3>G*}%ExHCTIrn|2_jh*Fv^?5=QEOw80#NY~R*!V%zB_q+KKa}MtvsyhhiAMT@q!dDkx3C!GdpAPq*G^|%Oa&UZzatbb`{A@%&h(9>F|$DiF=u&oF|gh}e)6f!f7OdPo~Gj#OlmB#zU^o62*n1T=)$uhl=Lt>{Y@kE1HF?uk$LpyX_Amxz%9VC3o#7yJfE7mu?9Df%F)m3p5J`yF*}b$RC+ZCHAmW(1Fzzihjh!jbyCHM8wnJXfy5vB&Gv7ei>8+oG%?jMqxDqT#_OW`h4V$@OZe0dHHRQCQVIJH{Mxj?txk1*LkaoJ5mDWnm|quiup_j6IkhU&w9)C+8H33vmy+v+)enP|Lu>@mSmvMffN=JJFHjSN)RF=Kf)6ezt=qYJboSko|Qmbb2O#f`ALwNC&Ju!62w&(F2AdoyOLwjqFrYl{yNlL=QQf|RA+0=QVrjzt7JY|Vs&$r?wT4v{(Y@u(Fs6n9MYt|9pd#BLKy<|ITKkCpWz-Gs?ADbX_Wb-?wLphtDe2cDLw~^nFn%GsAVak|N8j7D8CJYq4`Y6~$P#trn?j3srfmhh7}|?3Au%5n;qpupvY(jHA?F9`0y0bn@+1jQWp*Q@IwTN(J_+7sKJwx_>xEI&p8jnLojbaeM?qy+O_D%DdwFK`Av^(CS#?38$1$6YT!kig`5yf_djPkU-#xBK%^|FT%za@Sc86Zd8s{#+v|?wzPT8(SsEQw3SGhjYB9msPhz*;xA2P$t6xj-=cKun2!0LBLvg3a9WQg4vDBkQLH)k^SRP4gyFWkOtgiy1@Zb8(cMd0pJtTKJ(f)s7WC^U_t5Z@_Cgd=BGkv#-^NwNo@bA-dZd~7A;CpgTEF|%t)V$c@ccOW+|7RJ&OU64Q(ppg!Yowm-#K%fVJAgQ%|Ok*nm7Dw_s8Z7>1(FC%kPl^IOH$6Y4qE{-BGf9u?uf}Ja@RGv23e+3#zu_b~FU>Kd!zMFNP~Pzo+1B8^PRNdnjn3>58i~@&nV^6UYha{>Y8}=a;fOSCQCT+o=_FAC-T0}+Ji8_ze^XgpYHmeK5xSiWpbr1|x*w-!}``EiU8GoH(DJwZ1*bk5f%1z7tNj4iK tuple[Tensor, Tensor]: + logits, targets = _validate_softcapped_ce_inputs(logits, targets, float(softcap)) + n_rows, n_cols = logits.shape + losses = torch.empty((n_rows,), device=logits.device, dtype=torch.float32) + lse = torch.empty((n_rows,), device=logits.device, dtype=torch.float32) + _softcapped_ce_fwd_kernel[(n_rows,)]( + logits, losses, lse, targets, + logits.stride(0), logits.stride(1), + n_rows, n_cols, float(softcap), + block_size=_FUSED_CE_BLOCK_SIZE, num_warps=_FUSED_CE_NUM_WARPS, + ) + return losses, lse + @softcapped_ce_op.register_fake + def _(logits: Tensor, targets: Tensor, softcap: float): + n_rows = logits.shape[0] + return (logits.new_empty((n_rows,), dtype=torch.float32), logits.new_empty((n_rows,), dtype=torch.float32)) + @torch.library.custom_op(f"{_FUSED_CE_LIBRARY}::softcapped_ce_backward", mutates_args=()) + def softcapped_ce_backward_op(logits: Tensor, targets: Tensor, lse: Tensor, grad_losses: Tensor, softcap: float) -> Tensor: + logits, targets = _validate_softcapped_ce_inputs(logits, targets, float(softcap)) + lse = lse.contiguous(); grad_losses = grad_losses.contiguous().to(dtype=torch.float32) + grad_logits = torch.empty_like(logits) + n_rows, n_cols = logits.shape + _softcapped_ce_bwd_kernel[(n_rows,)]( + grad_logits, grad_losses, lse, logits, targets, + logits.stride(0), logits.stride(1), + grad_logits.stride(0), grad_logits.stride(1), + n_rows, n_cols, float(softcap), + block_size=_FUSED_CE_BLOCK_SIZE, num_warps=_FUSED_CE_NUM_WARPS, + ) + return grad_logits + @softcapped_ce_backward_op.register_fake + def _(logits: Tensor, targets: Tensor, lse: Tensor, grad_losses: Tensor, softcap: float): + return logits.new_empty(logits.shape) + def _softcapped_ce_setup_context(ctx, inputs, output): + logits, targets, softcap = inputs + _losses, lse = output + ctx.save_for_backward(logits, targets, lse) + ctx.softcap = float(softcap) + def _softcapped_ce_backward(ctx, grad_losses, grad_lse): + del grad_lse + logits, targets, lse = ctx.saved_tensors + grad_logits = getattr(torch.ops, _FUSED_CE_LIBRARY).softcapped_ce_backward(logits, targets, lse, grad_losses, ctx.softcap) + return grad_logits, None, None + softcapped_ce_op.register_autograd(_softcapped_ce_backward, setup_context=_softcapped_ce_setup_context) + def softcapped_cross_entropy(logits, targets, softcap, reduction="mean"): + losses, _lse = getattr(torch.ops, _FUSED_CE_LIBRARY).softcapped_ce(logits, targets, float(softcap)) + if reduction == "none": return losses + if reduction == "sum": return losses.sum() + if reduction == "mean": return losses.mean() + raise ValueError(f"Unsupported reduction={reduction!r}") +class Hyperparameters:data_dir=os.environ.get('DATA_DIR','./data/');seed=int(os.environ.get('SEED',1337));run_id=os.environ.get('RUN_ID',str(uuid.uuid4()));iterations=int(os.environ.get('ITERATIONS',20000));warmdown_frac=float(os.environ.get('WARMDOWN_FRAC',.72));warmup_steps=int(os.environ.get('WARMUP_STEPS',20));train_batch_tokens=int(os.environ.get('TRAIN_BATCH_TOKENS',393216));train_seq_len=int(os.environ.get('TRAIN_SEQ_LEN',2048));train_log_every=int(os.environ.get('TRAIN_LOG_EVERY',500));max_wallclock_seconds=float(os.environ.get('MAX_WALLCLOCK_SECONDS',6e2));val_batch_tokens=int(os.environ.get('VAL_BATCH_TOKENS',524288));eval_seq_len=int(os.environ.get('EVAL_SEQ_LEN',2048));val_loss_every=int(os.environ.get('VAL_LOSS_EVERY',4000));sliding_window_enabled=bool(int(os.environ.get('SLIDING_WINDOW_ENABLED','1')));vocab_size=int(os.environ.get('VOCAB_SIZE',8192));num_layers=int(os.environ.get('NUM_LAYERS',11));xsa_last_n=int(os.environ.get('XSA_LAST_N',11));model_dim=int(os.environ.get('MODEL_DIM',512));embedding_dim=int(os.environ.get('EMBEDDING_DIM',512));num_kv_heads=int(os.environ.get('NUM_KV_HEADS',4));num_heads=int(os.environ.get('NUM_HEADS',8));mlp_mult=float(os.environ.get('MLP_MULT',4.));skip_gates_enabled=bool(int(os.environ.get('SKIP_GATES_ENABLED','1')));tie_embeddings=bool(int(os.environ.get('TIE_EMBEDDINGS','1')));logit_softcap=float(os.environ.get('LOGIT_SOFTCAP',3e1));rope_base=float(os.environ.get('ROPE_BASE',1e4));rope_dims=int(os.environ.get('ROPE_DIMS',16));rope_train_seq_len=int(os.environ.get('ROPE_TRAIN_SEQ_LEN',2048));ln_scale=bool(int(os.environ.get('LN_SCALE','1')));qk_gain_init=float(os.environ.get('QK_GAIN_INIT',5.));parallel_residual_start=int(os.environ.get('PARALLEL_RESIDUAL_START',7));enable_looping_at=float(os.environ.get('ENABLE_LOOPING_AT',.45));min_lr=float(os.environ.get('MIN_LR',.10));embed_lr=float(os.environ.get('EMBED_LR',.6));head_lr=float(os.environ.get('HEAD_LR',.008));tied_embed_lr=float(os.environ.get('TIED_EMBED_LR',.03));tied_embed_init_std=float(os.environ.get('TIED_EMBED_INIT_STD',.005));matrix_lr=float(os.environ.get('MATRIX_LR',.028));scalar_lr=float(os.environ.get('SCALAR_LR',.02));muon_momentum=float(os.environ.get('MUON_MOMENTUM',.95));muon_backend_steps=int(os.environ.get('MUON_BACKEND_STEPS',5));muon_momentum_warmup_start=float(os.environ.get('MUON_MOMENTUM_WARMUP_START',.85));muon_momentum_warmup_steps=int(os.environ.get('MUON_MOMENTUM_WARMUP_STEPS',500));muon_row_normalize=bool(int(os.environ.get('MUON_ROW_NORMALIZE','1')));beta1=float(os.environ.get('BETA1',.9));beta2=float(os.environ.get('BETA2',.95));adam_eps=float(os.environ.get('ADAM_EPS',1e-08));grad_clip_norm=float(os.environ.get('GRAD_CLIP_NORM',.3));eval_stride=int(os.environ.get('EVAL_STRIDE',64));adam_wd=float(os.environ.get('ADAM_WD',.02));muon_wd=float(os.environ.get('MUON_WD',.095));embed_wd=float(os.environ.get('EMBED_WD',.085));ema_decay=float(os.environ.get('EMA_DECAY',.9965));compressor=os.environ.get('COMPRESSOR','brotli');swa_enabled=bool(int(os.environ.get('SWA_ENABLED','1')));swa_start_frac=float(os.environ.get('SWA_START_FRAC',.12));swa_every=int(os.environ.get('SWA_EVERY',1));gptq_calibration_batches=int(os.environ.get('GPTQ_CALIBRATION_BATCHES',64));gptq_reserve_seconds=float(os.environ.get('GPTQ_RESERVE_SECONDS',0.));matrix_bits=int(os.environ.get('MATRIX_BITS',6));embed_bits=int(os.environ.get('EMBED_BITS',8));matrix_clip_sigmas=float(os.environ.get('MATRIX_CLIP_SIGMAS',12.85));embed_clip_sigmas=float(os.environ.get('EMBED_CLIP_SIGMAS',2e1));hessian_clip_lambda=float(os.environ.get('HESSIAN_CLIP_LAMBDA',.175));eval_temperature=float(os.environ.get('EVAL_TEMPERATURE',1.));sparsity_start_frac=float(os.environ.get('SPARSITY_START_FRAC',0.));ttt_enabled=bool(int(os.environ.get('TTT_ENABLED','1')));ttt_lr=float(os.environ.get('TTT_LR',.02));ttt_epochs=int(os.environ.get('TTT_EPOCHS',3));ttt_momentum=float(os.environ.get('TTT_MOMENTUM',.9));ttt_chunk_tokens=int(os.environ.get('TTT_CHUNK_TOKENS',32768));ttt_lora_rank=int(os.environ.get('TTT_LORA_RANK',96));ttt_ns_steps=int(os.environ.get('TTT_NS_STEPS',0));ttt_swa=bool(int(os.environ.get('TTT_SWA','0')));ttt_reset_per_chunk=bool(int(os.environ.get('TTT_RESET_PER_CHUNK','0')));ttt_lora_alpha=float(os.environ.get('TTT_LORA_ALPHA',144));ttt_warm_start_a=bool(int(os.environ.get('TTT_WARM_START_A','1')));ttt_lora_lr=float(os.environ.get('TTT_LORA_LR',0.0001));ttt_chunk_size=int(os.environ.get('TTT_CHUNK_SIZE',48));ttt_batch_size=int(os.environ.get('TTT_BATCH_SIZE',64));ttt_grad_steps=int(os.environ.get('TTT_GRAD_STEPS',1));ttt_weight_decay=float(os.environ.get('TTT_WEIGHT_DECAY',1.0));ttt_beta1=float(os.environ.get('TTT_BETA1',0));ttt_beta2=float(os.environ.get('TTT_BETA2',0.999));ttt_k_lora=bool(int(os.environ.get('TTT_K_LORA','1')));ttt_mlp_lora=bool(int(os.environ.get('TTT_MLP_LORA','1')));ttt_o_lora=bool(int(os.environ.get('TTT_O_LORA','1')));ttt_optimizer=os.environ.get('TTT_OPTIMIZER','adam');val_doc_fraction=float(os.environ.get('VAL_DOC_FRACTION',1.0));phased_ttt_prefix_docs=int(os.environ.get('PHASED_TTT_PREFIX_DOCS',2000));phased_ttt_num_phases=int(os.environ.get('PHASED_TTT_NUM_PHASES',1));global_ttt_lr=float(os.environ.get('GLOBAL_TTT_LR',0.001));global_ttt_momentum=float(os.environ.get('GLOBAL_TTT_MOMENTUM',0.9));global_ttt_epochs=int(os.environ.get('GLOBAL_TTT_EPOCHS',1));global_ttt_chunk_tokens=int(os.environ.get('GLOBAL_TTT_CHUNK_TOKENS',32768));global_ttt_batch_seqs=int(os.environ.get('GLOBAL_TTT_BATCH_SEQS',32));global_ttt_warmup_start_lr=float(os.environ.get('GLOBAL_TTT_WARMUP_START_LR',0.0));global_ttt_warmup_chunks=int(os.environ.get('GLOBAL_TTT_WARMUP_CHUNKS',0));global_ttt_grad_clip=float(os.environ.get('GLOBAL_TTT_GRAD_CLIP',0));ttt_eval_seq_len=int(os.environ.get('TTT_EVAL_SEQ_LEN',2048));ttt_eval_batches=os.environ.get('TTT_EVAL_BATCHES','');mos_k=int(os.environ.get('MOS_K',1));byte_weighted_ce=bool(int(os.environ.get('BYTE_WEIGHTED_CE','0')));batch_schedule_enabled=bool(int(os.environ.get('BATCH_SCHEDULE_ENABLED','0')));scale_tuning_enabled=bool(int(os.environ.get('SCALE_TUNING_ENABLED','0')));scale_tuning_steps=int(os.environ.get('SCALE_TUNING_STEPS',20));scale_tuning_lr=float(os.environ.get('SCALE_TUNING_LR',0.001));scale_tuning_batches=int(os.environ.get('SCALE_TUNING_BATCHES',8));distributed='RANK'in os.environ and'WORLD_SIZE'in os.environ;rank=int(os.environ.get('RANK','0'));world_size=int(os.environ.get('WORLD_SIZE','1'));local_rank=int(os.environ.get('LOCAL_RANK','0'));is_main_process=rank==0;grad_accum_steps=8//world_size;datasets_dir=os.path.join(data_dir,'datasets',f"fineweb10B_sp{vocab_size}");train_files=os.path.join(datasets_dir,'fineweb_train_*.bin');val_files=os.path.join(datasets_dir,'fineweb_val_*.bin');tokenizer_path=os.path.join(data_dir,'tokenizers',f"fineweb_{vocab_size}_bpe.model");logfile=f"logs/{run_id}.txt";model_path='final_model.pt';quantized_model_path='final_model.int6.ptz';kd_enabled=bool(int(os.environ.get('KD_ENABLED','0')));kd_alpha=float(os.environ.get('KD_ALPHA',0.5));kd_temperature=float(os.environ.get('KD_TEMPERATURE',2.0));kd_top_k=int(os.environ.get('KD_TOP_K',32));kd_logits_dir=os.environ.get('KD_LOGITS_DIR','');kd_warmup_frac=float(os.environ.get('KD_WARMUP_FRAC',0.0));max_eval_seconds=float(os.environ.get('MAX_EVAL_SECONDS',6e2)) +_logger_hparams=None +def set_logging_hparams(h):global _logger_hparams;_logger_hparams=h +def log(msg,console=True): + if _logger_hparams is None:print(msg);return + if _logger_hparams.is_main_process: + if console:print(msg) + if _logger_hparams.logfile is not None: + with open(_logger_hparams.logfile,'a',encoding='utf-8')as f:print(msg,file=f) +class ValidationData: + def __init__(self,h,device): + self.sp=spm.SentencePieceProcessor(model_file=h.tokenizer_path) + if int(self.sp.vocab_size())!=h.vocab_size:raise ValueError(f"vocab mismatch") + self.val_tokens=load_validation_tokens(h.val_files,h.eval_seq_len);self.base_bytes_lut,self.has_leading_space_lut,self.is_boundary_token_lut=build_sentencepiece_luts(self.sp,h.vocab_size,device) +def build_sentencepiece_luts(sp,vocab_size,device): + sp_vocab_size=int(sp.vocab_size());assert sp.piece_to_id('▁')!=sp.unk_id(),"bad tokenizer";table_size=max(sp_vocab_size,vocab_size);base_bytes_np=np.zeros((table_size,),dtype=np.int16);has_leading_space_np=np.zeros((table_size,),dtype=np.bool_);is_boundary_token_np=np.ones((table_size,),dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id)or sp.is_unknown(token_id)or sp.is_unused(token_id):continue + is_boundary_token_np[token_id]=False + if sp.is_byte(token_id):base_bytes_np[token_id]=1;continue + piece=sp.id_to_piece(token_id) + if piece.startswith('▁'):has_leading_space_np[token_id]=True;piece=piece[1:] + base_bytes_np[token_id]=len(piece.encode('utf-8')) + return torch.tensor(base_bytes_np,dtype=torch.int16,device=device),torch.tensor(has_leading_space_np,dtype=torch.bool,device=device),torch.tensor(is_boundary_token_np,dtype=torch.bool,device=device) +def load_validation_tokens(pattern,seq_len): + files=[Path(p)for p in sorted(glob.glob(pattern))] + if not files:raise FileNotFoundError(f"no files: {pattern}") + tokens=torch.cat([load_data_shard(file)for file in files]).contiguous();usable=(tokens.numel()-1)//seq_len*seq_len + if usable<=0:raise ValueError(f"val too short") + return tokens[:usable+1] +def load_data_shard(file): + header_bytes=256*np.dtype('0 else 0;num_sequences=(self.num_tokens[si]-1-phase)//self.seq_len;sequence_order=self.rng.permutation(num_sequences);self.start_inds[si]=(phase+sequence_order*self.seq_len).tolist() + def next_batch(self,global_tokens,grad_accum_steps): + device_tokens=global_tokens//(self.world_size*grad_accum_steps);device_batch_size=device_tokens//self.seq_len;remaining=np.array([len(s)for s in self.start_inds],dtype=np.float64);x=torch.empty((device_batch_size,self.seq_len),dtype=torch.int64);y=torch.empty((device_batch_size,self.seq_len),dtype=torch.int64) + teacher_topk_idx=None;teacher_topk_val=None + if self.kd_enabled:teacher_topk_idx=torch.empty((device_batch_size,self.seq_len,self.kd_top_k),dtype=torch.int64);teacher_topk_val=torch.empty((device_batch_size,self.seq_len,self.kd_top_k),dtype=torch.float32) + for bi in range(device_batch_size): + total=remaining.sum() + if total<=0: + for si in range(len(self.files)):self._reset_shard(si) + remaining=np.array([len(s)for s in self.start_inds],dtype=np.float64);total=remaining.sum() + probs=remaining/total;si=int(self.rng.choice(len(self.files),p=probs));start_ind=self.start_inds[si].pop();remaining[si]-=1;mm=_get_shard_memmap(self.files[si]);window=torch.as_tensor(np.array(mm[start_ind:start_ind+self.seq_len+1],dtype=np.int64));x[bi]=window[:-1];y[bi]=window[1:] + if self.kd_enabled and teacher_topk_idx is not None: + tmaps=_get_teacher_logits_memmaps(self.files[si],self.kd_logits_dir,self.kd_top_k) + if tmaps is not None: + t_idx_mm,t_val_mm=tmaps;t_start=start_ind;t_end=start_ind+self.seq_len + if t_end<=t_idx_mm.shape[0]:teacher_topk_idx[bi]=torch.as_tensor(np.array(t_idx_mm[t_start:t_end],dtype=np.int64));teacher_topk_val[bi]=torch.as_tensor(np.array(t_val_mm[t_start:t_end],dtype=np.float32)) + else:teacher_topk_idx[bi]=0;teacher_topk_val[bi]=0. + else:teacher_topk_idx[bi]=0;teacher_topk_val[bi]=0. + result=(x.to(self.device,non_blocking=True),y.to(self.device,non_blocking=True)) + if self.kd_enabled and teacher_topk_idx is not None:result=result+(teacher_topk_idx.to(self.device,non_blocking=True),teacher_topk_val.to(self.device,non_blocking=True)) + return result +class RMSNorm(nn.Module): + def __init__(self,eps=None):super().__init__();self.eps=eps + def forward(self,x):return F.rms_norm(x,(x.size(-1),),eps=self.eps) +class CastedLinear(nn.Linear): + def forward(self,x): + w=self.weight.to(x.dtype) + bias=self.bias.to(x.dtype)if self.bias is not None else None;return F.linear(x,w,bias) +class Rotary(nn.Module): + def __init__(self,dim,base=1e4,train_seq_len=1024,rope_dims=0):super().__init__();self.dim=dim;self.base=base;self.train_seq_len=train_seq_len;self.rope_dims=rope_dims if rope_dims>0 else dim;inv_freq=1./base**(torch.arange(0,self.rope_dims,2,dtype=torch.float32)/self.rope_dims);self.register_buffer('inv_freq',inv_freq,persistent=False);self._seq_len_cached=0;self._cos_cached=None;self._sin_cached=None + def forward(self,seq_len,device,dtype): + if self._cos_cached is None or self._sin_cached is None or self._seq_len_cached!=seq_len or self._cos_cached.device!=device: + rd=self.rope_dims + if seq_len>self.train_seq_len:scale=seq_len/self.train_seq_len;new_base=self.base*scale**(rd/(rd-2));inv_freq=1./new_base**(torch.arange(0,rd,2,dtype=torch.float32,device=device)/rd) + else:inv_freq=self.inv_freq.to(device) + t=torch.arange(seq_len,device=device,dtype=inv_freq.dtype);freqs=torch.outer(t,inv_freq);self._cos_cached=freqs.cos()[None,:,None,:];self._sin_cached=freqs.sin()[None,:,None,:];self._seq_len_cached=seq_len + return self._cos_cached.to(dtype=dtype),self._sin_cached.to(dtype=dtype) +def apply_rotary_emb(x,cos,sin,rope_dims=0): + if rope_dims>0 and rope_dims1:k=k[:,:,:,None,:].expand(bsz,seqlen,self.num_kv_heads,rep,self.head_dim).reshape(bsz,seqlen,self.num_heads,self.head_dim);v=v[:,:,:,None,:].expand(bsz,seqlen,self.num_kv_heads,rep,self.head_dim).reshape(bsz,seqlen,self.num_heads,self.head_dim) + y=F.scaled_dot_product_attention(q.transpose(1,2),k.transpose(1,2),v.transpose(1,2),is_causal=True).transpose(1,2) + if self.use_xsa:y=self._xsa_efficient(y,v) + y=y.reshape(bsz,seqlen,dim);return self.proj(y),v_out +class MLP(nn.Module): + def __init__(self,dim,mlp_mult):super().__init__();hidden=int(mlp_mult*dim);self.fc=CastedLinear(dim,hidden,bias=False);self.proj=CastedLinear(hidden,dim,bias=False);self.proj._zero_init=True;self.gated=bool(int(os.environ.get('GATED_MLP','0'))) + def forward(self,x): + h=self.fc(x);act=F.leaky_relu(h,negative_slope=.5).square() + if self.gated:act=act*torch.sigmoid(h) + return self.proj(act) +class Block(nn.Module): + def __init__(self,dim,num_heads,num_kv_heads,mlp_mult,rope_base,qk_gain_init,train_seq_len,layer_idx=0,ln_scale=False):super().__init__();self.attn_norm=RMSNorm();self.mlp_norm=RMSNorm();self.attn=CausalSelfAttention(dim,num_heads,num_kv_heads,rope_base,qk_gain_init,train_seq_len);self.mlp=MLP(dim,mlp_mult);self.attn_scale=nn.Parameter(torch.ones(dim,dtype=torch.float32));self.mlp_scale=nn.Parameter(torch.ones(dim,dtype=torch.float32));self.resid_mix=nn.Parameter(torch.stack((torch.ones(dim),torch.zeros(dim))).float());self.ln_scale_factor=1./math.sqrt(layer_idx+1)if ln_scale else 1.;self.parallel=False + def forward(self,x,x0,v_residual=None): + mix=self.resid_mix.to(dtype=x.dtype);x_in=mix[0][None,None,:]*x+mix[1][None,None,:]*x0;attn_out,v_out=self.attn(self.attn_norm(x_in)*self.ln_scale_factor,v_residual=v_residual) + if self.parallel:mlp_out=self.mlp(self.mlp_norm(x_in)*self.ln_scale_factor);x_out=x_in+self.attn_scale.to(dtype=x_in.dtype)[None,None,:]*attn_out+self.mlp_scale.to(dtype=x_in.dtype)[None,None,:]*mlp_out + else:x_out=x_in+self.attn_scale.to(dtype=x_in.dtype)[None,None,:]*attn_out;x_out=x_out+self.mlp_scale.to(dtype=x_out.dtype)[None,None,:]*self.mlp(self.mlp_norm(x_out)*self.ln_scale_factor) + return x_out,v_out +def sparse_kd_loss(student_logits, teacher_topk_idx, teacher_topk_val, temperature, softcap=0.): + """Compute sparse KL-divergence loss between student and teacher (top-K only). + + Args: + student_logits: [N, V] raw student logits (pre- or post-softcap) + teacher_topk_idx: [N, K] top-K token indices from teacher + teacher_topk_val: [N, K] top-K logit values from teacher (post-softcap) + temperature: temperature for softening distributions + softcap: if > 0, apply softcap to student logits first + Returns: + scalar KD loss (mean over batch) + """ + # Apply softcap to student if needed + if softcap > 0.: + student_logits = softcap * torch.tanh(student_logits / softcap) + # Scale by temperature + s_scaled = student_logits.float() / temperature + t_scaled = teacher_topk_val.float() / temperature + # Student log-softmax over full vocab (for log_sum_exp) + s_log_softmax = s_scaled - torch.logsumexp(s_scaled, dim=-1, keepdim=True) + # Teacher softmax over top-K only (renormalized) + t_probs = F.softmax(t_scaled, dim=-1) + # Gather student log-probs at teacher's top-K positions + s_log_probs_at_topk = s_log_softmax.gather(1, teacher_topk_idx) + # KL = sum_k t_prob_k * (log(t_prob_k) - s_log_prob_k) + # = sum_k t_prob_k * log(t_prob_k) - sum_k t_prob_k * s_log_prob_k + # First term is teacher entropy (constant wrt student), we include it for proper KL + t_log_probs = torch.log(t_probs.clamp(min=1e-8)) + kl = (t_probs * (t_log_probs - s_log_probs_at_topk)).sum(dim=-1) + return (kl * temperature * temperature).mean() +class GPT(nn.Module): + def __init__(self,h): + super().__init__() + if h.logit_softcap<=.0:raise ValueError(f"bad softcap") + self.tie_embeddings=h.tie_embeddings;self.tied_embed_init_std=h.tied_embed_init_std;self.logit_softcap=h.logit_softcap;self.fused_ce_enabled=_FUSED_CE_ENABLED;self.tok_emb=nn.Embedding(h.vocab_size,h.embedding_dim) + if h.embedding_dim!=h.model_dim:self.embed_proj=CastedLinear(h.embedding_dim,h.model_dim,bias=False);self.head_proj=CastedLinear(h.model_dim,h.embedding_dim,bias=False) + else:self.embed_proj=None;self.head_proj=None + self.num_encoder_layers=h.num_layers//2;self.num_decoder_layers=h.num_layers-self.num_encoder_layers;self.blocks=nn.ModuleList([Block(h.model_dim,h.num_heads,h.num_kv_heads,h.mlp_mult,h.rope_base,h.qk_gain_init,h.train_seq_len,layer_idx=i,ln_scale=h.ln_scale)for i in range(h.num_layers)]) + if h.rope_dims>0: + head_dim=h.model_dim//h.num_heads + for block in self.blocks:block.attn.rope_dims=h.rope_dims;block.attn.rotary=Rotary(head_dim,base=h.rope_base,train_seq_len=h.train_seq_len,rope_dims=h.rope_dims) + self.final_norm=RMSNorm();self.lm_head=None if h.tie_embeddings else CastedLinear(h.embedding_dim,h.vocab_size,bias=False) + if self.lm_head is not None:self.lm_head._zero_init=True + if h.xsa_last_n>0: + for i in range(max(0,h.num_layers-h.xsa_last_n),h.num_layers):self.blocks[i].attn.use_xsa=True + if h.parallel_residual_start>=0: + for i in range(h.parallel_residual_start,h.num_layers):self.blocks[i].parallel=True + self.hourglass_enabled=bool(int(os.environ.get('HOURGLASS_ENABLED','0')));self.hourglass_down_after=int(os.environ.get('HOURGLASS_DOWN_AFTER',3));self.hourglass_up_before=int(os.environ.get('HOURGLASS_UP_BEFORE',8));self.hourglass_factor=int(os.environ.get('HOURGLASS_FACTOR',2)) + if self.hourglass_enabled:self.hourglass_skip_gate=nn.Parameter(torch.zeros(h.model_dim,dtype=torch.float32)) + self.looping_active=False;num_loops=int(os.environ.get('NUM_LOOPS',1));self.loop_start_idx=int(os.environ.get('LOOP_START',3)) + if num_loops>0: + ls=self.loop_start_idx;le=int(os.environ.get('LOOP_END',5));loop_seg=list(range(ls,le+1));all_idx=list(range(ls)) + for _ in range(num_loops+1):all_idx.extend(loop_seg) + all_idx.extend(range(le+1,h.num_layers));ne=len(all_idx)//2;self.encoder_indices=all_idx[:ne];self.decoder_indices=all_idx[ne:] + else:self.encoder_indices=list(range(self.num_encoder_layers));self.decoder_indices=list(range(self.num_encoder_layers,h.num_layers)) + self.num_skip_weights=min(len(self.encoder_indices),len(self.decoder_indices));self.skip_weights=nn.Parameter(torch.ones(self.num_skip_weights,h.model_dim,dtype=torch.float32));self.skip_gates=nn.Parameter(torch.zeros(self.num_skip_weights,h.model_dim,dtype=torch.float32))if h.skip_gates_enabled else None + self.loop_gate_enabled=bool(int(os.environ.get('LOOP_GATE_ENABLED','0'))) + if self.loop_gate_enabled:self.loop_gate=nn.Parameter(torch.full((h.model_dim,),-2.0,dtype=torch.float32)) + else:self.loop_gate=None + self.mos_k=h.mos_k + if self.mos_k>1: + self.mos_scales=nn.ParameterList([nn.Parameter(torch.ones(h.embedding_dim,dtype=torch.float32))for _ in range(self.mos_k)]) + self.mos_gate=nn.Parameter(torch.zeros(h.embedding_dim,self.mos_k,dtype=torch.float32)) + self._init_weights() + def _init_weights(self): + if self.tie_embeddings:nn.init.normal_(self.tok_emb.weight,mean=.0,std=self.tied_embed_init_std) + for(name,module)in self.named_modules(): + if isinstance(module,nn.Linear): + if getattr(module,'_zero_init',False):nn.init.zeros_(module.weight) + elif module.weight.ndim==2 and module.weight.shape[0]>=64 and module.weight.shape[1]>=64:nn.init.orthogonal_(module.weight,gain=1.) + def _hourglass_downsample(self,x): + """Downsample sequence T -> T/factor via avg_pool1d. x: [B, T, D] -> [B, T//factor, D]""" + f=self.hourglass_factor + return F.avg_pool1d(x.transpose(1,2),kernel_size=f,stride=f).transpose(1,2) + def _hourglass_upsample(self,x,target_len): + """Upsample sequence back to target_len via repeat_interleave. x: [B, T_down, D] -> [B, target_len, D]""" + x_up=x.repeat_interleave(self.hourglass_factor,dim=1) + if x_up.size(1)>target_len:x_up=x_up[:,:target_len,:] + return x_up + def _forward_hidden(self,input_ids): + """Forward through all blocks, returns hidden states before final linear projection.""" + x=self.tok_emb(input_ids);x=F.rms_norm(x,(x.size(-1),)) + if self.embed_proj is not None:x=self.embed_proj(x) + x0=x;skips=[];v_residual=None;enc_iter=self.encoder_indices if self.looping_active else range(self.num_encoder_layers);dec_iter=self.decoder_indices if self.looping_active else range(self.num_encoder_layers,self.num_encoder_layers+self.num_decoder_layers);is_first_layer=True;seen_loop_start=False;x_at_loop_start=None + hg=self.hourglass_enabled;full_len=x.size(1);hg_skip=None;in_downsampled=False + for step_idx,i in enumerate(enc_iter): + if self.loop_gate is not None and self.looping_active and i==self.loop_start_idx: + if not seen_loop_start:seen_loop_start=True;x_at_loop_start=x + else:g=torch.sigmoid(self.loop_gate.to(dtype=x.dtype))[None,None,:];x=(1-g)*x+g*x_at_loop_start + x,v_out=self.blocks[i](x,x0,v_residual=v_residual) + if _V_RESIDUAL_ENABLED and is_first_layer:v_residual=v_out;is_first_layer=False + skips.append(x) + if hg and not in_downsampled and step_idx==self.hourglass_down_after: + hg_skip=x;x=self._hourglass_downsample(x);x0=self._hourglass_downsample(x0);in_downsampled=True + if x_at_loop_start is not None:x_at_loop_start=self._hourglass_downsample(x_at_loop_start) + for(skip_idx,i)in enumerate(dec_iter): + if hg and in_downsampled and skip_idx==self.hourglass_up_before: + x=self._hourglass_upsample(x,full_len);x0=self._hourglass_upsample(x0,full_len);in_downsampled=False + if x_at_loop_start is not None:x_at_loop_start=self._hourglass_upsample(x_at_loop_start,full_len) + sg=torch.sigmoid(self.hourglass_skip_gate.to(dtype=x.dtype))[None,None,:];x=(1-sg)*x+sg*hg_skip + if self.loop_gate is not None and self.looping_active and i==self.loop_start_idx: + if not seen_loop_start:seen_loop_start=True;x_at_loop_start=x + else:g=torch.sigmoid(self.loop_gate.to(dtype=x.dtype))[None,None,:];x=(1-g)*x+g*x_at_loop_start + if skip_idxx.size(1):scaled_skip=self._hourglass_downsample(scaled_skip) + else:scaled_skip=self._hourglass_upsample(scaled_skip,x.size(1)) + scaled_skip=self.skip_weights[skip_idx].to(dtype=x.dtype)[None,None,:]*scaled_skip + if self.skip_gates is not None:g=torch.sigmoid(self.skip_gates[skip_idx].to(dtype=x.dtype))[None,None,:];x=torch.lerp(scaled_skip,x,g) + else:x=x+scaled_skip + x,_=self.blocks[i](x,x0,v_residual=v_residual) + if hg and in_downsampled:x=self._hourglass_upsample(x,full_len) + x=self.final_norm(x) + if self.head_proj is not None:x=self.head_proj(x) + return x + def _forward_pre_softcap(self,input_ids): + x=self._forward_hidden(input_ids) + if self.tie_embeddings:return F.linear(x,self.tok_emb.weight) + else:return self.lm_head(x) + def _mos_loss(self,x,target_ids,reduction='mean'): + """Mixture of Softmax loss. x is pre-projection hidden states [*, dim].""" + tok_w=self.tok_emb.weight if self.tie_embeddings else self.lm_head.weight + sc=self.logit_softcap;K=self.mos_k + # mixing weights: [*, K] + pi=F.softmax(x@self.mos_gate.to(dtype=x.dtype),dim=-1) + # compute K softmax distributions and mix + p_mixed=torch.zeros(*x.shape[:-1],tok_w.shape[0],device=x.device,dtype=torch.float32) + for k in range(K): + x_k=x*self.mos_scales[k].to(dtype=x.dtype)[None,None,:]if x.ndim==3 else x*self.mos_scales[k].to(dtype=x.dtype) + logits_k=F.linear(x_k,tok_w) + logits_k=sc*torch.tanh(logits_k/sc) + p_k=F.softmax(logits_k.float(),dim=-1) + p_mixed=p_mixed+pi[...,k:k+1].float()*p_k + # loss = -log(p_mixed[target]) + p_mixed=p_mixed.clamp(min=1e-8) + flat_p=p_mixed.reshape(-1,p_mixed.size(-1)) + flat_t=target_ids.reshape(-1) + losses=-torch.log(flat_p.gather(1,flat_t.unsqueeze(1)).squeeze(1)) + if reduction=='mean':return losses.mean() + return losses + def forward_logits(self,input_ids): + if self.mos_k>1: + x=self._forward_hidden(input_ids);tok_w=self.tok_emb.weight if self.tie_embeddings else self.lm_head.weight;sc=self.logit_softcap + pi=F.softmax(x@self.mos_gate.to(dtype=x.dtype),dim=-1) + p_mixed=torch.zeros(*x.shape[:-1],tok_w.shape[0],device=x.device,dtype=torch.float32) + for k in range(self.mos_k): + x_k=x*self.mos_scales[k].to(dtype=x.dtype)[None,None,:] + logits_k=sc*torch.tanh(F.linear(x_k,tok_w)/sc) + p_mixed=p_mixed+pi[...,k:k+1].float()*F.softmax(logits_k.float(),dim=-1) + return torch.log(p_mixed.clamp(min=1e-8)) + logits_proj=self._forward_pre_softcap(input_ids);return self.logit_softcap*torch.tanh(logits_proj/self.logit_softcap) + def forward(self,input_ids,target_ids,byte_weights=None,teacher_topk_idx=None,teacher_topk_val=None,kd_alpha=0.,kd_temperature=2.): + if self.mos_k>1: + x=self._forward_hidden(input_ids) + if byte_weights is not None: + losses=self._mos_loss(x,target_ids,reduction='none') + bw=byte_weights.reshape(-1).float();return(losses*bw).sum()/bw.sum() + return self._mos_loss(x,target_ids,reduction='mean') + # Knowledge distillation path + if teacher_topk_idx is not None and kd_alpha>0.: + logits_proj=self._forward_pre_softcap(input_ids);flat_logits=logits_proj.reshape(-1,logits_proj.size(-1));flat_targets=target_ids.reshape(-1) + if self.fused_ce_enabled:ce_loss=softcapped_cross_entropy(flat_logits,flat_targets,self.logit_softcap,reduction='mean') + else:logits_sc=self.logit_softcap*torch.tanh(flat_logits/self.logit_softcap);ce_loss=F.cross_entropy(logits_sc.float(),flat_targets,reduction='mean') + flat_t_idx=teacher_topk_idx.reshape(-1,teacher_topk_idx.size(-1));flat_t_val=teacher_topk_val.reshape(-1,teacher_topk_val.size(-1)) + kd_loss=sparse_kd_loss(flat_logits,flat_t_idx,flat_t_val,kd_temperature,softcap=self.logit_softcap) + return(1.-kd_alpha)*ce_loss+kd_alpha*kd_loss + if byte_weights is not None: + if self.fused_ce_enabled: + logits_proj=self._forward_pre_softcap(input_ids) + losses=softcapped_cross_entropy(logits_proj.reshape(-1,logits_proj.size(-1)),target_ids.reshape(-1),self.logit_softcap,reduction='none') + else: + logits=self.forward_logits(input_ids);losses=F.cross_entropy(logits.reshape(-1,logits.size(-1)).float(),target_ids.reshape(-1),reduction='none') + bw=byte_weights.reshape(-1).float();return(losses*bw).sum()/bw.sum() + if self.fused_ce_enabled: + logits_proj=self._forward_pre_softcap(input_ids) + return softcapped_cross_entropy(logits_proj.reshape(-1,logits_proj.size(-1)),target_ids.reshape(-1),self.logit_softcap,reduction='mean') + logits=self.forward_logits(input_ids);return F.cross_entropy(logits.reshape(-1,logits.size(-1)).float(),target_ids.reshape(-1),reduction='mean') + def forward_ttt(self, input_ids, target_ids, lora): + """Forward pass with batched LoRA adapters for TTT. Returns per-token loss [bsz, seq_len].""" + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + if self.embed_proj is not None: + x = self.embed_proj(x) + x0 = x + skips = [] + v_residual = None + enc_iter = self.encoder_indices if self.looping_active else range(self.num_encoder_layers) + dec_iter = self.decoder_indices if self.looping_active else range(self.num_encoder_layers, self.num_encoder_layers + self.num_decoder_layers) + slot = 0 + is_first_layer = True + seen_loop_start = False + x_at_loop_start = None + hg = self.hourglass_enabled + full_len = x.size(1) + hg_skip = None + in_downsampled = False + for step_idx, i in enumerate(enc_iter): + if self.loop_gate is not None and self.looping_active and i == self.loop_start_idx: + if not seen_loop_start: + seen_loop_start = True + x_at_loop_start = x + else: + g = torch.sigmoid(self.loop_gate.to(dtype=x.dtype))[None, None, :] + x = (1 - g) * x + g * x_at_loop_start + x, v_out = self._block_with_lora(self.blocks[i], x, x0, lora, slot, v_residual=v_residual) + if _V_RESIDUAL_ENABLED and is_first_layer: + v_residual = v_out + is_first_layer = False + slot += 1 + skips.append(x) + if hg and not in_downsampled and step_idx == self.hourglass_down_after: + hg_skip = x + x = self._hourglass_downsample(x) + x0 = self._hourglass_downsample(x0) + in_downsampled = True + if x_at_loop_start is not None: + x_at_loop_start = self._hourglass_downsample(x_at_loop_start) + for skip_idx, i in enumerate(dec_iter): + if hg and in_downsampled and skip_idx == self.hourglass_up_before: + x = self._hourglass_upsample(x, full_len) + x0 = self._hourglass_upsample(x0, full_len) + in_downsampled = False + if x_at_loop_start is not None: + x_at_loop_start = self._hourglass_upsample(x_at_loop_start, full_len) + sg = torch.sigmoid(self.hourglass_skip_gate.to(dtype=x.dtype))[None, None, :] + x = (1 - sg) * x + sg * hg_skip + if self.loop_gate is not None and self.looping_active and i == self.loop_start_idx: + if not seen_loop_start: + seen_loop_start = True + x_at_loop_start = x + else: + g = torch.sigmoid(self.loop_gate.to(dtype=x.dtype))[None, None, :] + x = (1 - g) * x + g * x_at_loop_start + if skip_idx < self.num_skip_weights and skips: + scaled_skip = skips.pop() + if hg and scaled_skip.size(1) != x.size(1): + if scaled_skip.size(1) > x.size(1): + scaled_skip = self._hourglass_downsample(scaled_skip) + else: + scaled_skip = self._hourglass_upsample(scaled_skip, x.size(1)) + scaled_skip = self.skip_weights[skip_idx].to(dtype=x.dtype)[None, None, :] * scaled_skip + if self.skip_gates is not None: + g = torch.sigmoid(self.skip_gates[skip_idx].to(dtype=x.dtype))[None, None, :] + x = torch.lerp(scaled_skip, x, g) + else: + x = x + scaled_skip + x, _ = self._block_with_lora(self.blocks[i], x, x0, lora, slot, v_residual=v_residual) + slot += 1 + if hg and in_downsampled: + x = self._hourglass_upsample(x, full_len) + x = self.final_norm(x) + if self.head_proj is not None: + x = self.head_proj(x) + if self.mos_k > 1: + tok_w = self.tok_emb.weight if self.tie_embeddings else self.lm_head.weight + sc = self.logit_softcap + pi = F.softmax(x @ self.mos_gate.to(dtype=x.dtype), dim=-1) + bsz, sl, dim = x.shape + V = tok_w.shape[0] + p_mixed = torch.zeros(bsz, sl, V, device=x.device, dtype=torch.float32) + for k in range(self.mos_k): + x_k = x * self.mos_scales[k].to(dtype=x.dtype)[None, None, :] + logits_k = F.linear(x_k, tok_w) + lora.lm_head_lora(x_k) + logits_k = sc * torch.tanh(logits_k / sc) + p_k = F.softmax(logits_k.float(), dim=-1) + p_mixed = p_mixed + pi[..., k:k+1].float() * p_k + p_mixed = p_mixed.clamp(min=1e-8) + flat_p = p_mixed.reshape(-1, V) + flat_t = target_ids.reshape(-1) + losses = -torch.log(flat_p.gather(1, flat_t.unsqueeze(1)).squeeze(1)) + return losses.reshape(bsz, sl) + if self.tie_embeddings: + logits = F.linear(x, self.tok_emb.weight) + else: + logits = self.lm_head(x) + logits = logits + lora.lm_head_lora(x) + logits = self.logit_softcap * torch.tanh(logits / self.logit_softcap) + bsz, sl, V = logits.shape + return F.cross_entropy( + logits.float().reshape(-1, V), target_ids.reshape(-1), reduction="none" + ).reshape(bsz, sl) + def _block_with_lora(self, block, x, x0, lora, slot, v_residual=None): + """Single block forward with LoRA injection, handles both parallel and sequential.""" + mix = block.resid_mix.to(dtype=x.dtype) + x_in = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + n = block.attn_norm(x_in) * block.ln_scale_factor + attn = block.attn + bsz, seqlen, dim = n.shape + # Q with LoRA + q = (attn.c_q(n) + lora.q_loras[slot](n)).reshape(bsz, seqlen, attn.num_heads, attn.head_dim) + # K with optional LoRA + k = attn.c_k(n) + if lora.k_loras is not None: + k = k + lora.k_loras[slot](n) + k = k.reshape(bsz, seqlen, attn.num_kv_heads, attn.head_dim) + # V with LoRA + v = (attn.c_v(n) + lora.v_loras[slot](n)).reshape(bsz, seqlen, attn.num_kv_heads, attn.head_dim) + if _V_RESIDUAL_ENABLED and v_residual is not None: + v_mix = torch.sigmoid(attn.v_mix).to(dtype=v.dtype)[None, None, :, None] + v = v_mix * v + (1.0 - v_mix) * v_residual + v_out = v + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = attn.rotary(seqlen, n.device, q.dtype) + q = apply_rotary_emb(q, cos, sin, attn.rope_dims) + k = apply_rotary_emb(k, cos, sin, attn.rope_dims) + q = q * attn.q_gain.to(dtype=q.dtype)[None, None, :, None] + if _HAS_FA3: + y = flash_attn_3_func(q, k, v, causal=True) + else: + rep = attn.num_heads // attn.num_kv_heads + if rep > 1: + k = k[:,:,:,None,:].expand(bsz,seqlen,attn.num_kv_heads,rep,attn.head_dim).reshape(bsz,seqlen,attn.num_heads,attn.head_dim) + v = v[:,:,:,None,:].expand(bsz,seqlen,attn.num_kv_heads,rep,attn.head_dim).reshape(bsz,seqlen,attn.num_heads,attn.head_dim) + y = F.scaled_dot_product_attention(q.transpose(1,2), k.transpose(1,2), v.transpose(1,2), is_causal=True).transpose(1,2) + if attn.use_xsa: + y = attn._xsa_efficient(y, v) + y = y.reshape(bsz, seqlen, dim) + attn_out = attn.proj(y) + if lora.o_loras is not None: + attn_out = attn_out + lora.o_loras[slot](y) + if block.parallel: + mlp_n = block.mlp_norm(x_in) * block.ln_scale_factor + mlp_out = block.mlp(mlp_n) + if lora.mlp_loras is not None: + mlp_out = mlp_out + lora.mlp_loras[slot](mlp_out) + x_out = x_in + block.attn_scale.to(dtype=x_in.dtype)[None, None, :] * attn_out + block.mlp_scale.to(dtype=x_in.dtype)[None, None, :] * mlp_out + else: + x_out = x_in + block.attn_scale.to(dtype=x_in.dtype)[None, None, :] * attn_out + mlp_n = block.mlp_norm(x_out) * block.ln_scale_factor + mlp_out = block.mlp(mlp_n) + if lora.mlp_loras is not None: + mlp_out = mlp_out + lora.mlp_loras[slot](mlp_out) + x_out = x_out + block.mlp_scale.to(dtype=x_out.dtype)[None, None, :] * mlp_out + return x_out, v_out +def classify_param(name): + if'tok_emb'in name or'lm_head'in name:return'embed' + if'.mlp.'in name:return'mlp' + if'.attn.'in name or'.proj.'in name and'.mlp.'not in name:return'attn' + return'other' +@torch.compile +def zeropower_via_newtonschulz5(G,steps=5,eps=1e-07): + coeffs=[(8.156554524902461,-22.48329292557795,15.878769915207462),(4.042929935166739,-2.808917465908714,0.5000178451051316),(3.8916678022926607,-2.772484153217685,0.5060648178503393),(3.285753657755655,-2.3681294933425376,0.46449024233003106),(2.3465413258596377,-1.7097828382687081,0.42323551169305323)] + X=G.bfloat16();X/=X.norm()+eps;transposed=G.size(0)>G.size(1) + if transposed:X=X.T + for a,b,c in coeffs[:steps]:A=X@X.T;B=b*A+c*A@A;X=a*X+B@X + return X.T if transposed else X +class Muon(torch.optim.Optimizer): + def __init__(self,params,lr,momentum,backend_steps,nesterov=True,weight_decay=.0,row_normalize=False):super().__init__(params,dict(lr=lr,momentum=momentum,backend_steps=backend_steps,nesterov=nesterov,weight_decay=weight_decay,row_normalize=row_normalize)) + @torch.no_grad() + def step(self,closure=None): + loss=None + if closure is not None: + with torch.enable_grad():loss=closure() + distributed=dist.is_available()and dist.is_initialized();world_size=dist.get_world_size()if distributed else 1;rank=dist.get_rank()if distributed else 0 + for group in self.param_groups: + params=group['params'] + if not params:continue + lr=group['lr'];momentum=group['momentum'];backend_steps=group['backend_steps'];nesterov=group['nesterov'];total_params=sum(int(p.numel())for p in params);updates_flat=torch.zeros(total_params,device=params[0].device,dtype=torch.bfloat16);curr=0 + for(i,p)in enumerate(params): + if i%world_size==rank and p.grad is not None: + g=p.grad;state=self.state[p] + if'momentum_buffer'not in state:state['momentum_buffer']=torch.zeros_like(g) + buf=state['momentum_buffer'];buf.mul_(momentum).add_(g) + if nesterov:g=g.add(buf,alpha=momentum) + if group.get('row_normalize',False):row_norms=g.float().norm(dim=-1,keepdim=True).clamp_min(1e-07);g=g/row_norms.to(g.dtype) + g=zeropower_via_newtonschulz5(g,steps=backend_steps);g*=max(1,g.size(0)/g.size(1))**.5;updates_flat[curr:curr+p.numel()]=g.reshape(-1) + curr+=p.numel() + if distributed:dist.all_reduce(updates_flat,op=dist.ReduceOp.SUM) + wd=group.get('weight_decay',.0);curr=0 + for p in params: + if wd>.0:p.data.mul_(1.-lr*wd) + g=updates_flat[curr:curr+p.numel()].view_as(p).to(dtype=p.dtype);p.add_(g,alpha=-lr);curr+=p.numel() + return loss +CONTROL_TENSOR_NAME_PATTERNS=tuple(pattern for pattern in os.environ.get('CONTROL_TENSOR_NAME_PATTERNS','attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,skip_gates,v_mix,loop_gate').split(',')if pattern) +class Optimizers: + def __init__(self,h,base_model): + block_named_params=list(base_model.blocks.named_parameters());matrix_params=[p for(name,p)in block_named_params if p.ndim==2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)];scalar_params=[p for(name,p)in block_named_params if p.ndim<2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)] + if base_model.skip_weights.numel()>0:scalar_params.append(base_model.skip_weights) + if base_model.skip_gates is not None and base_model.skip_gates.numel()>0:scalar_params.append(base_model.skip_gates) + if base_model.loop_gate is not None:scalar_params.append(base_model.loop_gate) + if getattr(base_model,'mos_k',1)>1: + for s in base_model.mos_scales:scalar_params.append(s) + scalar_params.append(base_model.mos_gate) + token_lr=h.tied_embed_lr if h.tie_embeddings else h.embed_lr;tok_params=[{'params':[base_model.tok_emb.weight],'lr':token_lr,'base_lr':token_lr}];self.optimizer_tok=torch.optim.AdamW(tok_params,betas=(h.beta1,h.beta2),eps=h.adam_eps,weight_decay=h.embed_wd,fused=True);self.optimizer_muon=Muon(matrix_params,lr=h.matrix_lr,momentum=h.muon_momentum,backend_steps=h.muon_backend_steps,weight_decay=h.muon_wd,row_normalize=h.muon_row_normalize) + for group in self.optimizer_muon.param_groups:group['base_lr']=h.matrix_lr + self.optimizer_scalar=torch.optim.AdamW([{'params':scalar_params,'lr':h.scalar_lr,'base_lr':h.scalar_lr}],betas=(h.beta1,h.beta2),eps=h.adam_eps,weight_decay=h.adam_wd,fused=True);self.optimizers=[self.optimizer_tok,self.optimizer_muon,self.optimizer_scalar] + if base_model.lm_head is not None:self.optimizer_head=torch.optim.Adam([{'params':[base_model.lm_head.weight],'lr':h.head_lr,'base_lr':h.head_lr}],betas=(h.beta1,h.beta2),eps=h.adam_eps,fused=True);self.optimizers.insert(1,self.optimizer_head) + else:self.optimizer_head=None + def __iter__(self):return iter(self.optimizers) + def zero_grad_all(self): + for opt in self.optimizers:opt.zero_grad(set_to_none=True) + def step(self): + for opt in self.optimizers:opt.step() + self.zero_grad_all() +def restore_fp32_params(model): + for module in model.modules(): + if isinstance(module,CastedLinear):module.float() + for(name,param)in model.named_parameters(): + if(param.ndim<2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS))and param.dtype!=torch.float32:param.data=param.data.float() +def collect_hessians(model,train_loader,h,device,n_calibration_batches=64): + hessians={};hooks=[] + def make_hook(name): + def hook_fn(module,inp,out): + x=inp[0].detach().float() + if x.ndim==3:x=x.reshape(-1,x.shape[-1]) + if name not in hessians:hessians[name]=torch.zeros(x.shape[1],x.shape[1],dtype=torch.float32,device=device) + hessians[name].addmm_(x.T,x) + return hook_fn + for(name,module)in model.named_modules(): + if isinstance(module,CastedLinear)and module.weight.numel()>65536: + cat=classify_param(name+'.weight') + if cat in('mlp','attn'):hooks.append(module.register_forward_hook(make_hook(name+'.weight'))) + if model.tie_embeddings: + hook_module=model.head_proj if model.head_proj is not None else model.final_norm + def make_output_hook(name): + def hook_fn(module,inp,out): + x=out.detach().float() + if x.ndim==3:x=x.reshape(-1,x.shape[-1]) + if name not in hessians:hessians[name]=torch.zeros(x.shape[1],x.shape[1],dtype=torch.float32,device=device) + hessians[name].addmm_(x.T,x) + return hook_fn + hooks.append(hook_module.register_forward_hook(make_output_hook('tok_emb.weight'))) + model.eval() + with torch.no_grad(): + for _ in range(n_calibration_batches):x,_=train_loader.next_batch(h.train_batch_tokens,h.grad_accum_steps);model.forward_logits(x) + for hook in hooks:hook.remove() + for name in hessians:hessians[name]=hessians[name].cpu()/n_calibration_batches + return hessians +def _gptq_core(W_orig,H_prepared,Hinv,perm,invperm,s,clip_range,block_size): + rows,cols=W_orig.shape;sf=s.float();Q=torch.zeros(rows,cols,dtype=torch.int8);W_work=W_orig[:,perm].clone() + for i1 in range(0,cols,block_size): + i2=min(i1+block_size,cols);W_block=W_work[:,i1:i2].clone();Hinv_block=Hinv[i1:i2,i1:i2];Err=torch.zeros(rows,i2-i1) + for j in range(i2-i1):w_col=W_block[:,j];d=Hinv_block[j,j];q_col=torch.clamp(torch.round(w_col/sf),-clip_range,clip_range);Q[:,i1+j]=q_col.to(torch.int8);err=(w_col-q_col.float()*sf)/d;Err[:,j]=err;W_block[:,j:]-=err.unsqueeze(1)*Hinv_block[j,j:].unsqueeze(0) + if i2=1.:s_cand=(W_orig.abs().max(dim=1).values/clip_range).clamp_min(1e-10).to(torch.float16) + else:s_cand=(torch.quantile(W_orig.abs(),pct,dim=1)/clip_range).clamp_min(1e-10).to(torch.float16) + sf_c=s_cand.float();Q_c=torch.zeros(rows,cols,dtype=torch.int8);W_work_c=W_perm.clone() + for i1 in range(0,cols,block_size): + i2=min(i1+block_size,cols);W_block=W_work_c[:,i1:i2].clone();Hinv_block=Hinv[i1:i2,i1:i2];Err=torch.zeros(rows,i2-i1) + for j in range(i2-i1):w_col=W_block[:,j];d=Hinv_block[j,j];q_col=torch.clamp(torch.round(w_col/sf_c),-clip_range,clip_range);Q_c[:,i1+j]=q_col.to(torch.int8);err=(w_col-q_col.float()*sf_c)/d;Err[:,j]=err;W_block[:,j:]-=err.unsqueeze(1)*Hinv_block[j,j:].unsqueeze(0) + if i20.: + diagH=torch.diag(H).clamp_min(1e-8);col_imp=diagH/diagH.mean();row_imp=(W_orig.abs()*col_imp.unsqueeze(0)).mean(dim=1);row_imp=row_imp/row_imp.mean();adj=1.+hessian_clip_lambda*(row_imp-1.);s=(clip_sigmas*row_std*adj/clip_range).clamp_min(1e-10).to(torch.float16) + else:s=(clip_sigmas*row_std/clip_range).clamp_min(1e-10).to(torch.float16) + sf=s.float();Q=torch.zeros(rows,cols,dtype=torch.int8);W_work=W_perm.clone() + for i1 in range(0,cols,block_size): + i2=min(i1+block_size,cols);W_block=W_work[:,i1:i2].clone();Hinv_block=Hinv[i1:i2,i1:i2];Err=torch.zeros(rows,i2-i1) + for j in range(i2-i1):w_col=W_block[:,j];d=Hinv_block[j,j];q_col=torch.clamp(torch.round(w_col/sf),-clip_range,clip_range);Q[:,i1+j]=q_col.to(torch.int8);err=(w_col-q_col.float()*sf)/d;Err[:,j]=err;W_block[:,j:]-=err.unsqueeze(1)*Hinv_block[j,j:].unsqueeze(0) + if i20:out[name]=(q.float()*s.float().view(q.shape[0],*[1]*(q.ndim-1))).to(orig_dtype) + else:out[name]=(q.float()*float(s.item())).to(orig_dtype) + return out +def scale_tune_post_gptq(quant_result, quant_meta, template_sd, h, device): + """Optimize per-row quantization scales using actual CE loss on calibration data. + + After GPTQ produces initial scales, this function makes the scales learnable + and fine-tunes them via backprop through the dequantized model. The integer + quantized weights (Q_int) stay frozen; only the float16 scale tensors are updated. + This minimizes the actual CE loss instead of the per-layer MSE heuristic that GPTQ uses. + """ + log(f"Scale tuning: {h.scale_tuning_steps} steps, lr={h.scale_tuning_lr}, batches={h.scale_tuning_batches}") + t0 = time.perf_counter() + + # Build a shell model for functional_call (weights will be overridden) + tune_model = GPT(h).to(device).bfloat16() + restore_fp32_params(tune_model) + tune_model.eval() + + # Collect the frozen Q_int tensors and learnable scale params + scale_params = [] + q_int_map = {} # name -> frozen Q_int tensor on device + + for name, info in quant_meta.items(): + if 'gptq' not in info: + continue + q_key = name + '.q' + s_key = name + '.scale' + q_int_map[name] = quant_result[q_key].to(device) # frozen int8 + # Make scale a learnable parameter (float32 for optimizer stability) + scale_params.append((name, quant_result[s_key].float().to(device).requires_grad_(True))) + + # Build optimizer over scale params only + optim_params = [sp for _, sp in scale_params] + optimizer = torch.optim.Adam(optim_params, lr=h.scale_tuning_lr) + + # Load calibration data + calib_loader = ShuffledSequenceLoader(h, device) + + # Pre-collect calibration batches (small number, reuse across steps) + calib_data = [] + for _ in range(h.scale_tuning_batches): + x, y = calib_loader.next_batch(h.train_batch_tokens, h.grad_accum_steps) + calib_data.append((x, y)) + + best_loss = float('inf') + best_scales = {name: sp.detach().clone() for name, sp in scale_params} + + for step in range(h.scale_tuning_steps): + optimizer.zero_grad() + total_loss = 0.0 + + # Build dequantized state dict (differentiable through scales) + deq_sd = {} + for pname, orig in template_sd.items(): + info = quant_meta.get(pname) + if info is None: + continue + if 'passthrough' in info: + t = quant_result[pname] + if t.dtype == torch.float16 and orig.dtype in (torch.float32, torch.bfloat16): + t = t.to(orig.dtype) + deq_sd[pname] = t.to(device) + continue + q_int = q_int_map[pname] + scale_tensor = None + for sname, sp in scale_params: + if sname == pname: + scale_tensor = sp + break + if scale_tensor is not None: + deq_sd[pname] = (q_int.float() * scale_tensor.view(q_int.shape[0], 1)).to(orig.dtype) + else: + deq_sd[pname] = (q_int.float() * quant_result[pname + '.scale'].float().to(device).view(q_int.shape[0], 1)).to(orig.dtype) + + # Forward pass via functional_call (keeps gradients flowing through scales) + for batch_idx in range(len(calib_data)): + inp, tgt = calib_data[batch_idx] + with torch.autocast(device_type='cuda', dtype=torch.bfloat16, enabled=True): + loss = torch.func.functional_call(tune_model, deq_sd, (inp, tgt)) + total_loss += loss.item() + (loss / len(calib_data)).backward() + + avg_loss = total_loss / len(calib_data) + optimizer.step() + + # Clamp scales to stay positive + with torch.no_grad(): + for _, sp in scale_params: + sp.clamp_min_(1e-10) + + if avg_loss < best_loss: + best_loss = avg_loss + best_scales = {name: sp.detach().clone() for name, sp in scale_params} + + if step % 5 == 0 or step == h.scale_tuning_steps - 1: + log(f" scale_tune step {step}: CE loss = {avg_loss:.6f}") + + # Write optimized scales back into quant_result + for name, sp in best_scales.items(): + quant_result[name + '.scale'] = sp.cpu().to(torch.float16) + + elapsed = time.perf_counter() - t0 + log(f"Scale tuning done in {elapsed:.1f}s, best CE = {best_loss:.6f}") + return quant_result + +_BSHF_MAGIC=b'BSHF' +def _byte_shuffle(data,stride=2): + if stride<=1 or len(data)target_bytes: + over=len(quant_blob)-target_bytes;log(f"prune:over by {over} bytes, selective pruning") + candidates=[] + for name,info in quant_meta.items(): + if'gptq'not in info:continue + q=quant_result[name+'.q'];s=quant_result[name+'.scale'] + sf=s.float().view(-1) if s.ndim==1 else s.float()[:,0] + for r in range(q.shape[0]): + for c in range(q.shape[1]): + v=int(q[r,c]) + if 01:nll=F.nll_loss(logits.reshape(-1,logits.size(-1)).float(),y_batch.reshape(-1),reduction='none').reshape(bsz,seq_len) + else:nll=F.cross_entropy(logits.reshape(-1,logits.size(-1)).float(),y_batch.reshape(-1),reduction='none').reshape(bsz,seq_len) + for(i,ws)in enumerate(batch_ws):wlen=wlens[i];s=0 if ws==0 else context_size;scored_nll=nll[i,s:wlen].to(torch.float64);loss_sum+=scored_nll.sum();token_count+=float(wlen-s);tgt=y_batch[i,s:wlen];prev=x_batch[i,s:wlen];tb=val_data.base_bytes_lut[tgt].to(torch.float64);tb+=(val_data.has_leading_space_lut[tgt]&~val_data.is_boundary_token_lut[prev]).to(torch.float64);byte_count+=tb.sum() + if dist.is_available()and dist.is_initialized():dist.all_reduce(loss_sum,op=dist.ReduceOp.SUM);dist.all_reduce(token_count,op=dist.ReduceOp.SUM);dist.all_reduce(byte_count,op=dist.ReduceOp.SUM) + base_model.train();return _loss_bpb(loss_sum,token_count,byte_count) +# ===== Batched LoRA TTT (ported from 1.06335 submission) ===== +class BatchedLinearLoRA(nn.Module): + """LoRA adapter with batched parameters [bsz, rank, dim] for parallel doc processing.""" + _ALPHA = float(os.environ.get("TTT_LORA_ALPHA", "144")) + _WARM_START_A = bool(int(os.environ.get("TTT_WARM_START_A", "1"))) + def __init__(self, bsz, in_features, out_features, rank): + super().__init__() + self._bound = 1.0 / math.sqrt(in_features) + self._scale = self._ALPHA / rank + self.A = nn.Parameter(torch.empty(bsz, rank, in_features).uniform_(-self._bound, self._bound)) + self.B = nn.Parameter(torch.zeros(bsz, out_features, rank)) + def reset(self): + with torch.no_grad(): + if not self._WARM_START_A: + self.A.uniform_(-self._bound, self._bound) + self.B.zero_() + def forward(self, x): + return ((x @ self.A.transpose(1, 2)) @ self.B.transpose(1, 2)) * self._scale + +class BatchedTTTLoRA(nn.Module): + """Container for all LoRA adapters needed for TTT on our model.""" + def __init__(self, bsz, model, rank, k_lora=True, mlp_lora=True, o_lora=True): + super().__init__() + self.bsz = bsz + dim = model.blocks[0].attn.c_q.in_features + vocab = model.tok_emb.num_embeddings + embed_dim = model.tok_emb.embedding_dim + if getattr(model, "looping_active", False): + num_slots = len(model.encoder_indices) + len(model.decoder_indices) + else: + num_slots = len(model.blocks) + kv_dim = model.blocks[0].attn.num_kv_heads * model.blocks[0].attn.head_dim + self.lm_head_lora = BatchedLinearLoRA(bsz, embed_dim, vocab, rank) + self.q_loras = nn.ModuleList([BatchedLinearLoRA(bsz, dim, dim, rank) for _ in range(num_slots)]) + self.v_loras = nn.ModuleList([BatchedLinearLoRA(bsz, dim, kv_dim, rank) for _ in range(num_slots)]) + self.k_loras = nn.ModuleList([BatchedLinearLoRA(bsz, dim, kv_dim, rank) for _ in range(num_slots)]) if k_lora else None + self.mlp_loras = nn.ModuleList([BatchedLinearLoRA(bsz, dim, dim, rank) for _ in range(num_slots)]) if mlp_lora else None + self.o_loras = nn.ModuleList([BatchedLinearLoRA(bsz, dim, dim, rank) for _ in range(num_slots)]) if o_lora else None + def reset(self): + with torch.no_grad(): + self.lm_head_lora.reset() + for loras in [self.q_loras, self.v_loras, self.k_loras, self.mlp_loras, self.o_loras]: + if loras is not None: + for lora in loras: + lora.reset() + +BOS_ID = None + +def _find_docs(all_tokens): + global BOS_ID + if BOS_ID is None: + BOS_ID = 1 + bos_positions = (all_tokens == BOS_ID).nonzero(as_tuple=True)[0].numpy() + docs = [] + for i in range(len(bos_positions)): + start = int(bos_positions[i]) + end = int(bos_positions[i + 1]) if i + 1 < len(bos_positions) else all_tokens.numel() + if i + 1 < len(bos_positions): + end += 1 + assert end - start >= 2 + docs.append((start, end - start)) + return docs + +def _select_ttt_doc_entries(docs, h): + doc_entries = list(enumerate(docs)) + if h.val_doc_fraction < 1.0: + sample_n = max(1, int(round(len(docs) * h.val_doc_fraction))) + sampled_indices = sorted(random.Random(h.seed).sample(range(len(docs)), sample_n)) + return [(i, docs[i]) for i in sampled_indices] + return doc_entries + +def _build_ttt_global_batches(doc_entries, h, ascending=False): + batch_size = h.ttt_batch_size + global_doc_entries = sorted(doc_entries, key=lambda x: x[1][1]) + global_batches = [global_doc_entries[i : i + batch_size] for i in range(0, len(global_doc_entries), batch_size)] + indexed = list(enumerate(global_batches)) + if not ascending: + indexed.sort(key=lambda ib: -max(dl for _, (_, dl) in ib[1])) + return indexed + +def _init_batch_counter(path): + with open(path, "wb") as f: + f.write((0).to_bytes(4, "little")) + +def _claim_next_batch(counter_path, queue_len): + try: + with open(counter_path, "r+b") as f: + fcntl.flock(f, fcntl.LOCK_EX) + idx = int.from_bytes(f.read(4), "little") + f.seek(0) + f.write((idx + 1).to_bytes(4, "little")) + f.flush() + except FileNotFoundError: + return queue_len + return idx + +def _compute_chunk_window(ci, pred_len, num_chunks, chunk_size, eval_seq_len): + chunk_end = pred_len if ci == num_chunks - 1 else (ci + 1) * chunk_size + win_start = max(0, chunk_end - eval_seq_len) + win_len = chunk_end - win_start + chunk_start = ci * chunk_size + chunk_offset = chunk_start - win_start + chunk_len = chunk_end - chunk_start + return win_start, win_len, chunk_offset, chunk_len + +def _accumulate_bpb(ptl, x, y, chunk_offsets, chunk_lens, pos_idx, + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + loss_sum, byte_sum, token_count): + pos = pos_idx[: x.size(1)].unsqueeze(0) + mask = ((chunk_lens.unsqueeze(1) > 0) + & (pos >= chunk_offsets.unsqueeze(1)) + & (pos < (chunk_offsets + chunk_lens).unsqueeze(1))) + mask_f64 = mask.to(torch.float64) + tok_bytes = base_bytes_lut[y].to(torch.float64) + tok_bytes += (has_leading_space_lut[y] & ~is_boundary_token_lut[x]).to(torch.float64) + loss_sum += (ptl.to(torch.float64) * mask_f64).sum() + byte_sum += (tok_bytes * mask_f64).sum() + token_count += chunk_lens.to(torch.float64).sum() + +def _loss_bpb_from_sums(loss_sum, token_count, byte_sum): + val_loss = (loss_sum / token_count).item() + val_bpb = val_loss / math.log(2.0) * (token_count.item() / byte_sum.item()) + return val_loss, val_bpb + +def _init_int64_counter(path): + with open(path, "wb") as f: + f.write((0).to_bytes(8, "little", signed=True)) + +def _add_to_counter(path, delta): + try: + with open(path, "r+b") as f: + fcntl.flock(f, fcntl.LOCK_EX) + cur = int.from_bytes(f.read(8), "little", signed=True) + cur += int(delta) + f.seek(0) + f.write(int(cur).to_bytes(8, "little", signed=True)) + f.flush() + return cur + except FileNotFoundError: + return int(delta) + +def train_val_ttt_global_sgd_distributed(h, device, val_data, base_model, val_tokens, batch_seqs=None): + """Global SGD TTT: train base model weights on scored prefix docs.""" + global BOS_ID + if BOS_ID is None: + BOS_ID = 1 + base_model.eval() + seq_len = h.eval_seq_len + total_tokens = val_tokens.numel() - 1 + ttt_chunk = h.global_ttt_chunk_tokens + batch_seqs = h.global_ttt_batch_seqs if batch_seqs is None else batch_seqs + num_chunks = (total_tokens + ttt_chunk - 1) // ttt_chunk + ttt_params = [p for p in base_model.parameters()] + for p in ttt_params: + p.requires_grad_(True) + optimizer = torch.optim.SGD(ttt_params, lr=h.global_ttt_lr, momentum=h.global_ttt_momentum) + t_start = time.perf_counter() + for ci in range(num_chunks): + chunk_start = ci * ttt_chunk + chunk_end = min((ci + 1) * ttt_chunk, total_tokens) + is_last_chunk = ci == num_chunks - 1 + if is_last_chunk or h.global_ttt_epochs <= 0: + continue + base_model.train() + chunk_seqs = (chunk_end - chunk_start) // seq_len + if chunk_seqs <= 0: + continue + warmup_chunks = max(0, min(h.global_ttt_warmup_chunks, num_chunks - 1)) + if warmup_chunks > 0 and ci < warmup_chunks: + warmup_denom = max(warmup_chunks - 1, 1) + warmup_t = ci / warmup_denom + lr_now = h.global_ttt_warmup_start_lr + (h.global_ttt_lr - h.global_ttt_warmup_start_lr) * warmup_t + else: + decay_steps = max(num_chunks - 1 - warmup_chunks, 1) + decay_ci = max(ci - warmup_chunks, 0) + lr_now = h.global_ttt_lr * 0.5 * (1.0 + math.cos(math.pi * decay_ci / decay_steps)) + for pg in optimizer.param_groups: + pg["lr"] = lr_now + my_seq_s = chunk_seqs * h.rank // h.world_size + my_seq_e = chunk_seqs * (h.rank + 1) // h.world_size + my_chunk_seqs = my_seq_e - my_seq_s + for _ in range(h.global_ttt_epochs): + for bs in range(0, my_chunk_seqs, batch_seqs): + be = min(bs + batch_seqs, my_chunk_seqs) + actual_bs = my_seq_s + bs + start_tok = chunk_start + actual_bs * seq_len + end_tok = chunk_start + (my_seq_s + be) * seq_len + 1 + if end_tok > val_tokens.numel(): + continue + local = val_tokens[start_tok:end_tok].to(device=device, dtype=torch.int64) + x_flat = local[:-1] + y_flat = local[1:] + optimizer.zero_grad(set_to_none=True) + with torch.enable_grad(): + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + x = x_flat.reshape(-1, seq_len) + y = y_flat.reshape(-1, seq_len) + loss = base_model(x, y) + loss.backward() + if dist.is_available() and dist.is_initialized(): + for p in ttt_params: + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.SUM) + p.grad.mul_(1.0 / h.world_size) + if h.global_ttt_grad_clip > 0: + torch.nn.utils.clip_grad_norm_(ttt_params, h.global_ttt_grad_clip) + optimizer.step() + base_model.eval() + if h.rank == 0: + elapsed = time.perf_counter() - t_start + log(f"tttg: c{ci+1}/{num_chunks} lr:{lr_now:.6f} t:{elapsed:.1f}s") + for p in base_model.parameters(): + p.requires_grad_(True) + base_model.eval() + +def eval_val_ttt_phased(h, base_model, device, val_data, forward_ttt_train, deadline=None): + """Phased LoRA TTT: document-level batched scoring with per-doc LoRA adaptation.""" + global BOS_ID + if BOS_ID is None: + BOS_ID = 1 + base_model.eval() + for p in base_model.parameters(): + p.requires_grad_(False) + all_tokens = val_data.val_tokens + all_tokens_idx = all_tokens.to(torch.int32) + docs = _find_docs(all_tokens) + doc_entries = _select_ttt_doc_entries(docs, h) + prefix_doc_limit = max(0, min(len(doc_entries), int(h.phased_ttt_prefix_docs))) + num_phases = max(1, int(h.phased_ttt_num_phases)) + phase_boundaries = [] + for pi in range(num_phases): + boundary = prefix_doc_limit * (pi + 1) // num_phases + phase_boundaries.append(boundary) + current_phase = 0 + current_phase_boundary = phase_boundaries[0] + log(f"ttt_phased: total_docs:{len(doc_entries)} prefix_docs:{prefix_doc_limit} " + f"suffix_docs:{len(doc_entries) - prefix_doc_limit} " + f"num_phases:{num_phases} boundaries:{phase_boundaries}") + chunk_size, eval_seq_len = h.ttt_chunk_size, h.ttt_eval_seq_len + eval_batch_set = None + if h.ttt_eval_batches: + eval_batch_set = set(int(x) for x in h.ttt_eval_batches.split(",") if x.strip()) + use_ascending = eval_batch_set is not None + global_batches_sorted = _build_ttt_global_batches(doc_entries, h, ascending=use_ascending) + queue_len = len(global_batches_sorted) + counter_path = f"/tmp/ttt_counter_{h.run_id}" + prefix_counter_path = f"/tmp/ttt_prefix_counter_{h.run_id}" + pause_flag_path = f"/tmp/ttt_pause_flag_{h.run_id}" + if h.rank == 0: + _init_batch_counter(counter_path) + _init_int64_counter(prefix_counter_path) + try: + os.remove(pause_flag_path) + except FileNotFoundError: + pass + if dist.is_available() and dist.is_initialized(): + path_list = [counter_path, prefix_counter_path, pause_flag_path] + dist.broadcast_object_list(path_list, src=0) + counter_path, prefix_counter_path, pause_flag_path = path_list + dist.barrier() + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + byte_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + t_start = time.perf_counter() + reusable_lora = BatchedTTTLoRA( + h.ttt_batch_size, base_model, h.ttt_lora_rank, + k_lora=h.ttt_k_lora, mlp_lora=h.ttt_mlp_lora, o_lora=h.ttt_o_lora, + ).to(device) + def _build_opt(lora): + if h.ttt_optimizer == "sgd": + return torch.optim.SGD(lora.parameters(), lr=h.ttt_lora_lr, + momentum=h.ttt_beta1, weight_decay=h.ttt_weight_decay) + return torch.optim.AdamW(lora.parameters(), lr=h.ttt_lora_lr, + betas=(h.ttt_beta1, h.ttt_beta2), + eps=1e-10, weight_decay=h.ttt_weight_decay, fused=True) + reusable_opt = _build_opt(reusable_lora) + local_scored_docs = [] + global_ttt_done = prefix_doc_limit == 0 + try: + while True: + if deadline is not None and time.perf_counter()>=deadline: + log(f"ttt:eval_time_limit reached, scored {int(token_count.item())} tokens");break + queue_idx = _claim_next_batch(counter_path, queue_len) + if queue_idx >= queue_len: + break + orig_batch_idx, batch_entries = global_batches_sorted[queue_idx] + batch = [doc for _, doc in batch_entries] + bsz = len(batch) + prev_loss = loss_sum.item() + prev_bytes = byte_sum.item() + prev_tokens = token_count.item() + if bsz == reusable_lora.bsz: + reusable_lora.reset() + for s in reusable_opt.state.values(): + for k, v in s.items(): + if isinstance(v, torch.Tensor): + v.zero_() + elif k == "step": + s[k] = 0 + cur_lora = reusable_lora + cur_opt = reusable_opt + else: + cur_lora = BatchedTTTLoRA( + bsz, base_model, h.ttt_lora_rank, + k_lora=h.ttt_k_lora, mlp_lora=h.ttt_mlp_lora, o_lora=h.ttt_o_lora, + ).to(device) + cur_opt = _build_opt(cur_lora) + pred_lens = [doc_len - 1 for _, doc_len in batch] + num_chunks = [(pl + chunk_size - 1) // chunk_size for pl in pred_lens] + max_nc = max(num_chunks) + num_chunks_t = torch.tensor(num_chunks, dtype=torch.int64, device=device) + for ci in range(max_nc): + active = [ci < nc for nc in num_chunks] + needs_train = any(ci < nc - 1 for nc in num_chunks) + tok_starts = torch.zeros(bsz, dtype=torch.int64) + tok_wls = torch.zeros(bsz, dtype=torch.int64) + chunk_offsets_cpu = torch.zeros(bsz, dtype=torch.int64) + chunk_lens_cpu = torch.zeros(bsz, dtype=torch.int64) + for b in range(bsz): + if not active[b]: + continue + doc_start, doc_len = batch[b] + win_start, win_len, chunk_offset, chunk_len = _compute_chunk_window( + ci, pred_lens[b], num_chunks[b], chunk_size, eval_seq_len) + tok_starts[b] = doc_start + win_start + tok_wls[b] = win_len + chunk_offsets_cpu[b] = chunk_offset + chunk_lens_cpu[b] = chunk_len + _, context_size, chunk_offset, _ = _compute_chunk_window( + ci, (ci + 1) * chunk_size, ci + 1, chunk_size, eval_seq_len) + col_idx = torch.arange(context_size + 1) + idx = tok_starts.unsqueeze(1) + col_idx.unsqueeze(0) + idx.clamp_(max=all_tokens.numel() - 1) + gathered_gpu = all_tokens_idx[idx].to(device=device, dtype=torch.int64, non_blocking=True) + valid = (col_idx[:context_size].unsqueeze(0) < tok_wls.unsqueeze(1)).to(device, non_blocking=True) + chunk_offsets = chunk_offsets_cpu.to(device, non_blocking=True) + chunk_lens = chunk_lens_cpu.to(device, non_blocking=True) + x = torch.where(valid, gathered_gpu[:, :context_size], 0) + y = torch.where(valid, gathered_gpu[:, 1 : context_size + 1], 0) + ctx_pos = torch.arange(context_size, device=device, dtype=torch.int64) + if needs_train: + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + per_tok_loss = forward_ttt_train(x, y, lora=cur_lora) + else: + with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16): + per_tok_loss = forward_ttt_train(x, y, lora=cur_lora) + with torch.no_grad(): + _accumulate_bpb(per_tok_loss, x, y, chunk_offsets, chunk_lens, ctx_pos, + val_data.base_bytes_lut, val_data.has_leading_space_lut, + val_data.is_boundary_token_lut, + loss_sum, byte_sum, token_count) + if needs_train: + activate_chunk_mask = (num_chunks_t - 1 > ci).float() + for gi in range(h.ttt_grad_steps): + if gi > 0: + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + per_tok_loss = forward_ttt_train(x, y, lora=cur_lora) + per_doc = per_tok_loss[:, chunk_offset : chunk_offset + chunk_size].mean(dim=-1) + cur_opt.zero_grad(set_to_none=True) + (per_doc * activate_chunk_mask).sum().backward() + cur_opt.step() + else: + del per_tok_loss + batch_num = orig_batch_idx + 1 + doc_lens = [dl for _, dl in batch] + should_report = batch_num in eval_batch_set if eval_batch_set is not None else True + if should_report: + cur_tokens = token_count.item() + cur_loss_val = loss_sum.item() + cur_bytes_val = byte_sum.item() + dt = cur_tokens - prev_tokens + db = cur_bytes_val - prev_bytes + if dt > 0 and db > 0: + b_loss = (cur_loss_val - prev_loss) / dt + b_bpb = b_loss / math.log(2.0) * (dt / db) + else: + b_loss = b_bpb = 0.0 + r_loss = cur_loss_val / max(cur_tokens, 1) + r_bpb = r_loss / math.log(2.0) * (cur_tokens / max(cur_bytes_val, 1)) + elapsed = time.perf_counter() - t_start + log(f"ttp: b{batch_num}/{queue_len} bl:{b_loss:.4f} bb:{b_bpb:.4f} " + f"rl:{r_loss:.4f} rb:{r_bpb:.4f} dl:{min(doc_lens)}-{max(doc_lens)} " + f"gd:{int(global_ttt_done)}") + if not global_ttt_done: + local_scored_docs.extend( + (orig_batch_idx, pos, doc_start, doc_len) + for pos, (doc_start, doc_len) in enumerate(batch)) + prefix_done = _add_to_counter(prefix_counter_path, len(batch_entries)) + if prefix_done >= current_phase_boundary: + try: + with open(pause_flag_path, "x"): + pass + except FileExistsError: + pass + should_pause = os.path.exists(pause_flag_path) + if should_pause: + if dist.is_available() and dist.is_initialized(): + dist.barrier() + gathered_scored_docs = [None] * h.world_size + if dist.is_available() and dist.is_initialized(): + dist.all_gather_object(gathered_scored_docs, local_scored_docs) + else: + gathered_scored_docs = [local_scored_docs] + scored_docs_for_global = [] + for rank_docs in gathered_scored_docs: + if rank_docs: + scored_docs_for_global.extend(rank_docs) + scored_docs_for_global.sort(key=lambda x: (x[0], x[1])) + scored_docs_for_global = scored_docs_for_global[:current_phase_boundary] + scored_token_chunks = [ + val_data.val_tokens[doc_start : doc_start + doc_len] + for _, _, doc_start, doc_len in scored_docs_for_global] + if scored_token_chunks: + global_ttt_tokens = torch.cat(scored_token_chunks) + else: + global_ttt_tokens = val_data.val_tokens[:0] + if h.rank == 0: + log(f"ttpp: phase:{current_phase + 1}/{num_phases} " + f"gd:{len(scored_docs_for_global)} t:{time.perf_counter() - t_start:.1f}s") + train_val_ttt_global_sgd_distributed(h, device, val_data, base_model, global_ttt_tokens) + for p in base_model.parameters(): + p.requires_grad_(False) + reusable_lora = BatchedTTTLoRA( + h.ttt_batch_size, base_model, h.ttt_lora_rank, + k_lora=h.ttt_k_lora, mlp_lora=h.ttt_mlp_lora, o_lora=h.ttt_o_lora, + ).to(device) + reusable_opt = _build_opt(reusable_lora) + current_phase += 1 + if current_phase >= num_phases: + global_ttt_done = True + else: + current_phase_boundary = phase_boundaries[current_phase] + if h.rank == 0: + try: + os.remove(pause_flag_path) + except FileNotFoundError: + pass + if dist.is_available() and dist.is_initialized(): + dist.barrier() + if h.rank == 0: + log(f"ttpr: phase:{current_phase}/{num_phases} t:{time.perf_counter() - t_start:.1f}s") + del cur_lora, cur_opt + finally: + pass + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + for p in base_model.parameters(): + p.requires_grad_(True) + base_model.train() + return _loss_bpb_from_sums(loss_sum, token_count, byte_sum) + +def timed_eval(label,fn,*args,**kwargs):torch.cuda.synchronize();t0=time.perf_counter();val_loss,val_bpb=fn(*args,**kwargs);torch.cuda.synchronize();elapsed_ms=1e3*(time.perf_counter()-t0);log(f"{label} val_loss:{val_loss:.8f} val_bpb:{val_bpb:.8f} eval_time:{elapsed_ms:.0f}ms");return val_loss,val_bpb +def train_model(h,device,val_data): + compile_mode=os.environ.get('COMPILE_MODE','default');base_model=GPT(h).to(device).bfloat16();restore_fp32_params(base_model) + # Build byte-count LUT for byte-weighted CE loss + _byte_lut=None + if h.byte_weighted_ce: + sp=spm.SentencePieceProcessor(model_file=h.tokenizer_path);base_bytes_lut,_,_=build_sentencepiece_luts(sp,h.vocab_size,device);_byte_lut=base_bytes_lut.clamp(min=1,max=4).float();log(f"byte_weighted_ce:enabled, lut shape={_byte_lut.shape}, mean_weight={_byte_lut.mean().item():.3f}") + compiled_model=torch.compile(base_model,mode=compile_mode if compile_mode!='default'else None,dynamic=False,fullgraph=True) + if h.distributed:model=DDP(compiled_model,device_ids=[h.local_rank],broadcast_buffers=False) + else:model=compiled_model + log(f"model_params:{sum(p.numel()for p in base_model.parameters())}");optimizers=Optimizers(h,base_model);train_loader=ShuffledSequenceLoader(h,device);max_wallclock_ms=1e3*h.max_wallclock_seconds if h.max_wallclock_seconds>0 else None + if max_wallclock_ms is not None:max_wallclock_ms-=h.gptq_reserve_seconds*1e3;log(f"gptq:reserving {h.gptq_reserve_seconds:.0f}s, effective={max_wallclock_ms:.0f}ms") + def training_frac(step,elapsed_ms): + if max_wallclock_ms is None:return step/max(h.iterations,1) + return elapsed_ms/max(max_wallclock_ms,1e-09) + lr_schedule=os.environ.get('LR_SCHEDULE','linear') + def lr_mul(frac): + if h.warmdown_frac<=0:return 1. + if frac>=1.-h.warmdown_frac: + raw=(1.-frac)/h.warmdown_frac + if lr_schedule=='sqrt':return max(math.sqrt(raw),h.min_lr) + return max(raw,h.min_lr) + return 1. + _cur_batch_tokens=[h.train_batch_tokens];_kd_active=[h.kd_enabled];_approx_training_ms=[0.] + if h.kd_enabled:log(f"kd:enabled alpha={h.kd_alpha} temp={h.kd_temperature} top_k={h.kd_top_k} dir={h.kd_logits_dir} warmup_frac={h.kd_warmup_frac}") + def step_fn(step,lr_scale): + optimizers.zero_grad_all();train_loss=torch.zeros((),device=device) + batch_tokens=_cur_batch_tokens[0] + for micro_step in range(h.grad_accum_steps): + if h.distributed:model.require_backward_grad_sync=micro_step==h.grad_accum_steps-1 + batch_data=train_loader.next_batch(batch_tokens,h.grad_accum_steps) + x,y=batch_data[0],batch_data[1] + bw=_byte_lut[y] if _byte_lut is not None else None + if _kd_active[0] and len(batch_data)==4: + t_idx,t_val=batch_data[2],batch_data[3] + kd_alpha_eff=h.kd_alpha + if h.kd_warmup_frac>0: + approx_ms=_approx_training_ms[0] + frac=training_frac(step,approx_ms) + if frac0 else 1.;muon_momentum=(1-frac)*h.muon_momentum_warmup_start+frac*h.muon_momentum + for group in optimizers.optimizer_muon.param_groups:group['momentum']=muon_momentum + for opt in optimizers: + for group in opt.param_groups:group['lr']=group['base_lr']*lr_scale + if h.grad_clip_norm>0:torch.nn.utils.clip_grad_norm_(base_model.parameters(),h.grad_clip_norm) + optimizers.step();return train_loss + if h.warmup_steps>0: + initial_model_state={name:tensor.detach().cpu().clone()for(name,tensor)in base_model.state_dict().items()};initial_optimizer_states=[copy.deepcopy(opt.state_dict())for opt in optimizers];model.train() + for warmup_step in range(h.warmup_steps): + step_fn(warmup_step,1.) + if warmup_step<=5 or(warmup_step+1)%10==0 or warmup_step+1==h.warmup_steps:log(f"warmup_step: {warmup_step+1}/{h.warmup_steps}") + + if len(base_model.encoder_indices)!=base_model.num_encoder_layers: + base_model.looping_active=True;log(f"loop_warmup:on") + for warmup_step in range(h.warmup_steps): + step_fn(warmup_step,1.) + if warmup_step<=5 or(warmup_step+1)%10==0 or warmup_step+1==h.warmup_steps:log(f"loop_warmup_step: {warmup_step+1}/{h.warmup_steps}") + base_model.looping_active=False + base_model.load_state_dict(initial_model_state,strict=True) + for(opt,state)in zip(optimizers,initial_optimizer_states,strict=True):opt.load_state_dict(state) + optimizers.zero_grad_all() + if h.distributed:model.require_backward_grad_sync=True + train_loader=ShuffledSequenceLoader(h,device) + ema_state={name:t.detach().float().clone()for(name,t)in base_model.state_dict().items()};ema_decay=h.ema_decay;swa_state=None;swa_count=0;training_time_ms=.0;stop_after_step=None;torch.cuda.synchronize();t0=time.perf_counter();step=0 + while True: + last_step=step==h.iterations or stop_after_step is not None and step>=stop_after_step;should_validate=last_step or h.val_loss_every>0 and step%h.val_loss_every==0 + if should_validate:torch.cuda.synchronize();training_time_ms+=1e3*(time.perf_counter()-t0);val_loss,val_bpb=eval_val(h,device,val_data,model);log(f"{step}/{h.iterations} val_loss: {val_loss:.4f} val_bpb: {val_bpb:.4f}");torch.cuda.synchronize();t0=time.perf_counter() + if last_step: + if stop_after_step is not None and step=h.enable_looping_at:base_model.looping_active=True;log(f"layer_loop:enabled step:{step} frac:{frac:.3f} encoder:{base_model.encoder_indices} decoder:{base_model.decoder_indices}") + if h.sparsity_start_frac>0 and frac>=h.sparsity_start_frac and not getattr(base_model,'_sparsity_applied',False): + with torch.no_grad(): + for module in base_model.modules(): + if isinstance(module,MLP): + for linear in[module.fc,module.proj]: + w=linear.weight.data + for i in range(0,w.shape[1],4): + group=w[:,i:i+4].abs();_,idx=group.topk(2,dim=1,largest=False) + for j in range(2):w[range(w.shape[0]),i+idx[:,j]]=0 + base_model._sparsity_applied=True;log(f"sparsity:applied 2:4 pruning at frac={frac:.3f}") + momentum_cooldown=float(os.environ.get('MOMENTUM_COOLDOWN',0.)) + if momentum_cooldown>0 and scale<1.: + cool_mom=h.muon_momentum-momentum_cooldown*(1.-scale); + for group in optimizers.optimizer_muon.param_groups:group['momentum']=cool_mom + train_loss=step_fn(step,scale) + with torch.no_grad(): + for(name,t)in base_model.state_dict().items():ema_state[name].mul_(ema_decay).add_(t.detach().float(),alpha=1.-ema_decay) + swa_decay=float(os.environ.get('SWA_DECAY',0)) + if h.swa_enabled and scale0: + for(name,t)in base_model.state_dict().items():swa_state[name].mul_(swa_decay).add_(t.detach().float(),alpha=1.-swa_decay) + swa_count=1 + else: + for(name,t)in base_model.state_dict().items():swa_state[name].add_(t.detach().float()) + swa_count+=1 + step+=1;approx_training_time_ms=training_time_ms+1e3*(time.perf_counter()-t0);should_log_train=h.train_log_every>0 and(step<=5 or step%h.train_log_every==0 or stop_after_step is not None) + if should_log_train:tok_per_sec=step*h.train_batch_tokens/(approx_training_time_ms/1e3);log(f"{step}/{h.iterations} train_loss: {train_loss.item():.4f} train_time: {approx_training_time_ms/60000:.1f}m tok/s: {tok_per_sec:.0f}") + reached_cap=max_wallclock_ms is not None and approx_training_time_ms>=max_wallclock_ms + if h.distributed and max_wallclock_ms is not None:reached_cap_tensor=torch.tensor(int(reached_cap),device=device);dist.all_reduce(reached_cap_tensor,op=dist.ReduceOp.MAX);reached_cap=bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap:stop_after_step=step + log(f"peak memory allocated: {torch.cuda.max_memory_allocated()//1024//1024} MiB reserved: {torch.cuda.max_memory_reserved()//1024//1024} MiB");current_state=base_model.state_dict() + if h.swa_enabled and swa_state is not None and swa_count>0: + log(f"swa:applying SWA weights ({swa_count} checkpoints)");avg_state={name:(t/swa_count).to(dtype=current_state[name].dtype)for(name,t)in swa_state.items()} + else: + log('ema:applying EMA weights');avg_state={name:t.to(dtype=current_state[name].dtype)for(name,t)in ema_state.items()} + base_model.load_state_dict(avg_state,strict=True);return base_model,compiled_model +def train_and_eval(h,device): + eval_only=bool(int(os.environ.get('EVAL_ONLY','0'))) + random.seed(h.seed);np.random.seed(h.seed);torch.manual_seed(h.seed);torch.cuda.manual_seed_all(h.seed);val_data=ValidationData(h,device);log(f"val_tokens: {val_data.val_tokens.numel()-1}") + if eval_only: + log("EVAL_ONLY=1: skipping training and serialization, loading existing quantized model") + else: + log(f"train_shards: {len(list(Path(h.datasets_dir).resolve().glob('fineweb_train_*.bin')))}");base_model,compiled_model=train_model(h,device,val_data);torch._dynamo.reset();timed_eval('pre-quantization post-ema',eval_val,h,device,val_data,compiled_model);serialize(h,base_model,Path(__file__).read_text(encoding='utf-8')) + if h.distributed:dist.barrier() + torch.cuda.synchronize();eval_t0=time.perf_counter();eval_deadline=eval_t0+h.max_eval_seconds if h.max_eval_seconds>0 else None;log(f"eval:budget {h.max_eval_seconds:.0f}s") + eval_model=deserialize(h,device) + if len(eval_model.encoder_indices)!=eval_model.num_encoder_layers:eval_model.looping_active=True + compiled_model=torch.compile(eval_model,dynamic=False,fullgraph=True);timed_eval('quantized',eval_val,h,device,val_data,compiled_model) + if h.sliding_window_enabled:timed_eval('quantized_sliding_window',eval_val_sliding,h,device,val_data,eval_model) + if h.ttt_enabled and h.ttt_lora_rank > 0: + ttt_model=deserialize(h,device) + if len(ttt_model.encoder_indices)!=ttt_model.num_encoder_layers:ttt_model.looping_active=True + # Warm up rotary caches for TTT eval seq len + for block in ttt_model.blocks: + block.attn.rotary._cos_cached = None + block.attn.rotary._sin_cached = None + block.attn.rotary._seq_len_cached = 0 + block.attn.rotary(h.ttt_eval_seq_len, device, torch.bfloat16) + def _fwd_ttt_inner(input_ids, target_ids, lora): + return ttt_model.forward_ttt(input_ids, target_ids, lora=lora) + _fwd_ttt_compiled_inner = None + def _fwd_ttt(input_ids, target_ids, lora): + nonlocal _fwd_ttt_compiled_inner + if _fwd_ttt_compiled_inner is None: + _fwd_ttt_compiled_inner = torch.compile(_fwd_ttt_inner, dynamic=True) + return _fwd_ttt_compiled_inner(input_ids, target_ids, lora=lora) + fwd_ttt_compiled = _fwd_ttt + # Compile warmup with random tokens + log(f"ttt_lora:warming up compile") + t_warmup = time.perf_counter() + for bsz_w in [h.ttt_batch_size]: + wl = BatchedTTTLoRA(bsz_w, ttt_model, h.ttt_lora_rank, + k_lora=h.ttt_k_lora, mlp_lora=h.ttt_mlp_lora, o_lora=h.ttt_o_lora).to(device) + wo = torch.optim.AdamW(wl.parameters(), lr=h.ttt_lora_lr, + betas=(h.ttt_beta1, h.ttt_beta2), eps=1e-10, + weight_decay=h.ttt_weight_decay, fused=True) + for ctx_len in (h.ttt_chunk_size, h.ttt_eval_seq_len): + xw = torch.randint(0, h.vocab_size, (bsz_w, ctx_len), device=device, dtype=torch.int64) + yw = torch.randint(0, h.vocab_size, (bsz_w, ctx_len), device=device, dtype=torch.int64) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + ptl = fwd_ttt_compiled(xw, yw, lora=wl) + ptl[:, : min(h.ttt_chunk_size, ctx_len)].mean(dim=-1).sum().backward() + wo.step() + wo.zero_grad(set_to_none=True) + del wl, wo + torch.cuda.empty_cache() + log(f"ttt_lora:compile warmup done ({time.perf_counter() - t_warmup:.1f}s)") + log(f"ttt_lora_alpha: {BatchedLinearLoRA._ALPHA}") + log(f"ttt_warm_start_a: {BatchedLinearLoRA._WARM_START_A}") + log("beginning TTT eval timer") + torch.cuda.synchronize() + t_ttt = time.perf_counter() + ttt_val_loss, ttt_val_bpb = eval_val_ttt_phased(h, ttt_model, device, val_data, forward_ttt_train=fwd_ttt_compiled, deadline=eval_deadline) + torch.cuda.synchronize() + ttt_eval_elapsed = time.perf_counter() - t_ttt + log(f"quantized_ttt_phased val_loss:{ttt_val_loss:.8f} val_bpb:{ttt_val_bpb:.8f} eval_time:{ttt_eval_elapsed*1e3:.0f}ms") + torch.cuda.synchronize();total_eval_elapsed=time.perf_counter()-eval_t0;log(f"TOTAL_EVAL_TIME: {total_eval_elapsed:.1f}s ({total_eval_elapsed/60:.1f}m) budget:{h.max_eval_seconds:.0f}s") +def main(): + world_size=int(os.environ.get('WORLD_SIZE','1'));local_rank=int(os.environ.get('LOCAL_RANK','0'));distributed='RANK'in os.environ and'WORLD_SIZE'in os.environ + if not torch.cuda.is_available():raise RuntimeError('CUDA is required') + if world_size<=0:raise ValueError(f"bad ws") + if 8%world_size!=0:raise ValueError(f"ws must divide 8") + device=torch.device('cuda',local_rank);torch.cuda.set_device(device) + if distributed:dist.init_process_group(backend='nccl',device_id=device);dist.barrier() + torch.backends.cuda.matmul.allow_tf32=True;torch.backends.cudnn.allow_tf32=True;torch.set_float32_matmul_precision('high');from torch.backends.cuda import enable_cudnn_sdp,enable_flash_sdp,enable_math_sdp,enable_mem_efficient_sdp;enable_cudnn_sdp(False);enable_flash_sdp(True);enable_mem_efficient_sdp(True);enable_math_sdp(False);torch._dynamo.config.optimize_ddp=False + h=Hyperparameters();set_logging_hparams(h) + if h.is_main_process: + os.makedirs('logs',exist_ok=True);log('Hyperparameters:',console=True) + for(k,v)in sorted(vars(type(h)).items()): + if not k.startswith('_'):log(f" {k}: {v}",console=True) + log('='*100,console=False) + train_and_eval(h,device) + if distributed:dist.destroy_process_group() +if __name__=='__main__':main() \ No newline at end of file diff --git a/records/track_10min_16mb/2026-04-30_PiyushDatta_SP8192_DepthRecur_PolarNS_LoRATTT/train_gpt_readable.py b/records/track_10min_16mb/2026-04-30_PiyushDatta_SP8192_DepthRecur_PolarNS_LoRATTT/train_gpt_readable.py new file mode 100644 index 0000000000..dff35b801d --- /dev/null +++ b/records/track_10min_16mb/2026-04-30_PiyushDatta_SP8192_DepthRecur_PolarNS_LoRATTT/train_gpt_readable.py @@ -0,0 +1,1696 @@ +import collections,copy,fcntl,glob,io,lzma,math,os +from pathlib import Path +import random,re,subprocess,sys,time,uuid,numpy as np,sentencepiece as spm,torch,torch.distributed as dist,torch.nn.functional as F +from torch.nn.parallel import DistributedDataParallel as DDP +from torch import Tensor,nn +try: + from flash_attn_interface import flash_attn_func as flash_attn_3_func + _HAS_FA3=True +except ImportError: + _HAS_FA3=False +# ===== Fused softcapped cross-entropy (Triton) — training-only path ===== +# Replaces the eager softcap*tanh(x/softcap) + F.cross_entropy with a single +# fused Triton kernel that reads logits once, applies softcap in-register, and +# computes (LSE, loss) in one streaming pass. The backward kernel mirrors the +# forward so there's no stored softcapped logits tensor. +# Math note: the kernel uses z = 2C*sigmoid(2x/C) instead of C*tanh(x/C). +# These differ by a constant +C, which cancels in log_softmax (shift invariance). +_FUSED_CE_ENABLED = bool(int(os.environ.get('FUSED_CE_ENABLED', '0'))) +_FUSED_CE_LIBRARY = "piyushdatta_fusedce" +_FUSED_CE_BLOCK_SIZE = 1024 # 8192 vocab / 1024 = 8 iterations, perfect alignment +_FUSED_CE_NUM_WARPS = 4 +if _FUSED_CE_ENABLED: + import triton + import triton.language as tl + @triton.jit + def _softcapped_ce_fwd_kernel( + logits_ptr, losses_ptr, lse_ptr, targets_ptr, + stride_logits_n, stride_logits_v, + n_rows, n_cols, softcap, + block_size: tl.constexpr, + ): + row_idx = tl.program_id(0).to(tl.int64) + logits_row_ptr = logits_ptr + row_idx * stride_logits_n + max_val = -float("inf") + sum_exp = 0.0 + A = 2.0 * softcap + inv_C = 2.0 / softcap + for off in range(0, n_cols, block_size): + cols = off + tl.arange(0, block_size) + mask = cols < n_cols + val = tl.load( + logits_row_ptr + cols * stride_logits_v, + mask=mask, other=-float("inf"), + ).to(tl.float32) + z = A * tl.sigmoid(val * inv_C) + z = tl.where(mask, z, -float("inf")) + curr_max = tl.max(z, axis=0) + new_max = tl.maximum(max_val, curr_max) + sum_exp = sum_exp * tl.exp(max_val - new_max) + tl.sum(tl.exp(z - new_max), axis=0) + max_val = new_max + lse = max_val + tl.log(sum_exp) + tl.store(lse_ptr + row_idx, lse) + target = tl.load(targets_ptr + row_idx).to(tl.int32) + target_val = tl.load(logits_row_ptr + target * stride_logits_v).to(tl.float32) + target_z = A * tl.sigmoid(target_val * inv_C) + tl.store(losses_ptr + row_idx, lse - target_z) + @triton.jit + def _softcapped_ce_bwd_kernel( + grad_logits_ptr, grad_losses_ptr, lse_ptr, logits_ptr, targets_ptr, + stride_logits_n, stride_logits_v, + stride_grad_n, stride_grad_v, + n_rows, n_cols, softcap, + block_size: tl.constexpr, + ): + row_idx = tl.program_id(0).to(tl.int64) + logits_row_ptr = logits_ptr + row_idx * stride_logits_n + grad_row_ptr = grad_logits_ptr + row_idx * stride_grad_n + lse = tl.load(lse_ptr + row_idx) + grad_loss = tl.load(grad_losses_ptr + row_idx).to(tl.float32) + target = tl.load(targets_ptr + row_idx).to(tl.int32) + A = 2.0 * softcap + inv_C = 2.0 / softcap + dz_dx_scale = A * inv_C + for off in range(0, n_cols, block_size): + cols = off + tl.arange(0, block_size) + mask = cols < n_cols + val = tl.load( + logits_row_ptr + cols * stride_logits_v, + mask=mask, other=0.0, + ).to(tl.float32) + sigmoid_u = tl.sigmoid(val * inv_C) + z = A * sigmoid_u + probs = tl.exp(z - lse) + grad_z = grad_loss * (probs - tl.where(cols == target, 1.0, 0.0)) + grad_x = grad_z * (dz_dx_scale * sigmoid_u * (1.0 - sigmoid_u)) + tl.store(grad_row_ptr + cols * stride_grad_v, grad_x, mask=mask) + def _validate_softcapped_ce_inputs(logits, targets, softcap): + if logits.ndim != 2: raise ValueError(f"Expected logits.ndim=2, got {logits.ndim}") + if targets.ndim != 1: raise ValueError(f"Expected targets.ndim=1, got {targets.ndim}") + if logits.shape[0] != targets.shape[0]: raise ValueError(f"Row mismatch logits={tuple(logits.shape)} targets={tuple(targets.shape)}") + if not logits.is_cuda or not targets.is_cuda: raise ValueError("softcapped_cross_entropy requires CUDA tensors") + if softcap <= 0.0: raise ValueError(f"softcap must be positive, got {softcap}") + logits = logits.contiguous() + targets = targets.contiguous() + if targets.dtype != torch.int64: targets = targets.to(dtype=torch.int64) + return logits, targets + @torch.library.custom_op(f"{_FUSED_CE_LIBRARY}::softcapped_ce", mutates_args=()) + def softcapped_ce_op(logits: Tensor, targets: Tensor, softcap: float) -> tuple[Tensor, Tensor]: + logits, targets = _validate_softcapped_ce_inputs(logits, targets, float(softcap)) + n_rows, n_cols = logits.shape + losses = torch.empty((n_rows,), device=logits.device, dtype=torch.float32) + lse = torch.empty((n_rows,), device=logits.device, dtype=torch.float32) + _softcapped_ce_fwd_kernel[(n_rows,)]( + logits, losses, lse, targets, + logits.stride(0), logits.stride(1), + n_rows, n_cols, float(softcap), + block_size=_FUSED_CE_BLOCK_SIZE, num_warps=_FUSED_CE_NUM_WARPS, + ) + return losses, lse + @softcapped_ce_op.register_fake + def _(logits: Tensor, targets: Tensor, softcap: float): + n_rows = logits.shape[0] + return (logits.new_empty((n_rows,), dtype=torch.float32), logits.new_empty((n_rows,), dtype=torch.float32)) + @torch.library.custom_op(f"{_FUSED_CE_LIBRARY}::softcapped_ce_backward", mutates_args=()) + def softcapped_ce_backward_op(logits: Tensor, targets: Tensor, lse: Tensor, grad_losses: Tensor, softcap: float) -> Tensor: + logits, targets = _validate_softcapped_ce_inputs(logits, targets, float(softcap)) + lse = lse.contiguous(); grad_losses = grad_losses.contiguous().to(dtype=torch.float32) + grad_logits = torch.empty_like(logits) + n_rows, n_cols = logits.shape + _softcapped_ce_bwd_kernel[(n_rows,)]( + grad_logits, grad_losses, lse, logits, targets, + logits.stride(0), logits.stride(1), + grad_logits.stride(0), grad_logits.stride(1), + n_rows, n_cols, float(softcap), + block_size=_FUSED_CE_BLOCK_SIZE, num_warps=_FUSED_CE_NUM_WARPS, + ) + return grad_logits + @softcapped_ce_backward_op.register_fake + def _(logits: Tensor, targets: Tensor, lse: Tensor, grad_losses: Tensor, softcap: float): + return logits.new_empty(logits.shape) + def _softcapped_ce_setup_context(ctx, inputs, output): + logits, targets, softcap = inputs + _losses, lse = output + ctx.save_for_backward(logits, targets, lse) + ctx.softcap = float(softcap) + def _softcapped_ce_backward(ctx, grad_losses, grad_lse): + del grad_lse + logits, targets, lse = ctx.saved_tensors + grad_logits = getattr(torch.ops, _FUSED_CE_LIBRARY).softcapped_ce_backward(logits, targets, lse, grad_losses, ctx.softcap) + return grad_logits, None, None + softcapped_ce_op.register_autograd(_softcapped_ce_backward, setup_context=_softcapped_ce_setup_context) + def softcapped_cross_entropy(logits, targets, softcap, reduction="mean"): + losses, _lse = getattr(torch.ops, _FUSED_CE_LIBRARY).softcapped_ce(logits, targets, float(softcap)) + if reduction == "none": return losses + if reduction == "sum": return losses.sum() + if reduction == "mean": return losses.mean() + raise ValueError(f"Unsupported reduction={reduction!r}") +class Hyperparameters:data_dir=os.environ.get('DATA_DIR','./data/');seed=int(os.environ.get('SEED',1337));run_id=os.environ.get('RUN_ID',str(uuid.uuid4()));iterations=int(os.environ.get('ITERATIONS',20000));warmdown_frac=float(os.environ.get('WARMDOWN_FRAC',.72));warmup_steps=int(os.environ.get('WARMUP_STEPS',20));train_batch_tokens=int(os.environ.get('TRAIN_BATCH_TOKENS',393216));train_seq_len=int(os.environ.get('TRAIN_SEQ_LEN',2048));train_log_every=int(os.environ.get('TRAIN_LOG_EVERY',500));max_wallclock_seconds=float(os.environ.get('MAX_WALLCLOCK_SECONDS',6e2));val_batch_tokens=int(os.environ.get('VAL_BATCH_TOKENS',524288));eval_seq_len=int(os.environ.get('EVAL_SEQ_LEN',2048));val_loss_every=int(os.environ.get('VAL_LOSS_EVERY',4000));sliding_window_enabled=bool(int(os.environ.get('SLIDING_WINDOW_ENABLED','1')));vocab_size=int(os.environ.get('VOCAB_SIZE',8192));num_layers=int(os.environ.get('NUM_LAYERS',11));xsa_last_n=int(os.environ.get('XSA_LAST_N',11));model_dim=int(os.environ.get('MODEL_DIM',512));embedding_dim=int(os.environ.get('EMBEDDING_DIM',512));num_kv_heads=int(os.environ.get('NUM_KV_HEADS',4));num_heads=int(os.environ.get('NUM_HEADS',8));mlp_mult=float(os.environ.get('MLP_MULT',4.));skip_gates_enabled=bool(int(os.environ.get('SKIP_GATES_ENABLED','1')));tie_embeddings=bool(int(os.environ.get('TIE_EMBEDDINGS','1')));logit_softcap=float(os.environ.get('LOGIT_SOFTCAP',3e1));rope_base=float(os.environ.get('ROPE_BASE',1e4));rope_dims=int(os.environ.get('ROPE_DIMS',16));rope_train_seq_len=int(os.environ.get('ROPE_TRAIN_SEQ_LEN',2048));ln_scale=bool(int(os.environ.get('LN_SCALE','1')));qk_gain_init=float(os.environ.get('QK_GAIN_INIT',5.));parallel_residual_start=int(os.environ.get('PARALLEL_RESIDUAL_START',7));enable_looping_at=float(os.environ.get('ENABLE_LOOPING_AT',.45));min_lr=float(os.environ.get('MIN_LR',.10));embed_lr=float(os.environ.get('EMBED_LR',.6));head_lr=float(os.environ.get('HEAD_LR',.008));tied_embed_lr=float(os.environ.get('TIED_EMBED_LR',.03));tied_embed_init_std=float(os.environ.get('TIED_EMBED_INIT_STD',.005));matrix_lr=float(os.environ.get('MATRIX_LR',.028));scalar_lr=float(os.environ.get('SCALAR_LR',.02));muon_momentum=float(os.environ.get('MUON_MOMENTUM',.95));muon_backend_steps=int(os.environ.get('MUON_BACKEND_STEPS',5));muon_momentum_warmup_start=float(os.environ.get('MUON_MOMENTUM_WARMUP_START',.85));muon_momentum_warmup_steps=int(os.environ.get('MUON_MOMENTUM_WARMUP_STEPS',500));muon_row_normalize=bool(int(os.environ.get('MUON_ROW_NORMALIZE','1')));beta1=float(os.environ.get('BETA1',.9));beta2=float(os.environ.get('BETA2',.95));adam_eps=float(os.environ.get('ADAM_EPS',1e-08));grad_clip_norm=float(os.environ.get('GRAD_CLIP_NORM',.3));eval_stride=int(os.environ.get('EVAL_STRIDE',64));adam_wd=float(os.environ.get('ADAM_WD',.02));muon_wd=float(os.environ.get('MUON_WD',.095));embed_wd=float(os.environ.get('EMBED_WD',.085));ema_decay=float(os.environ.get('EMA_DECAY',.9965));compressor=os.environ.get('COMPRESSOR','brotli');swa_enabled=bool(int(os.environ.get('SWA_ENABLED','1')));swa_start_frac=float(os.environ.get('SWA_START_FRAC',.12));swa_every=int(os.environ.get('SWA_EVERY',2));gptq_calibration_batches=int(os.environ.get('GPTQ_CALIBRATION_BATCHES',64));gptq_reserve_seconds=float(os.environ.get('GPTQ_RESERVE_SECONDS',12.));matrix_bits=int(os.environ.get('MATRIX_BITS',6));embed_bits=int(os.environ.get('EMBED_BITS',8));matrix_clip_sigmas=float(os.environ.get('MATRIX_CLIP_SIGMAS',12.85));mlp_clip_sigmas=float(os.environ.get('MLP_CLIP_SIGMAS',0));attn_clip_sigmas=float(os.environ.get('ATTN_CLIP_SIGMAS',0));embed_clip_sigmas=float(os.environ.get('EMBED_CLIP_SIGMAS',2e1));hessian_clip_lambda=float(os.environ.get('HESSIAN_CLIP_LAMBDA',.175));eval_temperature=float(os.environ.get('EVAL_TEMPERATURE',1.));sparsity_start_frac=float(os.environ.get('SPARSITY_START_FRAC',0.));ttt_enabled=bool(int(os.environ.get('TTT_ENABLED','1')));ttt_lr=float(os.environ.get('TTT_LR',.02));ttt_epochs=int(os.environ.get('TTT_EPOCHS',3));ttt_momentum=float(os.environ.get('TTT_MOMENTUM',.9));ttt_chunk_tokens=int(os.environ.get('TTT_CHUNK_TOKENS',32768));ttt_lora_rank=int(os.environ.get('TTT_LORA_RANK',96));ttt_ns_steps=int(os.environ.get('TTT_NS_STEPS',0));ttt_swa=bool(int(os.environ.get('TTT_SWA','0')));ttt_reset_per_chunk=bool(int(os.environ.get('TTT_RESET_PER_CHUNK','0')));ttt_lora_alpha=float(os.environ.get('TTT_LORA_ALPHA',144));ttt_warm_start_a=bool(int(os.environ.get('TTT_WARM_START_A','1')));ttt_lora_lr=float(os.environ.get('TTT_LORA_LR',0.0001));ttt_chunk_size=int(os.environ.get('TTT_CHUNK_SIZE',48));ttt_batch_size=int(os.environ.get('TTT_BATCH_SIZE',64));ttt_grad_steps=int(os.environ.get('TTT_GRAD_STEPS',1));ttt_weight_decay=float(os.environ.get('TTT_WEIGHT_DECAY',1.0));ttt_beta1=float(os.environ.get('TTT_BETA1',0));ttt_beta2=float(os.environ.get('TTT_BETA2',0.999));ttt_k_lora=bool(int(os.environ.get('TTT_K_LORA','1')));ttt_mlp_lora=bool(int(os.environ.get('TTT_MLP_LORA','1')));ttt_o_lora=bool(int(os.environ.get('TTT_O_LORA','1')));ttt_optimizer=os.environ.get('TTT_OPTIMIZER','adam');val_doc_fraction=float(os.environ.get('VAL_DOC_FRACTION',1.0));phased_ttt_prefix_docs=int(os.environ.get('PHASED_TTT_PREFIX_DOCS',2000));phased_ttt_num_phases=int(os.environ.get('PHASED_TTT_NUM_PHASES',1));global_ttt_lr=float(os.environ.get('GLOBAL_TTT_LR',0.001));global_ttt_momentum=float(os.environ.get('GLOBAL_TTT_MOMENTUM',0.9));global_ttt_epochs=int(os.environ.get('GLOBAL_TTT_EPOCHS',1));global_ttt_chunk_tokens=int(os.environ.get('GLOBAL_TTT_CHUNK_TOKENS',32768));global_ttt_batch_seqs=int(os.environ.get('GLOBAL_TTT_BATCH_SEQS',32));global_ttt_warmup_start_lr=float(os.environ.get('GLOBAL_TTT_WARMUP_START_LR',0.0));global_ttt_warmup_chunks=int(os.environ.get('GLOBAL_TTT_WARMUP_CHUNKS',0));global_ttt_grad_clip=float(os.environ.get('GLOBAL_TTT_GRAD_CLIP',0));ttt_eval_seq_len=int(os.environ.get('TTT_EVAL_SEQ_LEN',2048));ttt_eval_batches=os.environ.get('TTT_EVAL_BATCHES','');mos_k=int(os.environ.get('MOS_K',1));byte_weighted_ce=bool(int(os.environ.get('BYTE_WEIGHTED_CE','0')));batch_schedule_enabled=bool(int(os.environ.get('BATCH_SCHEDULE_ENABLED','0')));scale_tuning_enabled=bool(int(os.environ.get('SCALE_TUNING_ENABLED','0')));scale_tuning_steps=int(os.environ.get('SCALE_TUNING_STEPS',20));scale_tuning_lr=float(os.environ.get('SCALE_TUNING_LR',0.001));scale_tuning_batches=int(os.environ.get('SCALE_TUNING_BATCHES',8));distributed='RANK'in os.environ and'WORLD_SIZE'in os.environ;rank=int(os.environ.get('RANK','0'));world_size=int(os.environ.get('WORLD_SIZE','1'));local_rank=int(os.environ.get('LOCAL_RANK','0'));is_main_process=rank==0;grad_accum_steps=int(os.environ.get('GRAD_ACCUM_STEPS',str(8//world_size)));datasets_dir=os.path.join(data_dir,'datasets',f"fineweb10B_sp{vocab_size}");train_files=os.path.join(datasets_dir,'fineweb_train_*.bin');val_files=os.path.join(datasets_dir,'fineweb_val_*.bin');tokenizer_path=os.path.join(data_dir,'tokenizers',f"fineweb_{vocab_size}_bpe.model");logfile=f"logs/{run_id}.txt";model_path='final_model.pt';quantized_model_path='final_model.int6.ptz';kd_enabled=bool(int(os.environ.get('KD_ENABLED','0')));kd_alpha=float(os.environ.get('KD_ALPHA',0.5));kd_temperature=float(os.environ.get('KD_TEMPERATURE',2.0));kd_top_k=int(os.environ.get('KD_TOP_K',32));kd_logits_dir=os.environ.get('KD_LOGITS_DIR','');kd_warmup_frac=float(os.environ.get('KD_WARMUP_FRAC',0.0));hard_token_mining=bool(int(os.environ.get('HARD_TOKEN_MINING','0')));hard_token_frac=float(os.environ.get('HARD_TOKEN_FRAC',0.5));hard_token_boost=float(os.environ.get('HARD_TOKEN_BOOST',2.0));aux_loss_weight=float(os.environ.get('AUX_LOSS_WEIGHT',0.));multi_traj_swa=bool(int(os.environ.get('MULTI_TRAJ_SWA','0')));lqer_enabled=bool(int(os.environ.get('LQER_ENABLED','0')));lqer_rank=int(os.environ.get('LQER_RANK','4'));lqer_top_k=int(os.environ.get('LQER_TOP_K','3'));lqer_asym_group=int(os.environ.get('LQER_ASYM_GROUP','64')) +_logger_hparams=None +def set_logging_hparams(h):global _logger_hparams;_logger_hparams=h +def log(msg,console=True): + if _logger_hparams is None:print(msg);return + if _logger_hparams.is_main_process: + if console:print(msg) + if _logger_hparams.logfile is not None: + with open(_logger_hparams.logfile,'a',encoding='utf-8')as f:print(msg,file=f) +class ValidationData: + def __init__(self,h,device): + self.sp=spm.SentencePieceProcessor(model_file=h.tokenizer_path) + if int(self.sp.vocab_size())!=h.vocab_size:raise ValueError(f"vocab mismatch") + self.val_tokens=load_validation_tokens(h.val_files,h.eval_seq_len);self.base_bytes_lut,self.has_leading_space_lut,self.is_boundary_token_lut=build_sentencepiece_luts(self.sp,h.vocab_size,device) +def build_sentencepiece_luts(sp,vocab_size,device): + sp_vocab_size=int(sp.vocab_size());assert sp.piece_to_id('▁')!=sp.unk_id(),"bad tokenizer";table_size=max(sp_vocab_size,vocab_size);base_bytes_np=np.zeros((table_size,),dtype=np.int16);has_leading_space_np=np.zeros((table_size,),dtype=np.bool_);is_boundary_token_np=np.ones((table_size,),dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id)or sp.is_unknown(token_id)or sp.is_unused(token_id):continue + is_boundary_token_np[token_id]=False + if sp.is_byte(token_id):base_bytes_np[token_id]=1;continue + piece=sp.id_to_piece(token_id) + if piece.startswith('▁'):has_leading_space_np[token_id]=True;piece=piece[1:] + base_bytes_np[token_id]=len(piece.encode('utf-8')) + return torch.tensor(base_bytes_np,dtype=torch.int16,device=device),torch.tensor(has_leading_space_np,dtype=torch.bool,device=device),torch.tensor(is_boundary_token_np,dtype=torch.bool,device=device) +def load_validation_tokens(pattern,seq_len): + files=[Path(p)for p in sorted(glob.glob(pattern))] + if not files:raise FileNotFoundError(f"no files: {pattern}") + tokens=torch.cat([load_data_shard(file)for file in files]).contiguous();usable=(tokens.numel()-1)//seq_len*seq_len + if usable<=0:raise ValueError(f"val too short") + return tokens[:usable+1] +def load_data_shard(file): + header_bytes=256*np.dtype('0 else 0;num_sequences=(self.num_tokens[si]-1-phase)//self.seq_len;sequence_order=self.rng.permutation(num_sequences);self.start_inds[si]=(phase+sequence_order*self.seq_len).tolist() + def next_batch(self,global_tokens,grad_accum_steps): + device_tokens=global_tokens//(self.world_size*grad_accum_steps);device_batch_size=device_tokens//self.seq_len;remaining=np.array([len(s)for s in self.start_inds],dtype=np.float64);x=torch.empty((device_batch_size,self.seq_len),dtype=torch.int64);y=torch.empty((device_batch_size,self.seq_len),dtype=torch.int64) + teacher_topk_idx=None;teacher_topk_val=None + if self.kd_enabled:teacher_topk_idx=torch.empty((device_batch_size,self.seq_len,self.kd_top_k),dtype=torch.int64);teacher_topk_val=torch.empty((device_batch_size,self.seq_len,self.kd_top_k),dtype=torch.float32) + for bi in range(device_batch_size): + total=remaining.sum() + if total<=0: + for si in range(len(self.files)):self._reset_shard(si) + remaining=np.array([len(s)for s in self.start_inds],dtype=np.float64);total=remaining.sum() + probs=remaining/total;si=int(self.rng.choice(len(self.files),p=probs));start_ind=self.start_inds[si].pop();remaining[si]-=1;mm=_get_shard_memmap(self.files[si]);window=torch.as_tensor(np.array(mm[start_ind:start_ind+self.seq_len+1],dtype=np.int64));x[bi]=window[:-1];y[bi]=window[1:] + if self.kd_enabled and teacher_topk_idx is not None: + tmaps=_get_teacher_logits_memmaps(self.files[si],self.kd_logits_dir,self.kd_top_k) + if tmaps is not None: + t_idx_mm,t_val_mm=tmaps;t_start=start_ind;t_end=start_ind+self.seq_len + if t_end<=t_idx_mm.shape[0]:teacher_topk_idx[bi]=torch.as_tensor(np.array(t_idx_mm[t_start:t_end],dtype=np.int64));teacher_topk_val[bi]=torch.as_tensor(np.array(t_val_mm[t_start:t_end],dtype=np.float32)) + else:teacher_topk_idx[bi]=0;teacher_topk_val[bi]=0. + else:teacher_topk_idx[bi]=0;teacher_topk_val[bi]=0. + result=(x.to(self.device,non_blocking=True),y.to(self.device,non_blocking=True)) + if self.kd_enabled and teacher_topk_idx is not None:result=result+(teacher_topk_idx.to(self.device,non_blocking=True),teacher_topk_val.to(self.device,non_blocking=True)) + return result +class RMSNorm(nn.Module): + def __init__(self,eps=None):super().__init__();self.eps=eps + def forward(self,x):return F.rms_norm(x,(x.size(-1),),eps=self.eps) +class CastedLinear(nn.Linear): + def forward(self,x): + w=self.weight.to(x.dtype) + bias=self.bias.to(x.dtype)if self.bias is not None else None;return F.linear(x,w,bias) +class Rotary(nn.Module): + def __init__(self,dim,base=1e4,train_seq_len=1024,rope_dims=0):super().__init__();self.dim=dim;self.base=base;self.train_seq_len=train_seq_len;self.rope_dims=rope_dims if rope_dims>0 else dim;inv_freq=1./base**(torch.arange(0,self.rope_dims,2,dtype=torch.float32)/self.rope_dims);self.register_buffer('inv_freq',inv_freq,persistent=False);self._seq_len_cached=0;self._cos_cached=None;self._sin_cached=None + def forward(self,seq_len,device,dtype): + if self._cos_cached is None or self._sin_cached is None or self._seq_len_cached!=seq_len or self._cos_cached.device!=device: + rd=self.rope_dims + if seq_len>self.train_seq_len:scale=seq_len/self.train_seq_len;new_base=self.base*scale**(rd/(rd-2));inv_freq=1./new_base**(torch.arange(0,rd,2,dtype=torch.float32,device=device)/rd) + else:inv_freq=self.inv_freq.to(device) + t=torch.arange(seq_len,device=device,dtype=inv_freq.dtype);freqs=torch.outer(t,inv_freq);self._cos_cached=freqs.cos()[None,:,None,:];self._sin_cached=freqs.sin()[None,:,None,:];self._seq_len_cached=seq_len + return self._cos_cached.to(dtype=dtype),self._sin_cached.to(dtype=dtype) +def apply_rotary_emb(x,cos,sin,rope_dims=0): + if rope_dims>0 and rope_dims1:k=k[:,:,:,None,:].expand(bsz,seqlen,self.num_kv_heads,rep,self.head_dim).reshape(bsz,seqlen,self.num_heads,self.head_dim);v=v[:,:,:,None,:].expand(bsz,seqlen,self.num_kv_heads,rep,self.head_dim).reshape(bsz,seqlen,self.num_heads,self.head_dim) + y=F.scaled_dot_product_attention(q.transpose(1,2),k.transpose(1,2),v.transpose(1,2),is_causal=True).transpose(1,2) + if self.use_xsa:y=self._xsa_efficient(y,v) + y=y.reshape(bsz,seqlen,dim);return self.proj(y),v_out +class MLP(nn.Module): + def __init__(self,dim,mlp_mult):super().__init__();hidden=int(mlp_mult*dim);self.fc=CastedLinear(dim,hidden,bias=False);self.proj=CastedLinear(hidden,dim,bias=False);self.proj._zero_init=True;self.gated=bool(int(os.environ.get('GATED_MLP','0'))) + def forward(self,x): + h=self.fc(x);act=F.leaky_relu(h,negative_slope=.5).square() + if self.gated:act=act*torch.sigmoid(h) + return self.proj(act) +class Block(nn.Module): + def __init__(self,dim,num_heads,num_kv_heads,mlp_mult,rope_base,qk_gain_init,train_seq_len,layer_idx=0,ln_scale=False):super().__init__();self.attn_norm=RMSNorm();self.mlp_norm=RMSNorm();self.attn=CausalSelfAttention(dim,num_heads,num_kv_heads,rope_base,qk_gain_init,train_seq_len);self.mlp=MLP(dim,mlp_mult);self.attn_scale=nn.Parameter(torch.ones(dim,dtype=torch.float32));self.mlp_scale=nn.Parameter(torch.ones(dim,dtype=torch.float32));self.resid_mix=nn.Parameter(torch.stack((torch.ones(dim),torch.zeros(dim))).float());self.ln_scale_factor=1./math.sqrt(layer_idx+1)if ln_scale else 1.;self.parallel=False + def forward(self,x,x0,v_residual=None): + mix=self.resid_mix.to(dtype=x.dtype);x_in=mix[0][None,None,:]*x+mix[1][None,None,:]*x0;attn_out,v_out=self.attn(self.attn_norm(x_in)*self.ln_scale_factor,v_residual=v_residual) + if self.parallel:mlp_out=self.mlp(self.mlp_norm(x_in)*self.ln_scale_factor);x_out=x_in+self.attn_scale.to(dtype=x_in.dtype)[None,None,:]*attn_out+self.mlp_scale.to(dtype=x_in.dtype)[None,None,:]*mlp_out + else:x_out=x_in+self.attn_scale.to(dtype=x_in.dtype)[None,None,:]*attn_out;x_out=x_out+self.mlp_scale.to(dtype=x_out.dtype)[None,None,:]*self.mlp(self.mlp_norm(x_out)*self.ln_scale_factor) + return x_out,v_out +def sparse_kd_loss(student_logits, teacher_topk_idx, teacher_topk_val, temperature, softcap=0.): + """Compute sparse KL-divergence loss between student and teacher (top-K only). + + Args: + student_logits: [N, V] raw student logits (pre- or post-softcap) + teacher_topk_idx: [N, K] top-K token indices from teacher + teacher_topk_val: [N, K] top-K logit values from teacher (post-softcap) + temperature: temperature for softening distributions + softcap: if > 0, apply softcap to student logits first + Returns: + scalar KD loss (mean over batch) + """ + # Apply softcap to student if needed + if softcap > 0.: + student_logits = softcap * torch.tanh(student_logits / softcap) + # Scale by temperature + s_scaled = student_logits.float() / temperature + t_scaled = teacher_topk_val.float() / temperature + # Student log-softmax over full vocab (for log_sum_exp) + s_log_softmax = s_scaled - torch.logsumexp(s_scaled, dim=-1, keepdim=True) + # Teacher softmax over top-K only (renormalized) + t_probs = F.softmax(t_scaled, dim=-1) + # Gather student log-probs at teacher's top-K positions + s_log_probs_at_topk = s_log_softmax.gather(1, teacher_topk_idx) + # KL = sum_k t_prob_k * (log(t_prob_k) - s_log_prob_k) + # = sum_k t_prob_k * log(t_prob_k) - sum_k t_prob_k * s_log_prob_k + # First term is teacher entropy (constant wrt student), we include it for proper KL + t_log_probs = torch.log(t_probs.clamp(min=1e-8)) + kl = (t_probs * (t_log_probs - s_log_probs_at_topk)).sum(dim=-1) + return (kl * temperature * temperature).mean() +class GPT(nn.Module): + def __init__(self,h): + super().__init__() + if h.logit_softcap<=.0:raise ValueError(f"bad softcap") + self.tie_embeddings=h.tie_embeddings;self.tied_embed_init_std=h.tied_embed_init_std;self.logit_softcap=h.logit_softcap;self.fused_ce_enabled=_FUSED_CE_ENABLED;self.aux_loss_weight=h.aux_loss_weight;self._aux_hidden=None;self.tok_emb=nn.Embedding(h.vocab_size,h.embedding_dim) + if h.embedding_dim!=h.model_dim:self.embed_proj=CastedLinear(h.embedding_dim,h.model_dim,bias=False);self.head_proj=CastedLinear(h.model_dim,h.embedding_dim,bias=False) + else:self.embed_proj=None;self.head_proj=None + self.num_encoder_layers=h.num_layers//2;self.num_decoder_layers=h.num_layers-self.num_encoder_layers;self.blocks=nn.ModuleList([Block(h.model_dim,h.num_heads,h.num_kv_heads,h.mlp_mult,h.rope_base,h.qk_gain_init,h.train_seq_len,layer_idx=i,ln_scale=h.ln_scale)for i in range(h.num_layers)]) + if h.rope_dims>0: + head_dim=h.model_dim//h.num_heads + for block in self.blocks:block.attn.rope_dims=h.rope_dims;block.attn.rotary=Rotary(head_dim,base=h.rope_base,train_seq_len=h.train_seq_len,rope_dims=h.rope_dims) + self.final_norm=RMSNorm();self.lm_head=None if h.tie_embeddings else CastedLinear(h.embedding_dim,h.vocab_size,bias=False) + if self.lm_head is not None:self.lm_head._zero_init=True + if h.xsa_last_n>0: + for i in range(max(0,h.num_layers-h.xsa_last_n),h.num_layers):self.blocks[i].attn.use_xsa=True + if h.parallel_residual_start>=0: + for i in range(h.parallel_residual_start,h.num_layers):self.blocks[i].parallel=True + self.hourglass_enabled=bool(int(os.environ.get('HOURGLASS_ENABLED','0')));self.hourglass_down_after=int(os.environ.get('HOURGLASS_DOWN_AFTER',3));self.hourglass_up_before=int(os.environ.get('HOURGLASS_UP_BEFORE',8));self.hourglass_factor=int(os.environ.get('HOURGLASS_FACTOR',2)) + if self.hourglass_enabled:self.hourglass_skip_gate=nn.Parameter(torch.zeros(h.model_dim,dtype=torch.float32)) + self.looping_active=False;num_loops=int(os.environ.get('NUM_LOOPS',1));self.loop_start_idx=int(os.environ.get('LOOP_START',3)) + if num_loops>0: + ls=self.loop_start_idx;le=int(os.environ.get('LOOP_END',5));loop_seg=list(range(ls,le+1));all_idx=list(range(ls)) + for _ in range(num_loops+1):all_idx.extend(loop_seg) + all_idx.extend(range(le+1,h.num_layers));ne=len(all_idx)//2;self.encoder_indices=all_idx[:ne];self.decoder_indices=all_idx[ne:] + else:self.encoder_indices=list(range(self.num_encoder_layers));self.decoder_indices=list(range(self.num_encoder_layers,h.num_layers)) + self.num_skip_weights=min(len(self.encoder_indices),len(self.decoder_indices));self.skip_weights=nn.Parameter(torch.ones(self.num_skip_weights,h.model_dim,dtype=torch.float32));self.skip_gates=nn.Parameter(torch.zeros(self.num_skip_weights,h.model_dim,dtype=torch.float32))if h.skip_gates_enabled else None + self.loop_gate_enabled=bool(int(os.environ.get('LOOP_GATE_ENABLED','0'))) + if self.loop_gate_enabled:self.loop_gate=nn.Parameter(torch.full((h.model_dim,),-2.0,dtype=torch.float32)) + else:self.loop_gate=None + self.mos_k=h.mos_k + if self.mos_k>1: + self.mos_scales=nn.ParameterList([nn.Parameter(torch.ones(h.embedding_dim,dtype=torch.float32))for _ in range(self.mos_k)]) + self.mos_gate=nn.Parameter(torch.zeros(h.embedding_dim,self.mos_k,dtype=torch.float32)) + self._init_weights() + def _init_weights(self): + if self.tie_embeddings:nn.init.normal_(self.tok_emb.weight,mean=.0,std=self.tied_embed_init_std) + for(name,module)in self.named_modules(): + if isinstance(module,nn.Linear): + if getattr(module,'_zero_init',False):nn.init.zeros_(module.weight) + elif module.weight.ndim==2 and module.weight.shape[0]>=64 and module.weight.shape[1]>=64:nn.init.orthogonal_(module.weight,gain=1.) + def _hourglass_downsample(self,x): + """Downsample sequence T -> T/factor via avg_pool1d. x: [B, T, D] -> [B, T//factor, D]""" + f=self.hourglass_factor + return F.avg_pool1d(x.transpose(1,2),kernel_size=f,stride=f).transpose(1,2) + def _hourglass_upsample(self,x,target_len): + """Upsample sequence back to target_len via repeat_interleave. x: [B, T_down, D] -> [B, target_len, D]""" + x_up=x.repeat_interleave(self.hourglass_factor,dim=1) + if x_up.size(1)>target_len:x_up=x_up[:,:target_len,:] + return x_up + def _forward_hidden(self,input_ids): + """Forward through all blocks, returns hidden states before final linear projection.""" + x=self.tok_emb(input_ids);x=F.rms_norm(x,(x.size(-1),)) + if self.embed_proj is not None:x=self.embed_proj(x) + x0=x;skips=[];v_residual=None;enc_iter=self.encoder_indices if self.looping_active else range(self.num_encoder_layers);dec_iter=self.decoder_indices if self.looping_active else range(self.num_encoder_layers,self.num_encoder_layers+self.num_decoder_layers);is_first_layer=True;seen_loop_start=False;x_at_loop_start=None + hg=self.hourglass_enabled;full_len=x.size(1);hg_skip=None;in_downsampled=False + for step_idx,i in enumerate(enc_iter): + if self.loop_gate is not None and self.looping_active and i==self.loop_start_idx: + if not seen_loop_start:seen_loop_start=True;x_at_loop_start=x + else:g=torch.sigmoid(self.loop_gate.to(dtype=x.dtype))[None,None,:];x=(1-g)*x+g*x_at_loop_start + x,v_out=self.blocks[i](x,x0,v_residual=v_residual) + if _V_RESIDUAL_ENABLED and is_first_layer:v_residual=v_out;is_first_layer=False + skips.append(x) + if hg and not in_downsampled and step_idx==self.hourglass_down_after: + hg_skip=x;x=self._hourglass_downsample(x);x0=self._hourglass_downsample(x0);in_downsampled=True + if x_at_loop_start is not None:x_at_loop_start=self._hourglass_downsample(x_at_loop_start) + if v_residual is not None:v_residual=v_residual[:,::self.hourglass_factor,:,:] + self._aux_hidden=x + for(skip_idx,i)in enumerate(dec_iter): + if hg and in_downsampled and skip_idx==self.hourglass_up_before: + x=self._hourglass_upsample(x,full_len);x0=self._hourglass_upsample(x0,full_len);in_downsampled=False + if x_at_loop_start is not None:x_at_loop_start=self._hourglass_upsample(x_at_loop_start,full_len) + if v_residual is not None:v_residual=v_residual.repeat_interleave(self.hourglass_factor,dim=1)[:,:full_len,:,:] + sg=torch.sigmoid(self.hourglass_skip_gate.to(dtype=x.dtype))[None,None,:];x=(1-sg)*x+sg*hg_skip + if self.loop_gate is not None and self.looping_active and i==self.loop_start_idx: + if not seen_loop_start:seen_loop_start=True;x_at_loop_start=x + else:g=torch.sigmoid(self.loop_gate.to(dtype=x.dtype))[None,None,:];x=(1-g)*x+g*x_at_loop_start + if skip_idxx.size(1):scaled_skip=self._hourglass_downsample(scaled_skip) + else:scaled_skip=self._hourglass_upsample(scaled_skip,x.size(1)) + scaled_skip=self.skip_weights[skip_idx].to(dtype=x.dtype)[None,None,:]*scaled_skip + if self.skip_gates is not None:g=torch.sigmoid(self.skip_gates[skip_idx].to(dtype=x.dtype))[None,None,:];x=torch.lerp(scaled_skip,x,g) + else:x=x+scaled_skip + x,_=self.blocks[i](x,x0,v_residual=v_residual) + if hg and in_downsampled:x=self._hourglass_upsample(x,full_len) + x=self.final_norm(x) + if self.head_proj is not None:x=self.head_proj(x) + return x + def _forward_pre_softcap(self,input_ids): + x=self._forward_hidden(input_ids) + if self.tie_embeddings:return F.linear(x,self.tok_emb.weight) + else:return self.lm_head(x) + def _mos_loss(self,x,target_ids,reduction='mean'): + """Mixture of Softmax loss. x is pre-projection hidden states [*, dim].""" + tok_w=self.tok_emb.weight if self.tie_embeddings else self.lm_head.weight + sc=self.logit_softcap;K=self.mos_k + # mixing weights: [*, K] + pi=F.softmax(x@self.mos_gate.to(dtype=x.dtype),dim=-1) + # compute K softmax distributions and mix + p_mixed=torch.zeros(*x.shape[:-1],tok_w.shape[0],device=x.device,dtype=torch.float32) + for k in range(K): + x_k=x*self.mos_scales[k].to(dtype=x.dtype)[None,None,:]if x.ndim==3 else x*self.mos_scales[k].to(dtype=x.dtype) + logits_k=F.linear(x_k,tok_w) + logits_k=sc*torch.tanh(logits_k/sc) + p_k=F.softmax(logits_k.float(),dim=-1) + p_mixed=p_mixed+pi[...,k:k+1].float()*p_k + # loss = -log(p_mixed[target]) + p_mixed=p_mixed.clamp(min=1e-8) + flat_p=p_mixed.reshape(-1,p_mixed.size(-1)) + flat_t=target_ids.reshape(-1) + losses=-torch.log(flat_p.gather(1,flat_t.unsqueeze(1)).squeeze(1)) + if reduction=='mean':return losses.mean() + return losses + def _aux_ce_loss(self,target_ids): + """Auxiliary CE from encoder output through tied embeddings.""" + h=self._aux_hidden;h=F.rms_norm(h,(h.size(-1),)) + if self.head_proj is not None:h=self.head_proj(h) + tok_w=self.tok_emb.weight if self.tie_embeddings else self.lm_head.weight + logits=F.linear(h,tok_w);sc=self.logit_softcap + return softcapped_cross_entropy(logits.reshape(-1,logits.size(-1)),target_ids.reshape(-1),sc,reduction='mean')if self.fused_ce_enabled else F.cross_entropy((sc*torch.tanh(logits/sc)).float().reshape(-1,logits.size(-1)),target_ids.reshape(-1),reduction='mean') + def forward_logits(self,input_ids): + if self.mos_k>1: + x=self._forward_hidden(input_ids);tok_w=self.tok_emb.weight if self.tie_embeddings else self.lm_head.weight;sc=self.logit_softcap + pi=F.softmax(x@self.mos_gate.to(dtype=x.dtype),dim=-1) + p_mixed=torch.zeros(*x.shape[:-1],tok_w.shape[0],device=x.device,dtype=torch.float32) + for k in range(self.mos_k): + x_k=x*self.mos_scales[k].to(dtype=x.dtype)[None,None,:] + logits_k=sc*torch.tanh(F.linear(x_k,tok_w)/sc) + p_mixed=p_mixed+pi[...,k:k+1].float()*F.softmax(logits_k.float(),dim=-1) + return torch.log(p_mixed.clamp(min=1e-8)) + logits_proj=self._forward_pre_softcap(input_ids);return self.logit_softcap*torch.tanh(logits_proj/self.logit_softcap) + def forward(self,input_ids,target_ids,byte_weights=None,teacher_topk_idx=None,teacher_topk_val=None,kd_alpha=0.,kd_temperature=2.,aux_lr_scale=1.): + if self.mos_k>1: + x=self._forward_hidden(input_ids) + if byte_weights is not None: + losses=self._mos_loss(x,target_ids,reduction='none') + bw=byte_weights.reshape(-1).float();return(losses*bw).sum()/bw.sum() + return self._mos_loss(x,target_ids,reduction='mean') + # Knowledge distillation path + if teacher_topk_idx is not None and kd_alpha>0.: + logits_proj=self._forward_pre_softcap(input_ids);flat_logits=logits_proj.reshape(-1,logits_proj.size(-1));flat_targets=target_ids.reshape(-1) + if self.fused_ce_enabled:ce_loss=softcapped_cross_entropy(flat_logits,flat_targets,self.logit_softcap,reduction='mean') + else:logits_sc=self.logit_softcap*torch.tanh(flat_logits/self.logit_softcap);ce_loss=F.cross_entropy(logits_sc.float(),flat_targets,reduction='mean') + flat_t_idx=teacher_topk_idx.reshape(-1,teacher_topk_idx.size(-1));flat_t_val=teacher_topk_val.reshape(-1,teacher_topk_val.size(-1)) + kd_loss=sparse_kd_loss(flat_logits,flat_t_idx,flat_t_val,kd_temperature,softcap=self.logit_softcap) + return(1.-kd_alpha)*ce_loss+kd_alpha*kd_loss + # Compute per-token losses when hard_token_mining or byte_weights active + htm=Hyperparameters.hard_token_mining + if byte_weights is not None or htm: + if self.fused_ce_enabled: + logits_proj=self._forward_pre_softcap(input_ids) + losses=softcapped_cross_entropy(logits_proj.reshape(-1,logits_proj.size(-1)),target_ids.reshape(-1),self.logit_softcap,reduction='none') + else: + logits=self.forward_logits(input_ids);losses=F.cross_entropy(logits.reshape(-1,logits.size(-1)).float(),target_ids.reshape(-1),reduction='none') + w=torch.ones_like(losses) + if byte_weights is not None:w=byte_weights.reshape(-1).float() + if htm: + k=max(1,int(losses.numel()*Hyperparameters.hard_token_frac)) + _,hard_idx=losses.detach().topk(k) + boost=torch.ones_like(losses);boost[hard_idx]=Hyperparameters.hard_token_boost + w=w*boost + loss=(losses*w).sum()/w.sum() + elif self.fused_ce_enabled: + logits_proj=self._forward_pre_softcap(input_ids) + loss=softcapped_cross_entropy(logits_proj.reshape(-1,logits_proj.size(-1)),target_ids.reshape(-1),self.logit_softcap,reduction='mean') + else:logits=self.forward_logits(input_ids);loss=F.cross_entropy(logits.reshape(-1,logits.size(-1)).float(),target_ids.reshape(-1),reduction='mean') + if self.aux_loss_weight>0 and self._aux_hidden is not None:loss=loss+self.aux_loss_weight*aux_lr_scale*self._aux_ce_loss(target_ids) + return loss + def forward_ttt(self, input_ids, target_ids, lora): + """Forward pass with batched LoRA adapters for TTT. Returns per-token loss [bsz, seq_len].""" + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + if self.embed_proj is not None: + x = self.embed_proj(x) + x0 = x + skips = [] + v_residual = None + enc_iter = self.encoder_indices if self.looping_active else range(self.num_encoder_layers) + dec_iter = self.decoder_indices if self.looping_active else range(self.num_encoder_layers, self.num_encoder_layers + self.num_decoder_layers) + slot = 0 + is_first_layer = True + seen_loop_start = False + x_at_loop_start = None + hg = self.hourglass_enabled + full_len = x.size(1) + hg_skip = None + in_downsampled = False + for step_idx, i in enumerate(enc_iter): + if self.loop_gate is not None and self.looping_active and i == self.loop_start_idx: + if not seen_loop_start: + seen_loop_start = True + x_at_loop_start = x + else: + g = torch.sigmoid(self.loop_gate.to(dtype=x.dtype))[None, None, :] + x = (1 - g) * x + g * x_at_loop_start + x, v_out = self._block_with_lora(self.blocks[i], x, x0, lora, slot, v_residual=v_residual) + if _V_RESIDUAL_ENABLED and is_first_layer: + v_residual = v_out + is_first_layer = False + slot += 1 + skips.append(x) + if hg and not in_downsampled and step_idx == self.hourglass_down_after: + hg_skip = x + x = self._hourglass_downsample(x) + x0 = self._hourglass_downsample(x0) + in_downsampled = True + if x_at_loop_start is not None: + x_at_loop_start = self._hourglass_downsample(x_at_loop_start) + if v_residual is not None: + v_residual = v_residual[:, ::self.hourglass_factor, :, :] + for skip_idx, i in enumerate(dec_iter): + if hg and in_downsampled and skip_idx == self.hourglass_up_before: + x = self._hourglass_upsample(x, full_len) + x0 = self._hourglass_upsample(x0, full_len) + in_downsampled = False + if x_at_loop_start is not None: + x_at_loop_start = self._hourglass_upsample(x_at_loop_start, full_len) + if v_residual is not None: + v_residual = v_residual.repeat_interleave(self.hourglass_factor, dim=1)[:, :full_len, :, :] + sg = torch.sigmoid(self.hourglass_skip_gate.to(dtype=x.dtype))[None, None, :] + x = (1 - sg) * x + sg * hg_skip + if self.loop_gate is not None and self.looping_active and i == self.loop_start_idx: + if not seen_loop_start: + seen_loop_start = True + x_at_loop_start = x + else: + g = torch.sigmoid(self.loop_gate.to(dtype=x.dtype))[None, None, :] + x = (1 - g) * x + g * x_at_loop_start + if skip_idx < self.num_skip_weights and skips: + scaled_skip = skips.pop() + if hg and scaled_skip.size(1) != x.size(1): + if scaled_skip.size(1) > x.size(1): + scaled_skip = self._hourglass_downsample(scaled_skip) + else: + scaled_skip = self._hourglass_upsample(scaled_skip, x.size(1)) + scaled_skip = self.skip_weights[skip_idx].to(dtype=x.dtype)[None, None, :] * scaled_skip + if self.skip_gates is not None: + g = torch.sigmoid(self.skip_gates[skip_idx].to(dtype=x.dtype))[None, None, :] + x = torch.lerp(scaled_skip, x, g) + else: + x = x + scaled_skip + x, _ = self._block_with_lora(self.blocks[i], x, x0, lora, slot, v_residual=v_residual) + slot += 1 + if hg and in_downsampled: + x = self._hourglass_upsample(x, full_len) + x = self.final_norm(x) + if self.head_proj is not None: + x = self.head_proj(x) + if self.mos_k > 1: + tok_w = self.tok_emb.weight if self.tie_embeddings else self.lm_head.weight + sc = self.logit_softcap + pi = F.softmax(x @ self.mos_gate.to(dtype=x.dtype), dim=-1) + bsz, sl, dim = x.shape + V = tok_w.shape[0] + p_mixed = torch.zeros(bsz, sl, V, device=x.device, dtype=torch.float32) + for k in range(self.mos_k): + x_k = x * self.mos_scales[k].to(dtype=x.dtype)[None, None, :] + logits_k = F.linear(x_k, tok_w) + lora.lm_head_lora(x_k) + logits_k = sc * torch.tanh(logits_k / sc) + p_k = F.softmax(logits_k.float(), dim=-1) + p_mixed = p_mixed + pi[..., k:k+1].float() * p_k + p_mixed = p_mixed.clamp(min=1e-8) + flat_p = p_mixed.reshape(-1, V) + flat_t = target_ids.reshape(-1) + losses = -torch.log(flat_p.gather(1, flat_t.unsqueeze(1)).squeeze(1)) + return losses.reshape(bsz, sl) + if self.tie_embeddings: + logits = F.linear(x, self.tok_emb.weight) + else: + logits = self.lm_head(x) + logits = logits + lora.lm_head_lora(x) + logits = self.logit_softcap * torch.tanh(logits / self.logit_softcap) + bsz, sl, V = logits.shape + return F.cross_entropy( + logits.float().reshape(-1, V), target_ids.reshape(-1), reduction="none" + ).reshape(bsz, sl) + def _block_with_lora(self, block, x, x0, lora, slot, v_residual=None): + """Single block forward with LoRA injection, handles both parallel and sequential.""" + mix = block.resid_mix.to(dtype=x.dtype) + x_in = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + n = block.attn_norm(x_in) * block.ln_scale_factor + attn = block.attn + bsz, seqlen, dim = n.shape + # Q with LoRA + q = (attn.c_q(n) + lora.q_loras[slot](n)).reshape(bsz, seqlen, attn.num_heads, attn.head_dim) + # K with optional LoRA + k = attn.c_k(n) + if lora.k_loras is not None: + k = k + lora.k_loras[slot](n) + k = k.reshape(bsz, seqlen, attn.num_kv_heads, attn.head_dim) + # V with LoRA + v = (attn.c_v(n) + lora.v_loras[slot](n)).reshape(bsz, seqlen, attn.num_kv_heads, attn.head_dim) + if _V_RESIDUAL_ENABLED and v_residual is not None: + v_mix = torch.sigmoid(attn.v_mix).to(dtype=v.dtype)[None, None, :, None] + v = v_mix * v + (1.0 - v_mix) * v_residual + v_out = v + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = attn.rotary(seqlen, n.device, q.dtype) + q = apply_rotary_emb(q, cos, sin, attn.rope_dims) + k = apply_rotary_emb(k, cos, sin, attn.rope_dims) + q = q * attn.q_gain.to(dtype=q.dtype)[None, None, :, None] + if _HAS_FA3: + y = flash_attn_3_func(q, k, v, causal=True) + else: + rep = attn.num_heads // attn.num_kv_heads + if rep > 1: + k = k[:,:,:,None,:].expand(bsz,seqlen,attn.num_kv_heads,rep,attn.head_dim).reshape(bsz,seqlen,attn.num_heads,attn.head_dim) + v = v[:,:,:,None,:].expand(bsz,seqlen,attn.num_kv_heads,rep,attn.head_dim).reshape(bsz,seqlen,attn.num_heads,attn.head_dim) + y = F.scaled_dot_product_attention(q.transpose(1,2), k.transpose(1,2), v.transpose(1,2), is_causal=True).transpose(1,2) + if attn.use_xsa: + y = attn._xsa_efficient(y, v) + y = y.reshape(bsz, seqlen, dim) + attn_out = attn.proj(y) + if lora.o_loras is not None: + attn_out = attn_out + lora.o_loras[slot](y) + if block.parallel: + mlp_n = block.mlp_norm(x_in) * block.ln_scale_factor + mlp_out = block.mlp(mlp_n) + if lora.mlp_loras is not None: + mlp_out = mlp_out + lora.mlp_loras[slot](mlp_out) + x_out = x_in + block.attn_scale.to(dtype=x_in.dtype)[None, None, :] * attn_out + block.mlp_scale.to(dtype=x_in.dtype)[None, None, :] * mlp_out + else: + x_out = x_in + block.attn_scale.to(dtype=x_in.dtype)[None, None, :] * attn_out + mlp_n = block.mlp_norm(x_out) * block.ln_scale_factor + mlp_out = block.mlp(mlp_n) + if lora.mlp_loras is not None: + mlp_out = mlp_out + lora.mlp_loras[slot](mlp_out) + x_out = x_out + block.mlp_scale.to(dtype=x_out.dtype)[None, None, :] * mlp_out + return x_out, v_out +def classify_param(name): + if'tok_emb'in name or'lm_head'in name:return'embed' + if'.mlp.'in name:return'mlp' + if'.attn.'in name or'.proj.'in name and'.mlp.'not in name:return'attn' + return'other' +@torch.compile +def zeropower_via_newtonschulz5(G,steps=5,eps=1e-07): + coeffs=[(8.156554524902461,-22.48329292557795,15.878769915207462),(4.042929935166739,-2.808917465908714,0.5000178451051316),(3.8916678022926607,-2.772484153217685,0.5060648178503393),(3.285753657755655,-2.3681294933425376,0.46449024233003106),(2.3465413258596377,-1.7097828382687081,0.42323551169305323)] + X=G.bfloat16();X/=X.norm()+eps;transposed=G.size(0)>G.size(1) + if transposed:X=X.T + for a,b,c in coeffs[:steps]:A=X@X.T;B=b*A+c*A@A;X=a*X+B@X + return X.T if transposed else X +class Muon(torch.optim.Optimizer): + def __init__(self,params,lr,momentum,backend_steps,nesterov=True,weight_decay=.0,row_normalize=False):super().__init__(params,dict(lr=lr,momentum=momentum,backend_steps=backend_steps,nesterov=nesterov,weight_decay=weight_decay,row_normalize=row_normalize)) + @torch.no_grad() + def step(self,closure=None): + loss=None + if closure is not None: + with torch.enable_grad():loss=closure() + distributed=dist.is_available()and dist.is_initialized();world_size=dist.get_world_size()if distributed else 1;rank=dist.get_rank()if distributed else 0;solo=getattr(self,'_multi_traj_active',False) + for group in self.param_groups: + params=group['params'] + if not params:continue + lr=group['lr'];momentum=group['momentum'];backend_steps=group['backend_steps'];nesterov=group['nesterov'];total_params=sum(int(p.numel())for p in params);updates_flat=torch.zeros(total_params,device=params[0].device,dtype=torch.bfloat16);curr=0 + for(i,p)in enumerate(params): + if (solo or i%world_size==rank) and p.grad is not None: + g=p.grad;state=self.state[p] + if'momentum_buffer'not in state:state['momentum_buffer']=torch.zeros_like(g) + buf=state['momentum_buffer'];buf.mul_(momentum).add_(g) + if nesterov:g=g.add(buf,alpha=momentum) + if group.get('row_normalize',False):row_norms=g.float().norm(dim=-1,keepdim=True).clamp_min(1e-07);g=g/row_norms.to(g.dtype) + g=zeropower_via_newtonschulz5(g,steps=backend_steps);g*=max(1,g.size(0)/g.size(1))**.5;updates_flat[curr:curr+p.numel()]=g.reshape(-1) + curr+=p.numel() + if distributed and not solo:dist.all_reduce(updates_flat,op=dist.ReduceOp.SUM) + wd=group.get('weight_decay',.0);curr=0 + for p in params: + if wd>.0:p.data.mul_(1.-lr*wd) + g=updates_flat[curr:curr+p.numel()].view_as(p).to(dtype=p.dtype);p.add_(g,alpha=-lr);curr+=p.numel() + return loss +CONTROL_TENSOR_NAME_PATTERNS=tuple(pattern for pattern in os.environ.get('CONTROL_TENSOR_NAME_PATTERNS','attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,skip_gates,v_mix,loop_gate').split(',')if pattern) +class Optimizers: + def __init__(self,h,base_model): + block_named_params=list(base_model.blocks.named_parameters());matrix_params=[p for(name,p)in block_named_params if p.ndim==2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)];scalar_params=[p for(name,p)in block_named_params if p.ndim<2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)] + if base_model.skip_weights.numel()>0:scalar_params.append(base_model.skip_weights) + if base_model.skip_gates is not None and base_model.skip_gates.numel()>0:scalar_params.append(base_model.skip_gates) + if base_model.loop_gate is not None:scalar_params.append(base_model.loop_gate) + if getattr(base_model,'mos_k',1)>1: + for s in base_model.mos_scales:scalar_params.append(s) + scalar_params.append(base_model.mos_gate) + token_lr=h.tied_embed_lr if h.tie_embeddings else h.embed_lr;tok_params=[{'params':[base_model.tok_emb.weight],'lr':token_lr,'base_lr':token_lr}];self.optimizer_tok=torch.optim.AdamW(tok_params,betas=(h.beta1,h.beta2),eps=h.adam_eps,weight_decay=h.embed_wd,fused=True);self.optimizer_muon=Muon(matrix_params,lr=h.matrix_lr,momentum=h.muon_momentum,backend_steps=h.muon_backend_steps,weight_decay=h.muon_wd,row_normalize=h.muon_row_normalize) + for group in self.optimizer_muon.param_groups:group['base_lr']=h.matrix_lr + self.optimizer_scalar=torch.optim.AdamW([{'params':scalar_params,'lr':h.scalar_lr,'base_lr':h.scalar_lr}],betas=(h.beta1,h.beta2),eps=h.adam_eps,weight_decay=h.adam_wd,fused=True);self.optimizers=[self.optimizer_tok,self.optimizer_muon,self.optimizer_scalar] + if base_model.lm_head is not None:self.optimizer_head=torch.optim.Adam([{'params':[base_model.lm_head.weight],'lr':h.head_lr,'base_lr':h.head_lr}],betas=(h.beta1,h.beta2),eps=h.adam_eps,fused=True);self.optimizers.insert(1,self.optimizer_head) + else:self.optimizer_head=None + def __iter__(self):return iter(self.optimizers) + def zero_grad_all(self): + for opt in self.optimizers:opt.zero_grad(set_to_none=True) + def step(self): + for opt in self.optimizers:opt.step() + self.zero_grad_all() +def restore_fp32_params(model): + for module in model.modules(): + if isinstance(module,CastedLinear):module.float() + for(name,param)in model.named_parameters(): + if(param.ndim<2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS))and param.dtype!=torch.float32:param.data=param.data.float() +def collect_hessians(model,train_loader,h,device,n_calibration_batches=64): + hessians={};hooks=[] + def make_hook(name): + def hook_fn(module,inp,out): + x=inp[0].detach().float() + if x.ndim==3:x=x.reshape(-1,x.shape[-1]) + if name not in hessians:hessians[name]=torch.zeros(x.shape[1],x.shape[1],dtype=torch.float32,device=device) + hessians[name].addmm_(x.T,x) + return hook_fn + for(name,module)in model.named_modules(): + if isinstance(module,CastedLinear)and module.weight.numel()>65536: + cat=classify_param(name+'.weight') + if cat in('mlp','attn'):hooks.append(module.register_forward_hook(make_hook(name+'.weight'))) + if model.tie_embeddings: + hook_module=model.head_proj if model.head_proj is not None else model.final_norm + def make_output_hook(name): + def hook_fn(module,inp,out): + x=out.detach().float() + if x.ndim==3:x=x.reshape(-1,x.shape[-1]) + if name not in hessians:hessians[name]=torch.zeros(x.shape[1],x.shape[1],dtype=torch.float32,device=device) + hessians[name].addmm_(x.T,x) + return hook_fn + hooks.append(hook_module.register_forward_hook(make_output_hook('tok_emb.weight'))) + model.eval() + with torch.no_grad(): + for _ in range(n_calibration_batches):x,_=train_loader.next_batch(h.train_batch_tokens,h.grad_accum_steps);model.forward_logits(x) + for hook in hooks:hook.remove() + for name in hessians:hessians[name]=hessians[name].cpu()/n_calibration_batches + return hessians +def _gptq_core(W_orig,H_prepared,Hinv,perm,invperm,s,clip_range,block_size): + rows,cols=W_orig.shape;sf=s.float();Q=torch.zeros(rows,cols,dtype=torch.int8);W_work=W_orig[:,perm].clone() + for i1 in range(0,cols,block_size): + i2=min(i1+block_size,cols);W_block=W_work[:,i1:i2].clone();Hinv_block=Hinv[i1:i2,i1:i2];Err=torch.zeros(rows,i2-i1) + for j in range(i2-i1):w_col=W_block[:,j];d=Hinv_block[j,j];q_col=torch.clamp(torch.round(w_col/sf),-clip_range,clip_range);Q[:,i1+j]=q_col.to(torch.int8);err=(w_col-q_col.float()*sf)/d;Err[:,j]=err;W_block[:,j:]-=err.unsqueeze(1)*Hinv_block[j,j:].unsqueeze(0) + if i2=1.:s_cand=(W_orig.abs().max(dim=1).values/clip_range).clamp_min(1e-10).to(torch.float16) + else:s_cand=(torch.quantile(W_orig.abs(),pct,dim=1)/clip_range).clamp_min(1e-10).to(torch.float16) + sf_c=s_cand.float();Q_c=torch.zeros(rows,cols,dtype=torch.int8);W_work_c=W_perm.clone() + for i1 in range(0,cols,block_size): + i2=min(i1+block_size,cols);W_block=W_work_c[:,i1:i2].clone();Hinv_block=Hinv[i1:i2,i1:i2];Err=torch.zeros(rows,i2-i1) + for j in range(i2-i1):w_col=W_block[:,j];d=Hinv_block[j,j];q_col=torch.clamp(torch.round(w_col/sf_c),-clip_range,clip_range);Q_c[:,i1+j]=q_col.to(torch.int8);err=(w_col-q_col.float()*sf_c)/d;Err[:,j]=err;W_block[:,j:]-=err.unsqueeze(1)*Hinv_block[j,j:].unsqueeze(0) + if i20.: + diagH=torch.diag(H).clamp_min(1e-8);col_imp=diagH/diagH.mean();row_imp=(W_orig.abs()*col_imp.unsqueeze(0)).mean(dim=1);row_imp=row_imp/row_imp.mean();adj=1.+hessian_clip_lambda*(row_imp-1.);s=(clip_sigmas*row_std*adj/clip_range).clamp_min(1e-10).to(torch.float16) + else:s=(clip_sigmas*row_std/clip_range).clamp_min(1e-10).to(torch.float16) + sf=s.float();Q=torch.zeros(rows,cols,dtype=torch.int8);W_work=W_perm.clone() + for i1 in range(0,cols,block_size): + i2=min(i1+block_size,cols);W_block=W_work[:,i1:i2].clone();Hinv_block=Hinv[i1:i2,i1:i2];Err=torch.zeros(rows,i2-i1) + for j in range(i2-i1):w_col=W_block[:,j];d=Hinv_block[j,j];q_col=torch.clamp(torch.round(w_col/sf),-clip_range,clip_range);Q[:,i1+j]=q_col.to(torch.int8);err=(w_col-q_col.float()*sf)/d;Err[:,j]=err;W_block[:,j:]-=err.unsqueeze(1)*Hinv_block[j,j:].unsqueeze(0) + if i20:cs=h.mlp_clip_sigmas + elif'.attn.'in name and h.attn_clip_sigmas>0:cs=h.attn_clip_sigmas + else:cs=h.matrix_clip_sigmas + bits=h.embed_bits if'tok_emb'in name else h.matrix_bits;hcl=0. if'tok_emb'in name else h.hessian_clip_lambda;q,s=gptq_quantize_weight(t,hessians[name],clip_sigmas=cs,clip_range=2**(bits-1)-1,hessian_clip_lambda=hcl);result[name+'.q']=q;result[name+'.scale']=s;meta[name]=f"gptq (int{bits})" + if lqer_on and 'tok_emb' not in name: + W_q=q.float()*s.float().view(-1,1);E=t.float()-W_q;lqer_cands[name]=(E,float(E.norm())) + if lqer_on and lqer_cands: + top=sorted(lqer_cands.items(),key=lambda kv:-kv[1][1])[:h.lqer_top_k];asym_g=h.lqer_asym_group + for(name,(E,enorm))in top: + U,S,Vh=torch.linalg.svd(E,full_matrices=False);r=min(h.lqer_rank,S.numel()) + A=(U[:,:r]*S[:r]).contiguous();B=Vh[:r,:].contiguous() + if B.numel()%asym_g==0: + qA,sA,qB,sB=_lqer_pack_asym(A,B,asym_g) + result[name+'.lqA_a']=qA;result[name+'.lqAs_a']=sA;result[name+'.lqB_a']=qB;result[name+'.lqBs_a']=sB + meta[name]=meta[name]+'+lqer_asym';log(f" LQER asym rank-{r}: {name} (err_norm={enorm:.4f})") + else: + log(f" LQER skip {name}: B.numel()={B.numel()} not divisible by group={asym_g}") + categories=collections.defaultdict(set) + for(name,cat)in meta.items():short=re.sub('\\.\\d+$','',re.sub('blocks\\.\\d+','blocks',name));categories[cat].add(short) + log('Quantized weights:') + for cat in sorted(categories):log(f" {cat}: {", ".join(sorted(categories[cat]))}") + return result,meta +def dequantize_mixed(result,meta,template_sd): + out={} + for(name,orig)in template_sd.items(): + info=meta.get(name) + if info is None:continue + orig_dtype=orig.dtype + if'passthrough'in info: + t=result[name] + if t.dtype==torch.float16 and orig_dtype in(torch.float32,torch.bfloat16):t=t.to(orig_dtype) + out[name]=t;continue + q,s=result[name+'.q'],result[name+'.scale'] + if s.ndim>0:W=q.float()*s.float().view(q.shape[0],*[1]*(q.ndim-1)) + else:W=q.float()*float(s.item()) + if'lqer_asym'in info: + qA_t=result[name+'.lqA_a'];sA_t=result[name+'.lqAs_a'];qB_t=result[name+'.lqB_a'];sB_t=result[name+'.lqBs_a'] + qA=qA_t.float()*float(sA_t);g_sz=qB_t.numel()//sB_t.numel() + qB=(qB_t.reshape(-1,g_sz).float()*sB_t.float().view(-1,1)).reshape(qB_t.shape) + W=W+qA@qB + out[name]=W.to(orig_dtype) + return out +def scale_tune_post_gptq(quant_result, quant_meta, template_sd, h, device): + """Optimize per-row quantization scales using actual CE loss on calibration data. + + After GPTQ produces initial scales, this function makes the scales learnable + and fine-tunes them via backprop through the dequantized model. The integer + quantized weights (Q_int) stay frozen; only the float16 scale tensors are updated. + This minimizes the actual CE loss instead of the per-layer MSE heuristic that GPTQ uses. + """ + log(f"Scale tuning: {h.scale_tuning_steps} steps, lr={h.scale_tuning_lr}, batches={h.scale_tuning_batches}") + t0 = time.perf_counter() + + # Build a shell model for functional_call (weights will be overridden) + tune_model = GPT(h).to(device).bfloat16() + restore_fp32_params(tune_model) + tune_model.eval() + + # Collect the frozen Q_int tensors and learnable scale params + scale_params = [] + q_int_map = {} # name -> frozen Q_int tensor on device + + for name, info in quant_meta.items(): + if 'gptq' not in info: + continue + q_key = name + '.q' + s_key = name + '.scale' + q_int_map[name] = quant_result[q_key].to(device) # frozen int8 + # Make scale a learnable parameter (float32 for optimizer stability) + scale_params.append((name, quant_result[s_key].float().to(device).requires_grad_(True))) + + # Build optimizer over scale params only + optim_params = [sp for _, sp in scale_params] + optimizer = torch.optim.Adam(optim_params, lr=h.scale_tuning_lr) + + # Load calibration data + calib_loader = ShuffledSequenceLoader(h, device) + + # Pre-collect calibration batches (small number, reuse across steps) + calib_data = [] + for _ in range(h.scale_tuning_batches): + x, y = calib_loader.next_batch(h.train_batch_tokens, h.grad_accum_steps) + calib_data.append((x, y)) + + best_loss = float('inf') + best_scales = {name: sp.detach().clone() for name, sp in scale_params} + + for step in range(h.scale_tuning_steps): + optimizer.zero_grad() + total_loss = 0.0 + + # Build dequantized state dict (differentiable through scales) + deq_sd = {} + for pname, orig in template_sd.items(): + info = quant_meta.get(pname) + if info is None: + continue + if 'passthrough' in info: + t = quant_result[pname] + if t.dtype == torch.float16 and orig.dtype in (torch.float32, torch.bfloat16): + t = t.to(orig.dtype) + deq_sd[pname] = t.to(device) + continue + q_int = q_int_map[pname] + scale_tensor = None + for sname, sp in scale_params: + if sname == pname: + scale_tensor = sp + break + if scale_tensor is not None: + deq_sd[pname] = (q_int.float() * scale_tensor.view(q_int.shape[0], 1)).to(orig.dtype) + else: + deq_sd[pname] = (q_int.float() * quant_result[pname + '.scale'].float().to(device).view(q_int.shape[0], 1)).to(orig.dtype) + + # Forward pass via functional_call (keeps gradients flowing through scales) + for batch_idx in range(len(calib_data)): + inp, tgt = calib_data[batch_idx] + with torch.autocast(device_type='cuda', dtype=torch.bfloat16, enabled=True): + loss = torch.func.functional_call(tune_model, deq_sd, (inp, tgt)) + total_loss += loss.item() + (loss / len(calib_data)).backward() + + avg_loss = total_loss / len(calib_data) + optimizer.step() + + # Clamp scales to stay positive + with torch.no_grad(): + for _, sp in scale_params: + sp.clamp_min_(1e-10) + + if avg_loss < best_loss: + best_loss = avg_loss + best_scales = {name: sp.detach().clone() for name, sp in scale_params} + + if step % 5 == 0 or step == h.scale_tuning_steps - 1: + log(f" scale_tune step {step}: CE loss = {avg_loss:.6f}") + + # Write optimized scales back into quant_result + for name, sp in best_scales.items(): + quant_result[name + '.scale'] = sp.cpu().to(torch.float16) + + elapsed = time.perf_counter() - t0 + log(f"Scale tuning done in {elapsed:.1f}s, best CE = {best_loss:.6f}") + return quant_result + +_BSHF_MAGIC=b'BSHF' +def _byte_shuffle(data,stride=2): + if stride<=1 or len(data)target_bytes: + over=len(quant_blob)-target_bytes;log(f"prune:over by {over} bytes, selective pruning") + candidates=[] + for name,info in quant_meta.items(): + if'gptq'not in info:continue + q=quant_result[name+'.q'];s=quant_result[name+'.scale'] + sf=s.float().view(-1) if s.ndim==1 else s.float()[:,0] + for r in range(q.shape[0]): + for c in range(q.shape[1]): + v=int(q[r,c]) + if 01:nll=F.nll_loss(logits.reshape(-1,logits.size(-1)).float(),y_batch.reshape(-1),reduction='none').reshape(bsz,seq_len) + else:nll=F.cross_entropy(logits.reshape(-1,logits.size(-1)).float(),y_batch.reshape(-1),reduction='none').reshape(bsz,seq_len) + for(i,ws)in enumerate(batch_ws):wlen=wlens[i];s=0 if ws==0 else context_size;scored_nll=nll[i,s:wlen].to(torch.float64);loss_sum+=scored_nll.sum();token_count+=float(wlen-s);tgt=y_batch[i,s:wlen];prev=x_batch[i,s:wlen];tb=val_data.base_bytes_lut[tgt].to(torch.float64);tb+=(val_data.has_leading_space_lut[tgt]&~val_data.is_boundary_token_lut[prev]).to(torch.float64);byte_count+=tb.sum() + if dist.is_available()and dist.is_initialized():dist.all_reduce(loss_sum,op=dist.ReduceOp.SUM);dist.all_reduce(token_count,op=dist.ReduceOp.SUM);dist.all_reduce(byte_count,op=dist.ReduceOp.SUM) + base_model.train();return _loss_bpb(loss_sum,token_count,byte_count) +# ===== Batched LoRA TTT (ported from 1.06335 submission) ===== +class BatchedLinearLoRA(nn.Module): + """LoRA adapter with batched parameters [bsz, rank, dim] for parallel doc processing.""" + _ALPHA = float(os.environ.get("TTT_LORA_ALPHA", "144")) + _WARM_START_A = bool(int(os.environ.get("TTT_WARM_START_A", "1"))) + def __init__(self, bsz, in_features, out_features, rank): + super().__init__() + self._bound = 1.0 / math.sqrt(in_features) + self._scale = self._ALPHA / rank + self.A = nn.Parameter(torch.empty(bsz, rank, in_features).uniform_(-self._bound, self._bound)) + self.B = nn.Parameter(torch.zeros(bsz, out_features, rank)) + def reset(self): + with torch.no_grad(): + if not self._WARM_START_A: + self.A.uniform_(-self._bound, self._bound) + self.B.zero_() + def forward(self, x): + return ((x @ self.A.transpose(1, 2)) @ self.B.transpose(1, 2)) * self._scale + +class BatchedTTTLoRA(nn.Module): + """Container for all LoRA adapters needed for TTT on our model.""" + def __init__(self, bsz, model, rank, k_lora=True, mlp_lora=True, o_lora=True): + super().__init__() + self.bsz = bsz + dim = model.blocks[0].attn.c_q.in_features + vocab = model.tok_emb.num_embeddings + embed_dim = model.tok_emb.embedding_dim + if getattr(model, "looping_active", False): + num_slots = len(model.encoder_indices) + len(model.decoder_indices) + else: + num_slots = len(model.blocks) + kv_dim = model.blocks[0].attn.num_kv_heads * model.blocks[0].attn.head_dim + self.lm_head_lora = BatchedLinearLoRA(bsz, embed_dim, vocab, rank) + self.q_loras = nn.ModuleList([BatchedLinearLoRA(bsz, dim, dim, rank) for _ in range(num_slots)]) + self.v_loras = nn.ModuleList([BatchedLinearLoRA(bsz, dim, kv_dim, rank) for _ in range(num_slots)]) + self.k_loras = nn.ModuleList([BatchedLinearLoRA(bsz, dim, kv_dim, rank) for _ in range(num_slots)]) if k_lora else None + self.mlp_loras = nn.ModuleList([BatchedLinearLoRA(bsz, dim, dim, rank) for _ in range(num_slots)]) if mlp_lora else None + self.o_loras = nn.ModuleList([BatchedLinearLoRA(bsz, dim, dim, rank) for _ in range(num_slots)]) if o_lora else None + def reset(self): + with torch.no_grad(): + self.lm_head_lora.reset() + for loras in [self.q_loras, self.v_loras, self.k_loras, self.mlp_loras, self.o_loras]: + if loras is not None: + for lora in loras: + lora.reset() + +BOS_ID = None + +def _find_docs(all_tokens): + global BOS_ID + if BOS_ID is None: + BOS_ID = 1 + bos_positions = (all_tokens == BOS_ID).nonzero(as_tuple=True)[0].numpy() + docs = [] + for i in range(len(bos_positions)): + start = int(bos_positions[i]) + end = int(bos_positions[i + 1]) if i + 1 < len(bos_positions) else all_tokens.numel() + if i + 1 < len(bos_positions): + end += 1 + assert end - start >= 2 + docs.append((start, end - start)) + return docs + +def _select_ttt_doc_entries(docs, h): + doc_entries = list(enumerate(docs)) + if h.val_doc_fraction < 1.0: + sample_n = max(1, int(round(len(docs) * h.val_doc_fraction))) + sampled_indices = sorted(random.Random(h.seed).sample(range(len(docs)), sample_n)) + return [(i, docs[i]) for i in sampled_indices] + return doc_entries + +def _build_ttt_global_batches(doc_entries, h, ascending=False): + batch_size = h.ttt_batch_size + global_doc_entries = sorted(doc_entries, key=lambda x: x[1][1]) + global_batches = [global_doc_entries[i : i + batch_size] for i in range(0, len(global_doc_entries), batch_size)] + indexed = list(enumerate(global_batches)) + if not ascending: + indexed.sort(key=lambda ib: -max(dl for _, (_, dl) in ib[1])) + return indexed + +def _init_batch_counter(path): + with open(path, "wb") as f: + f.write((0).to_bytes(4, "little")) + +def _claim_next_batch(counter_path, queue_len): + try: + with open(counter_path, "r+b") as f: + fcntl.flock(f, fcntl.LOCK_EX) + idx = int.from_bytes(f.read(4), "little") + f.seek(0) + f.write((idx + 1).to_bytes(4, "little")) + f.flush() + except FileNotFoundError: + return queue_len + return idx + +def _compute_chunk_window(ci, pred_len, num_chunks, chunk_size, eval_seq_len): + chunk_end = pred_len if ci == num_chunks - 1 else (ci + 1) * chunk_size + win_start = max(0, chunk_end - eval_seq_len) + win_len = chunk_end - win_start + chunk_start = ci * chunk_size + chunk_offset = chunk_start - win_start + chunk_len = chunk_end - chunk_start + return win_start, win_len, chunk_offset, chunk_len + +def _accumulate_bpb(ptl, x, y, chunk_offsets, chunk_lens, pos_idx, + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + loss_sum, byte_sum, token_count): + pos = pos_idx[: x.size(1)].unsqueeze(0) + mask = ((chunk_lens.unsqueeze(1) > 0) + & (pos >= chunk_offsets.unsqueeze(1)) + & (pos < (chunk_offsets + chunk_lens).unsqueeze(1))) + mask_f64 = mask.to(torch.float64) + tok_bytes = base_bytes_lut[y].to(torch.float64) + tok_bytes += (has_leading_space_lut[y] & ~is_boundary_token_lut[x]).to(torch.float64) + loss_sum += (ptl.to(torch.float64) * mask_f64).sum() + byte_sum += (tok_bytes * mask_f64).sum() + token_count += chunk_lens.to(torch.float64).sum() + +def _loss_bpb_from_sums(loss_sum, token_count, byte_sum): + val_loss = (loss_sum / token_count).item() + val_bpb = val_loss / math.log(2.0) * (token_count.item() / byte_sum.item()) + return val_loss, val_bpb + +def _init_int64_counter(path): + with open(path, "wb") as f: + f.write((0).to_bytes(8, "little", signed=True)) + +def _add_to_counter(path, delta): + try: + with open(path, "r+b") as f: + fcntl.flock(f, fcntl.LOCK_EX) + cur = int.from_bytes(f.read(8), "little", signed=True) + cur += int(delta) + f.seek(0) + f.write(int(cur).to_bytes(8, "little", signed=True)) + f.flush() + return cur + except FileNotFoundError: + return int(delta) + +def train_val_ttt_global_sgd_distributed(h, device, val_data, base_model, val_tokens, batch_seqs=None): + """Global SGD TTT: train base model weights on scored prefix docs.""" + global BOS_ID + if BOS_ID is None: + BOS_ID = 1 + base_model.eval() + seq_len = h.eval_seq_len + total_tokens = val_tokens.numel() - 1 + ttt_chunk = h.global_ttt_chunk_tokens + batch_seqs = h.global_ttt_batch_seqs if batch_seqs is None else batch_seqs + num_chunks = (total_tokens + ttt_chunk - 1) // ttt_chunk + ttt_params = [p for p in base_model.parameters()] + for p in ttt_params: + p.requires_grad_(True) + optimizer = torch.optim.SGD(ttt_params, lr=h.global_ttt_lr, momentum=h.global_ttt_momentum) + t_start = time.perf_counter() + for ci in range(num_chunks): + chunk_start = ci * ttt_chunk + chunk_end = min((ci + 1) * ttt_chunk, total_tokens) + is_last_chunk = ci == num_chunks - 1 + if is_last_chunk or h.global_ttt_epochs <= 0: + continue + base_model.train() + chunk_seqs = (chunk_end - chunk_start) // seq_len + if chunk_seqs <= 0: + continue + warmup_chunks = max(0, min(h.global_ttt_warmup_chunks, num_chunks - 1)) + if warmup_chunks > 0 and ci < warmup_chunks: + warmup_denom = max(warmup_chunks - 1, 1) + warmup_t = ci / warmup_denom + lr_now = h.global_ttt_warmup_start_lr + (h.global_ttt_lr - h.global_ttt_warmup_start_lr) * warmup_t + else: + decay_steps = max(num_chunks - 1 - warmup_chunks, 1) + decay_ci = max(ci - warmup_chunks, 0) + lr_now = h.global_ttt_lr * 0.5 * (1.0 + math.cos(math.pi * decay_ci / decay_steps)) + for pg in optimizer.param_groups: + pg["lr"] = lr_now + my_seq_s = chunk_seqs * h.rank // h.world_size + my_seq_e = chunk_seqs * (h.rank + 1) // h.world_size + my_chunk_seqs = my_seq_e - my_seq_s + for _ in range(h.global_ttt_epochs): + for bs in range(0, my_chunk_seqs, batch_seqs): + be = min(bs + batch_seqs, my_chunk_seqs) + actual_bs = my_seq_s + bs + start_tok = chunk_start + actual_bs * seq_len + end_tok = chunk_start + (my_seq_s + be) * seq_len + 1 + if end_tok > val_tokens.numel(): + continue + local = val_tokens[start_tok:end_tok].to(device=device, dtype=torch.int64) + x_flat = local[:-1] + y_flat = local[1:] + optimizer.zero_grad(set_to_none=True) + with torch.enable_grad(): + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + x = x_flat.reshape(-1, seq_len) + y = y_flat.reshape(-1, seq_len) + loss = base_model(x, y) + loss.backward() + if dist.is_available() and dist.is_initialized(): + for p in ttt_params: + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.SUM) + p.grad.mul_(1.0 / h.world_size) + if h.global_ttt_grad_clip > 0: + torch.nn.utils.clip_grad_norm_(ttt_params, h.global_ttt_grad_clip) + optimizer.step() + base_model.eval() + if h.rank == 0: + elapsed = time.perf_counter() - t_start + log(f"tttg: c{ci+1}/{num_chunks} lr:{lr_now:.6f} t:{elapsed:.1f}s") + for p in base_model.parameters(): + p.requires_grad_(True) + base_model.eval() + +def eval_val_ttt_phased(h, base_model, device, val_data, forward_ttt_train): + """Phased LoRA TTT: document-level batched scoring with per-doc LoRA adaptation.""" + global BOS_ID + if BOS_ID is None: + BOS_ID = 1 + base_model.eval() + for p in base_model.parameters(): + p.requires_grad_(False) + all_tokens = val_data.val_tokens + all_tokens_idx = all_tokens.to(torch.int32) + docs = _find_docs(all_tokens) + doc_entries = _select_ttt_doc_entries(docs, h) + prefix_doc_limit = max(0, min(len(doc_entries), int(h.phased_ttt_prefix_docs))) + num_phases = max(1, int(h.phased_ttt_num_phases)) + phase_boundaries = [] + for pi in range(num_phases): + boundary = prefix_doc_limit * (pi + 1) // num_phases + phase_boundaries.append(boundary) + current_phase = 0 + current_phase_boundary = phase_boundaries[0] + log(f"ttt_phased: total_docs:{len(doc_entries)} prefix_docs:{prefix_doc_limit} " + f"suffix_docs:{len(doc_entries) - prefix_doc_limit} " + f"num_phases:{num_phases} boundaries:{phase_boundaries}") + chunk_size, eval_seq_len = h.ttt_chunk_size, h.ttt_eval_seq_len + eval_batch_set = None + if h.ttt_eval_batches: + eval_batch_set = set(int(x) for x in h.ttt_eval_batches.split(",") if x.strip()) + use_ascending = eval_batch_set is not None + global_batches_sorted = _build_ttt_global_batches(doc_entries, h, ascending=use_ascending) + queue_len = len(global_batches_sorted) + counter_path = f"/tmp/ttt_counter_{h.run_id}" + prefix_counter_path = f"/tmp/ttt_prefix_counter_{h.run_id}" + pause_flag_path = f"/tmp/ttt_pause_flag_{h.run_id}" + if h.rank == 0: + _init_batch_counter(counter_path) + _init_int64_counter(prefix_counter_path) + try: + os.remove(pause_flag_path) + except FileNotFoundError: + pass + if dist.is_available() and dist.is_initialized(): + path_list = [counter_path, prefix_counter_path, pause_flag_path] + dist.broadcast_object_list(path_list, src=0) + counter_path, prefix_counter_path, pause_flag_path = path_list + dist.barrier() + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + byte_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + t_start = time.perf_counter() + reusable_lora = BatchedTTTLoRA( + h.ttt_batch_size, base_model, h.ttt_lora_rank, + k_lora=h.ttt_k_lora, mlp_lora=h.ttt_mlp_lora, o_lora=h.ttt_o_lora, + ).to(device) + def _build_opt(lora): + if h.ttt_optimizer == "sgd": + return torch.optim.SGD(lora.parameters(), lr=h.ttt_lora_lr, + momentum=h.ttt_beta1, weight_decay=h.ttt_weight_decay) + return torch.optim.AdamW(lora.parameters(), lr=h.ttt_lora_lr, + betas=(h.ttt_beta1, h.ttt_beta2), + eps=1e-10, weight_decay=h.ttt_weight_decay, fused=True) + reusable_opt = _build_opt(reusable_lora) + local_scored_docs = [] + global_ttt_done = prefix_doc_limit == 0 + try: + while True: + queue_idx = _claim_next_batch(counter_path, queue_len) + if queue_idx >= queue_len: + break + orig_batch_idx, batch_entries = global_batches_sorted[queue_idx] + batch = [doc for _, doc in batch_entries] + bsz = len(batch) + prev_loss = loss_sum.item() + prev_bytes = byte_sum.item() + prev_tokens = token_count.item() + if bsz == reusable_lora.bsz: + reusable_lora.reset() + for s in reusable_opt.state.values(): + for k, v in s.items(): + if isinstance(v, torch.Tensor): + v.zero_() + elif k == "step": + s[k] = 0 + cur_lora = reusable_lora + cur_opt = reusable_opt + else: + cur_lora = BatchedTTTLoRA( + bsz, base_model, h.ttt_lora_rank, + k_lora=h.ttt_k_lora, mlp_lora=h.ttt_mlp_lora, o_lora=h.ttt_o_lora, + ).to(device) + cur_opt = _build_opt(cur_lora) + pred_lens = [doc_len - 1 for _, doc_len in batch] + num_chunks = [(pl + chunk_size - 1) // chunk_size for pl in pred_lens] + max_nc = max(num_chunks) + num_chunks_t = torch.tensor(num_chunks, dtype=torch.int64, device=device) + for ci in range(max_nc): + active = [ci < nc for nc in num_chunks] + needs_train = any(ci < nc - 1 for nc in num_chunks) + tok_starts = torch.zeros(bsz, dtype=torch.int64) + tok_wls = torch.zeros(bsz, dtype=torch.int64) + chunk_offsets_cpu = torch.zeros(bsz, dtype=torch.int64) + chunk_lens_cpu = torch.zeros(bsz, dtype=torch.int64) + for b in range(bsz): + if not active[b]: + continue + doc_start, doc_len = batch[b] + win_start, win_len, chunk_offset, chunk_len = _compute_chunk_window( + ci, pred_lens[b], num_chunks[b], chunk_size, eval_seq_len) + tok_starts[b] = doc_start + win_start + tok_wls[b] = win_len + chunk_offsets_cpu[b] = chunk_offset + chunk_lens_cpu[b] = chunk_len + _, context_size, chunk_offset, _ = _compute_chunk_window( + ci, (ci + 1) * chunk_size, ci + 1, chunk_size, eval_seq_len) + col_idx = torch.arange(context_size + 1) + idx = tok_starts.unsqueeze(1) + col_idx.unsqueeze(0) + idx.clamp_(max=all_tokens.numel() - 1) + gathered_gpu = all_tokens_idx[idx].to(device=device, dtype=torch.int64, non_blocking=True) + valid = (col_idx[:context_size].unsqueeze(0) < tok_wls.unsqueeze(1)).to(device, non_blocking=True) + chunk_offsets = chunk_offsets_cpu.to(device, non_blocking=True) + chunk_lens = chunk_lens_cpu.to(device, non_blocking=True) + x = torch.where(valid, gathered_gpu[:, :context_size], 0) + y = torch.where(valid, gathered_gpu[:, 1 : context_size + 1], 0) + ctx_pos = torch.arange(context_size, device=device, dtype=torch.int64) + if needs_train: + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + per_tok_loss = forward_ttt_train(x, y, lora=cur_lora) + else: + with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16): + per_tok_loss = forward_ttt_train(x, y, lora=cur_lora) + with torch.no_grad(): + _accumulate_bpb(per_tok_loss, x, y, chunk_offsets, chunk_lens, ctx_pos, + val_data.base_bytes_lut, val_data.has_leading_space_lut, + val_data.is_boundary_token_lut, + loss_sum, byte_sum, token_count) + if needs_train: + activate_chunk_mask = (num_chunks_t - 1 > ci).float() + for gi in range(h.ttt_grad_steps): + if gi > 0: + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + per_tok_loss = forward_ttt_train(x, y, lora=cur_lora) + per_doc = per_tok_loss[:, chunk_offset : chunk_offset + chunk_size].mean(dim=-1) + cur_opt.zero_grad(set_to_none=True) + (per_doc * activate_chunk_mask).sum().backward() + cur_opt.step() + else: + del per_tok_loss + batch_num = orig_batch_idx + 1 + doc_lens = [dl for _, dl in batch] + should_report = batch_num in eval_batch_set if eval_batch_set is not None else True + if should_report: + cur_tokens = token_count.item() + cur_loss_val = loss_sum.item() + cur_bytes_val = byte_sum.item() + dt = cur_tokens - prev_tokens + db = cur_bytes_val - prev_bytes + if dt > 0 and db > 0: + b_loss = (cur_loss_val - prev_loss) / dt + b_bpb = b_loss / math.log(2.0) * (dt / db) + else: + b_loss = b_bpb = 0.0 + r_loss = cur_loss_val / max(cur_tokens, 1) + r_bpb = r_loss / math.log(2.0) * (cur_tokens / max(cur_bytes_val, 1)) + elapsed = time.perf_counter() - t_start + log(f"ttp: b{batch_num}/{queue_len} bl:{b_loss:.4f} bb:{b_bpb:.4f} " + f"rl:{r_loss:.4f} rb:{r_bpb:.4f} dl:{min(doc_lens)}-{max(doc_lens)} " + f"gd:{int(global_ttt_done)}") + if not global_ttt_done: + local_scored_docs.extend( + (orig_batch_idx, pos, doc_start, doc_len) + for pos, (doc_start, doc_len) in enumerate(batch)) + prefix_done = _add_to_counter(prefix_counter_path, len(batch_entries)) + if prefix_done >= current_phase_boundary: + try: + with open(pause_flag_path, "x"): + pass + except FileExistsError: + pass + should_pause = os.path.exists(pause_flag_path) + if should_pause: + if dist.is_available() and dist.is_initialized(): + dist.barrier() + gathered_scored_docs = [None] * h.world_size + if dist.is_available() and dist.is_initialized(): + dist.all_gather_object(gathered_scored_docs, local_scored_docs) + else: + gathered_scored_docs = [local_scored_docs] + scored_docs_for_global = [] + for rank_docs in gathered_scored_docs: + if rank_docs: + scored_docs_for_global.extend(rank_docs) + scored_docs_for_global.sort(key=lambda x: (x[0], x[1])) + scored_docs_for_global = scored_docs_for_global[:current_phase_boundary] + scored_token_chunks = [ + val_data.val_tokens[doc_start : doc_start + doc_len] + for _, _, doc_start, doc_len in scored_docs_for_global] + if scored_token_chunks: + global_ttt_tokens = torch.cat(scored_token_chunks) + else: + global_ttt_tokens = val_data.val_tokens[:0] + if h.rank == 0: + log(f"ttpp: phase:{current_phase + 1}/{num_phases} " + f"gd:{len(scored_docs_for_global)} t:{time.perf_counter() - t_start:.1f}s") + train_val_ttt_global_sgd_distributed(h, device, val_data, base_model, global_ttt_tokens) + for p in base_model.parameters(): + p.requires_grad_(False) + reusable_lora = BatchedTTTLoRA( + h.ttt_batch_size, base_model, h.ttt_lora_rank, + k_lora=h.ttt_k_lora, mlp_lora=h.ttt_mlp_lora, o_lora=h.ttt_o_lora, + ).to(device) + reusable_opt = _build_opt(reusable_lora) + current_phase += 1 + if current_phase >= num_phases: + global_ttt_done = True + else: + current_phase_boundary = phase_boundaries[current_phase] + if h.rank == 0: + try: + os.remove(pause_flag_path) + except FileNotFoundError: + pass + if dist.is_available() and dist.is_initialized(): + dist.barrier() + if h.rank == 0: + log(f"ttpr: phase:{current_phase}/{num_phases} t:{time.perf_counter() - t_start:.1f}s") + del cur_lora, cur_opt + finally: + pass + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + for p in base_model.parameters(): + p.requires_grad_(True) + base_model.train() + return _loss_bpb_from_sums(loss_sum, token_count, byte_sum) + +def timed_eval(label,fn,*args,**kwargs):torch.cuda.synchronize();t0=time.perf_counter();val_loss,val_bpb=fn(*args,**kwargs);torch.cuda.synchronize();elapsed_ms=1e3*(time.perf_counter()-t0);log(f"{label} val_loss:{val_loss:.8f} val_bpb:{val_bpb:.8f} eval_time:{elapsed_ms:.0f}ms");return val_loss,val_bpb +def train_model(h,device,val_data): + compile_mode=os.environ.get('COMPILE_MODE','default');base_model=GPT(h).to(device).bfloat16();restore_fp32_params(base_model) + # Build byte-count LUT for byte-weighted CE loss + _byte_lut=None + if h.byte_weighted_ce: + sp=spm.SentencePieceProcessor(model_file=h.tokenizer_path);base_bytes_lut,_,_=build_sentencepiece_luts(sp,h.vocab_size,device);_byte_lut=base_bytes_lut.clamp(min=1,max=4).float();log(f"byte_weighted_ce:enabled, lut shape={_byte_lut.shape}, mean_weight={_byte_lut.mean().item():.3f}") + compiled_model=torch.compile(base_model,mode=compile_mode if compile_mode!='default'else None,dynamic=False,fullgraph=True) + if h.distributed:model=DDP(compiled_model,device_ids=[h.local_rank],broadcast_buffers=False) + else:model=compiled_model + log(f"model_params:{sum(p.numel()for p in base_model.parameters())}");optimizers=Optimizers(h,base_model);train_loader=ShuffledSequenceLoader(h,device);max_wallclock_ms=1e3*h.max_wallclock_seconds if h.max_wallclock_seconds>0 else None + if max_wallclock_ms is not None:max_wallclock_ms-=h.gptq_reserve_seconds*1e3;log(f"gptq:reserving {h.gptq_reserve_seconds:.0f}s, effective={max_wallclock_ms:.0f}ms") + def training_frac(step,elapsed_ms): + if max_wallclock_ms is None:return step/max(h.iterations,1) + return elapsed_ms/max(max_wallclock_ms,1e-09) + lr_schedule=os.environ.get('LR_SCHEDULE','linear') + def lr_mul(frac): + if h.warmdown_frac<=0:return 1. + if frac>=1.-h.warmdown_frac: + raw=(1.-frac)/h.warmdown_frac + if lr_schedule=='sqrt':return max(math.sqrt(raw),h.min_lr) + return max(raw,h.min_lr) + return 1. + _cur_batch_tokens=[h.train_batch_tokens];_kd_active=[h.kd_enabled];_approx_training_ms=[0.];_multi_traj=[False] + if h.kd_enabled:log(f"kd:enabled alpha={h.kd_alpha} temp={h.kd_temperature} top_k={h.kd_top_k} dir={h.kd_logits_dir} warmup_frac={h.kd_warmup_frac}") + def step_fn(step,lr_scale): + optimizers.zero_grad_all();train_loss=torch.zeros((),device=device) + batch_tokens=_cur_batch_tokens[0] + for micro_step in range(h.grad_accum_steps): + if h.distributed:model.require_backward_grad_sync=(not _multi_traj[0]) and micro_step==h.grad_accum_steps-1 + batch_data=train_loader.next_batch(batch_tokens,h.grad_accum_steps) + x,y=batch_data[0],batch_data[1] + bw=_byte_lut[y] if _byte_lut is not None else None + if _kd_active[0] and len(batch_data)==4: + t_idx,t_val=batch_data[2],batch_data[3] + kd_alpha_eff=h.kd_alpha + if h.kd_warmup_frac>0: + approx_ms=_approx_training_ms[0] + frac=training_frac(step,approx_ms) + if frac0 else 1.;muon_momentum=(1-frac)*h.muon_momentum_warmup_start+frac*h.muon_momentum + for group in optimizers.optimizer_muon.param_groups:group['momentum']=muon_momentum + for opt in optimizers: + for group in opt.param_groups:group['lr']=group['base_lr']*lr_scale + if h.grad_clip_norm>0:torch.nn.utils.clip_grad_norm_(base_model.parameters(),h.grad_clip_norm) + optimizers.step();return train_loss + if h.warmup_steps>0: + initial_model_state={name:tensor.detach().cpu().clone()for(name,tensor)in base_model.state_dict().items()};initial_optimizer_states=[copy.deepcopy(opt.state_dict())for opt in optimizers];model.train() + for warmup_step in range(h.warmup_steps): + step_fn(warmup_step,1.) + if warmup_step<=5 or(warmup_step+1)%10==0 or warmup_step+1==h.warmup_steps:log(f"warmup_step: {warmup_step+1}/{h.warmup_steps}") + + if len(base_model.encoder_indices)!=base_model.num_encoder_layers: + base_model.looping_active=True;log(f"loop_warmup:on") + for warmup_step in range(h.warmup_steps): + step_fn(warmup_step,1.) + if warmup_step<=5 or(warmup_step+1)%10==0 or warmup_step+1==h.warmup_steps:log(f"loop_warmup_step: {warmup_step+1}/{h.warmup_steps}") + base_model.looping_active=False + base_model.load_state_dict(initial_model_state,strict=True) + for(opt,state)in zip(optimizers,initial_optimizer_states,strict=True):opt.load_state_dict(state) + optimizers.zero_grad_all() + if h.distributed:model.require_backward_grad_sync=True + train_loader=ShuffledSequenceLoader(h,device) + ema_state={name:t.detach().float().clone()for(name,t)in base_model.state_dict().items()};ema_decay=h.ema_decay;swa_state=None;swa_count=0;training_time_ms=.0;stop_after_step=None;torch.cuda.synchronize();t0=time.perf_counter();step=0 + while True: + last_step=step==h.iterations or stop_after_step is not None and step>=stop_after_step;should_validate=last_step or h.val_loss_every>0 and step%h.val_loss_every==0 + if should_validate:torch.cuda.synchronize();training_time_ms+=1e3*(time.perf_counter()-t0);val_loss,val_bpb=eval_val(h,device,val_data,model);log(f"{step}/{h.iterations} val_loss: {val_loss:.4f} val_bpb: {val_bpb:.4f}");torch.cuda.synchronize();t0=time.perf_counter() + if last_step: + if stop_after_step is not None and step=h.enable_looping_at:base_model.looping_active=True;log(f"layer_loop:enabled step:{step} frac:{frac:.3f} encoder:{base_model.encoder_indices} decoder:{base_model.decoder_indices}") + if h.sparsity_start_frac>0 and frac>=h.sparsity_start_frac and not getattr(base_model,'_sparsity_applied',False): + with torch.no_grad(): + for module in base_model.modules(): + if isinstance(module,MLP): + for linear in[module.fc,module.proj]: + w=linear.weight.data + for i in range(0,w.shape[1],4): + group=w[:,i:i+4].abs();_,idx=group.topk(2,dim=1,largest=False) + for j in range(2):w[range(w.shape[0]),i+idx[:,j]]=0 + base_model._sparsity_applied=True;log(f"sparsity:applied 2:4 pruning at frac={frac:.3f}") + momentum_cooldown=float(os.environ.get('MOMENTUM_COOLDOWN',0.)) + if momentum_cooldown>0 and scale<1.: + cool_mom=h.muon_momentum-momentum_cooldown*(1.-scale); + for group in optimizers.optimizer_muon.param_groups:group['momentum']=cool_mom + momentum_warmdown_target=float(os.environ.get('MOMENTUM_WARMDOWN_TARGET','0')) + if momentum_warmdown_target>0 and scale<1.: + warmdown_progress=1.-scale + mom=h.muon_momentum+(momentum_warmdown_target-h.muon_momentum)*warmdown_progress + for group in optimizers.optimizer_muon.param_groups:group['momentum']=mom + if h.multi_traj_swa and h.distributed and scale<1. and not _multi_traj[0]:_multi_traj[0]=True;optimizers.optimizer_muon._multi_traj_active=True;log(f"multi_traj_swa:enabled step:{step} frac:{frac:.3f}") + train_loss=step_fn(step,scale) + with torch.no_grad(): + for(name,t)in base_model.state_dict().items():ema_state[name].mul_(ema_decay).add_(t.detach().float(),alpha=1.-ema_decay) + swa_decay=float(os.environ.get('SWA_DECAY',0)) + if h.swa_enabled and scale0: + for(name,t)in base_model.state_dict().items():swa_state[name].mul_(swa_decay).add_(t.detach().float(),alpha=1.-swa_decay) + swa_count=1 + else: + for(name,t)in base_model.state_dict().items():swa_state[name].add_(t.detach().float()) + swa_count+=1 + step+=1;approx_training_time_ms=training_time_ms+1e3*(time.perf_counter()-t0);should_log_train=h.train_log_every>0 and(step<=5 or step%h.train_log_every==0 or stop_after_step is not None) + if should_log_train: + tok_per_sec=step*h.train_batch_tokens/(approx_training_time_ms/1e3);log(f"{step}/{h.iterations} train_loss: {train_loss.item():.4f} train_time: {approx_training_time_ms/60000:.1f}m tok/s: {tok_per_sec:.0f}") + if step%h.train_log_every==0: + gnorms=[]; + for i,blk in enumerate(base_model.blocks):gnorms.append(f"L{i}:{sum(p.grad.norm().item() for p in blk.parameters() if p.grad is not None):.3f}") + log(f"grad_norms: {' '.join(gnorms)}") + reached_cap=max_wallclock_ms is not None and approx_training_time_ms>=max_wallclock_ms + if h.distributed and max_wallclock_ms is not None:reached_cap_tensor=torch.tensor(int(reached_cap),device=device);dist.all_reduce(reached_cap_tensor,op=dist.ReduceOp.MAX);reached_cap=bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap:stop_after_step=step + log(f"peak memory allocated: {torch.cuda.max_memory_allocated()//1024//1024} MiB reserved: {torch.cuda.max_memory_reserved()//1024//1024} MiB");current_state=base_model.state_dict() + if h.swa_enabled and swa_state is not None and swa_count>0: + if _multi_traj[0] and h.distributed: + for name in swa_state:dist.all_reduce(swa_state[name],op=dist.ReduceOp.SUM);swa_state[name]/=h.world_size + log(f"multi_traj_swa:averaged {h.world_size} trajectories") + log(f"swa:applying SWA weights ({swa_count} checkpoints)");avg_state={name:(t/swa_count).to(dtype=current_state[name].dtype)for(name,t)in swa_state.items()} + else: + log('ema:applying EMA weights');avg_state={name:t.to(dtype=current_state[name].dtype)for(name,t)in ema_state.items()} + base_model.load_state_dict(avg_state,strict=True);return base_model,compiled_model +def train_and_eval(h,device): + eval_only=bool(int(os.environ.get('EVAL_ONLY','0'))) + random.seed(h.seed);np.random.seed(h.seed);torch.manual_seed(h.seed);torch.cuda.manual_seed_all(h.seed);val_data=ValidationData(h,device);log(f"val_tokens: {val_data.val_tokens.numel()-1}") + if eval_only: + log("EVAL_ONLY=1: skipping training and serialization, loading existing quantized model") + else: + log(f"train_shards: {len(list(Path(h.datasets_dir).resolve().glob('fineweb_train_*.bin')))}");base_model,compiled_model=train_model(h,device,val_data);torch._dynamo.reset();timed_eval('pre-quantization post-ema',eval_val,h,device,val_data,compiled_model);serialize(h,base_model,Path(__file__).read_text(encoding='utf-8')) + if h.distributed:dist.barrier() + eval_model=deserialize(h,device) + if len(eval_model.encoder_indices)!=eval_model.num_encoder_layers:eval_model.looping_active=True + compiled_model=torch.compile(eval_model,dynamic=False,fullgraph=True);timed_eval('quantized',eval_val,h,device,val_data,compiled_model) + if h.sliding_window_enabled:timed_eval('quantized_sliding_window',eval_val_sliding,h,device,val_data,eval_model) + if h.ttt_enabled and h.ttt_lora_rank > 0: + ttt_model=deserialize(h,device) + if len(ttt_model.encoder_indices)!=ttt_model.num_encoder_layers:ttt_model.looping_active=True + # Warm up rotary caches for TTT eval seq len + for block in ttt_model.blocks: + block.attn.rotary._cos_cached = None + block.attn.rotary._sin_cached = None + block.attn.rotary._seq_len_cached = 0 + block.attn.rotary(h.ttt_eval_seq_len, device, torch.bfloat16) + def _fwd_ttt_inner(input_ids, target_ids, lora): + return ttt_model.forward_ttt(input_ids, target_ids, lora=lora) + _fwd_ttt_compiled_inner = None + def _fwd_ttt(input_ids, target_ids, lora): + nonlocal _fwd_ttt_compiled_inner + if _fwd_ttt_compiled_inner is None: + _fwd_ttt_compiled_inner = torch.compile(_fwd_ttt_inner, dynamic=True) + return _fwd_ttt_compiled_inner(input_ids, target_ids, lora=lora) + fwd_ttt_compiled = _fwd_ttt + # Compile warmup with random tokens + log(f"ttt_lora:warming up compile") + t_warmup = time.perf_counter() + for bsz_w in [h.ttt_batch_size]: + wl = BatchedTTTLoRA(bsz_w, ttt_model, h.ttt_lora_rank, + k_lora=h.ttt_k_lora, mlp_lora=h.ttt_mlp_lora, o_lora=h.ttt_o_lora).to(device) + wo = torch.optim.AdamW(wl.parameters(), lr=h.ttt_lora_lr, + betas=(h.ttt_beta1, h.ttt_beta2), eps=1e-10, + weight_decay=h.ttt_weight_decay, fused=True) + for ctx_len in (h.ttt_chunk_size, h.ttt_eval_seq_len): + xw = torch.randint(0, h.vocab_size, (bsz_w, ctx_len), device=device, dtype=torch.int64) + yw = torch.randint(0, h.vocab_size, (bsz_w, ctx_len), device=device, dtype=torch.int64) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + ptl = fwd_ttt_compiled(xw, yw, lora=wl) + ptl[:, : min(h.ttt_chunk_size, ctx_len)].mean(dim=-1).sum().backward() + wo.step() + wo.zero_grad(set_to_none=True) + del wl, wo + torch.cuda.empty_cache() + log(f"ttt_lora:compile warmup done ({time.perf_counter() - t_warmup:.1f}s)") + log(f"ttt_lora_alpha: {BatchedLinearLoRA._ALPHA}") + log(f"ttt_warm_start_a: {BatchedLinearLoRA._WARM_START_A}") + log("beginning TTT eval timer") + torch.cuda.synchronize() + t_ttt = time.perf_counter() + ttt_val_loss, ttt_val_bpb = eval_val_ttt_phased(h, ttt_model, device, val_data, forward_ttt_train=fwd_ttt_compiled) + torch.cuda.synchronize() + ttt_eval_elapsed = time.perf_counter() - t_ttt + log(f"quantized_ttt_phased val_loss:{ttt_val_loss:.8f} val_bpb:{ttt_val_bpb:.8f} eval_time:{ttt_eval_elapsed*1e3:.0f}ms") +def main(): + world_size=int(os.environ.get('WORLD_SIZE','1'));local_rank=int(os.environ.get('LOCAL_RANK','0'));distributed='RANK'in os.environ and'WORLD_SIZE'in os.environ + if not torch.cuda.is_available():raise RuntimeError('CUDA is required') + if world_size<=0:raise ValueError(f"bad ws") + if 8%world_size!=0:raise ValueError(f"ws must divide 8") + device=torch.device('cuda',local_rank);torch.cuda.set_device(device) + if distributed:dist.init_process_group(backend='nccl',device_id=device);dist.barrier() + torch.backends.cuda.matmul.allow_tf32=True;torch.backends.cudnn.allow_tf32=True;torch.set_float32_matmul_precision('high');from torch.backends.cuda import enable_cudnn_sdp,enable_flash_sdp,enable_math_sdp,enable_mem_efficient_sdp;enable_cudnn_sdp(False);enable_flash_sdp(True);enable_mem_efficient_sdp(False);enable_math_sdp(False);torch._dynamo.config.optimize_ddp=False + h=Hyperparameters();set_logging_hparams(h) + if h.is_main_process: + os.makedirs('logs',exist_ok=True);log('Hyperparameters:',console=True) + for(k,v)in sorted(vars(type(h)).items()): + if not k.startswith('_'):log(f" {k}: {v}",console=True) + log('='*100,console=False) + train_and_eval(h,device) + if distributed:dist.destroy_process_group() +if __name__=='__main__':main() \ No newline at end of file