diff --git a/records/track_10min_16mb/2026-04-22_AlphaLoRA144_WarmStart_WD1_1.07209/README.md b/records/track_10min_16mb/2026-04-22_AlphaLoRA144_WarmStart_WD1_1.07209/README.md new file mode 100644 index 0000000000..baee29a0eb --- /dev/null +++ b/records/track_10min_16mb/2026-04-22_AlphaLoRA144_WarmStart_WD1_1.07209/README.md @@ -0,0 +1,127 @@ +# Alpha=144 LoRA + Warm-Start A + WD=1.0 — 1.07209 BPB + +**val_bpb: 1.07208661** (3-seed mean: seeds 1337, 42, 314) + +## Results + +| Seed | BPB | Eval time | Artifact | +|------|-----|-----------|----------| +| 1337 | 1.07189164 | 456.5s | 15,935,101 B | +| 42 | 1.07247808 | 456.7s | 15,930,195 B | +| 314 | 1.07189010 | 455.7s | 15,935,817 B | +| **Mean** | **1.07208661** | | | + +All runs: train ≤600s, eval ≤600s, artifact ≤16MB. + +## Four novel changes on top of dexhunter's phased-TTT pipeline + +Prior phased-TTT submissions (PR #1530 @samacqua, PR #1610 @romeerp, @dexhunter 1.07193) +use `BatchedLinearLoRA` with these defaults: + +- `forward(x) = (x @ A.T) @ B.T` *(no rank scaling)* +- `reset()`: re-randomize A uniform in [-1/√in, +1/√in], zero B +- `TTT_WEIGHT_DECAY = 0.5` +- `TTT_LORA_RANK = 96` + +This submission composes four small changes to the LoRA module: + +### (1) Alpha/rank output scaling — enables safe higher rank + +```python +class BatchedLinearLoRA(nn.Module): + _ALPHA = float(os.environ.get("TTT_LORA_ALPHA", "144")) + + def __init__(self, bsz, in_features, out_features, rank): + ... + self._scale = self._ALPHA / rank # <-- novel + + def forward(self, x): + return ((x @ self.A.T) @ self.B.T) * self._scale # <-- novel +``` + +Without this, raising rank directly causes divergence on some seeds (we saw +seeds 314/1337 collapse to ~1.133 BPB with raw rank 128). + +### (2) Warm-start A across batches + +```python +_WARM_START_A = bool(int(os.environ.get("TTT_WARM_START_A", "1"))) + +def reset(self): + with torch.no_grad(): + if not self._WARM_START_A: + self.A.uniform_(-self._bound, self._bound) + self.B.zero_() +``` + +Phased TTT processes ~780 batches of ~64 documents each. Previously A was +re-randomized every batch, discarding whatever feature directions the +optimizer found. Keeping A warm (B still zeroes) lets A accumulate useful +directions across the eval while still starting each batch with LoRA output = 0. + +### (3) Raised TTT weight decay 0.5 → 1.0 to stabilize (2) + +Warm-start A alone regresses on seed 314 (A drifts into an over-specialized +state for that seed's doc ordering). Doubling weight decay explicitly +counteracts this drift — on seed 314 it restores parity, on other seeds the +warm-start gain is preserved. + +### (4) Lift alpha from 96 to 144 (effective scale 1.125 on rank 128) + +With (1)+(2)+(3) stable, the LoRA is under-utilized. Alpha=96 gives +`scale = 96/128 = 0.75` — weaker than the prior no-scaling code. Raising +alpha to 144 gives `scale = 144/128 = 1.125`, so the LoRA has more +adaptation strength per step. WD=1.0 keeps it from destabilizing. + +### Ablation on seed 42 + +| Config | TTT BPB | Delta vs baseline | +|--------|---------|-------------------| +| rank 96 baseline | 1.07341 | 0 | +| + alpha 96 scaling, rank 128 | 1.07320 | −0.00021 | +| + warm-start A | 1.07259 | −0.00082 | +| + WD=1.0 | 1.07298 | −0.00043 | +| **+ alpha 144** (this work) | **1.07248** | **−0.00093** | + +### Combined 3-seed result + +| Seed | rank-96 baseline | + alpha 96 rank 128 | + warm A + WD=1.0 | **+ alpha 144** | +|------|------------------|---------------------|--------------------|------------------| +| 1337 | 1.07423 | 1.07379 | 1.07298 | **1.07189** | +| 42 | 1.07341 | 1.07320 | 1.07298 | **1.07248** | +| 314 | 1.07214 | 1.07200 | 1.07203 | **1.07189** | +| Mean | 1.07326 | 1.07300 | 1.07266 | **1.07209** | + +Every seed improves monotonically across each change. + +## Legality (Issue #1017) + +- **Condition 1 (Causal)**: single left-to-right pass; LoRA state at `t` + depends only on earlier tokens of the same doc. +- **Condition 2 (Full normalized distribution)**: standard softmax over + the 8192 SentencePiece tokens. +- **Condition 3 (Score-before-update)**: each chunk is scored through + `forward_ttt_train` *before* the optimizer step on that chunk. +- **Condition 4 (Single pass)**: one left-to-right pass, no rescoring. + +## Attribution + +- @bigbag (PR #1493) — triple depth recurrence, parallel residuals +- @EthanYangTW (PR #1523) — parameter banking refinements +- @samacqua (PR #1530) — VarLen attention, Fused Triton MLP, doc-independent LoRA TTT +- @romeerp (PR #1610) — phased TTT (single-phase global SGD) +- @dexhunter (1.07193 submission) — multi-phase global SGD, trimmed GPTQ, MATRIX_LR=0.026, per-layer clip sigmas, int7 embeddings +- @abaybektursun (PR #549) — legal TTT framework + +## Reproduction + +```bash +export DATA_DIR=/path/to/parameter-golf/data +torchrun --standalone --nproc_per_node=8 train_gpt.py # seed 1337 +SEED=42 torchrun --standalone --nproc_per_node=8 train_gpt.py +SEED=314 torchrun --standalone --nproc_per_node=8 train_gpt.py +``` + +All four novel hyperparameters are hardcoded as defaults in `train_gpt.py`: +`TTT_LORA_RANK=128`, `TTT_LORA_ALPHA=144`, `TTT_WARM_START_A=1`, +`TTT_WEIGHT_DECAY=1.0`. diff --git a/records/track_10min_16mb/2026-04-22_AlphaLoRA144_WarmStart_WD1_1.07209/requirements.txt b/records/track_10min_16mb/2026-04-22_AlphaLoRA144_WarmStart_WD1_1.07209/requirements.txt new file mode 100644 index 0000000000..5e4f03a11f --- /dev/null +++ b/records/track_10min_16mb/2026-04-22_AlphaLoRA144_WarmStart_WD1_1.07209/requirements.txt @@ -0,0 +1,7 @@ +torch>=2.9 +flash-attn>=3.0 +triton>=3.5 +sentencepiece +python-minifier +brotli +numpy diff --git a/records/track_10min_16mb/2026-04-22_AlphaLoRA144_WarmStart_WD1_1.07209/submission.json b/records/track_10min_16mb/2026-04-22_AlphaLoRA144_WarmStart_WD1_1.07209/submission.json new file mode 100644 index 0000000000..e258330b6a --- /dev/null +++ b/records/track_10min_16mb/2026-04-22_AlphaLoRA144_WarmStart_WD1_1.07209/submission.json @@ -0,0 +1,57 @@ +{ + "authors": [ + { + "name": "Renqian Luo", + "github_id": "renqianluo" + } + ], + "description": "Four composable novel TTT-LoRA improvements on top of dexhunter's phased TTT pipeline: (1) alpha/rank output scaling, (2) warm-start of LoRA A across batches, (3) raised TTT weight decay 0.5->1.0 to regularize the warmed A, (4) alpha lifted to 144 (vs rank 128 so effective scale=1.125) giving LoRA more adaptation strength while WD=1.0 keeps it stable. 3-seed mean 1.07209 BPB.", + "val_bpb": 1.07208661, + "seed_results": { + "1337": 1.07189164, + "42": 1.07247808, + "314": 1.07189010 + }, + "eval_time_seconds": { + "1337": 456.5, + "42": 456.7, + "314": 455.7 + }, + "train_time_seconds": { + "1337": 596.0, + "42": 596.0, + "314": 596.1 + }, + "artifact_size_bytes": { + "1337": 15935101, + "42": 15930195, + "314": 15935817 + }, + "methods": [ + "Novel (1): alpha/rank scaling on BatchedLinearLoRA — decouples rank from effective LR, enabling stable rank 128", + "Novel (2): warm-start of LoRA A across batches (only B resets to zero) — feature directions accumulate", + "Novel (3): TTT weight decay 0.5 -> 1.0 — counteracts the across-batch A overfit enabled by (2)", + "Novel (4): alpha lifted from 96 to 144 (scale = 144/128 = 1.125). With WD=1.0 as a regularizer the stronger LoRA converges cleanly on all 3 seeds instead of drifting.", + "LoRA rank 128, alpha 144, WD 1.0, warm-start A enabled", + "Inherits dexhunter's phased TTT + multi-phase global SGD + trimmed GPTQ + MATRIX_LR=0.026; samacqua's VarLen + Fused Triton MLP; bigbag's triple depth recurrence + parallel residuals" + ], + "attribution": { + "alpha_scaled_lora__warm_start_A__higher_wd__raised_alpha": "Renqian Luo (this work)", + "varlen_attention_fused_mlp_doc_ttt": "@samacqua (PR #1530)", + "phased_ttt_concept": "@romeerp (PR #1610)", + "multi_phase_global_sgd_trimmed_gptq": "@dexhunter", + "triple_recurrence_parallel_residuals": "@bigbag (PR #1493), @EthanYangTW (PR #1523)", + "legal_ttt_framework": "@abaybektursun (PR #549)" + }, + "legal_ttt": true, + "compliance": { + "train_under_600s": true, + "eval_under_600s": true, + "artifact_under_16mb": true, + "no_slot": true, + "no_pre_quant_ttt": true, + "no_ngram_cache": true, + "score_first_ttt": true, + "three_seeds": true + } +} diff --git a/records/track_10min_16mb/2026-04-22_AlphaLoRA144_WarmStart_WD1_1.07209/train_gpt.py b/records/track_10min_16mb/2026-04-22_AlphaLoRA144_WarmStart_WD1_1.07209/train_gpt.py new file mode 100644 index 0000000000..27f49ebc78 --- /dev/null +++ b/records/track_10min_16mb/2026-04-22_AlphaLoRA144_WarmStart_WD1_1.07209/train_gpt.py @@ -0,0 +1,2985 @@ +import base64, collections, copy, fcntl, glob, io, json, lzma, math, os +from pathlib import Path +import random, re, subprocess, sys, time, uuid, numpy as np, sentencepiece as spm, torch, torch.distributed as dist, torch.nn.functional as F +from torch import nn +from flash_attn_interface import ( + flash_attn_func as flash_attn_3_func, + flash_attn_varlen_func, +) +from concurrent.futures import ThreadPoolExecutor +import triton +import triton.language as tl +from triton.tools.tensor_descriptor import TensorDescriptor + + +class Hyperparameters: + data_dir = os.environ.get("DATA_DIR", "./data/") + seed = int(os.environ.get("SEED", 1337)) + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_frac = float(os.environ.get("WARMDOWN_FRAC", 0.75)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 786432)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 500)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 6e2)) + val_batch_tokens = int(os.environ.get("VAL_BATCH_TOKENS", 524288)) + eval_seq_len = int(os.environ.get("EVAL_SEQ_LEN", 2048)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 4000)) + sliding_window_enabled = bool(int(os.environ.get("SLIDING_WINDOW_ENABLED", "0"))) + vocab_size = int(os.environ.get("VOCAB_SIZE", 8192)) + num_layers = int(os.environ.get("NUM_LAYERS", 11)) + xsa_last_n = int(os.environ.get("XSA_LAST_N", 11)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = float(os.environ.get("MLP_MULT", 4.0)) + skip_gates_enabled = bool(int(os.environ.get("SKIP_GATES_ENABLED", "1"))) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 3e1)) + rope_base = float(os.environ.get("ROPE_BASE", 1e4)) + rope_dims = int(os.environ.get("ROPE_DIMS", 16)) + rope_train_seq_len = int(os.environ.get("ROPE_TRAIN_SEQ_LEN", 2048)) + rope_yarn = bool(int(os.environ.get("ROPE_YARN", "0"))) + ln_scale = bool(int(os.environ.get("LN_SCALE", "1"))) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 5.0)) + num_loops = int(os.environ.get("NUM_LOOPS", 2)) + loop_start = int(os.environ.get("LOOP_START", 3)) + loop_end = int(os.environ.get("LOOP_END", 5)) + enable_looping_at = float(os.environ.get("ENABLE_LOOPING_AT", 0.35)) + parallel_start_layer = int(os.environ.get("PARALLEL_START_LAYER", 8)) + parallel_final_lane = os.environ.get("PARALLEL_FINAL_LANE", "mean") + min_lr = float(os.environ.get("MIN_LR", 0.0)) + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.03)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.026)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.02)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.97)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float( + os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92) + ) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) + muon_row_normalize = bool(int(os.environ.get("MUON_ROW_NORMALIZE", "1"))) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-08)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) + eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) + adam_wd = float(os.environ.get("ADAM_WD", 0.02)) + muon_wd = float(os.environ.get("MUON_WD", 0.095)) + embed_wd = float(os.environ.get("EMBED_WD", 0.085)) + ema_decay = float(os.environ.get("EMA_DECAY", 0.9965)) + ttt_enabled = bool(int(os.environ.get("TTT_ENABLED", "1"))) + ttt_lora_rank = int(os.environ.get("TTT_LORA_RANK", 128)) + ttt_lora_lr = float(os.environ.get("TTT_LORA_LR", 0.0001)) + ttt_chunk_size = int(os.environ.get("TTT_CHUNK_SIZE", 48)) + ttt_eval_seq_len = int(os.environ.get("TTT_EVAL_SEQ_LEN", 2048)) + ttt_batch_size = int(os.environ.get("TTT_BATCH_SIZE", 64)) + ttt_grad_steps = int(os.environ.get("TTT_GRAD_STEPS", 1)) + ttt_weight_decay = float(os.environ.get("TTT_WEIGHT_DECAY", 1.0)) + ttt_beta1 = float(os.environ.get("TTT_BETA1", 0)) + ttt_beta2 = float(os.environ.get("TTT_BETA2", 0.999)) + ttt_k_lora = bool(int(os.environ.get("TTT_K_LORA", "1"))) + ttt_mlp_lora = bool(int(os.environ.get("TTT_MLP_LORA", "1"))) + ttt_o_lora = bool(int(os.environ.get("TTT_O_LORA", "1"))) + ttt_optimizer = os.environ.get("TTT_OPTIMIZER", "adam") + ttt_eval_batches = os.environ.get("TTT_EVAL_BATCHES", "") + val_doc_fraction = float(os.environ.get("VAL_DOC_FRACTION", 1.0)) + compressor = os.environ.get("COMPRESSOR", "brotli") + gptq_calibration_batches = int(os.environ.get("GPTQ_CALIBRATION_BATCHES", 16)) + gptq_reserve_seconds = float(os.environ.get("GPTQ_RESERVE_SECONDS", 4.0)) + phased_ttt_enabled = bool(int(os.environ.get("PHASED_TTT_ENABLED", "1"))) + phased_ttt_prefix_docs = int(os.environ.get("PHASED_TTT_PREFIX_DOCS", 2000)) + phased_ttt_num_phases = int(os.environ.get("PHASED_TTT_NUM_PHASES", 3)) + global_ttt_lr = float(os.environ.get("GLOBAL_TTT_LR", 0.001)) + global_ttt_momentum = float(os.environ.get("GLOBAL_TTT_MOMENTUM", 0.9)) + global_ttt_epochs = int(os.environ.get("GLOBAL_TTT_EPOCHS", 1)) + global_ttt_chunk_tokens = int(os.environ.get("GLOBAL_TTT_CHUNK_TOKENS", 32768)) + global_ttt_batch_seqs = int(os.environ.get("GLOBAL_TTT_BATCH_SEQS", 32)) + global_ttt_warmup_start_lr = float(os.environ.get("GLOBAL_TTT_WARMUP_START_LR", 0.0)) + global_ttt_warmup_chunks = int(os.environ.get("GLOBAL_TTT_WARMUP_CHUNKS", 0)) + global_ttt_grad_clip = float(os.environ.get("GLOBAL_TTT_GRAD_CLIP", 1.0)) + global_ttt_respect_doc_boundaries = bool(int(os.environ.get("GLOBAL_TTT_RESPECT_DOC_BOUNDARIES", "1"))) + matrix_bits = int(os.environ.get("MATRIX_BITS", 6)) + embed_bits = int(os.environ.get("EMBED_BITS", 7)) + matrix_clip_sigmas = float(os.environ.get("MATRIX_CLIP_SIGMAS", 12.85)) + embed_clip_sigmas = float(os.environ.get("EMBED_CLIP_SIGMAS", 15.0)) + mlp_clip_sigmas = float(os.environ.get("MLP_CLIP_SIGMAS", 12.0)) + attn_clip_sigmas = float(os.environ.get("ATTN_CLIP_SIGMAS", 13.0)) + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + is_main_process = rank == 0 + grad_accum_steps = 8 // world_size + datasets_dir = os.path.join(data_dir, "datasets", f"fineweb10B_sp{vocab_size}") + train_files = os.path.join(datasets_dir, "fineweb_train_*.bin") + val_files = os.path.join(datasets_dir, "fineweb_val_*.bin") + tokenizer_path = os.path.join( + data_dir, "tokenizers", f"fineweb_{vocab_size}_bpe.model" + ) + artifact_dir = os.environ.get("ARTIFACT_DIR", "") + logfile = ( + os.path.join(artifact_dir, f"{run_id}.txt") + if artifact_dir + else f"logs/{run_id}.txt" + ) + model_path = ( + os.path.join(artifact_dir, "final_model.pt") + if artifact_dir + else "final_model.pt" + ) + quantized_model_path = ( + os.path.join(artifact_dir, "final_model.int6.ptz") + if artifact_dir + else "final_model.int6.ptz" + ) + + +_logger_hparams = None + + +def set_logging_hparams(h): + global _logger_hparams + _logger_hparams = h + + +def log(msg, console=True): + if _logger_hparams is None: + print(msg) + return + if _logger_hparams.is_main_process: + if console: + print(msg) + if _logger_hparams.logfile is not None: + with open(_logger_hparams.logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + +class ValidationData: + def __init__(self, h, device): + self.sp = spm.SentencePieceProcessor(model_file=h.tokenizer_path) + if int(self.sp.vocab_size()) != h.vocab_size: + raise ValueError( + f"VOCAB_SIZE={h.vocab_size} does not match tokenizer vocab_size={int(self.sp.vocab_size())}" + ) + self.val_tokens = load_validation_tokens(h.val_files, h.eval_seq_len) + ( + self.base_bytes_lut, + self.has_leading_space_lut, + self.is_boundary_token_lut, + ) = build_sentencepiece_luts(self.sp, h.vocab_size, device) + + +def build_sentencepiece_luts(sp, vocab_size, device): + sp_vocab_size = int(sp.vocab_size()) + assert ( + sp.piece_to_id("▁") != sp.unk_id() + ), "Tokenizer must have '▁' (space) as its own token for correct BPB byte counting" + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern, seq_len): + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = (tokens.numel() - 1) // seq_len * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def load_data_shard(file): + header_bytes = 256 * np.dtype(" 0: + pos = start + while pos < end: + seg_starts.append(pos) + pos += max_doc_len + else: + seg_starts.append(start) + boundaries = seg_starts + [total_len] + padded_len = get_next_multiple_of_n(len(boundaries), bucket_size) + cu = torch.full((padded_len,), total_len, dtype=torch.int32, device=device) + cu[: len(boundaries)] = torch.tensor(boundaries, dtype=torch.int32, device=device) + seg_ends = seg_starts[1:] + [total_len] + max_seqlen = max(end - start for start, end in zip(seg_starts, seg_ends)) + return cu, max_seqlen + +class DocumentPackingLoader: + _shard_pool = ThreadPoolExecutor(1) + + def __init__(self, h, device, cu_bucket_size=64): + self.rank = h.rank + self.world_size = h.world_size + self.device = device + self.cu_bucket_size = cu_bucket_size + self.max_seq_len = h.train_seq_len + all_files = [Path(p) for p in sorted(glob.glob(h.train_files))] + if not all_files: + raise FileNotFoundError(f"No files found for pattern: {h.train_files}") + self.files = all_files + self.file_iter = iter(self.files) + self._init_shard(load_data_shard(next(self.file_iter))) + self._next_shard = self._submit_next_shard() + self._batch_pool = ThreadPoolExecutor(1) + self._next_batch = None + + def _init_shard(self, tokens): + global BOS_ID + self.tokens = tokens + self.shard_size = tokens.numel() + if BOS_ID is None: + BOS_ID = 1 + self.bos_idx = ( + (tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + ) + if self.bos_idx.size == 0: + self.bos_idx = np.array([0], dtype=np.int64) + self.cursor = int(self.bos_idx[0]) + + def _submit_next_shard(self): + try: + path = next(self.file_iter) + return self._shard_pool.submit(load_data_shard, path) + except StopIteration: + return None + + def _advance_shard(self): + if self._next_shard is None: + self.file_iter = iter(self.files) + self._next_shard = self._shard_pool.submit( + load_data_shard, next(self.file_iter) + ) + self._init_shard(self._next_shard.result()) + self._next_shard = self._submit_next_shard() + + def _local_doc_starts(self, local_start, total_len): + lo = np.searchsorted(self.bos_idx, local_start, side="left") + hi = np.searchsorted(self.bos_idx, local_start + total_len, side="left") + return (self.bos_idx[lo:hi] - local_start).tolist() + + def _prepare_batch(self, num_tokens_local, max_seq_len): + per_rank_span = num_tokens_local + 1 + global_span = per_rank_span * self.world_size + while self.cursor + global_span > self.shard_size: + self._advance_shard() + local_start = self.cursor + self.rank * per_rank_span + buf = self.tokens[local_start : local_start + per_rank_span] + inputs = buf[:-1].to(dtype=torch.int64).pin_memory() + targets = buf[1:].to(dtype=torch.int64).pin_memory() + starts = self._local_doc_starts(local_start, inputs.numel()) + cu_seqlens, max_seqlen = _build_cu_seqlens( + starts, inputs.numel(), inputs.device, max_seq_len, self.cu_bucket_size + ) + cu_seqlens = cu_seqlens.pin_memory() + self.cursor += global_span + return inputs, targets, cu_seqlens, max_seqlen + + def next_batch(self, global_tokens, grad_accum_steps): + num_tokens_local = global_tokens // (self.world_size * grad_accum_steps) + if self._next_batch is not None: + inputs, targets, cu_seqlens, max_seqlen = self._next_batch.result() + else: + inputs, targets, cu_seqlens, max_seqlen = self._prepare_batch( + num_tokens_local, self.max_seq_len + ) + self._next_batch = self._batch_pool.submit( + self._prepare_batch, num_tokens_local, self.max_seq_len + ) + return ( + inputs[None].to(self.device, non_blocking=True), + targets[None].to(self.device, non_blocking=True), + cu_seqlens.to(self.device, non_blocking=True), + max_seqlen, + ) + + +class ShuffledSequenceLoader: + def __init__(self, h, device): + self.world_size = h.world_size + self.seq_len = h.train_seq_len + self.device = device + all_files = [Path(p) for p in sorted(glob.glob(h.train_files))] + if not all_files: + raise FileNotFoundError(f"No files found for pattern: {h.train_files}") + self.files = all_files[h.rank :: h.world_size] + self.rng = np.random.Generator(np.random.PCG64(h.rank)) + self.num_tokens = [_read_num_tokens(f) for f in self.files] + self.start_inds = [[] for _ in self.files] + for si in range(len(self.files)): + self._reset_shard(si) + + def _reset_shard(self, si): + max_phase = min( + self.seq_len - 1, max(0, self.num_tokens[si] - self.seq_len - 1) + ) + phase = int(self.rng.integers(max_phase + 1)) if max_phase > 0 else 0 + num_sequences = (self.num_tokens[si] - 1 - phase) // self.seq_len + sequence_order = self.rng.permutation(num_sequences) + self.start_inds[si] = (phase + sequence_order * self.seq_len).tolist() + + def next_batch(self, global_tokens, grad_accum_steps): + device_tokens = global_tokens // (self.world_size * grad_accum_steps) + device_batch_size = device_tokens // self.seq_len + remaining = np.array([len(s) for s in self.start_inds], dtype=np.float64) + x = torch.empty((device_batch_size, self.seq_len), dtype=torch.int64) + y = torch.empty((device_batch_size, self.seq_len), dtype=torch.int64) + for bi in range(device_batch_size): + total = remaining.sum() + if total <= 0: + for si in range(len(self.files)): + self._reset_shard(si) + remaining = np.array( + [len(s) for s in self.start_inds], dtype=np.float64 + ) + total = remaining.sum() + probs = remaining / total + si = int(self.rng.choice(len(self.files), p=probs)) + start_ind = self.start_inds[si].pop() + remaining[si] -= 1 + mm = _get_shard_memmap(self.files[si]) + window = torch.as_tensor( + np.array(mm[start_ind : start_ind + self.seq_len + 1], dtype=np.int64) + ) + x[bi] = window[:-1] + y[bi] = window[1:] + return x.to(self.device, non_blocking=True), y.to( + self.device, non_blocking=True + ) + + +class RMSNorm(nn.Module): + def __init__(self, eps=None): + super().__init__() + self.eps = eps + + def forward(self, x): + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + def forward(self, x): + w = self.weight.to(x.dtype) + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w, bias) + + +@triton.jit +def linear_leaky_relu_square_kernel( + a_desc, + b_desc, + c_desc, + aux_desc, + M, + N, + K, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + NUM_SMS: tl.constexpr, + FORWARD: tl.constexpr, +): + dtype = tl.bfloat16 + start_pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + k_tiles = tl.cdiv(K, BLOCK_SIZE_K) + num_tiles = num_pid_m * num_pid_n + tile_id_c = start_pid - NUM_SMS + for tile_id in tl.range(start_pid, num_tiles, NUM_SMS, flatten=True): + pid_m = tile_id // num_pid_n + pid_n = tile_id % num_pid_n + offs_am = pid_m * BLOCK_SIZE_M + offs_bn = pid_n * BLOCK_SIZE_N + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for ki in range(k_tiles): + offs_k = ki * BLOCK_SIZE_K + a = a_desc.load([offs_am, offs_k]) + b = b_desc.load([offs_bn, offs_k]) + accumulator = tl.dot(a, b.T, accumulator) + tile_id_c += NUM_SMS + offs_am_c = offs_am + offs_bn_c = offs_bn + acc = tl.reshape(accumulator, (BLOCK_SIZE_M, 2, BLOCK_SIZE_N // 2)) + acc = tl.permute(acc, (0, 2, 1)) + acc0, acc1 = tl.split(acc) + c0 = acc0.to(dtype) + c1 = acc1.to(dtype) + if not FORWARD: + pre0 = aux_desc.load([offs_am_c, offs_bn_c]) + pre1 = aux_desc.load([offs_am_c, offs_bn_c + BLOCK_SIZE_N // 2]) + c0 = c0 * tl.where(pre0 > 0, 2.0 * pre0, 0.5 * pre0) + c1 = c1 * tl.where(pre1 > 0, 2.0 * pre1, 0.5 * pre1) + c_desc.store([offs_am_c, offs_bn_c], c0) + c_desc.store([offs_am_c, offs_bn_c + BLOCK_SIZE_N // 2], c1) + if FORWARD: + aux0 = tl.where(c0 > 0, c0, 0.5 * c0) + aux1 = tl.where(c1 > 0, c1, 0.5 * c1) + aux_desc.store([offs_am_c, offs_bn_c], aux0 * aux0) + aux_desc.store([offs_am_c, offs_bn_c + BLOCK_SIZE_N // 2], aux1 * aux1) + + +def linear_leaky_relu_square(a, b, aux=None): + M, K = a.shape + N, K2 = b.shape + assert K == K2 + c = torch.empty((M, N), device=a.device, dtype=a.dtype) + forward = aux is None + if aux is None: + aux = torch.empty((M, N), device=a.device, dtype=a.dtype) + num_sms = torch.cuda.get_device_properties(a.device).multi_processor_count + BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K = 128, 256, 64 + num_stages = 4 if forward else 3 + a_desc = TensorDescriptor.from_tensor(a, [BLOCK_SIZE_M, BLOCK_SIZE_K]) + b_desc = TensorDescriptor.from_tensor(b, [BLOCK_SIZE_N, BLOCK_SIZE_K]) + c_desc = TensorDescriptor.from_tensor(c, [BLOCK_SIZE_M, BLOCK_SIZE_N // 2]) + aux_desc = TensorDescriptor.from_tensor(aux, [BLOCK_SIZE_M, BLOCK_SIZE_N // 2]) + grid = lambda _meta: ( + min(num_sms, triton.cdiv(M, BLOCK_SIZE_M) * triton.cdiv(N, BLOCK_SIZE_N)), + ) + linear_leaky_relu_square_kernel[grid]( + a_desc, + b_desc, + c_desc, + aux_desc, + M, + N, + K, + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, + BLOCK_SIZE_K=BLOCK_SIZE_K, + NUM_SMS=num_sms, + FORWARD=forward, + num_stages=num_stages, + num_warps=8, + ) + if forward: + return c, aux + return c + + +class FusedLinearLeakyReLUSquareFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, x, w1, w2): + x_flat = x.reshape(-1, x.shape[-1]) + pre, post = linear_leaky_relu_square(x_flat, w1) + out = F.linear(post, w2) + ctx.save_for_backward(x, w1, w2, pre, post) + return out.view(*x.shape[:-1], out.shape[-1]) + + @staticmethod + def backward(ctx, grad_output): + x, w1, w2, pre, post = ctx.saved_tensors + x_flat = x.reshape(-1, x.shape[-1]) + grad_output_flat = grad_output.reshape(-1, grad_output.shape[-1]) + dw2 = grad_output_flat.T @ post + dpre = linear_leaky_relu_square(grad_output_flat, w2.T.contiguous(), aux=pre) + dw1 = dpre.T @ x_flat + dx = dpre @ w1 + return dx.view_as(x), dw1, dw2 + + +FusedLeakyReLUSquareMLP = FusedLinearLeakyReLUSquareFunction.apply + + +class Rotary(nn.Module): + def __init__(self, dim, base=1e4, train_seq_len=1024, rope_dims=0, yarn=True): + super().__init__() + self.dim = dim + self.base = base + self.train_seq_len = train_seq_len + self.yarn = yarn + self.rope_dims = rope_dims if rope_dims > 0 else dim + inv_freq = 1.0 / base ** ( + torch.arange(0, self.rope_dims, 2, dtype=torch.float32) / self.rope_dims + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached = None + self._sin_cached = None + + def forward(self, seq_len, device, dtype): + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached < seq_len + or self._cos_cached.device != device + ): + rd = self.rope_dims + if self.yarn and seq_len > self.train_seq_len: + scale = seq_len / self.train_seq_len + new_base = self.base * scale ** (rd / (rd - 2)) + inv_freq = 1.0 / new_base ** ( + torch.arange(0, rd, 2, dtype=torch.float32, device=device) / rd + ) + else: + inv_freq = self.inv_freq.float().to(device) + t = torch.arange(seq_len, device=device, dtype=torch.float32) + freqs = torch.outer(t, inv_freq) + self._cos_cached = freqs.cos()[None, :, None, :] + self._sin_cached = freqs.sin()[None, :, None, :] + self._seq_len_cached = seq_len + return self._cos_cached[:, :seq_len].to(dtype=dtype), self._sin_cached[:, :seq_len].to(dtype=dtype) + + +def apply_rotary_emb(x, cos, sin, rope_dims=0): + if rope_dims > 0 and rope_dims < x.size(-1): + x_rope, x_pass = x[..., :rope_dims], x[..., rope_dims:] + half = rope_dims // 2 + x1, x2 = x_rope[..., :half], x_rope[..., half:] + x_rope = torch.cat((x1 * cos + x2 * sin, x1 * -sin + x2 * cos), dim=-1) + return torch.cat((x_rope, x_pass), dim=-1) + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * -sin + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, dim, num_heads, num_kv_heads, rope_base, qk_gain_init, train_seq_len, yarn=True + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + self.q_gain = nn.Parameter( + torch.full((num_heads,), qk_gain_init, dtype=torch.float32) + ) + self.rope_dims = 0 + self.rotary = Rotary(self.head_dim, base=rope_base, train_seq_len=train_seq_len, yarn=yarn) + self.use_xsa = False + + def _xsa_efficient(self, y, v): + B, T, H, D = y.shape + Hkv = v.size(-2) + group = H // Hkv + y_g = y.reshape(B, T, Hkv, group, D) + vn = F.normalize(v, dim=-1).unsqueeze(-2) + proj = (y_g * vn).sum(dim=-1, keepdim=True) * vn + return (y_g - proj).reshape(B, T, H, D) + + def forward(self, x, q_w, k_w, v_w, out_w, cu_seqlens=None, max_seqlen=0): + bsz, seqlen, dim = x.shape + q = F.linear(x, q_w.to(x.dtype)).reshape(bsz, seqlen, self.num_heads, self.head_dim) + k = F.linear(x, k_w.to(x.dtype)).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + v = F.linear(x, v_w.to(x.dtype)).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin, self.rope_dims) + k = apply_rotary_emb(k, cos, sin, self.rope_dims) + q = q * self.q_gain.to(dtype=q.dtype)[None, None, :, None] + if cu_seqlens is not None: + y = flash_attn_varlen_func( + q[0], + k[0], + v[0], + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=max_seqlen, + max_seqlen_k=max_seqlen, + causal=True, + window_size=(-1, -1), + )[None] + else: + y = flash_attn_3_func(q, k, v, causal=True) + if self.use_xsa: + y = self._xsa_efficient(y, v) + y = y.reshape(bsz, seqlen, dim) + self._last_proj_input = y.detach() if getattr(self, "_calib", False) else None + return F.linear(y, out_w.to(x.dtype)) + + +class MLP(nn.Module): + def __init__(self, dim, mlp_mult): + super().__init__() + self.use_fused = True + + def forward(self, x, up_w, down_w): + if self.training and self.use_fused: + return FusedLeakyReLUSquareMLP(x, up_w.to(x.dtype), down_w.to(x.dtype)) + hidden = F.leaky_relu(F.linear(x, up_w.to(x.dtype)), negative_slope=0.5).square() + self._last_down_input = hidden.detach() if getattr(self, "_calib", False) else None + return F.linear(hidden, down_w.to(x.dtype)) + + +class Block(nn.Module): + def __init__( + self, + dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + train_seq_len, + layer_idx=0, + ln_scale=False, + yarn=True, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention( + dim, num_heads, num_kv_heads, rope_base, qk_gain_init, train_seq_len, yarn=yarn + ) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter( + torch.stack((torch.ones(dim), torch.zeros(dim))).float() + ) + self.ln_scale_factor = 1.0 / math.sqrt(layer_idx + 1) if ln_scale else 1.0 + + def forward(self, x, x0, q_w, k_w, v_w, out_w, up_w, down_w, cu_seqlens=None, max_seqlen=0): + mix = self.resid_mix.to(dtype=x.dtype) + x_in = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn( + self.attn_norm(x_in) * self.ln_scale_factor, + q_w, k_w, v_w, out_w, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + ) + x_out = x_in + self.attn_scale.to(dtype=x_in.dtype)[None, None, :] * attn_out + x_out = x_out + self.mlp_scale.to(dtype=x_out.dtype)[ + None, None, : + ] * self.mlp(self.mlp_norm(x_out) * self.ln_scale_factor, up_w, down_w) + return x_out + +class GPT(nn.Module): + def __init__(self, h): + super().__init__() + if h.logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {h.logit_softcap}") + self.tie_embeddings = h.tie_embeddings + self.tied_embed_init_std = h.tied_embed_init_std + self.logit_softcap = h.logit_softcap + self.tok_emb = nn.Embedding(h.vocab_size, h.model_dim) + self.num_layers = h.num_layers + head_dim = h.model_dim // h.num_heads + kv_dim = h.num_kv_heads * head_dim + hidden_dim = int(h.mlp_mult * h.model_dim) + self.qo_bank = nn.Parameter(torch.empty(2 * h.num_layers, h.model_dim, h.model_dim)) + self.kv_bank = nn.Parameter(torch.empty(2 * h.num_layers, kv_dim, h.model_dim)) + self.mlp_up_bank = nn.Parameter(torch.empty(h.num_layers, hidden_dim, h.model_dim)) + self.mlp_down_bank = nn.Parameter(torch.empty(h.num_layers, h.model_dim, hidden_dim)) + self.num_encoder_layers = h.num_layers // 2 + self.num_decoder_layers = h.num_layers - self.num_encoder_layers + self.blocks = nn.ModuleList( + [ + Block( + h.model_dim, + h.num_heads, + h.num_kv_heads, + h.mlp_mult, + h.rope_base, + h.qk_gain_init, + h.train_seq_len, + layer_idx=i, + ln_scale=h.ln_scale, + yarn=h.rope_yarn, + ) + for i in range(h.num_layers) + ] + ) + if h.rope_dims > 0: + head_dim = h.model_dim // h.num_heads + for block in self.blocks: + block.attn.rope_dims = h.rope_dims + block.attn.rotary = Rotary( + head_dim, + base=h.rope_base, + train_seq_len=h.train_seq_len, + rope_dims=h.rope_dims, + yarn=h.rope_yarn, + ) + self.final_norm = RMSNorm() + self.lm_head = ( + None + if h.tie_embeddings + else CastedLinear(h.model_dim, h.vocab_size, bias=False) + ) + if self.lm_head is not None: + self.lm_head._zero_init = True + if h.xsa_last_n > 0: + for i in range(max(0, h.num_layers - h.xsa_last_n), h.num_layers): + self.blocks[i].attn.use_xsa = True + self.looping_active = False + if h.num_loops > 0: + loop_seg = list(range(h.loop_start, h.loop_end + 1)) + all_indices = list(range(h.loop_start)) + for _ in range(h.num_loops + 1): + all_indices.extend(loop_seg) + all_indices.extend(range(h.loop_end + 1, h.num_layers)) + num_enc = len(all_indices) // 2 + self.encoder_indices = all_indices[:num_enc] + self.decoder_indices = all_indices[num_enc:] + else: + self.encoder_indices = list(range(self.num_encoder_layers)) + self.decoder_indices = list(range(self.num_encoder_layers, h.num_layers)) + self.num_skip_weights = min( + len(self.encoder_indices), len(self.decoder_indices) + ) + self.skip_weights = nn.Parameter( + torch.ones(self.num_skip_weights, h.model_dim, dtype=torch.float32) + ) + self.skip_gates = ( + nn.Parameter( + torch.zeros(self.num_skip_weights, h.model_dim, dtype=torch.float32) + ) + if h.skip_gates_enabled + else None + ) + self.parallel_start_layer = h.parallel_start_layer + self.parallel_final_lane = h.parallel_final_lane.lower() + self.parallel_post_lambdas = nn.Parameter( + torch.ones(h.num_layers, 2, 2, dtype=torch.float32) + ) + self.parallel_resid_lambdas = nn.Parameter( + torch.full((h.num_layers, 2), 1.1, dtype=torch.float32) + ) + self._init_weights() + + def _init_weights(self): + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + n = self.num_layers + proj_scale = 1.0 / math.sqrt(2 * n) + for i in range(n): + nn.init.orthogonal_(self.qo_bank.data[i], gain=1.0) + nn.init.zeros_(self.qo_bank.data[n + i]) + self.qo_bank.data[n + i].mul_(proj_scale) + nn.init.orthogonal_(self.kv_bank.data[i], gain=1.0) + nn.init.orthogonal_(self.kv_bank.data[n + i], gain=1.0) + for i in range(n): + nn.init.orthogonal_(self.mlp_up_bank.data[i], gain=1.0) + nn.init.zeros_(self.mlp_down_bank.data[i]) + self.mlp_down_bank.data[i].mul_(proj_scale) + for name, module in self.named_modules(): + if isinstance(module, nn.Linear): + if getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + elif ( + module.weight.ndim == 2 + and module.weight.shape[0] >= 64 + and module.weight.shape[1] >= 64 + ): + nn.init.orthogonal_(module.weight, gain=1.0) + + def _bank_weights(self, i): + n = self.num_layers + return ( + self.qo_bank[i], + self.kv_bank[i], + self.kv_bank[n + i], + self.qo_bank[n + i], + self.mlp_up_bank[i], + self.mlp_down_bank[i], + ) + + def _parallel_block( + self, block_idx, lane0, lane1, x0, + q_w, k_w, v_w, out_w, up_w, down_w, + cu_seqlens=None, max_seqlen=0, + ): + block = self.blocks[block_idx] + mix = block.resid_mix.to(dtype=lane0.dtype) + attn_read = mix[0][None, None, :] * lane0 + mix[1][None, None, :] * x0 + attn_out = block.attn( + block.attn_norm(attn_read) * block.ln_scale_factor, + q_w, k_w, v_w, out_w, + cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, + ) + attn_out = block.attn_scale.to(dtype=attn_out.dtype)[None, None, :] * attn_out + mlp_read = lane1 + mlp_out = block.mlp_scale.to(dtype=lane1.dtype)[None, None, :] * block.mlp( + block.mlp_norm(mlp_read) * block.ln_scale_factor, up_w, down_w + ) + attn_resid = self.parallel_resid_lambdas[block_idx, 0].to(dtype=lane0.dtype) + attn_post = self.parallel_post_lambdas[block_idx, 0].to(dtype=lane0.dtype) + mlp_resid = self.parallel_resid_lambdas[block_idx, 1].to(dtype=lane0.dtype) + mlp_post = self.parallel_post_lambdas[block_idx, 1].to(dtype=lane0.dtype) + lane0 = attn_resid * lane0 + attn_post[0] * attn_out + mlp_post[0] * mlp_out + lane1 = mlp_resid * lane1 + attn_post[1] * attn_out + mlp_post[1] * mlp_out + return lane0, lane1 + + def _final_parallel_hidden(self, lane0, lane1): + if self.parallel_final_lane == "mlp": + return lane1 + if self.parallel_final_lane == "attn": + return lane0 + return 0.5 * (lane0 + lane1) + + def forward_logits(self, input_ids, cu_seqlens=None, max_seqlen=0): + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips = [] + enc_iter = ( + self.encoder_indices + if self.looping_active + else range(self.num_encoder_layers) + ) + dec_iter = ( + self.decoder_indices + if self.looping_active + else range( + self.num_encoder_layers, + self.num_encoder_layers + self.num_decoder_layers, + ) + ) + for i in enc_iter: + q_w, k_w, v_w, out_w, up_w, down_w = self._bank_weights(i) + x = self.blocks[i](x, x0, q_w, k_w, v_w, out_w, up_w, down_w, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen) + skips.append(x) + psl = self.parallel_start_layer + lane0 = None + lane1 = None + for skip_idx, i in enumerate(dec_iter): + q_w, k_w, v_w, out_w, up_w, down_w = self._bank_weights(i) + if i >= psl and psl > 0: + if lane0 is None: + lane0 = x + lane1 = x + if skip_idx < self.num_skip_weights and skips: + skip = skips.pop() + w = self.skip_weights[skip_idx].to(dtype=lane0.dtype)[None, None, :] + if self.skip_gates is not None: + g = torch.sigmoid(self.skip_gates[skip_idx].to(dtype=lane0.dtype))[None, None, :] + lane0 = torch.lerp(w * skip, lane0, g) + else: + lane0 = lane0 + w * skip + lane0, lane1 = self._parallel_block( + i, lane0, lane1, x0, q_w, k_w, v_w, out_w, up_w, down_w, + cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, + ) + else: + if skip_idx < self.num_skip_weights and skips: + scaled_skip = ( + self.skip_weights[skip_idx].to(dtype=x.dtype)[None, None, :] + * skips.pop() + ) + if self.skip_gates is not None: + g = torch.sigmoid(self.skip_gates[skip_idx].to(dtype=x.dtype))[None, None, :] + x = torch.lerp(scaled_skip, x, g) + else: + x = x + scaled_skip + x = self.blocks[i](x, x0, q_w, k_w, v_w, out_w, up_w, down_w, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen) + if lane0 is not None: + x = self._final_parallel_hidden(lane0, lane1) + x = self.final_norm(x) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + logits_proj = self.lm_head(x) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + + def forward(self, input_ids, target_ids, cu_seqlens=None, max_seqlen=0): + logits = self.forward_logits( + input_ids, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen + ) + return F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + target_ids.reshape(-1), + reduction="mean", + ) + + def forward_ttt(self, input_ids, target_ids, lora): + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips = [] + enc_iter = ( + self.encoder_indices + if self.looping_active + else list(range(self.num_encoder_layers)) + ) + dec_iter = ( + self.decoder_indices + if self.looping_active + else list( + range( + self.num_encoder_layers, + self.num_encoder_layers + self.num_decoder_layers, + ) + ) + ) + slot = 0 + for i in enc_iter: + q_w, k_w, v_w, out_w, up_w, down_w = self._bank_weights(i) + x = self._block_with_lora(self.blocks[i], x, x0, lora, slot, q_w, k_w, v_w, out_w, up_w, down_w) + slot += 1 + skips.append(x) + psl = self.parallel_start_layer + lane0 = None + lane1 = None + for skip_idx, i in enumerate(dec_iter): + q_w, k_w, v_w, out_w, up_w, down_w = self._bank_weights(i) + if i >= psl and psl > 0: + if lane0 is None: + lane0 = x + lane1 = x + if skip_idx < self.num_skip_weights and skips: + skip = skips.pop() + w = self.skip_weights[skip_idx].to(dtype=lane0.dtype)[None, None, :] + if self.skip_gates is not None: + g = torch.sigmoid(self.skip_gates[skip_idx].to(dtype=lane0.dtype))[None, None, :] + lane0 = torch.lerp(w * skip, lane0, g) + else: + lane0 = lane0 + w * skip + lane0, lane1 = self._parallel_block_with_lora( + i, lane0, lane1, x0, lora, slot, + q_w, k_w, v_w, out_w, up_w, down_w, + ) + else: + if skip_idx < self.num_skip_weights and skips: + scaled_skip = ( + self.skip_weights[skip_idx].to(dtype=x.dtype)[None, None, :] + * skips.pop() + ) + if self.skip_gates is not None: + g = torch.sigmoid(self.skip_gates[skip_idx].to(dtype=x.dtype))[None, None, :] + x = torch.lerp(scaled_skip, x, g) + else: + x = x + scaled_skip + x = self._block_with_lora(self.blocks[i], x, x0, lora, slot, q_w, k_w, v_w, out_w, up_w, down_w) + slot += 1 + if lane0 is not None: + x = self._final_parallel_hidden(lane0, lane1) + x = self.final_norm(x) + if self.tie_embeddings: + logits = F.linear(x, self.tok_emb.weight) + else: + logits = self.lm_head(x) + logits = logits + lora.lm_head_lora(x) + logits = self.logit_softcap * torch.tanh(logits / self.logit_softcap) + bsz, sl, V = logits.shape + return F.cross_entropy( + logits.float().reshape(-1, V), target_ids.reshape(-1), reduction="none" + ).reshape(bsz, sl) + + def _block_with_lora(self, block, x, x0, lora, slot, q_w, k_w, v_w, out_w, up_w, down_w): + mix = block.resid_mix.to(dtype=x.dtype) + x_in = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + n = block.attn_norm(x_in) * block.ln_scale_factor + attn = block.attn + bsz, seqlen, dim = n.shape + q = (F.linear(n, q_w.to(n.dtype)) + lora.q_loras[slot](n)).reshape( + bsz, seqlen, attn.num_heads, attn.head_dim + ) + k = F.linear(n, k_w.to(n.dtype)) + if lora.k_loras is not None: + k = k + lora.k_loras[slot](n) + k = k.reshape(bsz, seqlen, attn.num_kv_heads, attn.head_dim) + v = (F.linear(n, v_w.to(n.dtype)) + lora.v_loras[slot](n)).reshape( + bsz, seqlen, attn.num_kv_heads, attn.head_dim + ) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = attn.rotary(seqlen, n.device, q.dtype) + q = apply_rotary_emb(q, cos, sin, attn.rope_dims) + k = apply_rotary_emb(k, cos, sin, attn.rope_dims) + q = q * attn.q_gain.to(dtype=q.dtype)[None, None, :, None] + y = flash_attn_3_func(q, k, v, causal=True) + if attn.use_xsa: + y = attn._xsa_efficient(y, v) + y = y.reshape(bsz, seqlen, dim) + attn_out = F.linear(y, out_w.to(n.dtype)) + if lora.o_loras is not None: + attn_out = attn_out + lora.o_loras[slot](n) + x_out = x_in + block.attn_scale.to(dtype=x_in.dtype)[None, None, :] * attn_out + mlp_n = block.mlp_norm(x_out) * block.ln_scale_factor + mlp_out = block.mlp(mlp_n, up_w, down_w) + if lora.mlp_loras is not None: + mlp_out = mlp_out + lora.mlp_loras[slot](mlp_n) + x_out = x_out + block.mlp_scale.to(dtype=x_out.dtype)[None, None, :] * mlp_out + return x_out + + def _parallel_block_with_lora( + self, block_idx, lane0, lane1, x0, lora, slot, + q_w, k_w, v_w, out_w, up_w, down_w, + ): + block = self.blocks[block_idx] + mix = block.resid_mix.to(dtype=lane0.dtype) + attn_read = mix[0][None, None, :] * lane0 + mix[1][None, None, :] * x0 + n = block.attn_norm(attn_read) * block.ln_scale_factor + attn = block.attn + bsz, seqlen, dim = n.shape + q = (F.linear(n, q_w.to(n.dtype)) + lora.q_loras[slot](n)).reshape( + bsz, seqlen, attn.num_heads, attn.head_dim + ) + k = F.linear(n, k_w.to(n.dtype)) + if lora.k_loras is not None: + k = k + lora.k_loras[slot](n) + k = k.reshape(bsz, seqlen, attn.num_kv_heads, attn.head_dim) + v = (F.linear(n, v_w.to(n.dtype)) + lora.v_loras[slot](n)).reshape( + bsz, seqlen, attn.num_kv_heads, attn.head_dim + ) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = attn.rotary(seqlen, n.device, q.dtype) + q = apply_rotary_emb(q, cos, sin, attn.rope_dims) + k = apply_rotary_emb(k, cos, sin, attn.rope_dims) + q = q * attn.q_gain.to(dtype=q.dtype)[None, None, :, None] + y = flash_attn_3_func(q, k, v, causal=True) + if attn.use_xsa: + y = attn._xsa_efficient(y, v) + y = y.reshape(bsz, seqlen, dim) + attn_out = F.linear(y, out_w.to(n.dtype)) + if lora.o_loras is not None: + attn_out = attn_out + lora.o_loras[slot](n) + attn_out = block.attn_scale.to(dtype=attn_out.dtype)[None, None, :] * attn_out + mlp_read = lane1 + mlp_n = block.mlp_norm(mlp_read) * block.ln_scale_factor + mlp_out = block.mlp(mlp_n, up_w, down_w) + if lora.mlp_loras is not None: + mlp_out = mlp_out + lora.mlp_loras[slot](mlp_n) + mlp_out = block.mlp_scale.to(dtype=lane1.dtype)[None, None, :] * mlp_out + attn_resid = self.parallel_resid_lambdas[block_idx, 0].to(dtype=lane0.dtype) + attn_post = self.parallel_post_lambdas[block_idx, 0].to(dtype=lane0.dtype) + mlp_resid = self.parallel_resid_lambdas[block_idx, 1].to(dtype=lane0.dtype) + mlp_post = self.parallel_post_lambdas[block_idx, 1].to(dtype=lane0.dtype) + lane0 = attn_resid * lane0 + attn_post[0] * attn_out + mlp_post[0] * mlp_out + lane1 = mlp_resid * lane1 + attn_post[1] * attn_out + mlp_post[1] * mlp_out + return lane0, lane1 + + +class BatchedLinearLoRA(nn.Module): + # Novel: rank-scaled output (alpha/rank), like standard LoRA. Decouples + # effective magnitude from rank so changing rank does not change LR scale. + _ALPHA = float(os.environ.get("TTT_LORA_ALPHA", "144")) + + def __init__(self, bsz, in_features, out_features, rank): + super().__init__() + self._bound = 1.0 / math.sqrt(in_features) + self._scale = self._ALPHA / rank + self.A = nn.Parameter( + torch.empty(bsz, rank, in_features).uniform_(-self._bound, self._bound) + ) + self.B = nn.Parameter(torch.zeros(bsz, out_features, rank)) + + _WARM_START_A = bool(int(os.environ.get("TTT_WARM_START_A", "1"))) + + def reset(self): + with torch.no_grad(): + # Novel: optionally keep A warm across batch resets (accumulates feature directions). + if not self._WARM_START_A: + self.A.uniform_(-self._bound, self._bound) + self.B.zero_() + + def forward(self, x): + return ((x @ self.A.transpose(1, 2)) @ self.B.transpose(1, 2)) * self._scale + + +class BatchedTTTLoRA(nn.Module): + def __init__(self, bsz, model, rank, k_lora=True, mlp_lora=True, o_lora=True): + super().__init__() + self.bsz = bsz + dim = model.qo_bank.shape[-1] + vocab = model.tok_emb.num_embeddings + if getattr(model, "looping_active", False): + num_slots = len(model.encoder_indices) + len(model.decoder_indices) + else: + num_slots = len(model.blocks) + kv_dim = model.blocks[0].attn.num_kv_heads * ( + dim // model.blocks[0].attn.num_heads + ) + embed_dim = model.tok_emb.embedding_dim + self.lm_head_lora = BatchedLinearLoRA(bsz, embed_dim, vocab, rank) + self.q_loras = nn.ModuleList( + [BatchedLinearLoRA(bsz, dim, dim, rank) for _ in range(num_slots)] + ) + self.v_loras = nn.ModuleList( + [BatchedLinearLoRA(bsz, dim, kv_dim, rank) for _ in range(num_slots)] + ) + self.k_loras = ( + nn.ModuleList( + [BatchedLinearLoRA(bsz, dim, kv_dim, rank) for _ in range(num_slots)] + ) + if k_lora + else None + ) + self.mlp_loras = ( + nn.ModuleList( + [BatchedLinearLoRA(bsz, dim, dim, rank) for _ in range(num_slots)] + ) + if mlp_lora + else None + ) + self.o_loras = ( + nn.ModuleList( + [BatchedLinearLoRA(bsz, dim, dim, rank) for _ in range(num_slots)] + ) + if o_lora + else None + ) + + def reset(self): + with torch.no_grad(): + self.lm_head_lora.reset() + for loras in [self.q_loras, self.v_loras, self.k_loras, + self.mlp_loras, self.o_loras]: + if loras is not None: + for lora in loras: + lora.reset() + + +@torch.compile +def zeropower_via_newtonschulz5(G, steps=10, eps=1e-07): + a, b, c = 3.4445, -4.775, 2.0315 + was_2d = G.ndim == 2 + if was_2d: + G = G.unsqueeze(0) + X = G.bfloat16() + transposed = X.size(-2) > X.size(-1) + if transposed: + X = X.mT + X = X / (X.norm(dim=(-2, -1), keepdim=True) + eps) + for _ in range(steps): + A = X @ X.mT + B = b * A + c * (A @ A) + X = a * X + B @ X + if transposed: + X = X.mT + if was_2d: + X = X.squeeze(0) + return X + + +class Muon(torch.optim.Optimizer): + def __init__( + self, + params, + lr, + momentum, + backend_steps, + nesterov=True, + weight_decay=0.0, + row_normalize=False, + ): + super().__init__( + params, + dict( + lr=lr, + momentum=momentum, + backend_steps=backend_steps, + nesterov=nesterov, + weight_decay=weight_decay, + row_normalize=row_normalize, + ), + ) + self._built = False + + def _build(self): + self._distributed = dist.is_available() and dist.is_initialized() + self._world_size = dist.get_world_size() if self._distributed else 1 + self._rank = dist.get_rank() if self._distributed else 0 + ws = self._world_size + self._bank_meta = [] + for group in self.param_groups: + for p in group["params"]: + B = p.shape[0] + padded_B = ((B + ws - 1) // ws) * ws + shard_B = padded_B // ws + tail = p.shape[1:] + dev = p.device + self._bank_meta.append({ + "p": p, + "B": B, + "padded_grad": torch.zeros(padded_B, *tail, device=dev, dtype=torch.bfloat16), + "shard": torch.zeros(shard_B, *tail, device=dev, dtype=torch.bfloat16), + "shard_mom": torch.zeros(shard_B, *tail, device=dev, dtype=torch.bfloat16), + "full_update": torch.zeros(padded_B, *tail, device=dev, dtype=torch.bfloat16), + "scale": max(1, p.shape[-2] / p.shape[-1]) ** 0.5, + }) + self._bank_meta.sort(key=lambda m: -m["p"].numel()) + self._built = True + + def launch_reduce_scatters(self): + if not self._built: + self._build() + if not self._distributed: + return + self._rs_futures = [] + for m in self._bank_meta: + p = m["p"] + if p.grad is None: + self._rs_futures.append(None) + continue + pg = m["padded_grad"] + pg[: m["B"]].copy_(p.grad.bfloat16()) + if pg.shape[0] > m["B"]: + pg[m["B"] :].zero_() + fut = dist.reduce_scatter_tensor( + m["shard"], pg, op=dist.ReduceOp.AVG, async_op=True + ) + self._rs_futures.append(fut) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + if not self._built: + self._build() + for group in self.param_groups: + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + wd = group.get("weight_decay", 0.0) + row_normalize = group.get("row_normalize", False) + prev_ag_handle = None + prev_m = None + sharded = self._distributed and hasattr(self, "_rs_futures") + for idx, m in enumerate(self._bank_meta): + p = m["p"] + if p.grad is None: + continue + if prev_ag_handle is not None: + prev_ag_handle.wait() + pp = prev_m["p"] + upd = prev_m["full_update"][: prev_m["B"]] + if wd > 0.0: + pp.data.mul_(1.0 - lr * wd) + pp.add_(upd.to(dtype=pp.dtype), alpha=-lr * prev_m["scale"]) + if sharded and self._rs_futures[idx] is not None: + self._rs_futures[idx].wait() + g = m["shard"] + buf = m["shard_mom"] + else: + g = p.grad.bfloat16() + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + update = g.add(buf, alpha=momentum) + else: + update = buf + if row_normalize: + rn = update.float().norm(dim=-1, keepdim=True).clamp_min(1e-07) + update = update / rn.to(update.dtype) + update = zeropower_via_newtonschulz5(update, steps=backend_steps) + if sharded: + prev_ag_handle = dist.all_gather_into_tensor( + m["full_update"], update, async_op=True + ) + prev_m = m + else: + if wd > 0.0: + p.data.mul_(1.0 - lr * wd) + p.add_(update.to(dtype=p.dtype), alpha=-lr * m["scale"]) + if prev_ag_handle is not None: + prev_ag_handle.wait() + pp = prev_m["p"] + upd = prev_m["full_update"][: prev_m["B"]] + if wd > 0.0: + pp.data.mul_(1.0 - lr * wd) + pp.add_(upd.to(dtype=pp.dtype), alpha=-lr * prev_m["scale"]) + if hasattr(self, "_rs_futures"): + del self._rs_futures + return loss + + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,skip_gates,parallel_post_lambdas,parallel_resid_lambdas", + ).split(",") + if pattern +) + + +PACKED_REPLICATED_GRAD_MAX_NUMEL = 1 << 15 + + +class Optimizers: + def __init__(self, h, base_model): + matrix_params = [ + base_model.qo_bank, + base_model.kv_bank, + base_model.mlp_up_bank, + base_model.mlp_down_bank, + ] + block_named_params = list(base_model.blocks.named_parameters()) + scalar_params = [ + p + for (name, p) in block_named_params + if p.ndim < 2 + or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + if base_model.skip_gates is not None and base_model.skip_gates.numel() > 0: + scalar_params.append(base_model.skip_gates) + if base_model.parallel_post_lambdas is not None: + scalar_params.append(base_model.parallel_post_lambdas) + if base_model.parallel_resid_lambdas is not None: + scalar_params.append(base_model.parallel_resid_lambdas) + token_lr = h.tied_embed_lr if h.tie_embeddings else h.embed_lr + tok_params = [ + {"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr} + ] + self.optimizer_tok = torch.optim.AdamW( + tok_params, + betas=(h.beta1, h.beta2), + eps=h.adam_eps, + weight_decay=h.embed_wd, + fused=True, + ) + self.optimizer_muon = Muon( + matrix_params, + lr=h.matrix_lr, + momentum=h.muon_momentum, + backend_steps=h.muon_backend_steps, + weight_decay=h.muon_wd, + row_normalize=h.muon_row_normalize, + ) + for group in self.optimizer_muon.param_groups: + group["base_lr"] = h.matrix_lr + self.optimizer_scalar = torch.optim.AdamW( + [{"params": scalar_params, "lr": h.scalar_lr, "base_lr": h.scalar_lr}], + betas=(h.beta1, h.beta2), + eps=h.adam_eps, + weight_decay=h.adam_wd, + fused=True, + ) + self.optimizers = [ + self.optimizer_tok, + self.optimizer_muon, + self.optimizer_scalar, + ] + self.replicated_params = list(tok_params[0]["params"]) + self.replicated_params.extend(scalar_params) + self.replicated_large_params = [] + self.replicated_packed_params = [] + for p in self.replicated_params: + if p.numel() <= PACKED_REPLICATED_GRAD_MAX_NUMEL: + self.replicated_packed_params.append(p) + else: + self.replicated_large_params.append(p) + + def __iter__(self): + return iter(self.optimizers) + + def zero_grad_all(self): + for opt in self.optimizers: + opt.zero_grad(set_to_none=True) + + def _all_reduce_packed_grads(self): + grads_by_key = collections.defaultdict(list) + for p in self.replicated_packed_params: + if p.grad is not None: + grads_by_key[(p.grad.device, p.grad.dtype)].append(p.grad) + for grads in grads_by_key.values(): + flat = torch.empty( + sum(g.numel() for g in grads), + device=grads[0].device, + dtype=grads[0].dtype, + ) + offset = 0 + for g in grads: + n = g.numel() + flat[offset : offset + n].copy_(g.contiguous().view(-1)) + offset += n + dist.all_reduce(flat, op=dist.ReduceOp.AVG) + offset = 0 + for g in grads: + n = g.numel() + g.copy_(flat[offset : offset + n].view_as(g)) + offset += n + + def step(self, distributed=False): + self.optimizer_muon.launch_reduce_scatters() + if distributed: + reduce_handles = [ + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG, async_op=True) + for p in self.replicated_large_params + if p.grad is not None + ] + self._all_reduce_packed_grads() + for handle in reduce_handles: + handle.wait() + self.optimizer_tok.step() + self.optimizer_scalar.step() + self.optimizer_muon.step() + self.zero_grad_all() + + +def restore_fp32_params(model): + for module in model.modules(): + if isinstance(module, CastedLinear): + module.float() + for name, param in model.named_parameters(): + if ( + param.ndim < 2 + or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ) and param.dtype != torch.float32: + param.data = param.data.float() + if hasattr(model, "qo_bank") and model.qo_bank is not None: + model.qo_bank.data = model.qo_bank.data.float() + model.kv_bank.data = model.kv_bank.data.float() + model.mlp_up_bank.data = model.mlp_up_bank.data.float() + model.mlp_down_bank.data = model.mlp_down_bank.data.float() + + +def collect_hessians(model, train_loader, h, device, n_calibration_batches=64): + hessians = {} + hooks = [] + for i, block in enumerate(model.blocks): + block.attn._calib = True + block.mlp._calib = True + block.mlp.use_fused = False + + def make_attn_hook(layer_idx): + def hook_fn(module, inp, out): + x = inp[0].detach().float() + if x.ndim == 3: + x = x.reshape(-1, x.shape[-1]) + for suffix in ["c_q", "c_k", "c_v"]: + name = f"blocks.{layer_idx}.attn.{suffix}.weight" + if name not in hessians: + hessians[name] = torch.zeros( + x.shape[1], x.shape[1], dtype=torch.float32, device=device + ) + hessians[name].addmm_(x.T, x) + y = module._last_proj_input + if y is not None: + y = y.float() + if y.ndim == 3: + y = y.reshape(-1, y.shape[-1]) + name = f"blocks.{layer_idx}.attn.proj.weight" + if name not in hessians: + hessians[name] = torch.zeros( + y.shape[1], y.shape[1], dtype=torch.float32, device=device + ) + hessians[name].addmm_(y.T, y) + return hook_fn + + def make_mlp_hook(layer_idx): + def hook_fn(module, inp, out): + x = inp[0].detach().float() + if x.ndim == 3: + x = x.reshape(-1, x.shape[-1]) + name = f"blocks.{layer_idx}.mlp.fc.weight" + if name not in hessians: + hessians[name] = torch.zeros( + x.shape[1], x.shape[1], dtype=torch.float32, device=device + ) + hessians[name].addmm_(x.T, x) + h_act = module._last_down_input + if h_act is not None: + h_act = h_act.float() + if h_act.ndim == 3: + h_act = h_act.reshape(-1, h_act.shape[-1]) + name = f"blocks.{layer_idx}.mlp.proj.weight" + if name not in hessians: + hessians[name] = torch.zeros( + h_act.shape[1], h_act.shape[1], dtype=torch.float32, device=device + ) + hessians[name].addmm_(h_act.T, h_act) + return hook_fn + + for i, block in enumerate(model.blocks): + hooks.append(block.attn.register_forward_hook(make_attn_hook(i))) + hooks.append(block.mlp.register_forward_hook(make_mlp_hook(i))) + + # Hessian hooks for embedding factorization projection layers + def make_linear_input_hook(weight_name): + def hook_fn(module, inp, out): + x = inp[0].detach().float() + if x.ndim == 3: + x = x.reshape(-1, x.shape[-1]) + if weight_name not in hessians: + hessians[weight_name] = torch.zeros( + x.shape[1], x.shape[1], dtype=torch.float32, device=device + ) + hessians[weight_name].addmm_(x.T, x) + return hook_fn + + if model.tie_embeddings: + hook_module = model.final_norm + + def make_output_hook(name): + def hook_fn(module, inp, out): + x = out.detach().float() + if x.ndim == 3: + x = x.reshape(-1, x.shape[-1]) + if name not in hessians: + hessians[name] = torch.zeros( + x.shape[1], x.shape[1], dtype=torch.float32, device=device + ) + hessians[name].addmm_(x.T, x) + return hook_fn + + hooks.append( + hook_module.register_forward_hook(make_output_hook("tok_emb.weight")) + ) + model.eval() + with torch.no_grad(): + for _ in range(n_calibration_batches): + x, _ = train_loader.next_batch(h.train_batch_tokens, h.grad_accum_steps) + model.forward_logits(x) + for hook in hooks: + hook.remove() + for i, block in enumerate(model.blocks): + block.attn._calib = False + block.mlp._calib = False + block.mlp.use_fused = True + for name in hessians: + hessians[name] = hessians[name].cpu() / n_calibration_batches + return hessians + + +def gptq_quantize_weight(w, H, clip_sigmas=3.0, clip_range=63, block_size=128): + W_orig = w.float().clone() + rows, cols = W_orig.shape + H = H.float().clone() + dead = torch.diag(H) == 0 + H[dead, dead] = 1 + damp = 0.01 * H.diag().mean() + H.diagonal().add_(damp) + perm = torch.argsort(H.diag(), descending=True) + invperm = torch.argsort(perm) + W_perm = W_orig[:, perm].clone() + W_perm[:, dead[perm]] = 0 + H = H[perm][:, perm] + Hinv = torch.cholesky_inverse(torch.linalg.cholesky(H)) + Hinv = torch.linalg.cholesky(Hinv, upper=True) + row_std = W_orig.std(dim=1) + s = (clip_sigmas * row_std / clip_range).clamp_min(1e-10).to(torch.float16) + sf = s.float() + Q = torch.zeros(rows, cols, dtype=torch.int8) + W_work = W_perm.clone() + for i1 in range(0, cols, block_size): + i2 = min(i1 + block_size, cols) + W_block = W_work[:, i1:i2].clone() + Hinv_block = Hinv[i1:i2, i1:i2] + Err = torch.zeros(rows, i2 - i1) + for j in range(i2 - i1): + w_col = W_block[:, j] + d = Hinv_block[j, j] + q_col = torch.clamp(torch.round(w_col / sf), -clip_range, clip_range) + Q[:, i1 + j] = q_col.to(torch.int8) + err = (w_col - q_col.float() * sf) / d + Err[:, j] = err + W_block[:, j:] -= err.unsqueeze(1) * Hinv_block[j, j:].unsqueeze(0) + if i2 < cols: + W_work[:, i2:] -= Err @ Hinv[i1:i2, i2:] + return Q[:, invperm], s + + +def gptq_mixed_quantize(state_dict, hessians, h): + result = {} + meta = {} + for (name, tensor) in state_dict.items(): + t = tensor.detach().cpu().contiguous() + if not t.is_floating_point() or t.numel() <= 65536: + result[name] = t.to(torch.float16) if t.is_floating_point() else t + meta[name] = "passthrough (float16)" + continue + if "tok_emb" in name: + cs = h.embed_clip_sigmas + elif ".mlp." in name: + cs = h.mlp_clip_sigmas + elif ".attn." in name: + cs = h.attn_clip_sigmas + else: + cs = h.matrix_clip_sigmas + bits = h.embed_bits if "tok_emb" in name else h.matrix_bits + clip_range = 2 ** (bits - 1) - 1 + ret = gptq_quantize_weight( + t, hessians[name], clip_sigmas=cs, clip_range=clip_range + ) + q, s = ret + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = f"gptq (int{bits})" + categories = collections.defaultdict(set) + for (name, cat) in meta.items(): + short = re.sub("\\.\\d+$", "", re.sub("blocks\\.\\d+", "blocks", name)) + categories[cat].add(short) + log("Quantized weights:") + for cat in sorted(categories): + log(f" {cat}: {', '.join(sorted(categories[cat]))}") + return result, meta + + +def dequantize_mixed(result, meta, template_sd): + out = {} + for (name, orig) in template_sd.items(): + info = meta.get(name) + if info is None: + continue + orig_dtype = orig.dtype + if "passthrough" in info: + t = result[name] + if t.dtype == torch.float16 and orig_dtype in ( + torch.float32, + torch.bfloat16, + ): + t = t.to(orig_dtype) + out[name] = t + continue + q, s = result[name + ".q"], result[name + ".scale"] + if s.ndim > 0: + out[name] = ( + q.float() * s.float().view(q.shape[0], *[1] * (q.ndim - 1)) + ).to(orig_dtype) + else: + out[name] = (q.float() * float(s.item())).to(orig_dtype) + return out + + +_BSHF_MAGIC = b"BSHF" + + +def _byte_shuffle(data, stride=2): + if stride <= 1 or len(data) < stride: + return data + src = np.frombuffer(data, dtype=np.uint8) + n = len(src) + out = np.empty(n, dtype=np.uint8) + dest_off = 0 + for pos in range(stride): + chunk = src[pos::stride] + out[dest_off : dest_off + len(chunk)] = chunk + dest_off += len(chunk) + return _BSHF_MAGIC + bytes([stride]) + out.tobytes() + + +def _byte_unshuffle(data): + if len(data) < 5 or data[:4] != _BSHF_MAGIC: + return data + stride = data[4] + if stride < 2: + return data[5:] + payload = np.frombuffer(data, dtype=np.uint8, offset=5) + n = len(payload) + out = np.empty(n, dtype=np.uint8) + src_off = 0 + for pos in range(stride): + chunk_len = n // stride + (1 if pos < n % stride else 0) + out[pos::stride][:chunk_len] = payload[src_off : src_off + chunk_len] + src_off += chunk_len + return out.tobytes() + + +def _compress(data, compressor): + data = _byte_shuffle(data) + if compressor == "lzma": + return lzma.compress(data, preset=6) + elif compressor == "brotli": + import brotli + + return brotli.compress(data, quality=11) + raise ValueError(f"Unknown compressor: {compressor!r}") + + +def _decompress(data, compressor): + if compressor == "lzma": + raw = lzma.decompress(data) + elif compressor == "brotli": + import brotli + + raw = brotli.decompress(data) + else: + raise ValueError(f"Unknown compressor: {compressor!r}") + raw = _byte_unshuffle(raw) + return raw + + +def _unbank_state_dict(state_dict, num_layers): + sd = {} + n = num_layers + for k, v in state_dict.items(): + t = v.detach().cpu() if v is not None else None + if k == "qo_bank": + for i in range(n): + sd[f"blocks.{i}.attn.c_q.weight"] = t[i] + sd[f"blocks.{i}.attn.proj.weight"] = t[n + i] + elif k == "kv_bank": + for i in range(n): + sd[f"blocks.{i}.attn.c_k.weight"] = t[i] + sd[f"blocks.{i}.attn.c_v.weight"] = t[n + i] + elif k == "mlp_up_bank": + for i in range(n): + sd[f"blocks.{i}.mlp.fc.weight"] = t[i] + elif k == "mlp_down_bank": + for i in range(n): + sd[f"blocks.{i}.mlp.proj.weight"] = t[i] + else: + if t is not None: + sd[k] = t + return sd + + +def _rebank_state_dict(flat_sd, num_layers, model_dim, kv_dim, hidden_dim): + sd = {} + n = num_layers + sd["qo_bank"] = torch.zeros(2 * n, model_dim, model_dim) + sd["kv_bank"] = torch.zeros(2 * n, kv_dim, model_dim) + for i in range(n): + sd["qo_bank"][i] = flat_sd[f"blocks.{i}.attn.c_q.weight"] + sd["qo_bank"][n + i] = flat_sd[f"blocks.{i}.attn.proj.weight"] + sd["kv_bank"][i] = flat_sd[f"blocks.{i}.attn.c_k.weight"] + sd["kv_bank"][n + i] = flat_sd[f"blocks.{i}.attn.c_v.weight"] + sd["mlp_up_bank"] = torch.zeros(n, hidden_dim, model_dim) + sd["mlp_down_bank"] = torch.zeros(n, model_dim, hidden_dim) + for i in range(n): + sd["mlp_up_bank"][i] = flat_sd[f"blocks.{i}.mlp.fc.weight"] + sd["mlp_down_bank"][i] = flat_sd[f"blocks.{i}.mlp.proj.weight"] + for k, v in flat_sd.items(): + if not ( + k.startswith("blocks.") + and any( + p in k + for p in [ + ".attn.c_q.", ".attn.c_k.", ".attn.c_v.", + ".attn.proj.", ".mlp.fc.", ".mlp.proj.", + ] + ) + ): + sd[k] = v + return sd + + + +def _compressed_code_size(code): + code_raw = code.encode("utf-8") + minified = subprocess.run( + ["pyminify", "--no-rename-locals", "--no-hoist-literals", "--remove-literal-statements", "-"], + input=code_raw, capture_output=True, check=True, + ).stdout + compressed = lzma.compress(minified) + encoded = base64.b85encode(compressed) + wrapper = b'import lzma as L,base64 as B\nexec(L.decompress(B.b85decode("' + encoded + b'")))\n' + return len(code_raw), len(wrapper) + + +def serialize(h, base_model, code): + code_bytes_uncompressed, code_bytes = _compressed_code_size(code) + if h.is_main_process: + torch.save(base_model.state_dict(), h.model_path) + model_bytes = os.path.getsize(h.model_path) + log(f"Serialized model: {model_bytes} bytes") + log(f"Code size (uncompressed): {code_bytes_uncompressed} bytes") + log(f"Code size (compressed): {code_bytes} bytes") + sd_cpu = _unbank_state_dict(base_model.state_dict(), h.num_layers) + device = torch.device("cuda", h.local_rank) + log("GPTQ:collecting Hessians from calibration data...") + t0 = time.perf_counter() + calib_loader = ShuffledSequenceLoader(h, device) + hessians = collect_hessians( + base_model, + calib_loader, + h, + device, + n_calibration_batches=h.gptq_calibration_batches, + ) + log(f"GPTQ:collected {len(hessians)} Hessians in {time.perf_counter()-t0:.1f}s") + quant_result, quant_meta = gptq_mixed_quantize(sd_cpu, hessians, h) + quant_buf = io.BytesIO() + torch.save({"w": quant_result, "m": quant_meta}, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = _compress(quant_raw, h.compressor) + quant_file_bytes = len(quant_blob) + bytes_total = quant_file_bytes + code_bytes + if h.is_main_process: + with open(h.quantized_model_path, "wb") as f: + f.write(quant_blob) + log(f"Serialized model quantized+{h.compressor}: {quant_file_bytes} bytes") + log(f"Total submission size quantized+{h.compressor}: {bytes_total} bytes") + return bytes_total, quant_file_bytes + + +def deserialize(h, device): + eval_model = GPT(h).to(device).bfloat16() + restore_fp32_params(eval_model) + flat_template = _unbank_state_dict(eval_model.state_dict(), h.num_layers) + with open(h.quantized_model_path, "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load( + io.BytesIO(_decompress(quant_blob_disk, h.compressor)), map_location="cpu" + ) + deq_flat = dequantize_mixed(quant_state["w"], quant_state["m"], flat_template) + head_dim = h.model_dim // h.num_heads + kv_dim = h.num_kv_heads * head_dim + hidden_dim = int(h.mlp_mult * h.model_dim) + deq_state = _rebank_state_dict(deq_flat, h.num_layers, h.model_dim, kv_dim, hidden_dim) + eval_model.load_state_dict(deq_state, strict=True) + return eval_model + + +def _loss_bpb(loss_sum, token_count, byte_count): + val_loss = (loss_sum / token_count).item() + val_bpb = val_loss / math.log(2.0) * (token_count.item() / byte_count.item()) + return val_loss, val_bpb + + +def eval_val(h, device, val_data, model, forward_logits_fn=None): + seq_len = h.eval_seq_len + local_batch_tokens = h.val_batch_tokens // (h.world_size * h.grad_accum_steps) + if local_batch_tokens < seq_len: + raise ValueError( + f"VAL_BATCH_SIZE must provide at least one sequence per rank; got VAL_BATCH_SIZE={h.val_batch_tokens}, WORLD_SIZE={h.world_size}, GRAD_ACCUM_STEPS={h.grad_accum_steps}, seq_len={seq_len}" + ) + local_batch_seqs = local_batch_tokens // seq_len + total_seqs = (val_data.val_tokens.numel() - 1) // seq_len + seq_start = total_seqs * h.rank // h.world_size + seq_end = total_seqs * (h.rank + 1) // h.world_size + + # TODO: Don't truncate this. + seq_end = seq_start + ((seq_end - seq_start) // local_batch_seqs) * local_batch_seqs + + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + run_forward_logits = ( + (model.module.forward_logits if hasattr(model, "module") else model.forward_logits) + if forward_logits_fn is None + else forward_logits_fn + ) + model.eval() + global BOS_ID + if BOS_ID is None: + BOS_ID = 1 + with torch.no_grad(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * seq_len + raw_end = batch_seq_end * seq_len + 1 + local = val_data.val_tokens[raw_start:raw_end].to( + device=device, dtype=torch.int64, non_blocking=True + ) + x = local[:-1] + y = local[1:] + bos_pos = (x == BOS_ID).nonzero(as_tuple=True)[0].tolist() + cu_seqlens, max_seqlen = _build_cu_seqlens( + bos_pos, x.numel(), x.device, h.eval_seq_len, 64 + ) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + logits = run_forward_logits( + x[None], cu_seqlens=cu_seqlens, max_seqlen=max_seqlen + ).detach() + per_token_loss = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y.reshape(-1), + reduction="none", + ) + val_loss_sum += per_token_loss.to(torch.float64).sum() + val_token_count += float(y.numel()) + prev_ids = x + tgt_ids = y + token_bytes = val_data.base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += ( + val_data.has_leading_space_lut[tgt_ids] + & ~val_data.is_boundary_token_lut[prev_ids] + ).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + model.train() + return _loss_bpb(val_loss_sum, val_token_count, val_byte_count) + + +def eval_val_sliding(h, device, val_data, base_model, forward_logits_fn=None, batch_seqs=32): + global BOS_ID + if BOS_ID is None: + BOS_ID = 1 + base_model.eval() + run_forward_logits = base_model.forward_logits if forward_logits_fn is None else forward_logits_fn + seq_len = h.eval_seq_len + stride = h.eval_stride + total_tokens = val_data.val_tokens.numel() - 1 + context_size = seq_len - stride + window_starts = [ws for ws in range(0, total_tokens, stride) + if ws + context_size < total_tokens] + total_windows = len(window_starts) + my_s = (total_windows * h.rank) // h.world_size + my_e = (total_windows * (h.rank + 1)) // h.world_size + my_windows = window_starts[my_s:my_e] + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + total_batches = (len(my_windows) + batch_seqs - 1) // batch_seqs + is_master = h.rank == 0 + cu_bucket = 64 + t_sw_start = time.perf_counter() + with torch.no_grad(): + for bi in range(0, len(my_windows), batch_seqs): + batch_idx = bi // batch_seqs + if is_master and (batch_idx % 50 == 0 or batch_idx == total_batches - 1): + elapsed = time.perf_counter() - t_sw_start + rl = float(loss_sum.item() / token_count.item()) if token_count.item() > 0 else 0.0 + rb = float((rl / math.log(2.0)) * token_count.item() / byte_count.item()) if byte_count.item() > 0 else 0.0 + log(f"sliding_progress: batch {batch_idx+1}/{total_batches} " + f"tokens:{int(token_count.item())} running_loss:{rl:.4f} running_bpb:{rb:.4f} " + f"elapsed:{elapsed:.1f}s") + batch_ws = my_windows[bi:bi + batch_seqs] + x_parts = [] + y_parts = [] + cu_starts = [] + score_ranges = [] + offset = 0 + for ws in batch_ws: + end = min(ws + seq_len, total_tokens) + wlen = end - ws + chunk_cpu = val_data.val_tokens[ws:end + 1] + bos_pos = (chunk_cpu[:-1] == BOS_ID).nonzero(as_tuple=True)[0].tolist() + if not bos_pos or bos_pos[0] != 0: + bos_pos = [0] + bos_pos + cu_starts.extend(offset + pos for pos in bos_pos) + chunk = chunk_cpu.to(dtype=torch.int64, device=device) + x_parts.append(chunk[:-1]) + y_parts.append(chunk[1:]) + score_ranges.append((offset, wlen, ws)) + offset += wlen + x_cat = torch.cat(x_parts, dim=0)[None] + y_cat = torch.cat(y_parts, dim=0) + boundaries = cu_starts + [offset] + padded_len = get_next_multiple_of_n(len(boundaries), cu_bucket) + cu_seqlens = torch.full((padded_len,), offset, dtype=torch.int32, device=device) + cu_seqlens[:len(boundaries)] = torch.tensor(boundaries, dtype=torch.int32, device=device) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = run_forward_logits(x_cat, cu_seqlens=cu_seqlens, max_seqlen=seq_len) + flat_nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_cat, + reduction="none", + ) + flat_x = x_cat.reshape(-1) + for off, wlen, ws in score_ranges: + s = 0 if ws == 0 else context_size + lo = off + s + hi = off + wlen + scored_nll = flat_nll[lo:hi].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(hi - lo) + tgt = y_cat[lo:hi] + prev = flat_x[lo:hi] + tb = val_data.base_bytes_lut[tgt].to(torch.float64) + tb += (val_data.has_leading_space_lut[tgt] & ~val_data.is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + base_model.train() + return _loss_bpb(loss_sum, token_count, byte_count) + + +def _find_docs(all_tokens): + bos_positions = (all_tokens == BOS_ID).nonzero(as_tuple=True)[0].numpy() + docs = [] + for i in range(len(bos_positions)): + start = int(bos_positions[i]) + end = ( + int(bos_positions[i + 1]) + if i + 1 < len(bos_positions) + else all_tokens.numel() + ) + if i + 1 < len(bos_positions): + end += 1 + assert end - start >= 2 + docs.append((start, end - start)) + return docs + + +def _build_ttt_global_batches(doc_entries, h, ascending=False): + batch_size = h.ttt_batch_size + global_doc_entries = sorted(doc_entries, key=lambda x: x[1][1]) + global_batches = [ + global_doc_entries[i : i + batch_size] + for i in range(0, len(global_doc_entries), batch_size) + ] + indexed = list(enumerate(global_batches)) + if not ascending: + indexed.sort(key=lambda ib: -max(dl for _, (_, dl) in ib[1])) + return indexed + + +def _init_batch_counter(path): + with open(path, "wb") as f: + f.write((0).to_bytes(4, "little")) + + +def _claim_next_batch(counter_path, queue_len): + try: + with open(counter_path, "r+b") as f: + fcntl.flock(f, fcntl.LOCK_EX) + idx = int.from_bytes(f.read(4), "little") + f.seek(0) + f.write((idx + 1).to_bytes(4, "little")) + f.flush() + except FileNotFoundError: + return queue_len + return idx + + +def _compute_chunk_window(ci, pred_len, num_chunks, chunk_size, eval_seq_len): + chunk_end = pred_len if ci == num_chunks - 1 else (ci + 1) * chunk_size + win_start = max(0, chunk_end - eval_seq_len) + win_len = chunk_end - win_start + chunk_start = ci * chunk_size + chunk_offset = chunk_start - win_start + chunk_len = chunk_end - chunk_start + return win_start, win_len, chunk_offset, chunk_len + + +def _accumulate_bpb( + ptl, + x, + y, + chunk_offsets, + chunk_lens, + pos_idx, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + loss_sum, + byte_sum, + token_count, +): + pos = pos_idx[: x.size(1)].unsqueeze(0) + mask = ( + (chunk_lens.unsqueeze(1) > 0) + & (pos >= chunk_offsets.unsqueeze(1)) + & (pos < (chunk_offsets + chunk_lens).unsqueeze(1)) + ) + mask_f64 = mask.to(torch.float64) + tok_bytes = base_bytes_lut[y].to(torch.float64) + tok_bytes += (has_leading_space_lut[y] & ~is_boundary_token_lut[x]).to( + torch.float64 + ) + loss_sum += (ptl.to(torch.float64) * mask_f64).sum() + byte_sum += (tok_bytes * mask_f64).sum() + token_count += chunk_lens.to(torch.float64).sum() + + +def _loss_bpb_from_sums(loss_sum, token_count, byte_sum): + val_loss = (loss_sum / token_count).item() + val_bpb = val_loss / math.log(2.0) * (token_count.item() / byte_sum.item()) + return val_loss, val_bpb + + +def _split_doc_entries_for_phased(doc_entries, prefix_docs): + prefix_docs = max(0, min(len(doc_entries), int(prefix_docs))) + return doc_entries[:prefix_docs], doc_entries[prefix_docs:] + + +def _add_to_counter(path, delta): + try: + with open(path, "r+b") as f: + fcntl.flock(f, fcntl.LOCK_EX) + cur = int.from_bytes(f.read(8), "little", signed=True) + cur += int(delta) + f.seek(0) + f.write(int(cur).to_bytes(8, "little", signed=True)) + f.flush() + return cur + except FileNotFoundError: + return int(delta) + + +def _init_int64_counter(path): + with open(path, "wb") as f: + f.write((0).to_bytes(8, "little", signed=True)) + + +def _select_ttt_doc_entries(docs, h): + doc_entries = list(enumerate(docs)) + if h.val_doc_fraction < 1.0: + sample_n = max(1, int(round(len(docs) * h.val_doc_fraction))) + sampled_indices = sorted( + random.Random(h.seed).sample(range(len(docs)), sample_n) + ) + return [(i, docs[i]) for i in sampled_indices] + return doc_entries + + +def train_val_ttt_global_sgd_distributed(h, device, val_data, base_model, val_tokens, batch_seqs=None): + global BOS_ID + if BOS_ID is None: + BOS_ID = 1 + base_model.eval() + seq_len = h.eval_seq_len + total_tokens = val_tokens.numel() - 1 + ttt_chunk = h.global_ttt_chunk_tokens + batch_seqs = h.global_ttt_batch_seqs if batch_seqs is None else batch_seqs + num_chunks = (total_tokens + ttt_chunk - 1) // ttt_chunk + ttt_params = [p for p in base_model.parameters()] + for p in ttt_params: + p.requires_grad_(True) + optimizer = torch.optim.SGD( + ttt_params, lr=h.global_ttt_lr, momentum=h.global_ttt_momentum + ) + t_start = time.perf_counter() + for ci in range(num_chunks): + chunk_start = ci * ttt_chunk + chunk_end = min((ci + 1) * ttt_chunk, total_tokens) + is_last_chunk = ci == num_chunks - 1 + if is_last_chunk or h.global_ttt_epochs <= 0: + continue + base_model.train() + chunk_seqs = (chunk_end - chunk_start) // seq_len + if chunk_seqs <= 0: + continue + warmup_chunks = max(0, min(h.global_ttt_warmup_chunks, num_chunks - 1)) + if warmup_chunks > 0 and ci < warmup_chunks: + warmup_denom = max(warmup_chunks - 1, 1) + warmup_t = ci / warmup_denom + lr_now = ( + h.global_ttt_warmup_start_lr + + (h.global_ttt_lr - h.global_ttt_warmup_start_lr) * warmup_t + ) + else: + decay_steps = max(num_chunks - 1 - warmup_chunks, 1) + decay_ci = max(ci - warmup_chunks, 0) + lr_now = h.global_ttt_lr * 0.5 * ( + 1.0 + math.cos(math.pi * decay_ci / decay_steps) + ) + for pg in optimizer.param_groups: + pg["lr"] = lr_now + my_seq_s = chunk_seqs * h.rank // h.world_size + my_seq_e = chunk_seqs * (h.rank + 1) // h.world_size + my_chunk_seqs = my_seq_e - my_seq_s + for _ in range(h.global_ttt_epochs): + for bs in range(0, my_chunk_seqs, batch_seqs): + be = min(bs + batch_seqs, my_chunk_seqs) + actual_bs = my_seq_s + bs + start_tok = chunk_start + actual_bs * seq_len + end_tok = chunk_start + (my_seq_s + be) * seq_len + 1 + if end_tok > val_tokens.numel(): + continue + local = val_tokens[start_tok:end_tok].to(device=device, dtype=torch.int64) + x_flat = local[:-1] + y_flat = local[1:] + optimizer.zero_grad(set_to_none=True) + with torch.enable_grad(): + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + if h.global_ttt_respect_doc_boundaries: + bos_pos = (x_flat == BOS_ID).nonzero(as_tuple=True)[0].tolist() + cu_seqlens, max_seqlen = _build_cu_seqlens( + bos_pos, x_flat.numel(), x_flat.device, h.eval_seq_len, 64 + ) + loss = base_model( + x_flat[None], + y_flat[None], + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + ) + else: + x = x_flat.reshape(-1, seq_len) + y = y_flat.reshape(-1, seq_len) + loss = base_model(x, y) + loss.backward() + if dist.is_available() and dist.is_initialized(): + for p in ttt_params: + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.SUM) + p.grad.mul_(1.0 / h.world_size) + if h.global_ttt_grad_clip > 0: + torch.nn.utils.clip_grad_norm_(ttt_params, h.global_ttt_grad_clip) + optimizer.step() + base_model.eval() + if h.rank == 0: + elapsed = time.perf_counter() - t_start + log( + f"tttg: c{ci+1}/{num_chunks} lr:{lr_now:.6f} t:{elapsed:.1f}s" + ) + for p in base_model.parameters(): + p.requires_grad_(True) + base_model.eval() + + +def eval_val_ttt_phased(h, base_model, device, val_data, forward_ttt_train): + global BOS_ID + if BOS_ID is None: + BOS_ID = 1 + base_model.eval() + for p in base_model.parameters(): + p.requires_grad_(False) + all_tokens = val_data.val_tokens + all_tokens_idx = all_tokens.to(torch.int32) + docs = _find_docs(all_tokens) + doc_entries = _select_ttt_doc_entries(docs, h) + prefix_doc_limit = max(0, min(len(doc_entries), int(h.phased_ttt_prefix_docs))) + num_phases = max(1, int(h.phased_ttt_num_phases)) + phase_boundaries = [] + for pi in range(num_phases): + boundary = prefix_doc_limit * (pi + 1) // num_phases + phase_boundaries.append(boundary) + current_phase = 0 + current_phase_boundary = phase_boundaries[0] + log( + "ttt_phased:" + f" total_docs:{len(doc_entries)} prefix_docs:{prefix_doc_limit} " + f"suffix_docs:{len(doc_entries) - prefix_doc_limit}" + f" num_phases:{num_phases} boundaries:{phase_boundaries}" + ) + chunk_size, eval_seq_len = h.ttt_chunk_size, h.ttt_eval_seq_len + eval_batch_set = None + if h.ttt_eval_batches: + eval_batch_set = set(int(x) for x in h.ttt_eval_batches.split(",") if x.strip()) + use_ascending = eval_batch_set is not None + global_batches_sorted = _build_ttt_global_batches( + doc_entries, h, ascending=use_ascending + ) + queue_len = len(global_batches_sorted) + counter_path = f"/tmp/ttt_counter_{h.run_id}" + prefix_counter_path = f"/tmp/ttt_prefix_counter_{h.run_id}" + pause_flag_path = f"/tmp/ttt_pause_flag_{h.run_id}" + if h.rank == 0: + _init_batch_counter(counter_path) + _init_int64_counter(prefix_counter_path) + try: + os.remove(pause_flag_path) + except FileNotFoundError: + pass + if dist.is_available() and dist.is_initialized(): + path_list = [counter_path, prefix_counter_path, pause_flag_path] + dist.broadcast_object_list(path_list, src=0) + counter_path, prefix_counter_path, pause_flag_path = path_list + dist.barrier() + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + byte_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + t_start = time.perf_counter() + reusable_lora = BatchedTTTLoRA( + h.ttt_batch_size, base_model, h.ttt_lora_rank, + k_lora=h.ttt_k_lora, mlp_lora=h.ttt_mlp_lora, o_lora=h.ttt_o_lora, + ).to(device) + + def _build_opt(lora): + if h.ttt_optimizer == "sgd": + return torch.optim.SGD( + lora.parameters(), lr=h.ttt_lora_lr, + momentum=h.ttt_beta1, weight_decay=h.ttt_weight_decay, + ) + return torch.optim.AdamW( + lora.parameters(), lr=h.ttt_lora_lr, + betas=(h.ttt_beta1, h.ttt_beta2), + eps=1e-10, weight_decay=h.ttt_weight_decay, fused=True, + ) + + reusable_opt = _build_opt(reusable_lora) + local_scored_docs = [] + global_ttt_done = prefix_doc_limit == 0 + try: + while True: + queue_idx = _claim_next_batch(counter_path, queue_len) + if queue_idx >= queue_len: + break + orig_batch_idx, batch_entries = global_batches_sorted[queue_idx] + batch = [doc for _, doc in batch_entries] + bsz = len(batch) + prev_loss = loss_sum.item() + prev_bytes = byte_sum.item() + prev_tokens = token_count.item() + if bsz == reusable_lora.bsz: + reusable_lora.reset() + for s in reusable_opt.state.values(): + for k, v in s.items(): + if isinstance(v, torch.Tensor): + v.zero_() + elif k == "step": + s[k] = 0 + cur_lora = reusable_lora + cur_opt = reusable_opt + else: + cur_lora = BatchedTTTLoRA( + bsz, base_model, h.ttt_lora_rank, + k_lora=h.ttt_k_lora, mlp_lora=h.ttt_mlp_lora, o_lora=h.ttt_o_lora, + ).to(device) + cur_opt = _build_opt(cur_lora) + pred_lens = [doc_len - 1 for _, doc_len in batch] + num_chunks = [(pl + chunk_size - 1) // chunk_size for pl in pred_lens] + max_nc = max(num_chunks) + num_chunks_t = torch.tensor(num_chunks, dtype=torch.int64, device=device) + for ci in range(max_nc): + active = [ci < nc for nc in num_chunks] + needs_train = any(ci < nc - 1 for nc in num_chunks) + tok_starts = torch.zeros(bsz, dtype=torch.int64) + tok_wls = torch.zeros(bsz, dtype=torch.int64) + chunk_offsets_cpu = torch.zeros(bsz, dtype=torch.int64) + chunk_lens_cpu = torch.zeros(bsz, dtype=torch.int64) + for b in range(bsz): + if not active[b]: + continue + doc_start, doc_len = batch[b] + win_start, win_len, chunk_offset, chunk_len = _compute_chunk_window( + ci, pred_lens[b], num_chunks[b], chunk_size, eval_seq_len + ) + tok_starts[b] = doc_start + win_start + tok_wls[b] = win_len + chunk_offsets_cpu[b] = chunk_offset + chunk_lens_cpu[b] = chunk_len + _, context_size, chunk_offset, _ = _compute_chunk_window( + ci, (ci + 1) * chunk_size, ci + 1, chunk_size, eval_seq_len + ) + col_idx = torch.arange(context_size + 1) + idx = tok_starts.unsqueeze(1) + col_idx.unsqueeze(0) + idx.clamp_(max=all_tokens.numel() - 1) + gathered_gpu = all_tokens_idx[idx].to( + device=device, dtype=torch.int64, non_blocking=True + ) + valid = (col_idx[:context_size].unsqueeze(0) < tok_wls.unsqueeze(1)).to( + device, non_blocking=True + ) + chunk_offsets = chunk_offsets_cpu.to(device, non_blocking=True) + chunk_lens = chunk_lens_cpu.to(device, non_blocking=True) + x = torch.where(valid, gathered_gpu[:, :context_size], 0) + y = torch.where(valid, gathered_gpu[:, 1 : context_size + 1], 0) + ctx_pos = torch.arange(context_size, device=device, dtype=torch.int64) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + per_tok_loss = forward_ttt_train(x, y, lora=cur_lora) + with torch.no_grad(): + _accumulate_bpb( + per_tok_loss, + x, + y, + chunk_offsets, + chunk_lens, + ctx_pos, + val_data.base_bytes_lut, + val_data.has_leading_space_lut, + val_data.is_boundary_token_lut, + loss_sum, + byte_sum, + token_count, + ) + if needs_train: + activate_chunk_mask = (num_chunks_t - 1 > ci).float() + for gi in range(h.ttt_grad_steps): + if gi > 0: + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + per_tok_loss = forward_ttt_train(x, y, lora=cur_lora) + per_doc = per_tok_loss[ + :, chunk_offset : chunk_offset + chunk_size + ].mean(dim=-1) + cur_opt.zero_grad(set_to_none=True) + (per_doc * activate_chunk_mask).sum().backward() + cur_opt.step() + else: + del per_tok_loss + batch_num = orig_batch_idx + 1 + doc_lens = [dl for _, dl in batch] + should_report = batch_num in eval_batch_set if eval_batch_set is not None else True + if should_report: + cur_tokens = token_count.item() + cur_loss_val = loss_sum.item() + cur_bytes_val = byte_sum.item() + dt = cur_tokens - prev_tokens + db = cur_bytes_val - prev_bytes + if dt > 0 and db > 0: + b_loss = (cur_loss_val - prev_loss) / dt + b_bpb = b_loss / math.log(2.0) * (dt / db) + else: + b_loss = b_bpb = 0.0 + r_loss = cur_loss_val / max(cur_tokens, 1) + r_bpb = r_loss / math.log(2.0) * (cur_tokens / max(cur_bytes_val, 1)) + elapsed = time.perf_counter() - t_start + log( + f"ttp: b{batch_num}/{queue_len} bl:{b_loss:.4f} bb:{b_bpb:.4f} " + f"rl:{r_loss:.4f} rb:{r_bpb:.4f} dl:{min(doc_lens)}-{max(doc_lens)} " + f"gd:{int(global_ttt_done)}" + ) + if not global_ttt_done: + local_scored_docs.extend( + (orig_batch_idx, pos, doc_start, doc_len) + for pos, (doc_start, doc_len) in enumerate(batch) + ) + prefix_done = _add_to_counter(prefix_counter_path, len(batch_entries)) + if prefix_done >= current_phase_boundary: + try: + with open(pause_flag_path, "x"): + pass + except FileExistsError: + pass + should_pause = os.path.exists(pause_flag_path) + if should_pause: + if dist.is_available() and dist.is_initialized(): + dist.barrier() + gathered_scored_docs = [None] * h.world_size + if dist.is_available() and dist.is_initialized(): + dist.all_gather_object(gathered_scored_docs, local_scored_docs) + else: + gathered_scored_docs = [local_scored_docs] + scored_docs_for_global = [] + for rank_docs in gathered_scored_docs: + if rank_docs: + scored_docs_for_global.extend(rank_docs) + scored_docs_for_global.sort(key=lambda x: (x[0], x[1])) + scored_docs_for_global = scored_docs_for_global[:current_phase_boundary] + scored_token_chunks = [ + val_data.val_tokens[doc_start : doc_start + doc_len] + for _, _, doc_start, doc_len in scored_docs_for_global + ] + if scored_token_chunks: + global_ttt_tokens = torch.cat(scored_token_chunks) + else: + global_ttt_tokens = val_data.val_tokens[:0] + if h.rank == 0: + prefix_done = 0 + try: + with open(prefix_counter_path, "rb") as f: + prefix_done = int.from_bytes( + f.read(8), "little", signed=True + ) + except FileNotFoundError: + pass + log( + f"ttpp: phase:{current_phase + 1}/{num_phases} pd:{prefix_done} " + f"gd:{len(scored_docs_for_global)} " + f"t:{time.perf_counter() - t_start:.1f}s" + ) + train_val_ttt_global_sgd_distributed( + h, device, val_data, base_model, global_ttt_tokens + ) + for p in base_model.parameters(): + p.requires_grad_(False) + reusable_lora = BatchedTTTLoRA( + h.ttt_batch_size, base_model, h.ttt_lora_rank, + k_lora=h.ttt_k_lora, mlp_lora=h.ttt_mlp_lora, o_lora=h.ttt_o_lora, + ).to(device) + reusable_opt = _build_opt(reusable_lora) + current_phase += 1 + if current_phase >= num_phases: + global_ttt_done = True + else: + current_phase_boundary = phase_boundaries[current_phase] + if h.rank == 0: + try: + os.remove(pause_flag_path) + except FileNotFoundError: + pass + if dist.is_available() and dist.is_initialized(): + dist.barrier() + if h.rank == 0: + log(f"ttpr: phase:{current_phase}/{num_phases} t:{time.perf_counter() - t_start:.1f}s") + del cur_lora, cur_opt + finally: + pass + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + for p in base_model.parameters(): + p.requires_grad_(True) + base_model.train() + return _loss_bpb_from_sums(loss_sum, token_count, byte_sum) + + +def timed_eval(label, fn, *args, **kwargs): + torch.cuda.synchronize() + t0 = time.perf_counter() + val_loss, val_bpb = fn(*args, **kwargs) + torch.cuda.synchronize() + elapsed_ms = 1e3 * (time.perf_counter() - t0) + log( + f"{label} val_loss:{val_loss:.8f} val_bpb:{val_bpb:.8f} eval_time:{elapsed_ms:.0f}ms" + ) + return val_loss, val_bpb + + +def train_model(h, device, val_data): + base_model = GPT(h).to(device).bfloat16() + restore_fp32_params(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + compiled_forward_logits = torch.compile( + base_model.forward_logits, dynamic=False, fullgraph=True + ) + model = compiled_model + log(f"model_params:{sum(p.numel()for p in base_model.parameters())}") + optimizers = Optimizers(h, base_model) + train_loader = DocumentPackingLoader(h, device) + max_wallclock_ms = ( + 1e3 * h.max_wallclock_seconds if h.max_wallclock_seconds > 0 else None + ) + if max_wallclock_ms is not None: + max_wallclock_ms -= h.gptq_reserve_seconds * 1e3 + log( + f"gptq:reserving {h.gptq_reserve_seconds:.0f}s, effective={max_wallclock_ms:.0f}ms" + ) + + def training_frac(step, elapsed_ms): + if max_wallclock_ms is None: + return step / max(h.iterations, 1) + return elapsed_ms / max(max_wallclock_ms, 1e-09) + + def lr_mul(frac): + if h.warmdown_frac <= 0: + return 1.0 + if frac >= 1.0 - h.warmdown_frac: + return max((1.0 - frac) / h.warmdown_frac, h.min_lr) + return 1.0 + + def step_fn(step, lr_scale): + optimizers.zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(h.grad_accum_steps): + x, y, cu_seqlens, _max_seqlen = train_loader.next_batch( + h.train_batch_tokens, h.grad_accum_steps + ) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y, cu_seqlens=cu_seqlens, max_seqlen=h.train_seq_len) + train_loss += loss.detach() + (loss / h.grad_accum_steps).backward() + train_loss /= h.grad_accum_steps + frac = ( + min(step / h.muon_momentum_warmup_steps, 1.0) + if h.muon_momentum_warmup_steps > 0 + else 1.0 + ) + muon_momentum = ( + 1 - frac + ) * h.muon_momentum_warmup_start + frac * h.muon_momentum + for group in optimizers.optimizer_muon.param_groups: + group["momentum"] = muon_momentum + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * lr_scale + if h.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), h.grad_clip_norm) + optimizers.step(distributed=h.distributed) + return train_loss + + if h.warmup_steps > 0: + initial_model_state = { + name: tensor.detach().cpu().clone() + for (name, tensor) in base_model.state_dict().items() + } + initial_optimizer_states = [ + copy.deepcopy(opt.state_dict()) for opt in optimizers + ] + model.train() + num_tokens_local = h.train_batch_tokens // h.world_size + for blk in base_model.blocks: + blk.attn.rotary(num_tokens_local, device, torch.bfloat16) + cu_bucket_size = train_loader.cu_bucket_size + warmup_cu_buckets = tuple(cu_bucket_size * i for i in range(1, 5)) + warmup_cu_iters = 3 + x, y, cu_seqlens, _ = train_loader.next_batch( + h.train_batch_tokens, h.grad_accum_steps + ) + log(f"warmup_cu_buckets:{','.join(str(b) for b in warmup_cu_buckets)} iters_each:{warmup_cu_iters}") + def _run_cu_bucket_warmup(): + for bucket_len in warmup_cu_buckets: + boundaries = list(range(0, x.size(1), max(h.train_seq_len, 1))) + if boundaries[-1] != x.size(1): + boundaries.append(x.size(1)) + cu = torch.full((bucket_len,), x.size(1), dtype=torch.int32, device=device) + cu[: len(boundaries)] = torch.tensor(boundaries, dtype=torch.int32, device=device) + for _ in range(warmup_cu_iters): + optimizers.zero_grad_all() + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + wloss = model(x, y, cu_seqlens=cu, max_seqlen=h.train_seq_len) + (wloss / h.grad_accum_steps).backward() + optimizers.zero_grad_all() + _run_cu_bucket_warmup() + if h.num_loops > 0: + base_model.looping_active = True + _run_cu_bucket_warmup() + base_model.looping_active = False + for warmup_step in range(h.warmup_steps): + step_fn(warmup_step, 1.0) + if ( + warmup_step <= 5 + or (warmup_step + 1) % 10 == 0 + or warmup_step + 1 == h.warmup_steps + ): + log(f"warmup_step: {warmup_step+1}/{h.warmup_steps}") + if h.num_loops > 0: + base_model.looping_active = True + log( + f"loop_warmup:enabled encoder:{base_model.encoder_indices} decoder:{base_model.decoder_indices}" + ) + for warmup_step in range(h.warmup_steps): + step_fn(warmup_step, 1.0) + if ( + warmup_step <= 5 + or (warmup_step + 1) % 10 == 0 + or warmup_step + 1 == h.warmup_steps + ): + log(f"loop_warmup_step: {warmup_step+1}/{h.warmup_steps}") + base_model.looping_active = False + base_model.load_state_dict(initial_model_state, strict=True) + for (opt, state) in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + optimizers.zero_grad_all() + train_loader = DocumentPackingLoader(h, device) + ema_state = { + name: t.detach().float().clone() + for (name, t) in base_model.state_dict().items() + } + ema_decay = h.ema_decay + training_time_ms = 0.0 + stop_after_step = None + torch.cuda.synchronize() + t0 = time.perf_counter() + step = 0 + while True: + last_step = ( + step == h.iterations + or stop_after_step is not None + and step >= stop_after_step + ) + should_validate = ( + last_step or h.val_loss_every > 0 and step % h.val_loss_every == 0 + ) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1e3 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + h, device, val_data, model, compiled_forward_logits + ) + log( + f"{step}/{h.iterations} val_loss: {val_loss:.4f} val_bpb: {val_bpb:.4f}" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + if last_step: + if stop_after_step is not None and step < h.iterations: + log( + f"stopping_early: wallclock_cap train_time: {training_time_ms:.0f}ms step: {step}/{h.iterations}" + ) + break + elapsed_ms = training_time_ms + 1e3 * (time.perf_counter() - t0) + frac = training_frac(step, elapsed_ms) + scale = lr_mul(frac) + if ( + h.num_loops > 0 + and not base_model.looping_active + and frac >= h.enable_looping_at + ): + base_model.looping_active = True + log( + f"layer_loop:enabled step:{step} frac:{frac:.3f} encoder:{base_model.encoder_indices} decoder:{base_model.decoder_indices}" + ) + train_loss = step_fn(step, scale) + with torch.no_grad(): + for (name, t) in base_model.state_dict().items(): + ema_state[name].mul_(ema_decay).add_( + t.detach().float(), alpha=1.0 - ema_decay + ) + step += 1 + approx_training_time_ms = training_time_ms + 1e3 * (time.perf_counter() - t0) + should_log_train = h.train_log_every > 0 and ( + step <= 5 or step % h.train_log_every == 0 or stop_after_step is not None + ) + if should_log_train: + tok_per_sec = step * h.train_batch_tokens / (approx_training_time_ms / 1e3) + log( + f"{step}/{h.iterations} train_loss: {train_loss.item():.4f} train_time: {approx_training_time_ms/60000:.1f}m tok/s: {tok_per_sec:.0f}" + ) + reached_cap = ( + max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + ) + if h.distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + log( + f"peak memory allocated: {torch.cuda.max_memory_allocated()//1024//1024} MiB reserved: {torch.cuda.max_memory_reserved()//1024//1024} MiB" + ) + log("ema:applying EMA weights") + current_state = base_model.state_dict() + avg_state = { + name: t.to(dtype=current_state[name].dtype) for (name, t) in ema_state.items() + } + base_model.load_state_dict(avg_state, strict=True) + return base_model, compiled_model, compiled_forward_logits + + +def train_and_eval(h, device): + random.seed(h.seed) + np.random.seed(h.seed) + torch.manual_seed(h.seed) + torch.cuda.manual_seed_all(h.seed) + if h.artifact_dir and h.is_main_process: + os.makedirs(h.artifact_dir, exist_ok=True) + val_data = ValidationData(h, device) + log( + f"train_shards: {len(list(Path(h.datasets_dir).resolve().glob('fineweb_train_*.bin')))}" + ) + log(f"val_tokens: {val_data.val_tokens.numel()-1}") + base_model, compiled_model, compiled_forward_logits = train_model( + h, device, val_data + ) + torch._dynamo.reset() + timed_eval( + "diagnostic pre-quantization post-ema", + eval_val, + h, + device, + val_data, + compiled_model, + compiled_forward_logits, + ) + serialize(h, base_model, Path(__file__).read_text(encoding="utf-8")) + if h.distributed: + dist.barrier() + eval_model = deserialize(h, device) + if h.num_loops > 0: + eval_model.looping_active = True + compiled_model = torch.compile(eval_model, dynamic=False, fullgraph=True) + compiled_forward_logits = torch.compile( + eval_model.forward_logits, dynamic=False, fullgraph=True + ) + timed_eval( + "diagnostic quantized", + eval_val, + h, + device, + val_data, + compiled_model, + compiled_forward_logits, + ) + if h.sliding_window_enabled: + timed_eval( + "diagnostic quantized_sliding_window", + eval_val_sliding, + h, + device, + val_data, + eval_model, + forward_logits_fn=compiled_forward_logits, + ) + if h.ttt_enabled: + del eval_model, compiled_model + torch._dynamo.reset() + torch.cuda.empty_cache() + ttt_model = deserialize(h, device) + if h.num_loops > 0: + ttt_model.looping_active = True + for p in ttt_model.parameters(): + p.requires_grad_(False) + + if h.rope_yarn: + _yarn_seqlen = h.train_batch_tokens // h.grad_accum_steps + for block in ttt_model.blocks: + block.attn.rotary(_yarn_seqlen, device, torch.bfloat16) + else: + for block in ttt_model.blocks: + block.attn.rotary._cos_cached = None + block.attn.rotary._sin_cached = None + block.attn.rotary._seq_len_cached = 0 + block.attn.rotary(h.ttt_eval_seq_len, device, torch.bfloat16) + + def _fwd_ttt_inner(input_ids, target_ids, lora): + return ttt_model.forward_ttt(input_ids, target_ids, lora=lora) + + _fwd_ttt_compiled_inner = None + + def _fwd_ttt(input_ids, target_ids, lora): + nonlocal _fwd_ttt_compiled_inner + if _fwd_ttt_compiled_inner is None: + _fwd_ttt_compiled_inner = torch.compile(_fwd_ttt_inner, dynamic=True) + return _fwd_ttt_compiled_inner(input_ids, target_ids, lora=lora) + + fwd_ttt_compiled = _fwd_ttt + log(f"ttt_lora:warming up compile (random tokens, no val data)") + global BOS_ID + if BOS_ID is None: + BOS_ID = 1 + t_warmup = time.perf_counter() + warmup_bszes = [h.ttt_batch_size] + for bsz in warmup_bszes: + wl = BatchedTTTLoRA( + bsz, ttt_model, h.ttt_lora_rank, + k_lora=h.ttt_k_lora, mlp_lora=h.ttt_mlp_lora, o_lora=h.ttt_o_lora, + ).to(device) + wo = torch.optim.AdamW( + wl.parameters(), + lr=h.ttt_lora_lr, + betas=(h.ttt_beta1, h.ttt_beta2), + eps=1e-10, + weight_decay=h.ttt_weight_decay, + fused=True, + ) + for ctx_len in (h.ttt_chunk_size, h.ttt_eval_seq_len): + xw = torch.randint(0, h.vocab_size, (bsz, ctx_len), device=device, dtype=torch.int64) + yw = torch.randint(0, h.vocab_size, (bsz, ctx_len), device=device, dtype=torch.int64) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + ptl = fwd_ttt_compiled(xw, yw, lora=wl) + ptl[:, : min(h.ttt_chunk_size, ctx_len)].mean(dim=-1).sum().backward() + wo.step() + wo.zero_grad(set_to_none=True) + del wl, wo + torch.cuda.empty_cache() + compile_elapsed = time.perf_counter() - t_warmup + log(f"ttt_lora:compile warmup done ({compile_elapsed:.1f}s)") + log("\nbeginning TTT eval timer") + torch.cuda.synchronize() + t_ttt = time.perf_counter() + ttt_val_loss, ttt_val_bpb = eval_val_ttt_phased( + h, ttt_model, device, val_data, forward_ttt_train=fwd_ttt_compiled + ) + torch.cuda.synchronize() + ttt_eval_elapsed = time.perf_counter() - t_ttt + log( + "quantized_ttt_phased " + f"val_loss:{ttt_val_loss:.8f} val_bpb:{ttt_val_bpb:.8f} " + f"eval_time:{1e3*ttt_eval_elapsed:.0f}ms" + ) + log(f"total_eval_time:{ttt_eval_elapsed:.1f}s") + del ttt_model + + +def main(): + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError( + f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral" + ) + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + torch.set_float32_matmul_precision("high") + from torch.backends.cuda import ( + enable_cudnn_sdp, + enable_flash_sdp, + enable_math_sdp, + enable_mem_efficient_sdp, + ) + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + torch._dynamo.config.optimize_ddp = False + torch._dynamo.config.cache_size_limit = 16 + h = Hyperparameters() + set_logging_hparams(h) + if h.is_main_process: + os.makedirs(h.artifact_dir if h.artifact_dir else "logs", exist_ok=True) + log(100 * "=", console=False) + log("Hyperparameters:", console=True) + for (k, v) in sorted(vars(type(h)).items()): + if not k.startswith("_"): + log(f" {k}: {v}", console=True) + log("=" * 100, console=False) + log("Source code:", console=False) + log("=" * 100, console=False) + with open(__file__, "r", encoding="utf-8") as _src: + log(_src.read(), console=False) + log("=" * 100, console=False) + log(f"Running Python {sys.version}", console=False) + log(f"Running PyTorch {torch.__version__}", console=False) + log("=" * 100, console=False) + train_and_eval(h, device) + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() diff --git a/records/track_10min_16mb/2026-04-22_AlphaLoRA144_WarmStart_WD1_1.07209/train_seed1337.log b/records/track_10min_16mb/2026-04-22_AlphaLoRA144_WarmStart_WD1_1.07209/train_seed1337.log new file mode 100644 index 0000000000..dc9d8c491f --- /dev/null +++ b/records/track_10min_16mb/2026-04-22_AlphaLoRA144_WarmStart_WD1_1.07209/train_seed1337.log @@ -0,0 +1,761 @@ +[Wed Apr 22 00:10:57 UTC 2026] Starting SP8192 dexhunter experiment: pure_ttt_sp8192_dxa144_seed1337_8gpu +Extra args: SEED=1337 PHASED_TTT_ENABLED=1 PHASED_TTT_PREFIX_DOCS=2000 PHASED_TTT_NUM_PHASES=3 MLP_CLIP_SIGMAS=12.0 ATTN_CLIP_SIGMAS=13.0 EMBED_BITS=7 EMBED_CLIP_SIGMAS=15.0 MATRIX_LR=0.026 GPTQ_RESERVE_SECONDS=4 GPTQ_CALIBRATION_BATCHES=16 TTT_LORA_RANK=128 TTT_LORA_ALPHA=144 TTT_WARM_START_A=1 TTT_WEIGHT_DECAY=1.0 +torch: 2.9.1+cu128 +FA3 interface: OK +[Wed Apr 22 00:11:02 UTC 2026] GPU check: +0, 1 MiB, 0 % +1, 1 MiB, 0 % +2, 1 MiB, 0 % +3, 1 MiB, 0 % +4, 1 MiB, 0 % +5, 1 MiB, 0 % +6, 1 MiB, 0 % +7, 1 MiB, 0 % +W0422 00:11:03.905000 112 torch/distributed/run.py:803] +W0422 00:11:03.905000 112 torch/distributed/run.py:803] ***************************************** +W0422 00:11:03.905000 112 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0422 00:11:03.905000 112 torch/distributed/run.py:803] ***************************************** +Hyperparameters: + adam_eps: 1e-08 + adam_wd: 0.02 + artifact_dir: + attn_clip_sigmas: 13.0 + beta1: 0.9 + beta2: 0.95 + compressor: brotli + data_dir: DATA_DIR + datasets_dir: DATA_DIR/datasets/fineweb10B_sp8192 + distributed: True + ema_decay: 0.9965 + embed_bits: 7 + embed_clip_sigmas: 15.0 + embed_lr: 0.6 + embed_wd: 0.085 + enable_looping_at: 0.35 + eval_seq_len: 2048 + eval_stride: 64 + global_ttt_batch_seqs: 32 + global_ttt_chunk_tokens: 32768 + global_ttt_epochs: 1 + global_ttt_grad_clip: 1.0 + global_ttt_lr: 0.001 + global_ttt_momentum: 0.9 + global_ttt_respect_doc_boundaries: True + global_ttt_warmup_chunks: 0 + global_ttt_warmup_start_lr: 0.0 + gptq_calibration_batches: 16 + gptq_reserve_seconds: 4.0 + grad_accum_steps: 1 + grad_clip_norm: 0.3 + is_main_process: True + iterations: 20000 + ln_scale: True + local_rank: 0 + logfile: logs/91aa9b72-d152-4699-b179-4defd726d0e8.txt + logit_softcap: 30.0 + loop_end: 5 + loop_start: 3 + matrix_bits: 6 + matrix_clip_sigmas: 12.85 + matrix_lr: 0.026 + max_wallclock_seconds: 600.0 + min_lr: 0.0 + mlp_clip_sigmas: 12.0 + mlp_mult: 4.0 + model_dim: 512 + model_path: final_model.pt + muon_backend_steps: 5 + muon_momentum: 0.97 + muon_momentum_warmup_start: 0.92 + muon_momentum_warmup_steps: 1500 + muon_row_normalize: True + muon_wd: 0.095 + num_heads: 8 + num_kv_heads: 4 + num_layers: 11 + num_loops: 2 + parallel_final_lane: mean + parallel_start_layer: 8 + phased_ttt_enabled: True + phased_ttt_num_phases: 3 + phased_ttt_prefix_docs: 2000 + qk_gain_init: 5.0 + quantized_model_path: final_model.int6.ptz + rank: 0 + rope_base: 10000.0 + rope_dims: 16 + rope_train_seq_len: 2048 + rope_yarn: False + run_id: 91aa9b72-d152-4699-b179-4defd726d0e8 + scalar_lr: 0.02 + seed: 1337 + skip_gates_enabled: True + sliding_window_enabled: False + tie_embeddings: True + tied_embed_init_std: 0.005 + tied_embed_lr: 0.03 + tokenizer_path: DATA_DIR/tokenizers/fineweb_8192_bpe.model + train_batch_tokens: 786432 + train_files: DATA_DIR/datasets/fineweb10B_sp8192/fineweb_train_*.bin + train_log_every: 500 + train_seq_len: 2048 + ttt_batch_size: 64 + ttt_beta1: 0.0 + ttt_beta2: 0.999 + ttt_chunk_size: 48 + ttt_enabled: True + ttt_eval_batches: + ttt_eval_seq_len: 2048 + ttt_grad_steps: 1 + ttt_k_lora: True + ttt_lora_lr: 0.0001 + ttt_lora_rank: 128 + ttt_mlp_lora: True + ttt_o_lora: True + ttt_optimizer: adam + ttt_weight_decay: 1.0 + val_batch_tokens: 524288 + val_doc_fraction: 1.0 + val_files: DATA_DIR/datasets/fineweb10B_sp8192/fineweb_val_*.bin + val_loss_every: 4000 + vocab_size: 8192 + warmdown_frac: 0.75 + warmup_steps: 20 + world_size: 8 + xsa_last_n: 11 +train_shards: 80 +val_tokens: 40540160 +model_params:35944602 +gptq:reserving 4s, effective=596000ms +warmup_cu_buckets:64,128,192,256 iters_each:3 +warmup_step: 1/20 +warmup_step: 2/20 +warmup_step: 3/20 +warmup_step: 4/20 +warmup_step: 5/20 +warmup_step: 6/20 +warmup_step: 10/20 +warmup_step: 20/20 +loop_warmup:enabled encoder:[0, 1, 2, 3, 4, 5, 3, 4] decoder:[5, 3, 4, 5, 6, 7, 8, 9, 10] +loop_warmup_step: 1/20 +loop_warmup_step: 2/20 +loop_warmup_step: 3/20 +loop_warmup_step: 4/20 +loop_warmup_step: 5/20 +loop_warmup_step: 6/20 +loop_warmup_step: 10/20 +loop_warmup_step: 20/20 +0/20000 val_loss: 9.0095 val_bpb: 3.4877 +1/20000 train_loss: 9.0094 train_time: 0.0m tok/s: 16581994 +2/20000 train_loss: 12.2129 train_time: 0.0m tok/s: 12963897 +3/20000 train_loss: 11.2327 train_time: 0.0m tok/s: 11017492 +4/20000 train_loss: 9.6042 train_time: 0.0m tok/s: 10157279 +5/20000 train_loss: 8.2122 train_time: 0.0m tok/s: 9767453 +500/20000 train_loss: 3.2681 train_time: 0.8m tok/s: 8286062 +1000/20000 train_loss: 3.0278 train_time: 1.6m tok/s: 8248597 +1500/20000 train_loss: 3.0374 train_time: 2.4m tok/s: 8250022 +2000/20000 train_loss: 2.9874 train_time: 3.2m tok/s: 8249147 +layer_loop:enabled step:2187 frac:0.350 encoder:[0, 1, 2, 3, 4, 5, 3, 4] decoder:[5, 3, 4, 5, 6, 7, 8, 9, 10] +2500/20000 train_loss: 3.0798 train_time: 4.2m tok/s: 7800327 +3000/20000 train_loss: 2.9196 train_time: 5.4m tok/s: 7333523 +3500/20000 train_loss: 2.9842 train_time: 6.5m tok/s: 7036454 +4000/20000 train_loss: 2.9171 train_time: 7.7m tok/s: 6829773 +4000/20000 val_loss: 2.8938 val_bpb: 1.1203 +4500/20000 train_loss: 2.8754 train_time: 8.8m tok/s: 6676218 +4952/20000 val_loss: 2.7748 val_bpb: 1.0742 +stopping_early: wallclock_cap train_time: 596151ms step: 4952/20000 +peak memory allocated: 40029 MiB reserved: 44036 MiB +ema:applying EMA weights +diagnostic pre-quantization post-ema val_loss:2.77352391 val_bpb:1.07368293 eval_time:6579ms +Serialized model: 135409136 bytes +Code size (uncompressed): 122656 bytes +Code size (compressed): 27680 bytes +GPTQ:collecting Hessians from calibration data... +GPTQ:collected 67 Hessians in 3.6s +Quantized weights: + gptq (int6): blocks.attn.c_k.weight, blocks.attn.c_q.weight, blocks.attn.c_v.weight, blocks.attn.proj.weight, blocks.mlp.fc.weight, blocks.mlp.proj.weight + gptq (int7): tok_emb.weight + passthrough (float16): blocks.attn.q_gain, blocks.attn_scale, blocks.mlp_scale, blocks.resid_mix, parallel_post_lambdas, parallel_resid_lambdas, skip_gates, skip_weights +Serialized model quantized+brotli: 15907421 bytes +Total submission size quantized+brotli: 15935101 bytes +diagnostic quantized val_loss:2.80324509 val_bpb:1.08518856 eval_time:59228ms +ttt_lora:warming up compile (random tokens, no val data) +ttt_lora:compile warmup done (151.0s) + +beginning TTT eval timer +ttt_phased: total_docs:50000 prefix_docs:2000 suffix_docs:48000 num_phases:3 boundaries:[666, 1333, 2000] +ttp: b776/782 bl:2.7189 bb:1.0879 rl:2.7189 rb:1.0879 dl:6364-7180 gd:0 +ttp: b773/782 bl:2.6517 bb:1.0760 rl:2.6892 rb:1.0826 dl:5203-5550 gd:0 +ttp: b768/782 bl:2.7028 bb:1.0846 rl:2.6927 rb:1.0832 dl:4128-4306 gd:0 +ttpp: phase:1/3 pd:1104 gd:666 t:199.2s +tttg: c1/95 lr:0.001000 t:1.7s +tttg: c2/95 lr:0.001000 t:1.7s +tttg: c3/95 lr:0.000999 t:1.8s +tttg: c4/95 lr:0.000997 t:1.9s +tttg: c5/95 lr:0.000996 t:2.0s +tttg: c6/95 lr:0.000993 t:2.0s +tttg: c7/95 lr:0.000990 t:2.1s +tttg: c8/95 lr:0.000986 t:2.2s +tttg: c9/95 lr:0.000982 t:2.2s +tttg: c10/95 lr:0.000978 t:2.3s +tttg: c11/95 lr:0.000972 t:2.4s +tttg: c12/95 lr:0.000967 t:2.4s +tttg: c13/95 lr:0.000960 t:2.5s +tttg: c14/95 lr:0.000954 t:2.6s +tttg: c15/95 lr:0.000946 t:2.7s +tttg: c16/95 lr:0.000938 t:2.7s +tttg: c17/95 lr:0.000930 t:2.8s +tttg: c18/95 lr:0.000921 t:2.9s +tttg: c19/95 lr:0.000912 t:2.9s +tttg: c20/95 lr:0.000903 t:3.0s +tttg: c21/95 lr:0.000892 t:3.1s +tttg: c22/95 lr:0.000882 t:3.1s +tttg: c23/95 lr:0.000871 t:3.2s +tttg: c24/95 lr:0.000859 t:3.3s +tttg: c25/95 lr:0.000848 t:3.4s +tttg: c26/95 lr:0.000835 t:3.4s +tttg: c27/95 lr:0.000823 t:3.5s +tttg: c28/95 lr:0.000810 t:3.6s +tttg: c29/95 lr:0.000797 t:3.6s +tttg: c30/95 lr:0.000783 t:3.7s +tttg: c31/95 lr:0.000769 t:3.8s +tttg: c32/95 lr:0.000755 t:3.8s +tttg: c33/95 lr:0.000740 t:3.9s +tttg: c34/95 lr:0.000726 t:4.0s +tttg: c35/95 lr:0.000710 t:4.0s +tttg: c36/95 lr:0.000695 t:4.1s +tttg: c37/95 lr:0.000680 t:4.2s +tttg: c38/95 lr:0.000664 t:4.3s +tttg: c39/95 lr:0.000648 t:4.3s +tttg: c40/95 lr:0.000632 t:4.4s +tttg: c41/95 lr:0.000616 t:4.5s +tttg: c42/95 lr:0.000600 t:4.5s +tttg: c43/95 lr:0.000583 t:4.6s +tttg: c44/95 lr:0.000567 t:4.7s +tttg: c45/95 lr:0.000550 t:4.8s +tttg: c46/95 lr:0.000533 t:4.8s +tttg: c47/95 lr:0.000517 t:4.9s +tttg: c48/95 lr:0.000500 t:5.0s +tttg: c49/95 lr:0.000483 t:5.0s +tttg: c50/95 lr:0.000467 t:5.1s +tttg: c51/95 lr:0.000450 t:5.2s +tttg: c52/95 lr:0.000433 t:5.2s +tttg: c53/95 lr:0.000417 t:5.3s +tttg: c54/95 lr:0.000400 t:5.4s +tttg: c55/95 lr:0.000384 t:5.4s +tttg: c56/95 lr:0.000368 t:5.5s +tttg: c57/95 lr:0.000352 t:5.6s +tttg: c58/95 lr:0.000336 t:5.6s +tttg: c59/95 lr:0.000320 t:5.7s +tttg: c60/95 lr:0.000305 t:5.8s +tttg: c61/95 lr:0.000290 t:5.9s +tttg: c62/95 lr:0.000274 t:5.9s +tttg: c63/95 lr:0.000260 t:6.0s +tttg: c64/95 lr:0.000245 t:6.1s +tttg: c65/95 lr:0.000231 t:6.1s +tttg: c66/95 lr:0.000217 t:6.2s +tttg: c67/95 lr:0.000203 t:6.3s +tttg: c68/95 lr:0.000190 t:6.4s +tttg: c69/95 lr:0.000177 t:6.4s +tttg: c70/95 lr:0.000165 t:6.5s +tttg: c71/95 lr:0.000152 t:6.6s +tttg: c72/95 lr:0.000141 t:6.6s +tttg: c73/95 lr:0.000129 t:6.7s +tttg: c74/95 lr:0.000118 t:6.8s +tttg: c75/95 lr:0.000108 t:6.8s +tttg: c76/95 lr:0.000097 t:6.9s +tttg: c77/95 lr:0.000088 t:7.0s +tttg: c78/95 lr:0.000079 t:7.0s +tttg: c79/95 lr:0.000070 t:7.1s +tttg: c80/95 lr:0.000062 t:7.2s +tttg: c81/95 lr:0.000054 t:7.2s +tttg: c82/95 lr:0.000046 t:7.3s +tttg: c83/95 lr:0.000040 t:7.4s +tttg: c84/95 lr:0.000033 t:7.4s +tttg: c85/95 lr:0.000028 t:7.5s +tttg: c86/95 lr:0.000022 t:7.6s +tttg: c87/95 lr:0.000018 t:7.7s +tttg: c88/95 lr:0.000014 t:7.7s +tttg: c89/95 lr:0.000010 t:7.8s +tttg: c90/95 lr:0.000007 t:7.9s +tttg: c91/95 lr:0.000004 t:8.0s +tttg: c92/95 lr:0.000003 t:8.0s +tttg: c93/95 lr:0.000001 t:8.1s +tttg: c94/95 lr:0.000000 t:8.2s +ttpr: phase:1/3 t:209.8s +ttp: b757/782 bl:2.6379 bb:1.0195 rl:2.6840 rb:1.0727 dl:3033-3108 gd:0 +ttpp: phase:2/3 pd:1808 gd:1333 t:320.1s +tttg: c1/158 lr:0.001000 t:0.1s +tttg: c2/158 lr:0.001000 t:0.1s +tttg: c3/158 lr:0.001000 t:0.2s +tttg: c4/158 lr:0.000999 t:0.3s +tttg: c5/158 lr:0.000998 t:0.4s +tttg: c6/158 lr:0.000997 t:0.4s +tttg: c7/158 lr:0.000996 t:0.5s +tttg: c8/158 lr:0.000995 t:0.6s +tttg: c9/158 lr:0.000994 t:0.6s +tttg: c10/158 lr:0.000992 t:0.7s +tttg: c11/158 lr:0.000990 t:0.8s +tttg: c12/158 lr:0.000988 t:0.9s +tttg: c13/158 lr:0.000986 t:0.9s +tttg: c14/158 lr:0.000983 t:1.0s +tttg: c15/158 lr:0.000981 t:1.1s +tttg: c16/158 lr:0.000978 t:1.1s +tttg: c17/158 lr:0.000975 t:1.2s +tttg: c18/158 lr:0.000971 t:1.3s +tttg: c19/158 lr:0.000968 t:1.3s +tttg: c20/158 lr:0.000964 t:1.4s +tttg: c21/158 lr:0.000960 t:1.5s +tttg: c22/158 lr:0.000957 t:1.6s +tttg: c23/158 lr:0.000952 t:1.6s +tttg: c24/158 lr:0.000948 t:1.7s +tttg: c25/158 lr:0.000943 t:1.8s +tttg: c26/158 lr:0.000939 t:1.8s +tttg: c27/158 lr:0.000934 t:1.9s +tttg: c28/158 lr:0.000929 t:2.0s +tttg: c29/158 lr:0.000924 t:2.1s +tttg: c30/158 lr:0.000918 t:2.2s +tttg: c31/158 lr:0.000913 t:2.2s +tttg: c32/158 lr:0.000907 t:2.3s +tttg: c33/158 lr:0.000901 t:2.4s +tttg: c34/158 lr:0.000895 t:2.4s +tttg: c35/158 lr:0.000889 t:2.5s +tttg: c36/158 lr:0.000882 t:2.6s +tttg: c37/158 lr:0.000876 t:2.6s +tttg: c38/158 lr:0.000869 t:2.7s +tttg: c39/158 lr:0.000862 t:2.8s +tttg: c40/158 lr:0.000855 t:2.9s +tttg: c41/158 lr:0.000848 t:2.9s +tttg: c42/158 lr:0.000841 t:3.0s +tttg: c43/158 lr:0.000834 t:3.1s +tttg: c44/158 lr:0.000826 t:3.1s +tttg: c45/158 lr:0.000818 t:3.2s +tttg: c46/158 lr:0.000811 t:3.3s +tttg: c47/158 lr:0.000803 t:3.3s +tttg: c48/158 lr:0.000795 t:3.4s +tttg: c49/158 lr:0.000787 t:3.5s +tttg: c50/158 lr:0.000778 t:3.6s +tttg: c51/158 lr:0.000770 t:3.6s +tttg: c52/158 lr:0.000761 t:3.7s +tttg: c53/158 lr:0.000753 t:3.8s +tttg: c54/158 lr:0.000744 t:3.9s +tttg: c55/158 lr:0.000735 t:3.9s +tttg: c56/158 lr:0.000727 t:4.0s +tttg: c57/158 lr:0.000718 t:4.1s +tttg: c58/158 lr:0.000709 t:4.1s +tttg: c59/158 lr:0.000699 t:4.2s +tttg: c60/158 lr:0.000690 t:4.3s +tttg: c61/158 lr:0.000681 t:4.3s +tttg: c62/158 lr:0.000672 t:4.4s +tttg: c63/158 lr:0.000662 t:4.5s +tttg: c64/158 lr:0.000653 t:4.6s +tttg: c65/158 lr:0.000643 t:4.6s +tttg: c66/158 lr:0.000633 t:4.7s +tttg: c67/158 lr:0.000624 t:4.8s +tttg: c68/158 lr:0.000614 t:4.8s +tttg: c69/158 lr:0.000604 t:4.9s +tttg: c70/158 lr:0.000594 t:5.0s +tttg: c71/158 lr:0.000585 t:5.0s +tttg: c72/158 lr:0.000575 t:5.1s +tttg: c73/158 lr:0.000565 t:5.2s +tttg: c74/158 lr:0.000555 t:5.2s +tttg: c75/158 lr:0.000545 t:5.3s +tttg: c76/158 lr:0.000535 t:5.4s +tttg: c77/158 lr:0.000525 t:5.5s +tttg: c78/158 lr:0.000515 t:5.5s +tttg: c79/158 lr:0.000505 t:5.6s +tttg: c80/158 lr:0.000495 t:5.7s +tttg: c81/158 lr:0.000485 t:5.8s +tttg: c82/158 lr:0.000475 t:5.8s +tttg: c83/158 lr:0.000465 t:5.9s +tttg: c84/158 lr:0.000455 t:6.0s +tttg: c85/158 lr:0.000445 t:6.0s +tttg: c86/158 lr:0.000435 t:6.1s +tttg: c87/158 lr:0.000425 t:6.2s +tttg: c88/158 lr:0.000415 t:6.3s +tttg: c89/158 lr:0.000406 t:6.3s +tttg: c90/158 lr:0.000396 t:6.4s +tttg: c91/158 lr:0.000386 t:6.5s +tttg: c92/158 lr:0.000376 t:6.5s +tttg: c93/158 lr:0.000367 t:6.6s +tttg: c94/158 lr:0.000357 t:6.7s +tttg: c95/158 lr:0.000347 t:6.7s +tttg: c96/158 lr:0.000338 t:6.8s +tttg: c97/158 lr:0.000328 t:6.9s +tttg: c98/158 lr:0.000319 t:6.9s +tttg: c99/158 lr:0.000310 t:7.0s +tttg: c100/158 lr:0.000301 t:7.1s +tttg: c101/158 lr:0.000291 t:7.2s +tttg: c102/158 lr:0.000282 t:7.2s +tttg: c103/158 lr:0.000273 t:7.3s +tttg: c104/158 lr:0.000265 t:7.4s +tttg: c105/158 lr:0.000256 t:7.4s +tttg: c106/158 lr:0.000247 t:7.5s +tttg: c107/158 lr:0.000239 t:7.6s +tttg: c108/158 lr:0.000230 t:7.7s +tttg: c109/158 lr:0.000222 t:7.7s +tttg: c110/158 lr:0.000213 t:7.8s +tttg: c111/158 lr:0.000205 t:7.9s +tttg: c112/158 lr:0.000197 t:7.9s +tttg: c113/158 lr:0.000189 t:8.0s +tttg: c114/158 lr:0.000182 t:8.1s +tttg: c115/158 lr:0.000174 t:8.1s +tttg: c116/158 lr:0.000166 t:8.2s +tttg: c117/158 lr:0.000159 t:8.3s +tttg: c118/158 lr:0.000152 t:8.4s +tttg: c119/158 lr:0.000145 t:8.4s +tttg: c120/158 lr:0.000138 t:8.5s +tttg: c121/158 lr:0.000131 t:8.6s +tttg: c122/158 lr:0.000124 t:8.6s +tttg: c123/158 lr:0.000118 t:8.7s +tttg: c124/158 lr:0.000111 t:8.8s +tttg: c125/158 lr:0.000105 t:8.8s +tttg: c126/158 lr:0.000099 t:8.9s +tttg: c127/158 lr:0.000093 t:9.0s +tttg: c128/158 lr:0.000087 t:9.1s +tttg: c129/158 lr:0.000082 t:9.1s +tttg: c130/158 lr:0.000076 t:9.2s +tttg: c131/158 lr:0.000071 t:9.3s +tttg: c132/158 lr:0.000066 t:9.3s +tttg: c133/158 lr:0.000061 t:9.4s +tttg: c134/158 lr:0.000057 t:9.5s +tttg: c135/158 lr:0.000052 t:9.6s +tttg: c136/158 lr:0.000048 t:9.6s +tttg: c137/158 lr:0.000043 t:9.7s +tttg: c138/158 lr:0.000040 t:9.8s +tttg: c139/158 lr:0.000036 t:9.9s +tttg: c140/158 lr:0.000032 t:9.9s +tttg: c141/158 lr:0.000029 t:10.0s +tttg: c142/158 lr:0.000025 t:10.1s +tttg: c143/158 lr:0.000022 t:10.1s +tttg: c144/158 lr:0.000019 t:10.2s +tttg: c145/158 lr:0.000017 t:10.3s +tttg: c146/158 lr:0.000014 t:10.4s +tttg: c147/158 lr:0.000012 t:10.4s +tttg: c148/158 lr:0.000010 t:10.5s +tttg: c149/158 lr:0.000008 t:10.6s +tttg: c150/158 lr:0.000006 t:10.6s +tttg: c151/158 lr:0.000005 t:10.7s +tttg: c152/158 lr:0.000004 t:10.8s +tttg: c153/158 lr:0.000003 t:10.8s +tttg: c154/158 lr:0.000002 t:10.9s +tttg: c155/158 lr:0.000001 t:11.0s +tttg: c156/158 lr:0.000000 t:11.1s +tttg: c157/158 lr:0.000000 t:11.2s +ttpr: phase:2/3 t:333.7s +ttp: b746/782 bl:2.6772 bb:1.0541 rl:2.6832 rb:1.0706 dl:2459-2501 gd:0 +ttp: b744/782 bl:2.6497 bb:1.0556 rl:2.6799 rb:1.0691 dl:2388-2419 gd:0 +ttpp: phase:3/3 pd:2448 gd:2000 t:347.6s +tttg: c1/213 lr:0.001000 t:0.1s +tttg: c2/213 lr:0.001000 t:0.1s +tttg: c3/213 lr:0.001000 t:0.2s +tttg: c4/213 lr:0.001000 t:0.3s +tttg: c5/213 lr:0.000999 t:0.4s +tttg: c6/213 lr:0.000999 t:0.5s +tttg: c7/213 lr:0.000998 t:0.5s +tttg: c8/213 lr:0.000997 t:0.6s +tttg: c9/213 lr:0.000996 t:0.7s +tttg: c10/213 lr:0.000996 t:0.7s +tttg: c11/213 lr:0.000995 t:0.8s +tttg: c12/213 lr:0.000993 t:0.9s +tttg: c13/213 lr:0.000992 t:0.9s +tttg: c14/213 lr:0.000991 t:1.0s +tttg: c15/213 lr:0.000989 t:1.1s +tttg: c16/213 lr:0.000988 t:1.2s +tttg: c17/213 lr:0.000986 t:1.2s +tttg: c18/213 lr:0.000984 t:1.3s +tttg: c19/213 lr:0.000982 t:1.4s +tttg: c20/213 lr:0.000980 t:1.4s +tttg: c21/213 lr:0.000978 t:1.5s +tttg: c22/213 lr:0.000976 t:1.6s +tttg: c23/213 lr:0.000974 t:1.6s +tttg: c24/213 lr:0.000971 t:1.7s +tttg: c25/213 lr:0.000969 t:1.8s +tttg: c26/213 lr:0.000966 t:1.9s +tttg: c27/213 lr:0.000963 t:1.9s +tttg: c28/213 lr:0.000961 t:2.0s +tttg: c29/213 lr:0.000958 t:2.1s +tttg: c30/213 lr:0.000955 t:2.1s +tttg: c31/213 lr:0.000951 t:2.2s +tttg: c32/213 lr:0.000948 t:2.3s +tttg: c33/213 lr:0.000945 t:2.4s +tttg: c34/213 lr:0.000941 t:2.4s +tttg: c35/213 lr:0.000938 t:2.5s +tttg: c36/213 lr:0.000934 t:2.6s +tttg: c37/213 lr:0.000931 t:2.6s +tttg: c38/213 lr:0.000927 t:2.7s +tttg: c39/213 lr:0.000923 t:2.8s +tttg: c40/213 lr:0.000919 t:2.8s +tttg: c41/213 lr:0.000915 t:2.9s +tttg: c42/213 lr:0.000911 t:3.0s +tttg: c43/213 lr:0.000906 t:3.1s +tttg: c44/213 lr:0.000902 t:3.1s +tttg: c45/213 lr:0.000897 t:3.2s +tttg: c46/213 lr:0.000893 t:3.3s +tttg: c47/213 lr:0.000888 t:3.3s +tttg: c48/213 lr:0.000884 t:3.4s +tttg: c49/213 lr:0.000879 t:3.5s +tttg: c50/213 lr:0.000874 t:3.5s +tttg: c51/213 lr:0.000869 t:3.6s +tttg: c52/213 lr:0.000864 t:3.7s +tttg: c53/213 lr:0.000859 t:3.8s +tttg: c54/213 lr:0.000854 t:3.8s +tttg: c55/213 lr:0.000848 t:3.9s +tttg: c56/213 lr:0.000843 t:4.0s +tttg: c57/213 lr:0.000837 t:4.0s +tttg: c58/213 lr:0.000832 t:4.1s +tttg: c59/213 lr:0.000826 t:4.2s +tttg: c60/213 lr:0.000821 t:4.2s +tttg: c61/213 lr:0.000815 t:4.3s +tttg: c62/213 lr:0.000809 t:4.4s +tttg: c63/213 lr:0.000803 t:4.5s +tttg: c64/213 lr:0.000797 t:4.5s +tttg: c65/213 lr:0.000791 t:4.6s +tttg: c66/213 lr:0.000785 t:4.7s +tttg: c67/213 lr:0.000779 t:4.7s +tttg: c68/213 lr:0.000773 t:4.8s +tttg: c69/213 lr:0.000767 t:4.9s +tttg: c70/213 lr:0.000761 t:4.9s +tttg: c71/213 lr:0.000754 t:5.0s +tttg: c72/213 lr:0.000748 t:5.1s +tttg: c73/213 lr:0.000741 t:5.1s +tttg: c74/213 lr:0.000735 t:5.2s +tttg: c75/213 lr:0.000728 t:5.3s +tttg: c76/213 lr:0.000722 t:5.4s +tttg: c77/213 lr:0.000715 t:5.4s +tttg: c78/213 lr:0.000708 t:5.5s +tttg: c79/213 lr:0.000702 t:5.6s +tttg: c80/213 lr:0.000695 t:5.6s +tttg: c81/213 lr:0.000688 t:5.7s +tttg: c82/213 lr:0.000681 t:5.8s +tttg: c83/213 lr:0.000674 t:5.8s +tttg: c84/213 lr:0.000667 t:5.9s +tttg: c85/213 lr:0.000660 t:6.0s +tttg: c86/213 lr:0.000653 t:6.0s +tttg: c87/213 lr:0.000646 t:6.1s +tttg: c88/213 lr:0.000639 t:6.2s +tttg: c89/213 lr:0.000632 t:6.2s +tttg: c90/213 lr:0.000625 t:6.3s +tttg: c91/213 lr:0.000617 t:6.4s +tttg: c92/213 lr:0.000610 t:6.5s +tttg: c93/213 lr:0.000603 t:6.5s +tttg: c94/213 lr:0.000596 t:6.6s +tttg: c95/213 lr:0.000588 t:6.7s +tttg: c96/213 lr:0.000581 t:6.7s +tttg: c97/213 lr:0.000574 t:6.8s +tttg: c98/213 lr:0.000566 t:6.9s +tttg: c99/213 lr:0.000559 t:7.0s +tttg: c100/213 lr:0.000552 t:7.0s +tttg: c101/213 lr:0.000544 t:7.1s +tttg: c102/213 lr:0.000537 t:7.2s +tttg: c103/213 lr:0.000530 t:7.2s +tttg: c104/213 lr:0.000522 t:7.3s +tttg: c105/213 lr:0.000515 t:7.4s +tttg: c106/213 lr:0.000507 t:7.4s +tttg: c107/213 lr:0.000500 t:7.5s +tttg: c108/213 lr:0.000493 t:7.6s +tttg: c109/213 lr:0.000485 t:7.7s +tttg: c110/213 lr:0.000478 t:7.7s +tttg: c111/213 lr:0.000470 t:7.8s +tttg: c112/213 lr:0.000463 t:7.9s +tttg: c113/213 lr:0.000456 t:7.9s +tttg: c114/213 lr:0.000448 t:8.0s +tttg: c115/213 lr:0.000441 t:8.1s +tttg: c116/213 lr:0.000434 t:8.1s +tttg: c117/213 lr:0.000426 t:8.2s +tttg: c118/213 lr:0.000419 t:8.3s +tttg: c119/213 lr:0.000412 t:8.4s +tttg: c120/213 lr:0.000404 t:8.4s +tttg: c121/213 lr:0.000397 t:8.5s +tttg: c122/213 lr:0.000390 t:8.6s +tttg: c123/213 lr:0.000383 t:8.6s +tttg: c124/213 lr:0.000375 t:8.7s +tttg: c125/213 lr:0.000368 t:8.8s +tttg: c126/213 lr:0.000361 t:8.8s +tttg: c127/213 lr:0.000354 t:8.9s +tttg: c128/213 lr:0.000347 t:9.0s +tttg: c129/213 lr:0.000340 t:9.1s +tttg: c130/213 lr:0.000333 t:9.1s +tttg: c131/213 lr:0.000326 t:9.2s +tttg: c132/213 lr:0.000319 t:9.3s +tttg: c133/213 lr:0.000312 t:9.3s +tttg: c134/213 lr:0.000305 t:9.4s +tttg: c135/213 lr:0.000298 t:9.5s +tttg: c136/213 lr:0.000292 t:9.5s +tttg: c137/213 lr:0.000285 t:9.6s +tttg: c138/213 lr:0.000278 t:9.7s +tttg: c139/213 lr:0.000272 t:9.7s +tttg: c140/213 lr:0.000265 t:9.8s +tttg: c141/213 lr:0.000259 t:9.9s +tttg: c142/213 lr:0.000252 t:10.0s +tttg: c143/213 lr:0.000246 t:10.0s +tttg: c144/213 lr:0.000239 t:10.1s +tttg: c145/213 lr:0.000233 t:10.2s +tttg: c146/213 lr:0.000227 t:10.2s +tttg: c147/213 lr:0.000221 t:10.3s +tttg: c148/213 lr:0.000215 t:10.4s +tttg: c149/213 lr:0.000209 t:10.4s +tttg: c150/213 lr:0.000203 t:10.5s +tttg: c151/213 lr:0.000197 t:10.6s +tttg: c152/213 lr:0.000191 t:10.6s +tttg: c153/213 lr:0.000185 t:10.7s +tttg: c154/213 lr:0.000179 t:10.8s +tttg: c155/213 lr:0.000174 t:10.9s +tttg: c156/213 lr:0.000168 t:11.0s +tttg: c157/213 lr:0.000163 t:11.1s +tttg: c158/213 lr:0.000157 t:11.1s +tttg: c159/213 lr:0.000152 t:11.2s +tttg: c160/213 lr:0.000146 t:11.3s +tttg: c161/213 lr:0.000141 t:11.3s +tttg: c162/213 lr:0.000136 t:11.4s +tttg: c163/213 lr:0.000131 t:11.5s +tttg: c164/213 lr:0.000126 t:11.6s +tttg: c165/213 lr:0.000121 t:11.6s +tttg: c166/213 lr:0.000116 t:11.7s +tttg: c167/213 lr:0.000112 t:11.8s +tttg: c168/213 lr:0.000107 t:11.8s +tttg: c169/213 lr:0.000103 t:11.9s +tttg: c170/213 lr:0.000098 t:12.0s +tttg: c171/213 lr:0.000094 t:12.0s +tttg: c172/213 lr:0.000089 t:12.1s +tttg: c173/213 lr:0.000085 t:12.2s +tttg: c174/213 lr:0.000081 t:12.2s +tttg: c175/213 lr:0.000077 t:12.4s +tttg: c176/213 lr:0.000073 t:12.4s +tttg: c177/213 lr:0.000069 t:12.5s +tttg: c178/213 lr:0.000066 t:12.6s +tttg: c179/213 lr:0.000062 t:12.6s +tttg: c180/213 lr:0.000059 t:12.7s +tttg: c181/213 lr:0.000055 t:12.8s +tttg: c182/213 lr:0.000052 t:12.9s +tttg: c183/213 lr:0.000049 t:12.9s +tttg: c184/213 lr:0.000045 t:13.0s +tttg: c185/213 lr:0.000042 t:13.1s +tttg: c186/213 lr:0.000039 t:13.1s +tttg: c187/213 lr:0.000037 t:13.2s +tttg: c188/213 lr:0.000034 t:13.3s +tttg: c189/213 lr:0.000031 t:13.3s +tttg: c190/213 lr:0.000029 t:13.4s +tttg: c191/213 lr:0.000026 t:13.5s +tttg: c192/213 lr:0.000024 t:13.6s +tttg: c193/213 lr:0.000022 t:13.6s +tttg: c194/213 lr:0.000020 t:13.7s +tttg: c195/213 lr:0.000018 t:13.8s +tttg: c196/213 lr:0.000016 t:13.8s +tttg: c197/213 lr:0.000014 t:13.9s +tttg: c198/213 lr:0.000012 t:14.0s +tttg: c199/213 lr:0.000011 t:14.0s +tttg: c200/213 lr:0.000009 t:14.1s +tttg: c201/213 lr:0.000008 t:14.2s +tttg: c202/213 lr:0.000007 t:14.2s +tttg: c203/213 lr:0.000005 t:14.3s +tttg: c204/213 lr:0.000004 t:14.4s +tttg: c205/213 lr:0.000004 t:14.4s +tttg: c206/213 lr:0.000003 t:14.5s +tttg: c207/213 lr:0.000002 t:14.6s +tttg: c208/213 lr:0.000001 t:14.6s +tttg: c209/213 lr:0.000001 t:14.7s +tttg: c210/213 lr:0.000000 t:14.8s +tttg: c211/213 lr:0.000000 t:14.9s +tttg: c212/213 lr:0.000000 t:14.9s +ttpr: phase:3/3 t:365.0s +ttp: b736/782 bl:2.6753 bb:1.0428 rl:2.6795 rb:1.0669 dl:2140-2165 gd:1 +ttp: b734/782 bl:2.7676 bb:1.0554 rl:2.6860 rb:1.0660 dl:2091-2115 gd:1 +ttp: b722/782 bl:2.7645 bb:1.0570 rl:2.6908 rb:1.0654 dl:1846-1861 gd:1 +ttp: b720/782 bl:2.8156 bb:1.0755 rl:2.6979 rb:1.0660 dl:1816-1832 gd:1 +ttp: b708/782 bl:2.7173 bb:1.0442 rl:2.6988 rb:1.0649 dl:1639-1649 gd:1 +ttp: b703/782 bl:2.9084 bb:1.1001 rl:2.7082 rb:1.0666 dl:1582-1594 gd:1 +ttp: b696/782 bl:2.8038 bb:1.0717 rl:2.7122 rb:1.0668 dl:1513-1522 gd:1 +ttp: b690/782 bl:2.8371 bb:1.0631 rl:2.7169 rb:1.0666 dl:1458-1467 gd:1 +ttp: b683/782 bl:2.7659 bb:1.0650 rl:2.7187 rb:1.0666 dl:1400-1406 gd:1 +ttp: b675/782 bl:2.8325 bb:1.0633 rl:2.7224 rb:1.0665 dl:1341-1347 gd:1 +ttp: b667/782 bl:2.8104 bb:1.1009 rl:2.7251 rb:1.0675 dl:1288-1295 gd:1 +ttp: b662/782 bl:2.8085 bb:1.0717 rl:2.7275 rb:1.0676 dl:1258-1263 gd:1 +ttp: b653/782 bl:2.7562 bb:1.0340 rl:2.7282 rb:1.0667 dl:1203-1209 gd:1 +ttp: b644/782 bl:2.7308 bb:1.0304 rl:2.7283 rb:1.0658 dl:1155-1160 gd:1 +ttp: b636/782 bl:2.7570 bb:1.0695 rl:2.7290 rb:1.0658 dl:1116-1120 gd:1 +ttp: b629/782 bl:2.7155 bb:1.0404 rl:2.7287 rb:1.0653 dl:1082-1086 gd:1 +ttp: b621/782 bl:2.8291 bb:1.0835 rl:2.7308 rb:1.0657 dl:1046-1050 gd:1 +ttp: b614/782 bl:2.7797 bb:1.0632 rl:2.7318 rb:1.0656 dl:1016-1020 gd:1 +ttp: b605/782 bl:2.7390 bb:1.0565 rl:2.7319 rb:1.0654 dl:978-982 gd:1 +ttp: b597/782 bl:2.7674 bb:1.0390 rl:2.7326 rb:1.0649 dl:947-950 gd:1 +ttp: b588/782 bl:2.7318 bb:1.0422 rl:2.7326 rb:1.0645 dl:917-921 gd:1 +ttp: b580/782 bl:2.7270 bb:1.0361 rl:2.7325 rb:1.0640 dl:891-894 gd:1 +ttp: b572/782 bl:2.9327 bb:1.1161 rl:2.7356 rb:1.0649 dl:865-868 gd:1 +ttp: b564/782 bl:2.8607 bb:1.1068 rl:2.7375 rb:1.0655 dl:840-843 gd:1 +ttp: b556/782 bl:2.8239 bb:1.0796 rl:2.7388 rb:1.0657 dl:815-818 gd:1 +ttp: b546/782 bl:2.8229 bb:1.0718 rl:2.7399 rb:1.0658 dl:788-790 gd:1 +ttp: b538/782 bl:2.6794 bb:1.0363 rl:2.7391 rb:1.0654 dl:767-769 gd:1 +ttp: b534/782 bl:2.8109 bb:1.0692 rl:2.7400 rb:1.0655 dl:757-759 gd:1 +ttp: b526/782 bl:2.7622 bb:1.0548 rl:2.7403 rb:1.0653 dl:737-739 gd:1 +ttp: b518/782 bl:2.7269 bb:1.0503 rl:2.7402 rb:1.0652 dl:717-720 gd:1 +ttp: b511/782 bl:2.7677 bb:1.0453 rl:2.7405 rb:1.0649 dl:700-703 gd:1 +ttp: b503/782 bl:2.8165 bb:1.0725 rl:2.7413 rb:1.0650 dl:683-685 gd:1 +ttp: b494/782 bl:2.7876 bb:1.0509 rl:2.7418 rb:1.0649 dl:661-664 gd:1 +ttp: b486/782 bl:2.7917 bb:1.0596 rl:2.7423 rb:1.0648 dl:645-646 gd:1 +ttp: b478/782 bl:2.7870 bb:1.0495 rl:2.7428 rb:1.0646 dl:628-630 gd:1 +ttp: b467/782 bl:2.7851 bb:1.0521 rl:2.7432 rb:1.0645 dl:606-608 gd:1 +ttp: b461/782 bl:2.7674 bb:1.0554 rl:2.7434 rb:1.0644 dl:595-597 gd:1 +ttp: b453/782 bl:2.7566 bb:1.0578 rl:2.7435 rb:1.0644 dl:580-582 gd:1 +ttp: b443/782 bl:2.7576 bb:1.0504 rl:2.7436 rb:1.0643 dl:562-564 gd:1 +ttp: b433/782 bl:2.7681 bb:1.0624 rl:2.7438 rb:1.0642 dl:544-545 gd:1 +ttp: b425/782 bl:2.7505 bb:1.0464 rl:2.7439 rb:1.0641 dl:530-532 gd:1 +ttp: b417/782 bl:2.8074 bb:1.0528 rl:2.7444 rb:1.0640 dl:516-517 gd:1 +ttp: b409/782 bl:2.7074 bb:1.0460 rl:2.7441 rb:1.0639 dl:503-505 gd:1 +ttp: b401/782 bl:2.7316 bb:1.0574 rl:2.7440 rb:1.0638 dl:490-492 gd:1 +ttp: b394/782 bl:2.8830 bb:1.1119 rl:2.7450 rb:1.0642 dl:479-481 gd:1 +ttp: b391/782 bl:2.8154 bb:1.0965 rl:2.7454 rb:1.0644 dl:475-476 gd:1 +ttp: b383/782 bl:2.8311 bb:1.0842 rl:2.7460 rb:1.0645 dl:463-464 gd:1 +ttp: b375/782 bl:2.8022 bb:1.1043 rl:2.7464 rb:1.0648 dl:452-453 gd:1 +ttp: b367/782 bl:2.8309 bb:1.0633 rl:2.7469 rb:1.0648 dl:440-441 gd:1 +ttp: b358/782 bl:2.8126 bb:1.0871 rl:2.7473 rb:1.0649 dl:427-429 gd:1 +ttp: b350/782 bl:2.7238 bb:1.0565 rl:2.7471 rb:1.0648 dl:417-418 gd:1 +ttp: b342/782 bl:2.8599 bb:1.1003 rl:2.7478 rb:1.0650 dl:406-407 gd:1 +ttp: b334/782 bl:2.8684 bb:1.1038 rl:2.7484 rb:1.0653 dl:395-396 gd:1 +ttp: b325/782 bl:2.8396 bb:1.0908 rl:2.7489 rb:1.0654 dl:384-385 gd:1 +ttp: b319/782 bl:2.8215 bb:1.1069 rl:2.7493 rb:1.0656 dl:376-377 gd:1 +ttp: b311/782 bl:2.8510 bb:1.0922 rl:2.7498 rb:1.0657 dl:365-367 gd:1 +ttp: b302/782 bl:2.8296 bb:1.0974 rl:2.7502 rb:1.0659 dl:354-355 gd:1 +ttp: b295/782 bl:2.8361 bb:1.1182 rl:2.7506 rb:1.0661 dl:345-347 gd:1 +ttp: b286/782 bl:2.8804 bb:1.0942 rl:2.7512 rb:1.0663 dl:335-336 gd:1 +ttp: b280/782 bl:2.8146 bb:1.0924 rl:2.7514 rb:1.0664 dl:329-329 gd:1 +ttp: b271/782 bl:2.7734 bb:1.0688 rl:2.7515 rb:1.0664 dl:319-320 gd:1 +ttp: b263/782 bl:2.8274 bb:1.1012 rl:2.7518 rb:1.0665 dl:310-311 gd:1 +ttp: b255/782 bl:2.8699 bb:1.1325 rl:2.7523 rb:1.0668 dl:300-301 gd:1 +ttp: b247/782 bl:2.7860 bb:1.0765 rl:2.7524 rb:1.0668 dl:292-293 gd:1 +ttp: b238/782 bl:2.8924 bb:1.1474 rl:2.7530 rb:1.0671 dl:283-284 gd:1 +ttp: b230/782 bl:2.9091 bb:1.1132 rl:2.7535 rb:1.0673 dl:275-276 gd:1 +ttp: b222/782 bl:2.8777 bb:1.1180 rl:2.7539 rb:1.0674 dl:267-268 gd:1 +ttp: b214/782 bl:2.9313 bb:1.1277 rl:2.7545 rb:1.0677 dl:259-260 gd:1 +ttp: b206/782 bl:2.8795 bb:1.1146 rl:2.7549 rb:1.0678 dl:252-253 gd:1 +ttp: b198/782 bl:2.9694 bb:1.1484 rl:2.7556 rb:1.0681 dl:245-246 gd:1 +ttp: b188/782 bl:2.9052 bb:1.1509 rl:2.7561 rb:1.0683 dl:236-237 gd:1 +ttp: b180/782 bl:2.9042 bb:1.1325 rl:2.7565 rb:1.0685 dl:229-230 gd:1 +ttp: b172/782 bl:3.0050 bb:1.1818 rl:2.7572 rb:1.0688 dl:222-223 gd:1 +ttp: b163/782 bl:2.8895 bb:1.1341 rl:2.7576 rb:1.0690 dl:214-215 gd:1 +ttp: b153/782 bl:3.0065 bb:1.1598 rl:2.7582 rb:1.0692 dl:206-207 gd:1 +ttp: b146/782 bl:2.8973 bb:1.1501 rl:2.7586 rb:1.0694 dl:200-201 gd:1 +ttp: b137/782 bl:2.9322 bb:1.1816 rl:2.7590 rb:1.0697 dl:193-194 gd:1 +ttp: b129/782 bl:2.9350 bb:1.1782 rl:2.7594 rb:1.0699 dl:187-187 gd:1 +ttp: b121/782 bl:2.8476 bb:1.1282 rl:2.7596 rb:1.0701 dl:181-181 gd:1 +ttp: b112/782 bl:2.9877 bb:1.1556 rl:2.7601 rb:1.0702 dl:174-175 gd:1 +ttp: b104/782 bl:2.9932 bb:1.1647 rl:2.7606 rb:1.0704 dl:168-169 gd:1 +ttp: b96/782 bl:2.9381 bb:1.1484 rl:2.7609 rb:1.0706 dl:162-163 gd:1 +ttp: b88/782 bl:3.0995 bb:1.2068 rl:2.7616 rb:1.0709 dl:156-157 gd:1 +ttp: b81/782 bl:2.9246 bb:1.1631 rl:2.7619 rb:1.0710 dl:151-151 gd:1 +ttp: b71/782 bl:2.9609 bb:1.1552 rl:2.7622 rb:1.0712 dl:143-144 gd:1 +ttp: b68/782 bl:3.1098 bb:1.2081 rl:2.7628 rb:1.0714 dl:141-142 gd:1 +ttp: b61/782 bl:2.9215 bb:1.1422 rl:2.7631 rb:1.0715 dl:135-136 gd:1 +ttp: b54/782 bl:3.1007 bb:1.2697 rl:2.7637 rb:1.0718 dl:130-130 gd:1 +ttp: b45/782 bl:3.0902 bb:1.2362 rl:2.7641 rb:1.0721 dl:122-123 gd:1 +ttp: b36/782 bl:2.9988 bb:1.2264 rl:2.7645 rb:1.0723 dl:115-116 gd:1 +ttp: b25/782 bl:3.2922 bb:1.3047 rl:2.7652 rb:1.0726 dl:106-107 gd:1 +ttp: b17/782 bl:3.1456 bb:1.2468 rl:2.7656 rb:1.0728 dl:98-99 gd:1 +ttp: b9/782 bl:3.2171 bb:1.2749 rl:2.7661 rb:1.0730 dl:87-89 gd:1 +ttp: b1/782 bl:3.3365 bb:1.2383 rl:2.7665 rb:1.0731 dl:45-70 gd:1 +quantized_ttt_phased val_loss:2.76880744 val_bpb:1.07189164 eval_time:456456ms +total_eval_time:456.5s +[Wed Apr 22 00:41:20 UTC 2026] Experiment pure_ttt_sp8192_dxa144_seed1337_8gpu complete +Final BPB: diagnostic pre-quantization post-ema val_loss:2.77352391 val_bpb:1.07368293 eval_time:6579ms +diagnostic quantized val_loss:2.80324509 val_bpb:1.08518856 eval_time:59228ms +quantized_ttt_phased val_loss:2.76880744 val_bpb:1.07189164 eval_time:456456ms diff --git a/records/track_10min_16mb/2026-04-22_AlphaLoRA144_WarmStart_WD1_1.07209/train_seed314.log b/records/track_10min_16mb/2026-04-22_AlphaLoRA144_WarmStart_WD1_1.07209/train_seed314.log new file mode 100644 index 0000000000..3dcf598cbb --- /dev/null +++ b/records/track_10min_16mb/2026-04-22_AlphaLoRA144_WarmStart_WD1_1.07209/train_seed314.log @@ -0,0 +1,756 @@ +[Wed Apr 22 00:11:14 UTC 2026] Starting SP8192 dexhunter experiment: pure_ttt_sp8192_dxa144_seed314_8gpu +Extra args: SEED=314 PHASED_TTT_ENABLED=1 PHASED_TTT_PREFIX_DOCS=2000 PHASED_TTT_NUM_PHASES=3 MLP_CLIP_SIGMAS=12.0 ATTN_CLIP_SIGMAS=13.0 EMBED_BITS=7 EMBED_CLIP_SIGMAS=15.0 MATRIX_LR=0.026 GPTQ_RESERVE_SECONDS=4 GPTQ_CALIBRATION_BATCHES=16 TTT_LORA_RANK=128 TTT_LORA_ALPHA=144 TTT_WARM_START_A=1 TTT_WEIGHT_DECAY=1.0 +torch: 2.9.1+cu128 +FA3 interface: OK +[Wed Apr 22 00:11:20 UTC 2026] GPU check: +0, 1 MiB, 0 % +1, 1 MiB, 0 % +2, 1 MiB, 0 % +3, 1 MiB, 0 % +4, 1 MiB, 0 % +5, 1 MiB, 0 % +6, 1 MiB, 0 % +7, 1 MiB, 0 % +W0422 00:11:21.728000 112 torch/distributed/run.py:803] +W0422 00:11:21.728000 112 torch/distributed/run.py:803] ***************************************** +W0422 00:11:21.728000 112 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0422 00:11:21.728000 112 torch/distributed/run.py:803] ***************************************** +Hyperparameters: + adam_eps: 1e-08 + adam_wd: 0.02 + artifact_dir: + attn_clip_sigmas: 13.0 + beta1: 0.9 + beta2: 0.95 + compressor: brotli + data_dir: DATA_DIR + datasets_dir: DATA_DIR/datasets/fineweb10B_sp8192 + distributed: True + ema_decay: 0.9965 + embed_bits: 7 + embed_clip_sigmas: 15.0 + embed_lr: 0.6 + embed_wd: 0.085 + enable_looping_at: 0.35 + eval_seq_len: 2048 + eval_stride: 64 + global_ttt_batch_seqs: 32 + global_ttt_chunk_tokens: 32768 + global_ttt_epochs: 1 + global_ttt_grad_clip: 1.0 + global_ttt_lr: 0.001 + global_ttt_momentum: 0.9 + global_ttt_respect_doc_boundaries: True + global_ttt_warmup_chunks: 0 + global_ttt_warmup_start_lr: 0.0 + gptq_calibration_batches: 16 + gptq_reserve_seconds: 4.0 + grad_accum_steps: 1 + grad_clip_norm: 0.3 + is_main_process: True + iterations: 20000 + ln_scale: True + local_rank: 0 + logfile: logs/8b66f202-61ad-44b2-8133-e54dcaa5e869.txt + logit_softcap: 30.0 + loop_end: 5 + loop_start: 3 + matrix_bits: 6 + matrix_clip_sigmas: 12.85 + matrix_lr: 0.026 + max_wallclock_seconds: 600.0 + min_lr: 0.0 + mlp_clip_sigmas: 12.0 + mlp_mult: 4.0 + model_dim: 512 + model_path: final_model.pt + muon_backend_steps: 5 + muon_momentum: 0.97 + muon_momentum_warmup_start: 0.92 + muon_momentum_warmup_steps: 1500 + muon_row_normalize: True + muon_wd: 0.095 + num_heads: 8 + num_kv_heads: 4 + num_layers: 11 + num_loops: 2 + parallel_final_lane: mean + parallel_start_layer: 8 + phased_ttt_enabled: True + phased_ttt_num_phases: 3 + phased_ttt_prefix_docs: 2000 + qk_gain_init: 5.0 + quantized_model_path: final_model.int6.ptz + rank: 0 + rope_base: 10000.0 + rope_dims: 16 + rope_train_seq_len: 2048 + rope_yarn: False + run_id: 8b66f202-61ad-44b2-8133-e54dcaa5e869 + scalar_lr: 0.02 + seed: 314 + skip_gates_enabled: True + sliding_window_enabled: False + tie_embeddings: True + tied_embed_init_std: 0.005 + tied_embed_lr: 0.03 + tokenizer_path: DATA_DIR/tokenizers/fineweb_8192_bpe.model + train_batch_tokens: 786432 + train_files: DATA_DIR/datasets/fineweb10B_sp8192/fineweb_train_*.bin + train_log_every: 500 + train_seq_len: 2048 + ttt_batch_size: 64 + ttt_beta1: 0.0 + ttt_beta2: 0.999 + ttt_chunk_size: 48 + ttt_enabled: True + ttt_eval_batches: + ttt_eval_seq_len: 2048 + ttt_grad_steps: 1 + ttt_k_lora: True + ttt_lora_lr: 0.0001 + ttt_lora_rank: 128 + ttt_mlp_lora: True + ttt_o_lora: True + ttt_optimizer: adam + ttt_weight_decay: 1.0 + val_batch_tokens: 524288 + val_doc_fraction: 1.0 + val_files: DATA_DIR/datasets/fineweb10B_sp8192/fineweb_val_*.bin + val_loss_every: 4000 + vocab_size: 8192 + warmdown_frac: 0.75 + warmup_steps: 20 + world_size: 8 + xsa_last_n: 11 +train_shards: 80 +val_tokens: 40540160 +model_params:35944602 +gptq:reserving 4s, effective=596000ms +warmup_cu_buckets:64,128,192,256 iters_each:3 +warmup_step: 1/20 +warmup_step: 2/20 +warmup_step: 3/20 +warmup_step: 4/20 +warmup_step: 5/20 +warmup_step: 6/20 +warmup_step: 10/20 +warmup_step: 20/20 +loop_warmup:enabled encoder:[0, 1, 2, 3, 4, 5, 3, 4] decoder:[5, 3, 4, 5, 6, 7, 8, 9, 10] +loop_warmup_step: 1/20 +loop_warmup_step: 2/20 +loop_warmup_step: 3/20 +loop_warmup_step: 4/20 +loop_warmup_step: 5/20 +loop_warmup_step: 6/20 +loop_warmup_step: 10/20 +loop_warmup_step: 20/20 +0/20000 val_loss: 9.0092 val_bpb: 3.4876 +1/20000 train_loss: 9.0084 train_time: 0.0m tok/s: 16601837 +2/20000 train_loss: 12.3017 train_time: 0.0m tok/s: 12882499 +3/20000 train_loss: 11.2467 train_time: 0.0m tok/s: 10974981 +4/20000 train_loss: 9.6647 train_time: 0.0m tok/s: 10132876 +5/20000 train_loss: 8.2432 train_time: 0.0m tok/s: 9746251 +500/20000 train_loss: 3.2606 train_time: 0.8m tok/s: 8302023 +1000/20000 train_loss: 3.0299 train_time: 1.6m tok/s: 8246192 +1500/20000 train_loss: 3.0339 train_time: 2.4m tok/s: 8234101 +2000/20000 train_loss: 2.9832 train_time: 3.2m tok/s: 8237881 +layer_loop:enabled step:2183 frac:0.350 encoder:[0, 1, 2, 3, 4, 5, 3, 4] decoder:[5, 3, 4, 5, 6, 7, 8, 9, 10] +2500/20000 train_loss: 3.0748 train_time: 4.2m tok/s: 7779406 +3000/20000 train_loss: 2.9130 train_time: 5.4m tok/s: 7317277 +3500/20000 train_loss: 2.9788 train_time: 6.5m tok/s: 7018208 +4000/20000 train_loss: 2.9152 train_time: 7.7m tok/s: 6810416 +4000/20000 val_loss: 2.8882 val_bpb: 1.1181 +4500/20000 train_loss: 2.8696 train_time: 8.9m tok/s: 6658972 +4940/20000 val_loss: 2.7705 val_bpb: 1.0725 +stopping_early: wallclock_cap train_time: 596094ms step: 4940/20000 +peak memory allocated: 40029 MiB reserved: 44036 MiB +ema:applying EMA weights +diagnostic pre-quantization post-ema val_loss:2.76935172 val_bpb:1.07206780 eval_time:6637ms +Serialized model: 135409136 bytes +Code size (uncompressed): 122656 bytes +Code size (compressed): 27680 bytes +GPTQ:collecting Hessians from calibration data... +GPTQ:collected 67 Hessians in 3.6s +Quantized weights: + gptq (int6): blocks.attn.c_k.weight, blocks.attn.c_q.weight, blocks.attn.c_v.weight, blocks.attn.proj.weight, blocks.mlp.fc.weight, blocks.mlp.proj.weight + gptq (int7): tok_emb.weight + passthrough (float16): blocks.attn.q_gain, blocks.attn_scale, blocks.mlp_scale, blocks.resid_mix, parallel_post_lambdas, parallel_resid_lambdas, skip_gates, skip_weights +Serialized model quantized+brotli: 15908137 bytes +Total submission size quantized+brotli: 15935817 bytes +diagnostic quantized val_loss:2.80162128 val_bpb:1.08455995 eval_time:57889ms +ttt_lora:warming up compile (random tokens, no val data) +ttt_lora:compile warmup done (152.5s) + +beginning TTT eval timer +ttt_phased: total_docs:50000 prefix_docs:2000 suffix_docs:48000 num_phases:3 boundaries:[666, 1333, 2000] +ttp: b777/782 bl:2.7268 bb:1.0907 rl:2.7268 rb:1.0907 dl:7190-7938 gd:0 +ttp: b772/782 bl:2.7664 bb:1.1065 rl:2.7427 rb:1.0970 dl:4937-5193 gd:0 +ttp: b767/782 bl:2.7524 bb:1.0989 rl:2.7451 rb:1.0975 dl:3963-4123 gd:0 +ttpp: phase:1/3 pd:1104 gd:666 t:200.3s +tttg: c1/95 lr:0.001000 t:1.7s +tttg: c2/95 lr:0.001000 t:1.7s +tttg: c3/95 lr:0.000999 t:1.8s +tttg: c4/95 lr:0.000997 t:1.9s +tttg: c5/95 lr:0.000996 t:1.9s +tttg: c6/95 lr:0.000993 t:2.0s +tttg: c7/95 lr:0.000990 t:2.1s +tttg: c8/95 lr:0.000986 t:2.2s +tttg: c9/95 lr:0.000982 t:2.2s +tttg: c10/95 lr:0.000978 t:2.3s +tttg: c11/95 lr:0.000972 t:2.4s +tttg: c12/95 lr:0.000967 t:2.4s +tttg: c13/95 lr:0.000960 t:2.5s +tttg: c14/95 lr:0.000954 t:2.6s +tttg: c15/95 lr:0.000946 t:2.6s +tttg: c16/95 lr:0.000938 t:2.7s +tttg: c17/95 lr:0.000930 t:2.8s +tttg: c18/95 lr:0.000921 t:2.9s +tttg: c19/95 lr:0.000912 t:2.9s +tttg: c20/95 lr:0.000903 t:3.0s +tttg: c21/95 lr:0.000892 t:3.1s +tttg: c22/95 lr:0.000882 t:3.1s +tttg: c23/95 lr:0.000871 t:3.2s +tttg: c24/95 lr:0.000859 t:3.3s +tttg: c25/95 lr:0.000848 t:3.3s +tttg: c26/95 lr:0.000835 t:3.4s +tttg: c27/95 lr:0.000823 t:3.5s +tttg: c28/95 lr:0.000810 t:3.5s +tttg: c29/95 lr:0.000797 t:3.6s +tttg: c30/95 lr:0.000783 t:3.7s +tttg: c31/95 lr:0.000769 t:3.8s +tttg: c32/95 lr:0.000755 t:3.8s +tttg: c33/95 lr:0.000740 t:3.9s +tttg: c34/95 lr:0.000726 t:4.0s +tttg: c35/95 lr:0.000710 t:4.0s +tttg: c36/95 lr:0.000695 t:4.1s +tttg: c37/95 lr:0.000680 t:4.2s +tttg: c38/95 lr:0.000664 t:4.2s +tttg: c39/95 lr:0.000648 t:4.3s +tttg: c40/95 lr:0.000632 t:4.4s +tttg: c41/95 lr:0.000616 t:4.4s +tttg: c42/95 lr:0.000600 t:4.5s +tttg: c43/95 lr:0.000583 t:4.6s +tttg: c44/95 lr:0.000567 t:4.6s +tttg: c45/95 lr:0.000550 t:4.7s +tttg: c46/95 lr:0.000533 t:4.8s +tttg: c47/95 lr:0.000517 t:4.9s +tttg: c48/95 lr:0.000500 t:4.9s +tttg: c49/95 lr:0.000483 t:5.0s +tttg: c50/95 lr:0.000467 t:5.1s +tttg: c51/95 lr:0.000450 t:5.1s +tttg: c52/95 lr:0.000433 t:5.2s +tttg: c53/95 lr:0.000417 t:5.3s +tttg: c54/95 lr:0.000400 t:5.3s +tttg: c55/95 lr:0.000384 t:5.4s +tttg: c56/95 lr:0.000368 t:5.5s +tttg: c57/95 lr:0.000352 t:5.5s +tttg: c58/95 lr:0.000336 t:5.6s +tttg: c59/95 lr:0.000320 t:5.7s +tttg: c60/95 lr:0.000305 t:5.8s +tttg: c61/95 lr:0.000290 t:5.8s +tttg: c62/95 lr:0.000274 t:5.9s +tttg: c63/95 lr:0.000260 t:6.0s +tttg: c64/95 lr:0.000245 t:6.0s +tttg: c65/95 lr:0.000231 t:6.1s +tttg: c66/95 lr:0.000217 t:6.2s +tttg: c67/95 lr:0.000203 t:6.2s +tttg: c68/95 lr:0.000190 t:6.3s +tttg: c69/95 lr:0.000177 t:6.4s +tttg: c70/95 lr:0.000165 t:6.4s +tttg: c71/95 lr:0.000152 t:6.5s +tttg: c72/95 lr:0.000141 t:6.6s +tttg: c73/95 lr:0.000129 t:6.6s +tttg: c74/95 lr:0.000118 t:6.7s +tttg: c75/95 lr:0.000108 t:6.8s +tttg: c76/95 lr:0.000097 t:6.9s +tttg: c77/95 lr:0.000088 t:6.9s +tttg: c78/95 lr:0.000079 t:7.0s +tttg: c79/95 lr:0.000070 t:7.1s +tttg: c80/95 lr:0.000062 t:7.1s +tttg: c81/95 lr:0.000054 t:7.2s +tttg: c82/95 lr:0.000046 t:7.3s +tttg: c83/95 lr:0.000040 t:7.4s +tttg: c84/95 lr:0.000033 t:7.4s +tttg: c85/95 lr:0.000028 t:7.5s +tttg: c86/95 lr:0.000022 t:7.6s +tttg: c87/95 lr:0.000018 t:7.6s +tttg: c88/95 lr:0.000014 t:7.7s +tttg: c89/95 lr:0.000010 t:7.8s +tttg: c90/95 lr:0.000007 t:7.8s +tttg: c91/95 lr:0.000004 t:7.9s +tttg: c92/95 lr:0.000003 t:8.0s +tttg: c93/95 lr:0.000001 t:8.1s +tttg: c94/95 lr:0.000000 t:8.1s +ttpr: phase:1/3 t:211.0s +ttp: b758/782 bl:2.8748 bb:1.0847 rl:2.7657 rb:1.0953 dl:3108-3187 gd:0 +ttp: b756/782 bl:2.7802 bb:1.0776 rl:2.7676 rb:1.0930 dl:2973-3032 gd:0 +ttpp: phase:2/3 pd:1808 gd:1333 t:319.9s +tttg: c1/158 lr:0.001000 t:0.1s +tttg: c2/158 lr:0.001000 t:0.1s +tttg: c3/158 lr:0.001000 t:0.2s +tttg: c4/158 lr:0.000999 t:0.3s +tttg: c5/158 lr:0.000998 t:0.4s +tttg: c6/158 lr:0.000997 t:0.4s +tttg: c7/158 lr:0.000996 t:0.5s +tttg: c8/158 lr:0.000995 t:0.6s +tttg: c9/158 lr:0.000994 t:0.7s +tttg: c10/158 lr:0.000992 t:0.7s +tttg: c11/158 lr:0.000990 t:0.8s +tttg: c12/158 lr:0.000988 t:0.9s +tttg: c13/158 lr:0.000986 t:0.9s +tttg: c14/158 lr:0.000983 t:1.0s +tttg: c15/158 lr:0.000981 t:1.1s +tttg: c16/158 lr:0.000978 t:1.2s +tttg: c17/158 lr:0.000975 t:1.2s +tttg: c18/158 lr:0.000971 t:1.3s +tttg: c19/158 lr:0.000968 t:1.4s +tttg: c20/158 lr:0.000964 t:1.5s +tttg: c21/158 lr:0.000960 t:1.5s +tttg: c22/158 lr:0.000957 t:1.6s +tttg: c23/158 lr:0.000952 t:1.7s +tttg: c24/158 lr:0.000948 t:1.7s +tttg: c25/158 lr:0.000943 t:1.8s +tttg: c26/158 lr:0.000939 t:1.9s +tttg: c27/158 lr:0.000934 t:2.0s +tttg: c28/158 lr:0.000929 t:2.0s +tttg: c29/158 lr:0.000924 t:2.1s +tttg: c30/158 lr:0.000918 t:2.2s +tttg: c31/158 lr:0.000913 t:2.2s +tttg: c32/158 lr:0.000907 t:2.3s +tttg: c33/158 lr:0.000901 t:2.4s +tttg: c34/158 lr:0.000895 t:2.4s +tttg: c35/158 lr:0.000889 t:2.5s +tttg: c36/158 lr:0.000882 t:2.6s +tttg: c37/158 lr:0.000876 t:2.7s +tttg: c38/158 lr:0.000869 t:2.7s +tttg: c39/158 lr:0.000862 t:2.8s +tttg: c40/158 lr:0.000855 t:2.9s +tttg: c41/158 lr:0.000848 t:2.9s +tttg: c42/158 lr:0.000841 t:3.0s +tttg: c43/158 lr:0.000834 t:3.1s +tttg: c44/158 lr:0.000826 t:3.2s +tttg: c45/158 lr:0.000818 t:3.2s +tttg: c46/158 lr:0.000811 t:3.3s +tttg: c47/158 lr:0.000803 t:3.4s +tttg: c48/158 lr:0.000795 t:3.4s +tttg: c49/158 lr:0.000787 t:3.5s +tttg: c50/158 lr:0.000778 t:3.6s +tttg: c51/158 lr:0.000770 t:3.7s +tttg: c52/158 lr:0.000761 t:3.7s +tttg: c53/158 lr:0.000753 t:3.8s +tttg: c54/158 lr:0.000744 t:3.9s +tttg: c55/158 lr:0.000735 t:4.0s +tttg: c56/158 lr:0.000727 t:4.0s +tttg: c57/158 lr:0.000718 t:4.1s +tttg: c58/158 lr:0.000709 t:4.2s +tttg: c59/158 lr:0.000699 t:4.3s +tttg: c60/158 lr:0.000690 t:4.3s +tttg: c61/158 lr:0.000681 t:4.4s +tttg: c62/158 lr:0.000672 t:4.5s +tttg: c63/158 lr:0.000662 t:4.6s +tttg: c64/158 lr:0.000653 t:4.6s +tttg: c65/158 lr:0.000643 t:4.7s +tttg: c66/158 lr:0.000633 t:4.8s +tttg: c67/158 lr:0.000624 t:4.8s +tttg: c68/158 lr:0.000614 t:4.9s +tttg: c69/158 lr:0.000604 t:5.0s +tttg: c70/158 lr:0.000594 t:5.1s +tttg: c71/158 lr:0.000585 t:5.1s +tttg: c72/158 lr:0.000575 t:5.2s +tttg: c73/158 lr:0.000565 t:5.3s +tttg: c74/158 lr:0.000555 t:5.4s +tttg: c75/158 lr:0.000545 t:5.4s +tttg: c76/158 lr:0.000535 t:5.5s +tttg: c77/158 lr:0.000525 t:5.6s +tttg: c78/158 lr:0.000515 t:5.6s +tttg: c79/158 lr:0.000505 t:5.7s +tttg: c80/158 lr:0.000495 t:5.8s +tttg: c81/158 lr:0.000485 t:5.8s +tttg: c82/158 lr:0.000475 t:5.9s +tttg: c83/158 lr:0.000465 t:6.0s +tttg: c84/158 lr:0.000455 t:6.1s +tttg: c85/158 lr:0.000445 t:6.1s +tttg: c86/158 lr:0.000435 t:6.2s +tttg: c87/158 lr:0.000425 t:6.3s +tttg: c88/158 lr:0.000415 t:6.3s +tttg: c89/158 lr:0.000406 t:6.4s +tttg: c90/158 lr:0.000396 t:6.5s +tttg: c91/158 lr:0.000386 t:6.6s +tttg: c92/158 lr:0.000376 t:6.6s +tttg: c93/158 lr:0.000367 t:6.7s +tttg: c94/158 lr:0.000357 t:6.8s +tttg: c95/158 lr:0.000347 t:6.8s +tttg: c96/158 lr:0.000338 t:6.9s +tttg: c97/158 lr:0.000328 t:7.0s +tttg: c98/158 lr:0.000319 t:7.1s +tttg: c99/158 lr:0.000310 t:7.1s +tttg: c100/158 lr:0.000301 t:7.2s +tttg: c101/158 lr:0.000291 t:7.3s +tttg: c102/158 lr:0.000282 t:7.3s +tttg: c103/158 lr:0.000273 t:7.4s +tttg: c104/158 lr:0.000265 t:7.5s +tttg: c105/158 lr:0.000256 t:7.6s +tttg: c106/158 lr:0.000247 t:7.6s +tttg: c107/158 lr:0.000239 t:7.7s +tttg: c108/158 lr:0.000230 t:7.8s +tttg: c109/158 lr:0.000222 t:7.8s +tttg: c110/158 lr:0.000213 t:7.9s +tttg: c111/158 lr:0.000205 t:8.0s +tttg: c112/158 lr:0.000197 t:8.1s +tttg: c113/158 lr:0.000189 t:8.1s +tttg: c114/158 lr:0.000182 t:8.2s +tttg: c115/158 lr:0.000174 t:8.3s +tttg: c116/158 lr:0.000166 t:8.3s +tttg: c117/158 lr:0.000159 t:8.4s +tttg: c118/158 lr:0.000152 t:8.5s +tttg: c119/158 lr:0.000145 t:8.6s +tttg: c120/158 lr:0.000138 t:8.7s +tttg: c121/158 lr:0.000131 t:8.7s +tttg: c122/158 lr:0.000124 t:8.8s +tttg: c123/158 lr:0.000118 t:8.9s +tttg: c124/158 lr:0.000111 t:8.9s +tttg: c125/158 lr:0.000105 t:9.0s +tttg: c126/158 lr:0.000099 t:9.1s +tttg: c127/158 lr:0.000093 t:9.2s +tttg: c128/158 lr:0.000087 t:9.2s +tttg: c129/158 lr:0.000082 t:9.3s +tttg: c130/158 lr:0.000076 t:9.4s +tttg: c131/158 lr:0.000071 t:9.4s +tttg: c132/158 lr:0.000066 t:9.5s +tttg: c133/158 lr:0.000061 t:9.6s +tttg: c134/158 lr:0.000057 t:9.6s +tttg: c135/158 lr:0.000052 t:9.7s +tttg: c136/158 lr:0.000048 t:9.8s +tttg: c137/158 lr:0.000043 t:9.9s +tttg: c138/158 lr:0.000040 t:9.9s +tttg: c139/158 lr:0.000036 t:10.0s +tttg: c140/158 lr:0.000032 t:10.1s +tttg: c141/158 lr:0.000029 t:10.1s +tttg: c142/158 lr:0.000025 t:10.2s +tttg: c143/158 lr:0.000022 t:10.3s +tttg: c144/158 lr:0.000019 t:10.4s +tttg: c145/158 lr:0.000017 t:10.4s +tttg: c146/158 lr:0.000014 t:10.5s +tttg: c147/158 lr:0.000012 t:10.6s +tttg: c148/158 lr:0.000010 t:10.6s +tttg: c149/158 lr:0.000008 t:10.7s +tttg: c150/158 lr:0.000006 t:10.8s +tttg: c151/158 lr:0.000005 t:10.9s +tttg: c152/158 lr:0.000004 t:11.0s +tttg: c153/158 lr:0.000003 t:11.0s +tttg: c154/158 lr:0.000002 t:11.1s +tttg: c155/158 lr:0.000001 t:11.2s +tttg: c156/158 lr:0.000000 t:11.3s +tttg: c157/158 lr:0.000000 t:11.3s +ttpr: phase:2/3 t:333.9s +ttp: b746/782 bl:2.6773 bb:1.0542 rl:2.7587 rb:1.0891 dl:2459-2501 gd:0 +ttp: b745/782 bl:2.7833 bb:1.0880 rl:2.7609 rb:1.0890 dl:2421-2458 gd:0 +ttpp: phase:3/3 pd:2448 gd:2000 t:347.8s +tttg: c1/213 lr:0.001000 t:0.1s +tttg: c2/213 lr:0.001000 t:0.1s +tttg: c3/213 lr:0.001000 t:0.2s +tttg: c4/213 lr:0.001000 t:0.3s +tttg: c5/213 lr:0.000999 t:0.4s +tttg: c6/213 lr:0.000999 t:0.4s +tttg: c7/213 lr:0.000998 t:0.5s +tttg: c8/213 lr:0.000997 t:0.6s +tttg: c9/213 lr:0.000996 t:0.7s +tttg: c10/213 lr:0.000996 t:0.7s +tttg: c11/213 lr:0.000995 t:0.8s +tttg: c12/213 lr:0.000993 t:0.9s +tttg: c13/213 lr:0.000992 t:0.9s +tttg: c14/213 lr:0.000991 t:1.0s +tttg: c15/213 lr:0.000989 t:1.1s +tttg: c16/213 lr:0.000988 t:1.2s +tttg: c17/213 lr:0.000986 t:1.2s +tttg: c18/213 lr:0.000984 t:1.3s +tttg: c19/213 lr:0.000982 t:1.4s +tttg: c20/213 lr:0.000980 t:1.4s +tttg: c21/213 lr:0.000978 t:1.5s +tttg: c22/213 lr:0.000976 t:1.6s +tttg: c23/213 lr:0.000974 t:1.7s +tttg: c24/213 lr:0.000971 t:1.7s +tttg: c25/213 lr:0.000969 t:1.8s +tttg: c26/213 lr:0.000966 t:1.9s +tttg: c27/213 lr:0.000963 t:1.9s +tttg: c28/213 lr:0.000961 t:2.0s +tttg: c29/213 lr:0.000958 t:2.1s +tttg: c30/213 lr:0.000955 t:2.2s +tttg: c31/213 lr:0.000951 t:2.2s +tttg: c32/213 lr:0.000948 t:2.3s +tttg: c33/213 lr:0.000945 t:2.4s +tttg: c34/213 lr:0.000941 t:2.4s +tttg: c35/213 lr:0.000938 t:2.5s +tttg: c36/213 lr:0.000934 t:2.6s +tttg: c37/213 lr:0.000931 t:2.7s +tttg: c38/213 lr:0.000927 t:2.7s +tttg: c39/213 lr:0.000923 t:2.8s +tttg: c40/213 lr:0.000919 t:2.9s +tttg: c41/213 lr:0.000915 t:3.0s +tttg: c42/213 lr:0.000911 t:3.0s +tttg: c43/213 lr:0.000906 t:3.1s +tttg: c44/213 lr:0.000902 t:3.2s +tttg: c45/213 lr:0.000897 t:3.2s +tttg: c46/213 lr:0.000893 t:3.3s +tttg: c47/213 lr:0.000888 t:3.4s +tttg: c48/213 lr:0.000884 t:3.5s +tttg: c49/213 lr:0.000879 t:3.5s +tttg: c50/213 lr:0.000874 t:3.6s +tttg: c51/213 lr:0.000869 t:3.7s +tttg: c52/213 lr:0.000864 t:3.7s +tttg: c53/213 lr:0.000859 t:3.8s +tttg: c54/213 lr:0.000854 t:3.9s +tttg: c55/213 lr:0.000848 t:4.0s +tttg: c56/213 lr:0.000843 t:4.0s +tttg: c57/213 lr:0.000837 t:4.1s +tttg: c58/213 lr:0.000832 t:4.2s +tttg: c59/213 lr:0.000826 t:4.2s +tttg: c60/213 lr:0.000821 t:4.3s +tttg: c61/213 lr:0.000815 t:4.4s +tttg: c62/213 lr:0.000809 t:4.4s +tttg: c63/213 lr:0.000803 t:4.5s +tttg: c64/213 lr:0.000797 t:4.6s +tttg: c65/213 lr:0.000791 t:4.7s +tttg: c66/213 lr:0.000785 t:4.7s +tttg: c67/213 lr:0.000779 t:4.8s +tttg: c68/213 lr:0.000773 t:4.9s +tttg: c69/213 lr:0.000767 t:4.9s +tttg: c70/213 lr:0.000761 t:5.0s +tttg: c71/213 lr:0.000754 t:5.1s +tttg: c72/213 lr:0.000748 t:5.2s +tttg: c73/213 lr:0.000741 t:5.2s +tttg: c74/213 lr:0.000735 t:5.3s +tttg: c75/213 lr:0.000728 t:5.4s +tttg: c76/213 lr:0.000722 t:5.4s +tttg: c77/213 lr:0.000715 t:5.5s +tttg: c78/213 lr:0.000708 t:5.6s +tttg: c79/213 lr:0.000702 t:5.7s +tttg: c80/213 lr:0.000695 t:5.7s +tttg: c81/213 lr:0.000688 t:5.8s +tttg: c82/213 lr:0.000681 t:5.9s +tttg: c83/213 lr:0.000674 t:5.9s +tttg: c84/213 lr:0.000667 t:6.0s +tttg: c85/213 lr:0.000660 t:6.1s +tttg: c86/213 lr:0.000653 t:6.1s +tttg: c87/213 lr:0.000646 t:6.2s +tttg: c88/213 lr:0.000639 t:6.3s +tttg: c89/213 lr:0.000632 t:6.4s +tttg: c90/213 lr:0.000625 t:6.4s +tttg: c91/213 lr:0.000617 t:6.5s +tttg: c92/213 lr:0.000610 t:6.6s +tttg: c93/213 lr:0.000603 t:6.6s +tttg: c94/213 lr:0.000596 t:6.7s +tttg: c95/213 lr:0.000588 t:6.8s +tttg: c96/213 lr:0.000581 t:6.8s +tttg: c97/213 lr:0.000574 t:6.9s +tttg: c98/213 lr:0.000566 t:7.0s +tttg: c99/213 lr:0.000559 t:7.1s +tttg: c100/213 lr:0.000552 t:7.1s +tttg: c101/213 lr:0.000544 t:7.2s +tttg: c102/213 lr:0.000537 t:7.3s +tttg: c103/213 lr:0.000530 t:7.3s +tttg: c104/213 lr:0.000522 t:7.4s +tttg: c105/213 lr:0.000515 t:7.5s +tttg: c106/213 lr:0.000507 t:7.6s +tttg: c107/213 lr:0.000500 t:7.6s +tttg: c108/213 lr:0.000493 t:7.7s +tttg: c109/213 lr:0.000485 t:7.8s +tttg: c110/213 lr:0.000478 t:7.8s +tttg: c111/213 lr:0.000470 t:7.9s +tttg: c112/213 lr:0.000463 t:8.0s +tttg: c113/213 lr:0.000456 t:8.1s +tttg: c114/213 lr:0.000448 t:8.1s +tttg: c115/213 lr:0.000441 t:8.2s +tttg: c116/213 lr:0.000434 t:8.3s +tttg: c117/213 lr:0.000426 t:8.3s +tttg: c118/213 lr:0.000419 t:8.4s +tttg: c119/213 lr:0.000412 t:8.5s +tttg: c120/213 lr:0.000404 t:8.6s +tttg: c121/213 lr:0.000397 t:8.6s +tttg: c122/213 lr:0.000390 t:8.7s +tttg: c123/213 lr:0.000383 t:8.8s +tttg: c124/213 lr:0.000375 t:8.9s +tttg: c125/213 lr:0.000368 t:8.9s +tttg: c126/213 lr:0.000361 t:9.0s +tttg: c127/213 lr:0.000354 t:9.1s +tttg: c128/213 lr:0.000347 t:9.1s +tttg: c129/213 lr:0.000340 t:9.2s +tttg: c130/213 lr:0.000333 t:9.3s +tttg: c131/213 lr:0.000326 t:9.4s +tttg: c132/213 lr:0.000319 t:9.4s +tttg: c133/213 lr:0.000312 t:9.5s +tttg: c134/213 lr:0.000305 t:9.6s +tttg: c135/213 lr:0.000298 t:9.6s +tttg: c136/213 lr:0.000292 t:9.7s +tttg: c137/213 lr:0.000285 t:9.8s +tttg: c138/213 lr:0.000278 t:9.9s +tttg: c139/213 lr:0.000272 t:9.9s +tttg: c140/213 lr:0.000265 t:10.0s +tttg: c141/213 lr:0.000259 t:10.1s +tttg: c142/213 lr:0.000252 t:10.1s +tttg: c143/213 lr:0.000246 t:10.2s +tttg: c144/213 lr:0.000239 t:10.3s +tttg: c145/213 lr:0.000233 t:10.4s +tttg: c146/213 lr:0.000227 t:10.4s +tttg: c147/213 lr:0.000221 t:10.5s +tttg: c148/213 lr:0.000215 t:10.6s +tttg: c149/213 lr:0.000209 t:10.6s +tttg: c150/213 lr:0.000203 t:10.7s +tttg: c151/213 lr:0.000197 t:10.8s +tttg: c152/213 lr:0.000191 t:10.9s +tttg: c153/213 lr:0.000185 t:10.9s +tttg: c154/213 lr:0.000179 t:11.0s +tttg: c155/213 lr:0.000174 t:11.1s +tttg: c156/213 lr:0.000168 t:11.2s +tttg: c157/213 lr:0.000163 t:11.2s +tttg: c158/213 lr:0.000157 t:11.3s +tttg: c159/213 lr:0.000152 t:11.4s +tttg: c160/213 lr:0.000146 t:11.5s +tttg: c161/213 lr:0.000141 t:11.6s +tttg: c162/213 lr:0.000136 t:11.6s +tttg: c163/213 lr:0.000131 t:11.7s +tttg: c164/213 lr:0.000126 t:11.8s +tttg: c165/213 lr:0.000121 t:11.8s +tttg: c166/213 lr:0.000116 t:11.9s +tttg: c167/213 lr:0.000112 t:12.0s +tttg: c168/213 lr:0.000107 t:12.1s +tttg: c169/213 lr:0.000103 t:12.1s +tttg: c170/213 lr:0.000098 t:12.2s +tttg: c171/213 lr:0.000094 t:12.3s +tttg: c172/213 lr:0.000089 t:12.3s +tttg: c173/213 lr:0.000085 t:12.4s +tttg: c174/213 lr:0.000081 t:12.5s +tttg: c175/213 lr:0.000077 t:12.6s +tttg: c176/213 lr:0.000073 t:12.6s +tttg: c177/213 lr:0.000069 t:12.7s +tttg: c178/213 lr:0.000066 t:12.8s +tttg: c179/213 lr:0.000062 t:12.9s +tttg: c180/213 lr:0.000059 t:12.9s +tttg: c181/213 lr:0.000055 t:13.0s +tttg: c182/213 lr:0.000052 t:13.1s +tttg: c183/213 lr:0.000049 t:13.1s +tttg: c184/213 lr:0.000045 t:13.2s +tttg: c185/213 lr:0.000042 t:13.3s +tttg: c186/213 lr:0.000039 t:13.4s +tttg: c187/213 lr:0.000037 t:13.4s +tttg: c188/213 lr:0.000034 t:13.5s +tttg: c189/213 lr:0.000031 t:13.6s +tttg: c190/213 lr:0.000029 t:13.6s +tttg: c191/213 lr:0.000026 t:13.7s +tttg: c192/213 lr:0.000024 t:13.8s +tttg: c193/213 lr:0.000022 t:13.9s +tttg: c194/213 lr:0.000020 t:13.9s +tttg: c195/213 lr:0.000018 t:14.0s +tttg: c196/213 lr:0.000016 t:14.1s +tttg: c197/213 lr:0.000014 t:14.2s +tttg: c198/213 lr:0.000012 t:14.2s +tttg: c199/213 lr:0.000011 t:14.3s +tttg: c200/213 lr:0.000009 t:14.4s +tttg: c201/213 lr:0.000008 t:14.5s +tttg: c202/213 lr:0.000007 t:14.5s +tttg: c203/213 lr:0.000005 t:14.6s +tttg: c204/213 lr:0.000004 t:14.7s +tttg: c205/213 lr:0.000004 t:14.7s +tttg: c206/213 lr:0.000003 t:14.8s +tttg: c207/213 lr:0.000002 t:14.9s +tttg: c208/213 lr:0.000001 t:15.0s +tttg: c209/213 lr:0.000001 t:15.0s +tttg: c210/213 lr:0.000000 t:15.1s +tttg: c211/213 lr:0.000000 t:15.2s +tttg: c212/213 lr:0.000000 t:15.2s +ttpr: phase:3/3 t:365.6s +ttp: b736/782 bl:2.6753 bb:1.0428 rl:2.7547 rb:1.0857 dl:2140-2165 gd:1 +ttp: b735/782 bl:2.8310 bb:1.0781 rl:2.7598 rb:1.0852 dl:2116-2140 gd:1 +ttp: b722/782 bl:2.7645 bb:1.0570 rl:2.7601 rb:1.0836 dl:1846-1861 gd:1 +ttp: b719/782 bl:2.6756 bb:1.0260 rl:2.7558 rb:1.0806 dl:1793-1816 gd:1 +ttp: b705/782 bl:2.7803 bb:1.0709 rl:2.7568 rb:1.0802 dl:1606-1617 gd:1 +ttp: b702/782 bl:2.7954 bb:1.0633 rl:2.7584 rb:1.0795 dl:1572-1581 gd:1 +ttp: b694/782 bl:2.7596 bb:1.0651 rl:2.7584 rb:1.0789 dl:1494-1504 gd:1 +ttp: b681/782 bl:2.8133 bb:1.0681 rl:2.7603 rb:1.0786 dl:1383-1393 gd:1 +ttp: b679/782 bl:2.8498 bb:1.0857 rl:2.7631 rb:1.0788 dl:1368-1374 gd:1 +ttp: b668/782 bl:2.7964 bb:1.0599 rl:2.7641 rb:1.0782 dl:1295-1301 gd:1 +ttp: b656/782 bl:2.7432 bb:1.0356 rl:2.7635 rb:1.0770 dl:1220-1227 gd:1 +ttp: b649/782 bl:2.8005 bb:1.0562 rl:2.7645 rb:1.0765 dl:1183-1188 gd:1 +ttp: b644/782 bl:2.7308 bb:1.0304 rl:2.7637 rb:1.0753 dl:1155-1160 gd:1 +ttp: b633/782 bl:2.8193 bb:1.1000 rl:2.7649 rb:1.0759 dl:1101-1105 gd:1 +ttp: b630/782 bl:2.8226 bb:1.0570 rl:2.7662 rb:1.0755 dl:1087-1092 gd:1 +ttp: b622/782 bl:2.8382 bb:1.0740 rl:2.7676 rb:1.0754 dl:1050-1055 gd:1 +ttp: b615/782 bl:2.8311 bb:1.0630 rl:2.7689 rb:1.0752 dl:1020-1023 gd:1 +ttp: b605/782 bl:2.7388 bb:1.0565 rl:2.7683 rb:1.0748 dl:978-982 gd:1 +ttp: b597/782 bl:2.7680 bb:1.0393 rl:2.7683 rb:1.0742 dl:947-950 gd:1 +ttp: b589/782 bl:2.7453 bb:1.0509 rl:2.7679 rb:1.0738 dl:921-924 gd:1 +ttp: b578/782 bl:2.7942 bb:1.0646 rl:2.7683 rb:1.0737 dl:884-887 gd:1 +ttp: b539/782 bl:2.7195 bb:1.0412 rl:2.7677 rb:1.0732 dl:769-771 gd:1 +ttp: b533/782 bl:2.7623 bb:1.0318 rl:2.7676 rb:1.0726 dl:754-757 gd:1 +ttp: b527/782 bl:2.7312 bb:1.0378 rl:2.7672 rb:1.0722 dl:739-742 gd:1 +ttp: b520/782 bl:2.7857 bb:1.0557 rl:2.7674 rb:1.0720 dl:723-725 gd:1 +ttp: b513/782 bl:2.7277 bb:1.0096 rl:2.7669 rb:1.0712 dl:705-707 gd:1 +ttp: b505/782 bl:2.7691 bb:1.0580 rl:2.7669 rb:1.0711 dl:686-688 gd:1 +ttp: b497/782 bl:2.8345 bb:1.0808 rl:2.7677 rb:1.0712 dl:668-671 gd:1 +ttp: b489/782 bl:2.7911 bb:1.0792 rl:2.7679 rb:1.0713 dl:651-653 gd:1 +ttp: b481/782 bl:2.7928 bb:1.0980 rl:2.7682 rb:1.0715 dl:635-637 gd:1 +ttp: b473/782 bl:2.8314 bb:1.0773 rl:2.7688 rb:1.0716 dl:618-620 gd:1 +ttp: b465/782 bl:2.8056 bb:1.0583 rl:2.7691 rb:1.0715 dl:602-604 gd:1 +ttp: b457/782 bl:2.7556 bb:1.0463 rl:2.7690 rb:1.0712 dl:587-589 gd:1 +ttp: b449/782 bl:2.7947 bb:1.0519 rl:2.7692 rb:1.0710 dl:573-575 gd:1 +ttp: b441/782 bl:2.7079 bb:1.0424 rl:2.7687 rb:1.0708 dl:559-560 gd:1 +ttp: b433/782 bl:2.7688 bb:1.0627 rl:2.7687 rb:1.0707 dl:544-545 gd:1 +ttp: b425/782 bl:2.7502 bb:1.0463 rl:2.7686 rb:1.0705 dl:530-532 gd:1 +ttp: b413/782 bl:2.6408 bb:0.9957 rl:2.7676 rb:1.0699 dl:510-511 gd:1 +ttp: b405/782 bl:2.8201 bb:1.0659 rl:2.7680 rb:1.0699 dl:497-498 gd:1 +ttp: b396/782 bl:2.7567 bb:1.0549 rl:2.7679 rb:1.0698 dl:482-484 gd:1 +ttp: b387/782 bl:2.8322 bb:1.0721 rl:2.7684 rb:1.0698 dl:468-470 gd:1 +ttp: b380/782 bl:2.8361 bb:1.0739 rl:2.7688 rb:1.0699 dl:459-460 gd:1 +ttp: b371/782 bl:2.7978 bb:1.0694 rl:2.7690 rb:1.0698 dl:446-447 gd:1 +ttp: b363/782 bl:2.7365 bb:1.0912 rl:2.7688 rb:1.0700 dl:434-436 gd:1 +ttp: b355/782 bl:2.7023 bb:1.0648 rl:2.7684 rb:1.0699 dl:423-424 gd:1 +ttp: b346/782 bl:2.8384 bb:1.0832 rl:2.7688 rb:1.0700 dl:412-413 gd:1 +ttp: b338/782 bl:2.8479 bb:1.1108 rl:2.7692 rb:1.0703 dl:400-402 gd:1 +ttp: b330/782 bl:2.8601 bb:1.0904 rl:2.7697 rb:1.0704 dl:390-392 gd:1 +ttp: b322/782 bl:2.7493 bb:1.0743 rl:2.7696 rb:1.0704 dl:380-381 gd:1 +ttp: b314/782 bl:2.7943 bb:1.0621 rl:2.7698 rb:1.0703 dl:369-370 gd:1 +ttp: b306/782 bl:2.8764 bb:1.1380 rl:2.7703 rb:1.0707 dl:359-361 gd:1 +ttp: b298/782 bl:2.8387 bb:1.0988 rl:2.7706 rb:1.0708 dl:349-351 gd:1 +ttp: b290/782 bl:2.8577 bb:1.0827 rl:2.7710 rb:1.0709 dl:340-341 gd:1 +ttp: b282/782 bl:2.8096 bb:1.1188 rl:2.7712 rb:1.0711 dl:331-332 gd:1 +ttp: b274/782 bl:2.8057 bb:1.0893 rl:2.7713 rb:1.0711 dl:322-323 gd:1 +ttp: b266/782 bl:2.8404 bb:1.0923 rl:2.7716 rb:1.0712 dl:313-314 gd:1 +ttp: b258/782 bl:2.9483 bb:1.1626 rl:2.7724 rb:1.0716 dl:304-305 gd:1 +ttp: b250/782 bl:2.8636 bb:1.1379 rl:2.7727 rb:1.0719 dl:295-296 gd:1 +ttp: b242/782 bl:2.8942 bb:1.1065 rl:2.7732 rb:1.0720 dl:287-288 gd:1 +ttp: b239/782 bl:2.8872 bb:1.1323 rl:2.7736 rb:1.0722 dl:284-285 gd:1 +ttp: b231/782 bl:2.8147 bb:1.0978 rl:2.7738 rb:1.0723 dl:276-277 gd:1 +ttp: b223/782 bl:2.8259 bb:1.0879 rl:2.7739 rb:1.0724 dl:268-269 gd:1 +ttp: b215/782 bl:2.8464 bb:1.1421 rl:2.7742 rb:1.0726 dl:260-261 gd:1 +ttp: b208/782 bl:2.8235 bb:1.1148 rl:2.7744 rb:1.0727 dl:254-254 gd:1 +ttp: b198/782 bl:2.9696 bb:1.1485 rl:2.7750 rb:1.0730 dl:245-246 gd:1 +ttp: b190/782 bl:2.8805 bb:1.0950 rl:2.7753 rb:1.0730 dl:237-238 gd:1 +ttp: b182/782 bl:2.8483 bb:1.1331 rl:2.7755 rb:1.0732 dl:230-231 gd:1 +ttp: b173/782 bl:2.9637 bb:1.1521 rl:2.7761 rb:1.0734 dl:223-224 gd:1 +ttp: b167/782 bl:2.9598 bb:1.1831 rl:2.7766 rb:1.0737 dl:218-218 gd:1 +ttp: b157/782 bl:2.8216 bb:1.1121 rl:2.7767 rb:1.0738 dl:209-210 gd:1 +ttp: b148/782 bl:2.9736 bb:1.1557 rl:2.7772 rb:1.0741 dl:202-203 gd:1 +ttp: b138/782 bl:2.9154 bb:1.1605 rl:2.7776 rb:1.0743 dl:194-195 gd:1 +ttp: b132/782 bl:2.9419 bb:1.1322 rl:2.7780 rb:1.0744 dl:189-189 gd:1 +ttp: b123/782 bl:2.9439 bb:1.1762 rl:2.7783 rb:1.0746 dl:182-183 gd:1 +ttp: b115/782 bl:2.8615 bb:1.1547 rl:2.7785 rb:1.0748 dl:176-177 gd:1 +ttp: b107/782 bl:2.9248 bb:1.1481 rl:2.7788 rb:1.0750 dl:171-171 gd:1 +ttp: b99/782 bl:2.9770 bb:1.1837 rl:2.7793 rb:1.0752 dl:164-165 gd:1 +ttp: b89/782 bl:3.0022 bb:1.1973 rl:2.7797 rb:1.0754 dl:157-158 gd:1 +ttp: b82/782 bl:2.9801 bb:1.1996 rl:2.7801 rb:1.0756 dl:151-152 gd:1 +ttp: b74/782 bl:3.1294 bb:1.2791 rl:2.7807 rb:1.0760 dl:145-146 gd:1 +ttp: b66/782 bl:3.0976 bb:1.2688 rl:2.7813 rb:1.0763 dl:139-140 gd:1 +ttp: b59/782 bl:3.0511 bb:1.1918 rl:2.7817 rb:1.0765 dl:134-134 gd:1 +ttp: b50/782 bl:2.9708 bb:1.2195 rl:2.7820 rb:1.0767 dl:126-127 gd:1 +ttp: b42/782 bl:3.1147 bb:1.2469 rl:2.7825 rb:1.0770 dl:120-121 gd:1 +ttp: b33/782 bl:3.1028 bb:1.2147 rl:2.7830 rb:1.0772 dl:113-114 gd:1 +ttp: b26/782 bl:3.0867 bb:1.2584 rl:2.7834 rb:1.0774 dl:107-107 gd:1 +ttp: b16/782 bl:3.0547 bb:1.2179 rl:2.7837 rb:1.0775 dl:97-98 gd:1 +ttp: b5/782 bl:3.3140 bb:1.2927 rl:2.7842 rb:1.0778 dl:80-82 gd:1 +quantized_ttt_phased val_loss:2.76880346 val_bpb:1.07189010 eval_time:455666ms +total_eval_time:455.7s +[Wed Apr 22 00:41:27 UTC 2026] Experiment pure_ttt_sp8192_dxa144_seed314_8gpu complete +Final BPB: diagnostic pre-quantization post-ema val_loss:2.76935172 val_bpb:1.07206780 eval_time:6637ms +diagnostic quantized val_loss:2.80162128 val_bpb:1.08455995 eval_time:57889ms +quantized_ttt_phased val_loss:2.76880346 val_bpb:1.07189010 eval_time:455666ms diff --git a/records/track_10min_16mb/2026-04-22_AlphaLoRA144_WarmStart_WD1_1.07209/train_seed42.log b/records/track_10min_16mb/2026-04-22_AlphaLoRA144_WarmStart_WD1_1.07209/train_seed42.log new file mode 100644 index 0000000000..de6fbe48b6 --- /dev/null +++ b/records/track_10min_16mb/2026-04-22_AlphaLoRA144_WarmStart_WD1_1.07209/train_seed42.log @@ -0,0 +1,756 @@ +[Tue Apr 21 23:38:26 UTC 2026] Starting SP8192 dexhunter experiment: pure_ttt_sp8192_dxa144_seed42_8gpu +Extra args: SEED=42 PHASED_TTT_ENABLED=1 PHASED_TTT_PREFIX_DOCS=2000 PHASED_TTT_NUM_PHASES=3 MLP_CLIP_SIGMAS=12.0 ATTN_CLIP_SIGMAS=13.0 EMBED_BITS=7 EMBED_CLIP_SIGMAS=15.0 MATRIX_LR=0.026 GPTQ_RESERVE_SECONDS=4 GPTQ_CALIBRATION_BATCHES=16 TTT_LORA_RANK=128 TTT_LORA_ALPHA=144 TTT_WARM_START_A=1 TTT_WEIGHT_DECAY=1.0 +torch: 2.9.1+cu128 +FA3 interface: OK +[Tue Apr 21 23:38:30 UTC 2026] GPU check: +0, 1 MiB, 0 % +1, 1 MiB, 0 % +2, 1 MiB, 0 % +3, 1 MiB, 0 % +4, 1 MiB, 0 % +5, 1 MiB, 0 % +6, 1 MiB, 0 % +7, 1 MiB, 0 % +W0421 23:38:31.760000 112 torch/distributed/run.py:803] +W0421 23:38:31.760000 112 torch/distributed/run.py:803] ***************************************** +W0421 23:38:31.760000 112 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0421 23:38:31.760000 112 torch/distributed/run.py:803] ***************************************** +Hyperparameters: + adam_eps: 1e-08 + adam_wd: 0.02 + artifact_dir: + attn_clip_sigmas: 13.0 + beta1: 0.9 + beta2: 0.95 + compressor: brotli + data_dir: DATA_DIR + datasets_dir: DATA_DIR/datasets/fineweb10B_sp8192 + distributed: True + ema_decay: 0.9965 + embed_bits: 7 + embed_clip_sigmas: 15.0 + embed_lr: 0.6 + embed_wd: 0.085 + enable_looping_at: 0.35 + eval_seq_len: 2048 + eval_stride: 64 + global_ttt_batch_seqs: 32 + global_ttt_chunk_tokens: 32768 + global_ttt_epochs: 1 + global_ttt_grad_clip: 1.0 + global_ttt_lr: 0.001 + global_ttt_momentum: 0.9 + global_ttt_respect_doc_boundaries: True + global_ttt_warmup_chunks: 0 + global_ttt_warmup_start_lr: 0.0 + gptq_calibration_batches: 16 + gptq_reserve_seconds: 4.0 + grad_accum_steps: 1 + grad_clip_norm: 0.3 + is_main_process: True + iterations: 20000 + ln_scale: True + local_rank: 0 + logfile: logs/8ab05557-959a-45dc-9fdd-71eb83b0cd2b.txt + logit_softcap: 30.0 + loop_end: 5 + loop_start: 3 + matrix_bits: 6 + matrix_clip_sigmas: 12.85 + matrix_lr: 0.026 + max_wallclock_seconds: 600.0 + min_lr: 0.0 + mlp_clip_sigmas: 12.0 + mlp_mult: 4.0 + model_dim: 512 + model_path: final_model.pt + muon_backend_steps: 5 + muon_momentum: 0.97 + muon_momentum_warmup_start: 0.92 + muon_momentum_warmup_steps: 1500 + muon_row_normalize: True + muon_wd: 0.095 + num_heads: 8 + num_kv_heads: 4 + num_layers: 11 + num_loops: 2 + parallel_final_lane: mean + parallel_start_layer: 8 + phased_ttt_enabled: True + phased_ttt_num_phases: 3 + phased_ttt_prefix_docs: 2000 + qk_gain_init: 5.0 + quantized_model_path: final_model.int6.ptz + rank: 0 + rope_base: 10000.0 + rope_dims: 16 + rope_train_seq_len: 2048 + rope_yarn: False + run_id: 8ab05557-959a-45dc-9fdd-71eb83b0cd2b + scalar_lr: 0.02 + seed: 42 + skip_gates_enabled: True + sliding_window_enabled: False + tie_embeddings: True + tied_embed_init_std: 0.005 + tied_embed_lr: 0.03 + tokenizer_path: DATA_DIR/tokenizers/fineweb_8192_bpe.model + train_batch_tokens: 786432 + train_files: DATA_DIR/datasets/fineweb10B_sp8192/fineweb_train_*.bin + train_log_every: 500 + train_seq_len: 2048 + ttt_batch_size: 64 + ttt_beta1: 0.0 + ttt_beta2: 0.999 + ttt_chunk_size: 48 + ttt_enabled: True + ttt_eval_batches: + ttt_eval_seq_len: 2048 + ttt_grad_steps: 1 + ttt_k_lora: True + ttt_lora_lr: 0.0001 + ttt_lora_rank: 128 + ttt_mlp_lora: True + ttt_o_lora: True + ttt_optimizer: adam + ttt_weight_decay: 1.0 + val_batch_tokens: 524288 + val_doc_fraction: 1.0 + val_files: DATA_DIR/datasets/fineweb10B_sp8192/fineweb_val_*.bin + val_loss_every: 4000 + vocab_size: 8192 + warmdown_frac: 0.75 + warmup_steps: 20 + world_size: 8 + xsa_last_n: 11 +train_shards: 80 +val_tokens: 40540160 +model_params:35944602 +gptq:reserving 4s, effective=596000ms +warmup_cu_buckets:64,128,192,256 iters_each:3 +warmup_step: 1/20 +warmup_step: 2/20 +warmup_step: 3/20 +warmup_step: 4/20 +warmup_step: 5/20 +warmup_step: 6/20 +warmup_step: 10/20 +warmup_step: 20/20 +loop_warmup:enabled encoder:[0, 1, 2, 3, 4, 5, 3, 4] decoder:[5, 3, 4, 5, 6, 7, 8, 9, 10] +loop_warmup_step: 1/20 +loop_warmup_step: 2/20 +loop_warmup_step: 3/20 +loop_warmup_step: 4/20 +loop_warmup_step: 5/20 +loop_warmup_step: 6/20 +loop_warmup_step: 10/20 +loop_warmup_step: 20/20 +0/20000 val_loss: 9.0078 val_bpb: 3.4871 +1/20000 train_loss: 9.0072 train_time: 0.0m tok/s: 16446920 +2/20000 train_loss: 12.3320 train_time: 0.0m tok/s: 12166569 +3/20000 train_loss: 11.2835 train_time: 0.0m tok/s: 10366412 +4/20000 train_loss: 9.6141 train_time: 0.0m tok/s: 9160410 +5/20000 train_loss: 8.2163 train_time: 0.0m tok/s: 8647332 +500/20000 train_loss: 3.2683 train_time: 0.8m tok/s: 8300234 +1000/20000 train_loss: 3.0319 train_time: 1.6m tok/s: 8255359 +1500/20000 train_loss: 3.0367 train_time: 2.4m tok/s: 8251095 +2000/20000 train_loss: 2.9875 train_time: 3.2m tok/s: 8251662 +layer_loop:enabled step:2188 frac:0.350 encoder:[0, 1, 2, 3, 4, 5, 3, 4] decoder:[5, 3, 4, 5, 6, 7, 8, 9, 10] +2500/20000 train_loss: 3.0772 train_time: 4.2m tok/s: 7802589 +3000/20000 train_loss: 2.9166 train_time: 5.4m tok/s: 7338963 +3500/20000 train_loss: 2.9853 train_time: 6.5m tok/s: 7041554 +4000/20000 train_loss: 2.9144 train_time: 7.7m tok/s: 6832572 +4000/20000 val_loss: 2.8919 val_bpb: 1.1195 +4500/20000 train_loss: 2.8683 train_time: 8.8m tok/s: 6678830 +4952/20000 val_loss: 2.7729 val_bpb: 1.0735 +stopping_early: wallclock_cap train_time: 596008ms step: 4952/20000 +peak memory allocated: 40029 MiB reserved: 44036 MiB +ema:applying EMA weights +diagnostic pre-quantization post-ema val_loss:2.77158098 val_bpb:1.07293079 eval_time:6634ms +Serialized model: 135409136 bytes +Code size (uncompressed): 122656 bytes +Code size (compressed): 27680 bytes +GPTQ:collecting Hessians from calibration data... +GPTQ:collected 67 Hessians in 6.3s +Quantized weights: + gptq (int6): blocks.attn.c_k.weight, blocks.attn.c_q.weight, blocks.attn.c_v.weight, blocks.attn.proj.weight, blocks.mlp.fc.weight, blocks.mlp.proj.weight + gptq (int7): tok_emb.weight + passthrough (float16): blocks.attn.q_gain, blocks.attn_scale, blocks.mlp_scale, blocks.resid_mix, parallel_post_lambdas, parallel_resid_lambdas, skip_gates, skip_weights +Serialized model quantized+brotli: 15902515 bytes +Total submission size quantized+brotli: 15930195 bytes +diagnostic quantized val_loss:2.80209025 val_bpb:1.08474150 eval_time:55029ms +ttt_lora:warming up compile (random tokens, no val data) +ttt_lora:compile warmup done (160.8s) + +beginning TTT eval timer +ttt_phased: total_docs:50000 prefix_docs:2000 suffix_docs:48000 num_phases:3 boundaries:[666, 1333, 2000] +ttp: b781/782 bl:2.5675 bb:1.0599 rl:2.5675 rb:1.0599 dl:14510-25988 gd:0 +ttpp: phase:1/3 pd:1104 gd:666 t:198.7s +tttg: c1/95 lr:0.001000 t:1.7s +tttg: c2/95 lr:0.001000 t:1.8s +tttg: c3/95 lr:0.000999 t:1.8s +tttg: c4/95 lr:0.000997 t:1.9s +tttg: c5/95 lr:0.000996 t:2.0s +tttg: c6/95 lr:0.000993 t:2.1s +tttg: c7/95 lr:0.000990 t:2.2s +tttg: c8/95 lr:0.000986 t:2.3s +tttg: c9/95 lr:0.000982 t:2.4s +tttg: c10/95 lr:0.000978 t:2.5s +tttg: c11/95 lr:0.000972 t:2.6s +tttg: c12/95 lr:0.000967 t:2.7s +tttg: c13/95 lr:0.000960 t:2.8s +tttg: c14/95 lr:0.000954 t:2.8s +tttg: c15/95 lr:0.000946 t:2.9s +tttg: c16/95 lr:0.000938 t:3.1s +tttg: c17/95 lr:0.000930 t:3.2s +tttg: c18/95 lr:0.000921 t:3.3s +tttg: c19/95 lr:0.000912 t:3.3s +tttg: c20/95 lr:0.000903 t:3.4s +tttg: c21/95 lr:0.000892 t:3.6s +tttg: c22/95 lr:0.000882 t:3.7s +tttg: c23/95 lr:0.000871 t:3.7s +tttg: c24/95 lr:0.000859 t:3.8s +tttg: c25/95 lr:0.000848 t:3.9s +tttg: c26/95 lr:0.000835 t:4.0s +tttg: c27/95 lr:0.000823 t:4.1s +tttg: c28/95 lr:0.000810 t:4.2s +tttg: c29/95 lr:0.000797 t:4.3s +tttg: c30/95 lr:0.000783 t:4.3s +tttg: c31/95 lr:0.000769 t:4.4s +tttg: c32/95 lr:0.000755 t:4.5s +tttg: c33/95 lr:0.000740 t:4.6s +tttg: c34/95 lr:0.000726 t:4.7s +tttg: c35/95 lr:0.000710 t:4.8s +tttg: c36/95 lr:0.000695 t:4.9s +tttg: c37/95 lr:0.000680 t:4.9s +tttg: c38/95 lr:0.000664 t:5.0s +tttg: c39/95 lr:0.000648 t:5.1s +tttg: c40/95 lr:0.000632 t:5.2s +tttg: c41/95 lr:0.000616 t:5.3s +tttg: c42/95 lr:0.000600 t:5.4s +tttg: c43/95 lr:0.000583 t:5.5s +tttg: c44/95 lr:0.000567 t:5.5s +tttg: c45/95 lr:0.000550 t:5.6s +tttg: c46/95 lr:0.000533 t:5.7s +tttg: c47/95 lr:0.000517 t:5.8s +tttg: c48/95 lr:0.000500 t:5.9s +tttg: c49/95 lr:0.000483 t:6.0s +tttg: c50/95 lr:0.000467 t:6.1s +tttg: c51/95 lr:0.000450 t:6.2s +tttg: c52/95 lr:0.000433 t:6.3s +tttg: c53/95 lr:0.000417 t:6.4s +tttg: c54/95 lr:0.000400 t:6.5s +tttg: c55/95 lr:0.000384 t:6.5s +tttg: c56/95 lr:0.000368 t:6.6s +tttg: c57/95 lr:0.000352 t:6.7s +tttg: c58/95 lr:0.000336 t:6.8s +tttg: c59/95 lr:0.000320 t:6.9s +tttg: c60/95 lr:0.000305 t:7.0s +tttg: c61/95 lr:0.000290 t:7.1s +tttg: c62/95 lr:0.000274 t:7.2s +tttg: c63/95 lr:0.000260 t:7.2s +tttg: c64/95 lr:0.000245 t:7.3s +tttg: c65/95 lr:0.000231 t:7.4s +tttg: c66/95 lr:0.000217 t:7.5s +tttg: c67/95 lr:0.000203 t:7.6s +tttg: c68/95 lr:0.000190 t:7.7s +tttg: c69/95 lr:0.000177 t:7.8s +tttg: c70/95 lr:0.000165 t:7.8s +tttg: c71/95 lr:0.000152 t:7.9s +tttg: c72/95 lr:0.000141 t:8.1s +tttg: c73/95 lr:0.000129 t:8.1s +tttg: c74/95 lr:0.000118 t:8.2s +tttg: c75/95 lr:0.000108 t:8.3s +tttg: c76/95 lr:0.000097 t:8.4s +tttg: c77/95 lr:0.000088 t:8.5s +tttg: c78/95 lr:0.000079 t:8.6s +tttg: c79/95 lr:0.000070 t:8.7s +tttg: c80/95 lr:0.000062 t:8.8s +tttg: c81/95 lr:0.000054 t:8.8s +tttg: c82/95 lr:0.000046 t:8.9s +tttg: c83/95 lr:0.000040 t:9.0s +tttg: c84/95 lr:0.000033 t:9.1s +tttg: c85/95 lr:0.000028 t:9.2s +tttg: c86/95 lr:0.000022 t:9.3s +tttg: c87/95 lr:0.000018 t:9.4s +tttg: c88/95 lr:0.000014 t:9.4s +tttg: c89/95 lr:0.000010 t:9.5s +tttg: c90/95 lr:0.000007 t:9.6s +tttg: c91/95 lr:0.000004 t:9.7s +tttg: c92/95 lr:0.000003 t:9.8s +tttg: c93/95 lr:0.000001 t:9.9s +tttg: c94/95 lr:0.000000 t:10.0s +ttpr: phase:1/3 t:211.3s +ttp: b757/782 bl:2.6422 bb:1.0212 rl:2.5780 rb:1.0541 dl:3033-3108 gd:0 +ttpp: phase:2/3 pd:1808 gd:1333 t:320.5s +tttg: c1/158 lr:0.001000 t:0.1s +tttg: c2/158 lr:0.001000 t:0.2s +tttg: c3/158 lr:0.001000 t:0.3s +tttg: c4/158 lr:0.000999 t:0.4s +tttg: c5/158 lr:0.000998 t:0.5s +tttg: c6/158 lr:0.000997 t:0.6s +tttg: c7/158 lr:0.000996 t:0.6s +tttg: c8/158 lr:0.000995 t:0.7s +tttg: c9/158 lr:0.000994 t:0.8s +tttg: c10/158 lr:0.000992 t:0.9s +tttg: c11/158 lr:0.000990 t:1.0s +tttg: c12/158 lr:0.000988 t:1.1s +tttg: c13/158 lr:0.000986 t:1.2s +tttg: c14/158 lr:0.000983 t:1.3s +tttg: c15/158 lr:0.000981 t:1.4s +tttg: c16/158 lr:0.000978 t:1.5s +tttg: c17/158 lr:0.000975 t:1.6s +tttg: c18/158 lr:0.000971 t:1.7s +tttg: c19/158 lr:0.000968 t:1.7s +tttg: c20/158 lr:0.000964 t:1.8s +tttg: c21/158 lr:0.000960 t:1.9s +tttg: c22/158 lr:0.000957 t:2.0s +tttg: c23/158 lr:0.000952 t:2.1s +tttg: c24/158 lr:0.000948 t:2.2s +tttg: c25/158 lr:0.000943 t:2.3s +tttg: c26/158 lr:0.000939 t:2.3s +tttg: c27/158 lr:0.000934 t:2.4s +tttg: c28/158 lr:0.000929 t:2.5s +tttg: c29/158 lr:0.000924 t:2.6s +tttg: c30/158 lr:0.000918 t:2.7s +tttg: c31/158 lr:0.000913 t:2.7s +tttg: c32/158 lr:0.000907 t:2.9s +tttg: c33/158 lr:0.000901 t:3.0s +tttg: c34/158 lr:0.000895 t:3.0s +tttg: c35/158 lr:0.000889 t:3.1s +tttg: c36/158 lr:0.000882 t:3.2s +tttg: c37/158 lr:0.000876 t:3.3s +tttg: c38/158 lr:0.000869 t:3.4s +tttg: c39/158 lr:0.000862 t:3.5s +tttg: c40/158 lr:0.000855 t:3.6s +tttg: c41/158 lr:0.000848 t:3.6s +tttg: c42/158 lr:0.000841 t:3.7s +tttg: c43/158 lr:0.000834 t:3.8s +tttg: c44/158 lr:0.000826 t:3.9s +tttg: c45/158 lr:0.000818 t:4.0s +tttg: c46/158 lr:0.000811 t:4.1s +tttg: c47/158 lr:0.000803 t:4.2s +tttg: c48/158 lr:0.000795 t:4.3s +tttg: c49/158 lr:0.000787 t:4.4s +tttg: c50/158 lr:0.000778 t:4.5s +tttg: c51/158 lr:0.000770 t:4.5s +tttg: c52/158 lr:0.000761 t:4.6s +tttg: c53/158 lr:0.000753 t:4.7s +tttg: c54/158 lr:0.000744 t:4.8s +tttg: c55/158 lr:0.000735 t:4.9s +tttg: c56/158 lr:0.000727 t:5.0s +tttg: c57/158 lr:0.000718 t:5.1s +tttg: c58/158 lr:0.000709 t:5.1s +tttg: c59/158 lr:0.000699 t:5.2s +tttg: c60/158 lr:0.000690 t:5.3s +tttg: c61/158 lr:0.000681 t:5.4s +tttg: c62/158 lr:0.000672 t:5.5s +tttg: c63/158 lr:0.000662 t:5.5s +tttg: c64/158 lr:0.000653 t:5.7s +tttg: c65/158 lr:0.000643 t:5.8s +tttg: c66/158 lr:0.000633 t:5.8s +tttg: c67/158 lr:0.000624 t:5.9s +tttg: c68/158 lr:0.000614 t:6.0s +tttg: c69/158 lr:0.000604 t:6.1s +tttg: c70/158 lr:0.000594 t:6.1s +tttg: c71/158 lr:0.000585 t:6.2s +tttg: c72/158 lr:0.000575 t:6.3s +tttg: c73/158 lr:0.000565 t:6.4s +tttg: c74/158 lr:0.000555 t:6.5s +tttg: c75/158 lr:0.000545 t:6.5s +tttg: c76/158 lr:0.000535 t:6.6s +tttg: c77/158 lr:0.000525 t:6.7s +tttg: c78/158 lr:0.000515 t:6.8s +tttg: c79/158 lr:0.000505 t:6.9s +tttg: c80/158 lr:0.000495 t:7.0s +tttg: c81/158 lr:0.000485 t:7.1s +tttg: c82/158 lr:0.000475 t:7.1s +tttg: c83/158 lr:0.000465 t:7.2s +tttg: c84/158 lr:0.000455 t:7.3s +tttg: c85/158 lr:0.000445 t:7.4s +tttg: c86/158 lr:0.000435 t:7.4s +tttg: c87/158 lr:0.000425 t:7.5s +tttg: c88/158 lr:0.000415 t:7.6s +tttg: c89/158 lr:0.000406 t:7.7s +tttg: c90/158 lr:0.000396 t:7.7s +tttg: c91/158 lr:0.000386 t:7.8s +tttg: c92/158 lr:0.000376 t:7.9s +tttg: c93/158 lr:0.000367 t:8.0s +tttg: c94/158 lr:0.000357 t:8.1s +tttg: c95/158 lr:0.000347 t:8.1s +tttg: c96/158 lr:0.000338 t:8.2s +tttg: c97/158 lr:0.000328 t:8.3s +tttg: c98/158 lr:0.000319 t:8.4s +tttg: c99/158 lr:0.000310 t:8.5s +tttg: c100/158 lr:0.000301 t:8.6s +tttg: c101/158 lr:0.000291 t:8.7s +tttg: c102/158 lr:0.000282 t:8.7s +tttg: c103/158 lr:0.000273 t:8.8s +tttg: c104/158 lr:0.000265 t:8.9s +tttg: c105/158 lr:0.000256 t:9.0s +tttg: c106/158 lr:0.000247 t:9.1s +tttg: c107/158 lr:0.000239 t:9.2s +tttg: c108/158 lr:0.000230 t:9.3s +tttg: c109/158 lr:0.000222 t:9.3s +tttg: c110/158 lr:0.000213 t:9.4s +tttg: c111/158 lr:0.000205 t:9.5s +tttg: c112/158 lr:0.000197 t:9.6s +tttg: c113/158 lr:0.000189 t:9.7s +tttg: c114/158 lr:0.000182 t:9.8s +tttg: c115/158 lr:0.000174 t:9.9s +tttg: c116/158 lr:0.000166 t:10.0s +tttg: c117/158 lr:0.000159 t:10.0s +tttg: c118/158 lr:0.000152 t:10.1s +tttg: c119/158 lr:0.000145 t:10.2s +tttg: c120/158 lr:0.000138 t:10.3s +tttg: c121/158 lr:0.000131 t:10.4s +tttg: c122/158 lr:0.000124 t:10.5s +tttg: c123/158 lr:0.000118 t:10.6s +tttg: c124/158 lr:0.000111 t:10.6s +tttg: c125/158 lr:0.000105 t:10.7s +tttg: c126/158 lr:0.000099 t:10.8s +tttg: c127/158 lr:0.000093 t:10.9s +tttg: c128/158 lr:0.000087 t:11.0s +tttg: c129/158 lr:0.000082 t:11.1s +tttg: c130/158 lr:0.000076 t:11.2s +tttg: c131/158 lr:0.000071 t:11.3s +tttg: c132/158 lr:0.000066 t:11.3s +tttg: c133/158 lr:0.000061 t:11.4s +tttg: c134/158 lr:0.000057 t:11.5s +tttg: c135/158 lr:0.000052 t:11.6s +tttg: c136/158 lr:0.000048 t:11.7s +tttg: c137/158 lr:0.000043 t:11.8s +tttg: c138/158 lr:0.000040 t:11.9s +tttg: c139/158 lr:0.000036 t:11.9s +tttg: c140/158 lr:0.000032 t:12.0s +tttg: c141/158 lr:0.000029 t:12.1s +tttg: c142/158 lr:0.000025 t:12.2s +tttg: c143/158 lr:0.000022 t:12.2s +tttg: c144/158 lr:0.000019 t:12.3s +tttg: c145/158 lr:0.000017 t:12.4s +tttg: c146/158 lr:0.000014 t:12.5s +tttg: c147/158 lr:0.000012 t:12.6s +tttg: c148/158 lr:0.000010 t:12.7s +tttg: c149/158 lr:0.000008 t:12.8s +tttg: c150/158 lr:0.000006 t:12.8s +tttg: c151/158 lr:0.000005 t:12.9s +tttg: c152/158 lr:0.000004 t:13.0s +tttg: c153/158 lr:0.000003 t:13.1s +tttg: c154/158 lr:0.000002 t:13.2s +tttg: c155/158 lr:0.000001 t:13.3s +tttg: c156/158 lr:0.000000 t:13.4s +tttg: c157/158 lr:0.000000 t:13.4s +ttpr: phase:2/3 t:336.5s +ttp: b746/782 bl:2.6793 bb:1.0549 rl:2.5883 rb:1.0542 dl:2459-2501 gd:0 +ttp: b744/782 bl:2.6548 bb:1.0576 rl:2.5943 rb:1.0545 dl:2388-2419 gd:0 +ttpp: phase:3/3 pd:2448 gd:2000 t:350.3s +tttg: c1/213 lr:0.001000 t:0.1s +tttg: c2/213 lr:0.001000 t:0.2s +tttg: c3/213 lr:0.001000 t:0.3s +tttg: c4/213 lr:0.001000 t:0.4s +tttg: c5/213 lr:0.000999 t:0.5s +tttg: c6/213 lr:0.000999 t:0.5s +tttg: c7/213 lr:0.000998 t:0.6s +tttg: c8/213 lr:0.000997 t:0.7s +tttg: c9/213 lr:0.000996 t:0.8s +tttg: c10/213 lr:0.000996 t:0.9s +tttg: c11/213 lr:0.000995 t:1.0s +tttg: c12/213 lr:0.000993 t:1.1s +tttg: c13/213 lr:0.000992 t:1.2s +tttg: c14/213 lr:0.000991 t:1.3s +tttg: c15/213 lr:0.000989 t:1.3s +tttg: c16/213 lr:0.000988 t:1.4s +tttg: c17/213 lr:0.000986 t:1.5s +tttg: c18/213 lr:0.000984 t:1.6s +tttg: c19/213 lr:0.000982 t:1.7s +tttg: c20/213 lr:0.000980 t:1.8s +tttg: c21/213 lr:0.000978 t:1.9s +tttg: c22/213 lr:0.000976 t:2.0s +tttg: c23/213 lr:0.000974 t:2.0s +tttg: c24/213 lr:0.000971 t:2.1s +tttg: c25/213 lr:0.000969 t:2.2s +tttg: c26/213 lr:0.000966 t:2.3s +tttg: c27/213 lr:0.000963 t:2.4s +tttg: c28/213 lr:0.000961 t:2.5s +tttg: c29/213 lr:0.000958 t:2.6s +tttg: c30/213 lr:0.000955 t:2.7s +tttg: c31/213 lr:0.000951 t:2.7s +tttg: c32/213 lr:0.000948 t:2.8s +tttg: c33/213 lr:0.000945 t:2.9s +tttg: c34/213 lr:0.000941 t:3.0s +tttg: c35/213 lr:0.000938 t:3.1s +tttg: c36/213 lr:0.000934 t:3.2s +tttg: c37/213 lr:0.000931 t:3.3s +tttg: c38/213 lr:0.000927 t:3.4s +tttg: c39/213 lr:0.000923 t:3.5s +tttg: c40/213 lr:0.000919 t:3.6s +tttg: c41/213 lr:0.000915 t:3.7s +tttg: c42/213 lr:0.000911 t:3.8s +tttg: c43/213 lr:0.000906 t:3.9s +tttg: c44/213 lr:0.000902 t:4.0s +tttg: c45/213 lr:0.000897 t:4.1s +tttg: c46/213 lr:0.000893 t:4.1s +tttg: c47/213 lr:0.000888 t:4.2s +tttg: c48/213 lr:0.000884 t:4.3s +tttg: c49/213 lr:0.000879 t:4.4s +tttg: c50/213 lr:0.000874 t:4.5s +tttg: c51/213 lr:0.000869 t:4.6s +tttg: c52/213 lr:0.000864 t:4.7s +tttg: c53/213 lr:0.000859 t:4.8s +tttg: c54/213 lr:0.000854 t:4.8s +tttg: c55/213 lr:0.000848 t:4.9s +tttg: c56/213 lr:0.000843 t:5.0s +tttg: c57/213 lr:0.000837 t:5.1s +tttg: c58/213 lr:0.000832 t:5.2s +tttg: c59/213 lr:0.000826 t:5.3s +tttg: c60/213 lr:0.000821 t:5.4s +tttg: c61/213 lr:0.000815 t:5.5s +tttg: c62/213 lr:0.000809 t:5.6s +tttg: c63/213 lr:0.000803 t:5.7s +tttg: c64/213 lr:0.000797 t:5.8s +tttg: c65/213 lr:0.000791 t:5.9s +tttg: c66/213 lr:0.000785 t:5.9s +tttg: c67/213 lr:0.000779 t:6.0s +tttg: c68/213 lr:0.000773 t:6.1s +tttg: c69/213 lr:0.000767 t:6.2s +tttg: c70/213 lr:0.000761 t:6.3s +tttg: c71/213 lr:0.000754 t:6.4s +tttg: c72/213 lr:0.000748 t:6.5s +tttg: c73/213 lr:0.000741 t:6.5s +tttg: c74/213 lr:0.000735 t:6.6s +tttg: c75/213 lr:0.000728 t:6.7s +tttg: c76/213 lr:0.000722 t:6.8s +tttg: c77/213 lr:0.000715 t:6.9s +tttg: c78/213 lr:0.000708 t:6.9s +tttg: c79/213 lr:0.000702 t:7.0s +tttg: c80/213 lr:0.000695 t:7.1s +tttg: c81/213 lr:0.000688 t:7.2s +tttg: c82/213 lr:0.000681 t:7.3s +tttg: c83/213 lr:0.000674 t:7.4s +tttg: c84/213 lr:0.000667 t:7.5s +tttg: c85/213 lr:0.000660 t:7.6s +tttg: c86/213 lr:0.000653 t:7.6s +tttg: c87/213 lr:0.000646 t:7.7s +tttg: c88/213 lr:0.000639 t:7.8s +tttg: c89/213 lr:0.000632 t:7.9s +tttg: c90/213 lr:0.000625 t:8.0s +tttg: c91/213 lr:0.000617 t:8.1s +tttg: c92/213 lr:0.000610 t:8.2s +tttg: c93/213 lr:0.000603 t:8.3s +tttg: c94/213 lr:0.000596 t:8.4s +tttg: c95/213 lr:0.000588 t:8.5s +tttg: c96/213 lr:0.000581 t:8.6s +tttg: c97/213 lr:0.000574 t:8.7s +tttg: c98/213 lr:0.000566 t:8.8s +tttg: c99/213 lr:0.000559 t:8.8s +tttg: c100/213 lr:0.000552 t:8.9s +tttg: c101/213 lr:0.000544 t:9.0s +tttg: c102/213 lr:0.000537 t:9.1s +tttg: c103/213 lr:0.000530 t:9.2s +tttg: c104/213 lr:0.000522 t:9.2s +tttg: c105/213 lr:0.000515 t:9.3s +tttg: c106/213 lr:0.000507 t:9.4s +tttg: c107/213 lr:0.000500 t:9.5s +tttg: c108/213 lr:0.000493 t:9.6s +tttg: c109/213 lr:0.000485 t:9.7s +tttg: c110/213 lr:0.000478 t:9.8s +tttg: c111/213 lr:0.000470 t:9.9s +tttg: c112/213 lr:0.000463 t:10.0s +tttg: c113/213 lr:0.000456 t:10.1s +tttg: c114/213 lr:0.000448 t:10.1s +tttg: c115/213 lr:0.000441 t:10.2s +tttg: c116/213 lr:0.000434 t:10.3s +tttg: c117/213 lr:0.000426 t:10.4s +tttg: c118/213 lr:0.000419 t:10.4s +tttg: c119/213 lr:0.000412 t:10.5s +tttg: c120/213 lr:0.000404 t:10.6s +tttg: c121/213 lr:0.000397 t:10.7s +tttg: c122/213 lr:0.000390 t:10.8s +tttg: c123/213 lr:0.000383 t:10.9s +tttg: c124/213 lr:0.000375 t:11.0s +tttg: c125/213 lr:0.000368 t:11.1s +tttg: c126/213 lr:0.000361 t:11.2s +tttg: c127/213 lr:0.000354 t:11.3s +tttg: c128/213 lr:0.000347 t:11.3s +tttg: c129/213 lr:0.000340 t:11.4s +tttg: c130/213 lr:0.000333 t:11.5s +tttg: c131/213 lr:0.000326 t:11.6s +tttg: c132/213 lr:0.000319 t:11.7s +tttg: c133/213 lr:0.000312 t:11.8s +tttg: c134/213 lr:0.000305 t:11.9s +tttg: c135/213 lr:0.000298 t:11.9s +tttg: c136/213 lr:0.000292 t:12.0s +tttg: c137/213 lr:0.000285 t:12.1s +tttg: c138/213 lr:0.000278 t:12.2s +tttg: c139/213 lr:0.000272 t:12.3s +tttg: c140/213 lr:0.000265 t:12.3s +tttg: c141/213 lr:0.000259 t:12.4s +tttg: c142/213 lr:0.000252 t:12.5s +tttg: c143/213 lr:0.000246 t:12.6s +tttg: c144/213 lr:0.000239 t:12.7s +tttg: c145/213 lr:0.000233 t:12.7s +tttg: c146/213 lr:0.000227 t:12.8s +tttg: c147/213 lr:0.000221 t:13.0s +tttg: c148/213 lr:0.000215 t:13.1s +tttg: c149/213 lr:0.000209 t:13.2s +tttg: c150/213 lr:0.000203 t:13.2s +tttg: c151/213 lr:0.000197 t:13.3s +tttg: c152/213 lr:0.000191 t:13.4s +tttg: c153/213 lr:0.000185 t:13.5s +tttg: c154/213 lr:0.000179 t:13.6s +tttg: c155/213 lr:0.000174 t:13.7s +tttg: c156/213 lr:0.000168 t:13.8s +tttg: c157/213 lr:0.000163 t:13.8s +tttg: c158/213 lr:0.000157 t:13.9s +tttg: c159/213 lr:0.000152 t:14.0s +tttg: c160/213 lr:0.000146 t:14.1s +tttg: c161/213 lr:0.000141 t:14.2s +tttg: c162/213 lr:0.000136 t:14.3s +tttg: c163/213 lr:0.000131 t:14.4s +tttg: c164/213 lr:0.000126 t:14.4s +tttg: c165/213 lr:0.000121 t:14.5s +tttg: c166/213 lr:0.000116 t:14.6s +tttg: c167/213 lr:0.000112 t:14.7s +tttg: c168/213 lr:0.000107 t:14.8s +tttg: c169/213 lr:0.000103 t:14.9s +tttg: c170/213 lr:0.000098 t:15.0s +tttg: c171/213 lr:0.000094 t:15.1s +tttg: c172/213 lr:0.000089 t:15.2s +tttg: c173/213 lr:0.000085 t:15.2s +tttg: c174/213 lr:0.000081 t:15.3s +tttg: c175/213 lr:0.000077 t:15.4s +tttg: c176/213 lr:0.000073 t:15.5s +tttg: c177/213 lr:0.000069 t:15.6s +tttg: c178/213 lr:0.000066 t:15.7s +tttg: c179/213 lr:0.000062 t:15.8s +tttg: c180/213 lr:0.000059 t:15.8s +tttg: c181/213 lr:0.000055 t:15.9s +tttg: c182/213 lr:0.000052 t:16.0s +tttg: c183/213 lr:0.000049 t:16.1s +tttg: c184/213 lr:0.000045 t:16.2s +tttg: c185/213 lr:0.000042 t:16.3s +tttg: c186/213 lr:0.000039 t:16.3s +tttg: c187/213 lr:0.000037 t:16.4s +tttg: c188/213 lr:0.000034 t:16.5s +tttg: c189/213 lr:0.000031 t:16.6s +tttg: c190/213 lr:0.000029 t:16.7s +tttg: c191/213 lr:0.000026 t:16.8s +tttg: c192/213 lr:0.000024 t:16.9s +tttg: c193/213 lr:0.000022 t:16.9s +tttg: c194/213 lr:0.000020 t:17.0s +tttg: c195/213 lr:0.000018 t:17.1s +tttg: c196/213 lr:0.000016 t:17.2s +tttg: c197/213 lr:0.000014 t:17.3s +tttg: c198/213 lr:0.000012 t:17.4s +tttg: c199/213 lr:0.000011 t:17.5s +tttg: c200/213 lr:0.000009 t:17.6s +tttg: c201/213 lr:0.000008 t:17.7s +tttg: c202/213 lr:0.000007 t:17.8s +tttg: c203/213 lr:0.000005 t:17.8s +tttg: c204/213 lr:0.000004 t:17.9s +tttg: c205/213 lr:0.000004 t:18.0s +tttg: c206/213 lr:0.000003 t:18.1s +tttg: c207/213 lr:0.000002 t:18.2s +tttg: c208/213 lr:0.000001 t:18.3s +tttg: c209/213 lr:0.000001 t:18.3s +tttg: c210/213 lr:0.000000 t:18.4s +tttg: c211/213 lr:0.000000 t:18.5s +tttg: c212/213 lr:0.000000 t:18.6s +ttpr: phase:3/3 t:371.4s +ttp: b736/782 bl:2.6732 bb:1.0420 rl:2.6001 rb:1.0536 dl:2140-2165 gd:1 +ttp: b734/782 bl:2.7709 bb:1.0566 rl:2.6117 rb:1.0538 dl:2091-2115 gd:1 +ttp: b722/782 bl:2.7684 bb:1.0585 rl:2.6206 rb:1.0541 dl:1846-1861 gd:1 +ttp: b720/782 bl:2.8188 bb:1.0767 rl:2.6310 rb:1.0553 dl:1816-1832 gd:1 +ttp: b706/782 bl:2.7125 bb:1.0427 rl:2.6346 rb:1.0547 dl:1617-1627 gd:1 +ttp: b698/782 bl:2.7781 bb:1.0297 rl:2.6405 rb:1.0536 dl:1534-1543 gd:1 +ttp: b692/782 bl:2.7632 bb:1.0484 rl:2.6451 rb:1.0534 dl:1477-1484 gd:1 +ttp: b688/782 bl:2.7450 bb:1.0472 rl:2.6486 rb:1.0532 dl:1441-1450 gd:1 +ttp: b674/782 bl:2.7846 bb:1.0566 rl:2.6530 rb:1.0533 dl:1334-1341 gd:1 +ttp: b671/782 bl:2.8777 bb:1.1153 rl:2.6598 rb:1.0552 dl:1316-1321 gd:1 +ttp: b659/782 bl:2.7156 bb:1.0226 rl:2.6613 rb:1.0543 dl:1239-1245 gd:1 +ttp: b656/782 bl:2.7456 bb:1.0365 rl:2.6636 rb:1.0538 dl:1220-1227 gd:1 +ttp: b644/782 bl:2.7321 bb:1.0309 rl:2.6653 rb:1.0532 dl:1155-1160 gd:1 +ttp: b634/782 bl:2.6957 bb:1.0407 rl:2.6660 rb:1.0529 dl:1105-1111 gd:1 +ttp: b625/782 bl:2.6652 bb:1.0013 rl:2.6659 rb:1.0517 dl:1064-1068 gd:1 +ttp: b617/782 bl:2.7362 bb:1.0356 rl:2.6674 rb:1.0514 dl:1027-1031 gd:1 +ttp: b609/782 bl:2.7814 bb:1.0558 rl:2.6696 rb:1.0515 dl:994-999 gd:1 +ttp: b601/782 bl:2.7572 bb:1.0596 rl:2.6712 rb:1.0516 dl:963-966 gd:1 +ttp: b594/782 bl:2.8946 bb:1.0993 rl:2.6752 rb:1.0525 dl:937-940 gd:1 +ttp: b586/782 bl:2.7172 bb:1.0110 rl:2.6759 rb:1.0518 dl:911-914 gd:1 +ttp: b577/782 bl:2.7545 bb:1.0419 rl:2.6771 rb:1.0516 dl:880-884 gd:1 +ttp: b571/782 bl:2.7038 bb:1.0314 rl:2.6775 rb:1.0513 dl:862-865 gd:1 +ttp: b568/782 bl:2.7962 bb:1.0547 rl:2.6793 rb:1.0513 dl:852-855 gd:1 +ttp: b561/782 bl:2.7104 bb:1.0629 rl:2.6798 rb:1.0515 dl:831-834 gd:1 +ttp: b554/782 bl:2.7337 bb:1.0290 rl:2.6805 rb:1.0512 dl:809-812 gd:1 +ttp: b551/782 bl:2.8238 bb:1.0644 rl:2.6825 rb:1.0514 dl:801-804 gd:1 +ttp: b543/782 bl:2.7786 bb:1.0431 rl:2.6837 rb:1.0512 dl:779-782 gd:1 +ttp: b533/782 bl:2.7622 bb:1.0317 rl:2.6847 rb:1.0510 dl:754-757 gd:1 +ttp: b525/782 bl:2.7768 bb:1.0683 rl:2.6858 rb:1.0512 dl:735-737 gd:1 +ttp: b517/782 bl:2.7706 bb:1.0487 rl:2.6868 rb:1.0512 dl:715-717 gd:1 +ttp: b513/782 bl:2.7326 bb:1.0114 rl:2.6873 rb:1.0507 dl:705-707 gd:1 +ttp: b505/782 bl:2.7678 bb:1.0575 rl:2.6882 rb:1.0508 dl:686-688 gd:1 +ttp: b497/782 bl:2.8306 bb:1.0793 rl:2.6897 rb:1.0511 dl:668-671 gd:1 +ttp: b490/782 bl:2.8494 bb:1.0888 rl:2.6913 rb:1.0515 dl:653-655 gd:1 +ttp: b482/782 bl:2.7501 bb:1.0793 rl:2.6918 rb:1.0517 dl:637-639 gd:1 +ttp: b475/782 bl:2.7208 bb:1.0201 rl:2.6921 rb:1.0514 dl:622-623 gd:1 +ttp: b471/782 bl:2.8371 bb:1.0690 rl:2.6934 rb:1.0516 dl:614-616 gd:1 +ttp: b462/782 bl:2.8592 bb:1.0652 rl:2.6949 rb:1.0517 dl:597-599 gd:1 +ttp: b453/782 bl:2.7533 bb:1.0566 rl:2.6954 rb:1.0518 dl:580-582 gd:1 +ttp: b445/782 bl:2.7672 bb:1.0640 rl:2.6960 rb:1.0519 dl:566-568 gd:1 +ttp: b437/782 bl:2.8753 bb:1.0610 rl:2.6974 rb:1.0519 dl:551-553 gd:1 +ttp: b429/782 bl:2.7534 bb:1.0805 rl:2.6979 rb:1.0522 dl:537-539 gd:1 +ttp: b421/782 bl:2.7856 bb:1.0508 rl:2.6985 rb:1.0522 dl:523-524 gd:1 +ttp: b413/782 bl:2.6462 bb:0.9977 rl:2.6982 rb:1.0517 dl:510-511 gd:1 +ttp: b404/782 bl:2.7863 bb:1.0692 rl:2.6988 rb:1.0519 dl:495-497 gd:1 +ttp: b396/782 bl:2.7555 bb:1.0544 rl:2.6991 rb:1.0519 dl:482-484 gd:1 +ttp: b388/782 bl:2.7746 bb:1.0646 rl:2.6996 rb:1.0520 dl:470-471 gd:1 +ttp: b380/782 bl:2.8399 bb:1.0753 rl:2.7005 rb:1.0521 dl:459-460 gd:1 +ttp: b372/782 bl:2.8289 bb:1.0665 rl:2.7013 rb:1.0522 dl:447-449 gd:1 +ttp: b364/782 bl:2.7348 bb:1.0667 rl:2.7015 rb:1.0523 dl:436-437 gd:1 +ttp: b356/782 bl:2.6874 bb:1.0442 rl:2.7014 rb:1.0523 dl:424-426 gd:1 +ttp: b348/782 bl:2.8021 bb:1.0648 rl:2.7020 rb:1.0523 dl:414-415 gd:1 +ttp: b340/782 bl:2.8216 bb:1.0914 rl:2.7026 rb:1.0525 dl:403-404 gd:1 +ttp: b332/782 bl:2.8212 bb:1.0954 rl:2.7033 rb:1.0528 dl:393-394 gd:1 +ttp: b324/782 bl:2.7675 bb:1.0556 rl:2.7036 rb:1.0528 dl:382-384 gd:1 +ttp: b316/782 bl:2.7664 bb:1.0879 rl:2.7039 rb:1.0529 dl:371-373 gd:1 +ttp: b308/782 bl:2.7894 bb:1.0837 rl:2.7043 rb:1.0531 dl:362-363 gd:1 +ttp: b300/782 bl:2.8528 bb:1.0874 rl:2.7050 rb:1.0533 dl:352-353 gd:1 +ttp: b292/782 bl:2.7805 bb:1.0774 rl:2.7053 rb:1.0534 dl:342-343 gd:1 +ttp: b284/782 bl:2.8823 bb:1.0866 rl:2.7061 rb:1.0535 dl:333-334 gd:1 +ttp: b276/782 bl:2.8378 bb:1.0999 rl:2.7066 rb:1.0537 dl:324-325 gd:1 +ttp: b267/782 bl:2.8575 bb:1.0957 rl:2.7072 rb:1.0539 dl:314-315 gd:1 +ttp: b256/782 bl:2.8880 bb:1.1321 rl:2.7079 rb:1.0542 dl:301-302 gd:1 +ttp: b247/782 bl:2.7870 bb:1.0768 rl:2.7082 rb:1.0543 dl:292-293 gd:1 +ttp: b238/782 bl:2.8877 bb:1.1455 rl:2.7089 rb:1.0546 dl:283-284 gd:1 +ttp: b229/782 bl:2.8859 bb:1.1352 rl:2.7095 rb:1.0549 dl:274-275 gd:1 +ttp: b221/782 bl:2.8326 bb:1.1368 rl:2.7099 rb:1.0551 dl:266-267 gd:1 +ttp: b213/782 bl:2.9992 bb:1.1702 rl:2.7109 rb:1.0555 dl:258-259 gd:1 +ttp: b205/782 bl:2.8411 bb:1.1085 rl:2.7113 rb:1.0557 dl:251-252 gd:1 +ttp: b197/782 bl:2.8495 bb:1.1237 rl:2.7117 rb:1.0559 dl:244-245 gd:1 +ttp: b188/782 bl:2.9099 bb:1.1527 rl:2.7123 rb:1.0562 dl:236-237 gd:1 +ttp: b180/782 bl:2.9095 bb:1.1346 rl:2.7128 rb:1.0564 dl:229-230 gd:1 +ttp: b172/782 bl:3.0020 bb:1.1806 rl:2.7136 rb:1.0567 dl:222-223 gd:1 +ttp: b162/782 bl:2.9530 bb:1.1458 rl:2.7143 rb:1.0569 dl:213-214 gd:1 +ttp: b154/782 bl:2.9834 bb:1.1548 rl:2.7149 rb:1.0572 dl:207-207 gd:1 +ttp: b142/782 bl:2.9561 bb:1.1589 rl:2.7155 rb:1.0574 dl:197-198 gd:1 +ttp: b134/782 bl:3.0305 bb:1.2120 rl:2.7163 rb:1.0578 dl:190-191 gd:1 +ttp: b125/782 bl:3.0006 bb:1.1891 rl:2.7169 rb:1.0581 dl:184-185 gd:1 +ttp: b118/782 bl:2.9368 bb:1.1474 rl:2.7174 rb:1.0583 dl:178-179 gd:1 +ttp: b109/782 bl:3.0737 bb:1.2113 rl:2.7181 rb:1.0586 dl:172-173 gd:1 +ttp: b97/782 bl:3.0084 bb:1.1753 rl:2.7187 rb:1.0588 dl:163-164 gd:1 +ttp: b88/782 bl:3.0943 bb:1.2048 rl:2.7194 rb:1.0591 dl:156-157 gd:1 +ttp: b81/782 bl:2.9257 bb:1.1636 rl:2.7198 rb:1.0593 dl:151-151 gd:1 +ttp: b71/782 bl:2.9532 bb:1.1522 rl:2.7202 rb:1.0594 dl:143-144 gd:1 +ttp: b63/782 bl:3.0100 bb:1.2140 rl:2.7207 rb:1.0597 dl:137-138 gd:1 +ttp: b55/782 bl:3.0794 bb:1.2367 rl:2.7212 rb:1.0600 dl:130-131 gd:1 +ttp: b46/782 bl:3.1249 bb:1.2219 rl:2.7218 rb:1.0602 dl:123-124 gd:1 +ttp: b34/782 bl:3.0882 bb:1.2502 rl:2.7223 rb:1.0605 dl:114-115 gd:1 +ttp: b27/782 bl:3.0954 bb:1.2357 rl:2.7228 rb:1.0607 dl:107-108 gd:1 +ttp: b17/782 bl:3.1262 bb:1.2392 rl:2.7233 rb:1.0609 dl:98-99 gd:1 +ttp: b6/782 bl:3.2750 bb:1.2778 rl:2.7238 rb:1.0611 dl:82-84 gd:1 +quantized_ttt_phased val_loss:2.77032226 val_bpb:1.07247808 eval_time:456748ms +total_eval_time:456.7s +[Wed Apr 22 00:08:47 UTC 2026] Experiment pure_ttt_sp8192_dxa144_seed42_8gpu complete +Final BPB: diagnostic pre-quantization post-ema val_loss:2.77158098 val_bpb:1.07293079 eval_time:6634ms +diagnostic quantized val_loss:2.80209025 val_bpb:1.08474150 eval_time:55029ms +quantized_ttt_phased val_loss:2.77032226 val_bpb:1.07247808 eval_time:456748ms