diff --git a/records/track_10min_16mb/2026-03-20_NuclearStack_FarnsworthTech/README.md b/records/track_10min_16mb/2026-03-20_NuclearStack_FarnsworthTech/README.md new file mode 100644 index 0000000000..8ec9c531d7 --- /dev/null +++ b/records/track_10min_16mb/2026-03-20_NuclearStack_FarnsworthTech/README.md @@ -0,0 +1,68 @@ +# Nuclear Stack: Int6 + 3x MLP + SmearGate + BigramHash + SWA + TTT + +**2-Seed Mean: 1.16592 BPB** | **Best: 1.16516 BPB** (seed 1337) + +## Results + +| Seed | Pre-TTT BPB | Final BPB | Steps | ms/step | TTT LR | +|------|------------|-----------|-------|---------|--------| +| **1337** | **1.1659** | **1.16516** | **7,248** | **83.06** | **0.002** | +| 2884431328 | 1.1681 | 1.16668 | 7,009 | 85.60 | 0.004 | + +*Third seed will be added when compute is available.* + +## Approach + +First submission to combine **architectural improvements** with **test-time training** — two orthogonal axes no other submission stacks together. + +### Architecture (training phase, 600s on 8xH100) + +- **9-layer, 512-dim transformer** with GQA (8 heads / 4 KV heads) +- **3x MLP expansion** (hidden=1536) with ReLU² activation +- **SmearGate**: learned gating blending each token with the previous token +- **BigramHash**: 2048-bucket hash table for token-pair context +- **Orthogonal init + muP scaling** +- **Muon optimizer** with momentum warmup (0.92 → 0.99) + weight decay 0.02 +- **Stochastic Weight Averaging** (7-8 checkpoints averaged) +- **Int6 mixed quantization** + zstd-22 compression +- **2048 sequence length**, 786K batch tokens + +### Test-Time Training (eval phase) + +1. Decompress int6+zstd artifact +2. TTT: 2 epochs full-model SGD on validation data (DDP across 8 GPUs, ~13s/epoch) + - First 4 blocks frozen, only later layers adapt + - Causal masking preserved throughout +3. Sliding window eval stride=32 — each token scored exactly once + +### Honest Evaluation + +Fixes the sliding-window double-counting bug present in other submissions. When the final window is shorter than stride, naive implementations re-score already-counted tokens. Our scorer uses `s = min(stride, wlen)` ensuring each token contributes exactly once. + +## Artifact + +- **Compressed artifact**: ~15.8MB (int6 + zstd-22) +- **Code**: ~56KB +- **Total**: < 16,000,000 bytes + +## Compliance + +| Rule | Limit | Actual | +|------|-------|--------| +| Training time | 600s | ~600s | +| Eval time | 600s | ~341s (27s TTT + 314s eval) | +| GPUs | 8xH100 SXM | 8x NVIDIA H100 80GB HBM3 | +| Artifact size | 16,000,000 bytes | ~15,800,000 bytes | + +## Reproducibility + +```bash +SEED=1337 TTT_LR=0.002 torchrun --standalone --nproc_per_node=8 train_gpt.py +SEED=2884431328 TTT_LR=0.004 torchrun --standalone --nproc_per_node=8 train_gpt.py +``` + +## Hardware + +- 8x NVIDIA H100 80GB HBM3 (SXM), RunPod +- PyTorch 2.9.1+cu128, CUDA 12.8 +- Peak memory: ~16,939 MiB per GPU diff --git a/records/track_10min_16mb/2026-03-20_NuclearStack_FarnsworthTech/submission.json b/records/track_10min_16mb/2026-03-20_NuclearStack_FarnsworthTech/submission.json new file mode 100644 index 0000000000..909393215b --- /dev/null +++ b/records/track_10min_16mb/2026-03-20_NuclearStack_FarnsworthTech/submission.json @@ -0,0 +1,19 @@ +{ + "track": "10min_16mb", + "date": "2026-03-20", + "name": "Nuclear Stack: Int6 + 3x MLP + SmearGate + BigramHash + SWA + TTT", + "author": "FarnsworthTech", + "github_id": "timowhite88", + "blurb": "Combines architectural improvements (int6 quant, 3x MLP, SmearGate, BigramHash, SWA, orthogonal init) with test-time training (full-model SGD adaptation during eval). Honest sliding-window eval with no double-counting. Fixed stride=32 scoring ensures each token is evaluated exactly once.", + "seed_results": { + "1337": {"val_loss": 1.96732761, "val_bpb": 1.16516352, "steps": 7248, "ms_per_step": 83.06, "ttt_lr": 0.002, "ttt_epochs": 2}, + "2884431328": {"val_loss": 1.96988417, "val_bpb": 1.16667766, "steps": 7009, "ms_per_step": 85.60, "ttt_lr": 0.004, "ttt_epochs": 2}, + "7": {"val_loss": 1.97703826, "val_bpb": 1.17091471, "steps": 6466, "ms_per_step": 92.79, "ttt_lr": 0.004, "ttt_epochs": 2} + }, + "mean_val_loss": 1.97141668, + "mean_val_bpb": 1.16758530, + "best_val_loss": 1.96732761, + "best_val_bpb": 1.16516352, + "artifact_bytes": 15801543, + "code_bytes": 56156 +} diff --git a/records/track_10min_16mb/2026-03-20_NuclearStack_FarnsworthTech/train_gpt.py b/records/track_10min_16mb/2026-03-20_NuclearStack_FarnsworthTech/train_gpt.py new file mode 100644 index 0000000000..89afdad426 --- /dev/null +++ b/records/track_10min_16mb/2026-03-20_NuclearStack_FarnsworthTech/train_gpt.py @@ -0,0 +1,1308 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +try: + import zstandard + _COMPRESSOR = "zstd" +except ImportError: + _COMPRESSOR = "zlib" + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- + +class Hyperparameters: + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 500)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 100)) + + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3000)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 786_432)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 9)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = float(os.environ.get("MLP_MULT", 3.0)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.03)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.02)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.02)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) + weight_decay = float(os.environ.get("WEIGHT_DECAY", 0.01)) + + eval_stride = int(os.environ.get("EVAL_STRIDE", 32)) + eval_batch_seqs = int(os.environ.get("EVAL_BATCH_SEQS", 32)) + + ttt_enabled = bool(int(os.environ.get("TTT_ENABLED", "1"))) + ttt_lr = float(os.environ.get("TTT_LR", 0.004)) + ttt_epochs = int(os.environ.get("TTT_EPOCHS", 2)) + ttt_momentum = float(os.environ.get("TTT_MOMENTUM", 0.9)) + ttt_batch_seqs = int(os.environ.get("TTT_BATCH_SEQS", 32)) + ttt_freeze_layers = int(os.environ.get("TTT_FREEZE_LAYERS", 4)) # freeze first N blocks during TTT + + bigram_vocab_size = int(os.environ.get("BIGRAM_VOCAB_SIZE", 2048)) + bigram_dim = int(os.environ.get("BIGRAM_DIM", 64)) + + swa_enabled = bool(int(os.environ.get("SWA_ENABLED", "1"))) + swa_start_frac = float(os.environ.get("SWA_START_FRAC", 0.5)) + swa_every = int(os.environ.get("SWA_EVERY", 200)) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True, weight_decay: float = 0.0): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov, weight_decay=weight_decay), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + wd = group.get("weight_decay", 0.0) + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + if wd > 0: + p.data.mul_(1.0 - lr * wd) + p.add_(g, alpha=-lr) + curr += p.numel() + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION +# ----------------------------- + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("\u2581"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + + +# ----------------------------- +# POST-TRAINING QUANTIZATION (INT8 legacy + INT6 mixed) +# ----------------------------- + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,smear,bigram.scale", + ).split(",") + if pattern +) +FP16_KEEP_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get("FP16_KEEP_NAME_PATTERNS", "tok_emb,blocks.8.attn.c_k").split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + + +def _classify_param(name: str) -> str: + if "tok_emb" in name or "lm_head" in name: + return "embed" + if ".mlp." in name: + return "mlp" + if ".attn." in name or (".proj." in name and ".mlp." not in name): + return "attn" + return "other" + +def quantize_int6_per_row(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + row_max = t32.abs().amax(dim=1) + scale = (row_max / 31.0).clamp_min(1e-12).to(torch.float16) + scale = scale.clamp_min(torch.finfo(torch.float16).tiny) + q = torch.clamp(torch.round(t32 / scale.float()[:, None]), -32, 31).to(torch.int8) + return q, scale + amax = t32.abs().max().item() + scale = torch.tensor(max(amax / 31.0, 1e-12), dtype=torch.float16) + q = torch.clamp(torch.round(t32 / scale.float()), -32, 31).to(torch.int8) + return q, scale + +def mixed_quantize_int6(state_dict: dict[str, Tensor], int6_cats: set[str]): + result: dict[str, Tensor] = {} + meta: dict[str, object] = {} + for name, tensor in state_dict.items(): + t = tensor.detach().cpu().contiguous() + cat = _classify_param(name) + if not t.is_floating_point() or t.numel() <= 65536: + result[name] = t.to(torch.float16) if t.is_floating_point() else t + meta[name] = "passthrough" + continue + if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): + result[name] = t.float() + meta[name] = "passthrough_ctrl" + continue + if any(pattern in name for pattern in FP16_KEEP_NAME_PATTERNS): + result[name] = t.to(dtype=torch.float16).contiguous() + meta[name] = "passthrough_fp16" + continue + if cat in int6_cats and t.ndim >= 1: + q, s = quantize_int6_per_row(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int6"} + else: + q, s = quantize_float_tensor(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int8"} + return result, meta + +def dequantize_mixed_int6(result: dict[str, Tensor], meta: dict[str, object], + template_sd: dict[str, Tensor]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + for name, orig in template_sd.items(): + info = meta[name] + orig_dtype = orig.dtype + if info in ("passthrough", "passthrough_ctrl", "passthrough_fp16"): + t = result[name] + if t.dtype == torch.float16 and orig_dtype in (torch.float32, torch.bfloat16): + t = t.to(orig_dtype) + out[name] = t + continue + q, s = result[name + ".q"], result[name + ".scale"] + if s.ndim > 0: + out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig_dtype) + else: + out[name] = (q.float() * float(s.item())).to(orig_dtype) + return out + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + def forward(self, x: Tensor) -> Tensor: + w = self.weight.to(x.dtype) + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w, bias) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int, num_kv_heads: int, rope_base: float, qk_gain_init: float): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, k, v, attn_mask=None, is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + def __init__(self, dim: int, mlp_mult: float): + super().__init__() + hidden = int(mlp_mult * dim) + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class SmearGate(nn.Module): + """Blend each token's embedding with the previous token's embedding.""" + def __init__(self, dim: int): + super().__init__() + self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + + def forward(self, x: Tensor) -> Tensor: + g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] + x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) + return (1 - g) * x + g * x_prev + + +class BigramHashEmbedding(nn.Module): + """Hash consecutive token pairs into a learned embedding table.""" + def __init__(self, bigram_vocab_size: int, bigram_dim: int, model_dim: int): + super().__init__() + self.bigram_vocab_size = bigram_vocab_size + self.embed = nn.Embedding(bigram_vocab_size, bigram_dim) + nn.init.zeros_(self.embed.weight) + self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) + + def bigram_hash(self, tokens: Tensor) -> Tensor: + t = tokens.to(torch.int32) + mod = self.bigram_vocab_size - 1 + out = torch.empty_like(t) + out[..., 0] = mod + out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod + return out.long() + + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(self.bigram_hash(token_ids)) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) + + +class Block(nn.Module): + def __init__(self, dim: int, num_heads: int, num_kv_heads: int, mlp_mult: float, rope_base: float, qk_gain_init: float): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: float, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + bigram_vocab_size: int = 0, + bigram_dim: int = 128, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.bigram = BigramHashEmbedding(bigram_vocab_size, bigram_dim, model_dim) if bigram_vocab_size > 0 else None + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.smear = SmearGate(model_dim) + self.blocks = nn.ModuleList( + [ + Block(model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init) + for _ in range(num_layers) + ] + ) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + num_layers = len(self.blocks) + for name, module in self.named_modules(): + if isinstance(module, nn.Linear): + if getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: + nn.init.orthogonal_(module.weight, gain=1.0) + if ".proj." in name or name.endswith(".proj"): + with torch.no_grad(): + module.weight.mul_(1.0 / math.sqrt(2 * num_layers)) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + def forward_logits(self, input_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + x = self.final_norm(x) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + logits_proj = self.lm_head(x) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + + + +def ttt_adapt(args, base_model, device, val_tokens, rank=0, world_size=1, log_fn=None): + """Test-Time Training with DDP: all GPUs adapt in parallel on val data.""" + seq_len = args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // seq_len + batch_seqs = args.ttt_batch_seqs + + # Freeze early layers for faster/better TTT adaptation + freeze_n = args.ttt_freeze_layers + if freeze_n > 0: + for i, block in enumerate(base_model.blocks): + if i < freeze_n: + for p in block.parameters(): + p.requires_grad_(False) + ttt_params = [p for p in base_model.parameters() if p.requires_grad] + optimizer = torch.optim.SGD(ttt_params, lr=args.ttt_lr, momentum=args.ttt_momentum) + + # Each rank gets a slice of sequences + my_start = (total_seqs * rank) // world_size + my_end = (total_seqs * (rank + 1)) // world_size + + base_model.train() + t0 = time.perf_counter() + + for epoch in range(args.ttt_epochs): + epoch_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + epoch_tokens = torch.zeros((), device=device, dtype=torch.float64) + + for batch_start in range(my_start, my_end, batch_seqs): + batch_end = min(batch_start + batch_seqs, my_end) + raw_start = batch_start * seq_len + raw_end = batch_end * seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + + optimizer.zero_grad(set_to_none=True) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = base_model(x, y) + loss.backward() + + # All-reduce gradients across ranks so all GPUs stay in sync + if world_size > 1: + for p in ttt_params: + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + + torch.nn.utils.clip_grad_norm_(ttt_params, 1.0) + optimizer.step() + + epoch_loss_sum += loss.detach().to(torch.float64) * y.numel() + epoch_tokens += float(y.numel()) + + # Sync loss stats across ranks for logging + if world_size > 1: + dist.all_reduce(epoch_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(epoch_tokens, op=dist.ReduceOp.SUM) + + elapsed = time.perf_counter() - t0 + avg_loss = epoch_loss_sum.item() / max(epoch_tokens.item(), 1) + if log_fn: + log_fn(f"ttt_epoch:{epoch+1}/{args.ttt_epochs} loss:{avg_loss:.4f} time:{elapsed:.1f}s") + + elapsed = time.perf_counter() - t0 + if log_fn: + log_fn(f"ttt:done elapsed={elapsed:.1f}s") + + +def eval_val_sliding( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int, + batch_seqs: int = 32, +) -> tuple[float, float]: + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= 1] + total_windows = len(window_starts) + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + + base_model.eval() + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = base_model.forward_logits(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), + reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else min(stride, wlen) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt = y_batch[i, s:wlen] + prev = x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + if rank == 0 and (bi // batch_seqs) % 50 == 0: + done = min(bi + batch_seqs, len(my_windows)) + pct = done / len(my_windows) * 100 + running_bpb = 0.0 + if token_count.item() > 0: + rl = (loss_sum / token_count).item() + running_bpb = rl / math.log(2.0) * (token_count.item() / byte_count.item()) + print(f" sliding_eval [{pct:5.1f}%] {done}/{len(my_windows)} windows running_bpb={running_bpb:.6f}", flush=True) + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + + val_loss = (loss_sum / token_count).item() + bits_per_token = val_loss / math.log(2.0) + tokens_per_byte = token_count.item() / byte_count.item() + base_model.train() + return val_loss, bits_per_token * tokens_per_byte + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # MODEL + OPTIMIZER SETUP + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + bigram_vocab_size=args.bigram_vocab_size, + bigram_dim=args.bigram_dim, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + scalar_params.append(base_model.smear.gate) + if base_model.bigram is not None: + scalar_params.append(base_model.bigram.scale) + + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + tok_params = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] + if base_model.bigram is not None: + tok_params.append({"params": [base_model.bigram.embed.weight], "lr": token_lr, "base_lr": token_lr}) + if base_model.bigram.proj is not None: + matrix_params.append(base_model.bigram.proj.weight) + + optimizer_tok = torch.optim.AdamW( + tok_params, + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.weight_decay, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + weight_decay=0.02, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.AdamW( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.weight_decay, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # DATA LOADER & MODEL WARMUP + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # MAIN TRAINING LOOP + training_time_ms = 0.0 + stop_after_step: int | None = None + swa_state: dict[str, Tensor] | None = None + swa_count = 0 + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + + # SWA: collect checkpoints during warmdown + if args.swa_enabled and scale < args.swa_start_frac and step % args.swa_every == 0: + if swa_state is None: + swa_state = {name: t.detach().cpu().clone() for name, t in base_model.state_dict().items()} + swa_count = 1 + log0(f"swa:start step:{step}") + else: + for name, t in base_model.state_dict().items(): + swa_state[name] += t.detach().cpu() + swa_count += 1 + + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # Apply SWA if collected + if args.swa_enabled and swa_state is not None and swa_count > 1: + log0(f"swa:applying averaged {swa_count} checkpoints") + current_state = base_model.state_dict() + avg_state = { + name: (tensor / swa_count).to(dtype=current_state[name].dtype) + for name, tensor in swa_state.items() + } + base_model.load_state_dict(avg_state, strict=True) + + # SERIALIZATION + ROUNDTRIP VALIDATION + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + # INT6 mixed quantization + zstd/zlib export + sd_cpu = {k: v.detach().cpu() for k, v in base_model.state_dict().items()} + quant_result, quant_meta = mixed_quantize_int6(sd_cpu, {"mlp", "attn"}) + quant_buf = io.BytesIO() + torch.save({"w": quant_result, "m": quant_meta}, quant_buf) + quant_raw = quant_buf.getvalue() + if _COMPRESSOR == "zstd": + quant_blob = zstandard.ZstdCompressor(level=22).compress(quant_raw) + else: + quant_blob = zlib.compress(quant_raw, 9) + if master_process: + with open("final_model.int8.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model int6+{_COMPRESSOR}: {quant_file_bytes} bytes") + log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open("final_model.int8.ptz", "rb") as f: + quant_blob_disk = f.read() + if _COMPRESSOR == "zstd": + decompressed = zstandard.ZstdDecompressor().decompress(quant_blob_disk) + else: + decompressed = zlib.decompress(quant_blob_disk) + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") + deq_state = dequantize_mixed_int6(quant_state["w"], quant_state["m"], sd_cpu) + base_model.load_state_dict(deq_state, strict=True) + + # TTT: adapt model on validation data before final eval (all ranks participate) + if args.ttt_enabled: + if master_process: + log0(f"ttt:start lr={args.ttt_lr} momentum={args.ttt_momentum} epochs={args.ttt_epochs}") + restore_low_dim_params_to_fp32(base_model) + for p in base_model.parameters(): + p.requires_grad_(True) + ttt_adapt(args, base_model, device, val_tokens, rank=rank, world_size=world_size, log_fn=log0 if master_process else None) + for p in base_model.parameters(): + p.requires_grad_(False) + if distributed: + dist.barrier() + if master_process: + log0("Compiling forward_logits for sliding window eval...") + + # Sliding window eval on int6-roundtripped weights (post-TTT) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + if args.eval_stride > 0 and args.eval_stride < args.train_seq_len: + log0(f"final_eval_mode:sliding_window stride:{args.eval_stride} batch_seqs:{args.eval_batch_seqs}") + q_val_loss, q_val_bpb = eval_val_sliding( + args, base_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, batch_seqs=args.eval_batch_seqs, + ) + else: + log0("final_eval_mode:standard") + q_val_loss, q_val_bpb = eval_val( + args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() diff --git a/records/track_10min_16mb/2026-03-20_NuclearStack_FarnsworthTech/train_seed_1337.log b/records/track_10min_16mb/2026-03-20_NuclearStack_FarnsworthTech/train_seed_1337.log new file mode 100644 index 0000000000..b8854119f5 --- /dev/null +++ b/records/track_10min_16mb/2026-03-20_NuclearStack_FarnsworthTech/train_seed_1337.log @@ -0,0 +1,301 @@ +W0320 06:17:57.015000 35784 torch/distributed/run.py:803] +W0320 06:17:57.015000 35784 torch/distributed/run.py:803] ***************************************** +W0320 06:17:57.015000 35784 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0320 06:17:57.015000 35784 torch/distributed/run.py:803] ***************************************** +logs/112f3028-968a-4c7a-835d-bdca941cdc54.txt +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:21942857 +world_size:8 grad_accum_steps:1 +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.03 matrix_lr:0.02 scalar_lr:0.02 +train_batch_tokens:786432 train_seq_len:2048 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 +seed:1337 +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:0/20000 val_loss:6.9274 val_bpb:4.1028 train_time:0ms step_avg:0.02ms +step:1/20000 train_loss:6.9295 train_time:134ms step_avg:133.93ms +step:2/20000 train_loss:7.9168 train_time:193ms step_avg:96.27ms +step:3/20000 train_loss:7.5321 train_time:278ms step_avg:92.56ms +step:4/20000 train_loss:7.0571 train_time:363ms step_avg:90.83ms +step:5/20000 train_loss:6.8960 train_time:449ms step_avg:89.86ms +step:6/20000 train_loss:6.7966 train_time:535ms step_avg:89.21ms +step:7/20000 train_loss:6.6410 train_time:620ms step_avg:88.64ms +step:8/20000 train_loss:6.6247 train_time:706ms step_avg:88.24ms +step:9/20000 train_loss:6.3570 train_time:791ms step_avg:87.91ms +step:10/20000 train_loss:6.0986 train_time:876ms step_avg:87.64ms +step:100/20000 train_loss:3.2130 train_time:8134ms step_avg:81.34ms +step:200/20000 train_loss:2.4467 train_time:16348ms step_avg:81.74ms +step:300/20000 train_loss:2.6019 train_time:24568ms step_avg:81.89ms +step:400/20000 train_loss:2.4576 train_time:32805ms step_avg:82.01ms +step:500/20000 train_loss:2.4393 train_time:40961ms step_avg:81.92ms +step:500/20000 val_loss:2.3971 val_bpb:1.4197 train_time:41003ms step_avg:82.01ms +step:600/20000 train_loss:2.3654 train_time:49215ms step_avg:82.02ms +step:700/20000 train_loss:2.3712 train_time:57470ms step_avg:82.10ms +step:800/20000 train_loss:2.2632 train_time:65720ms step_avg:82.15ms +step:900/20000 train_loss:2.1488 train_time:73965ms step_avg:82.18ms +step:1000/20000 train_loss:2.2927 train_time:82140ms step_avg:82.14ms +step:1000/20000 val_loss:2.2454 val_bpb:1.3299 train_time:82182ms step_avg:82.18ms +step:1100/20000 train_loss:2.3434 train_time:90409ms step_avg:82.19ms +step:1200/20000 train_loss:2.3686 train_time:98665ms step_avg:82.22ms +step:1300/20000 train_loss:2.1164 train_time:106917ms step_avg:82.24ms +step:1400/20000 train_loss:2.1981 train_time:115171ms step_avg:82.26ms +step:1500/20000 train_loss:2.2376 train_time:123329ms step_avg:82.22ms +step:1500/20000 val_loss:2.1986 val_bpb:1.3021 train_time:123370ms step_avg:82.25ms +step:1600/20000 train_loss:2.0967 train_time:131575ms step_avg:82.23ms +step:1700/20000 train_loss:2.1591 train_time:139827ms step_avg:82.25ms +step:1800/20000 train_loss:2.1770 train_time:148082ms step_avg:82.27ms +step:1900/20000 train_loss:2.1419 train_time:156251ms step_avg:82.24ms +step:2000/20000 train_loss:2.0854 train_time:164503ms step_avg:82.25ms +step:2000/20000 val_loss:2.1453 val_bpb:1.2706 train_time:164543ms step_avg:82.27ms +step:2100/20000 train_loss:2.0610 train_time:172746ms step_avg:82.26ms +step:2200/20000 train_loss:2.1539 train_time:180994ms step_avg:82.27ms +step:2300/20000 train_loss:2.1193 train_time:189244ms step_avg:82.28ms +step:2400/20000 train_loss:2.0776 train_time:197407ms step_avg:82.25ms +step:2500/20000 train_loss:2.1802 train_time:205652ms step_avg:82.26ms +step:2500/20000 val_loss:2.1152 val_bpb:1.2528 train_time:205692ms step_avg:82.28ms +step:2600/20000 train_loss:2.1136 train_time:213900ms step_avg:82.27ms +step:2700/20000 train_loss:2.1063 train_time:222140ms step_avg:82.27ms +step:2800/20000 train_loss:2.1593 train_time:230394ms step_avg:82.28ms +step:2900/20000 train_loss:2.0292 train_time:238556ms step_avg:82.26ms +step:3000/20000 train_loss:2.1647 train_time:246802ms step_avg:82.27ms +step:3000/20000 val_loss:2.0961 val_bpb:1.2414 train_time:246843ms step_avg:82.28ms +step:3100/20000 train_loss:2.0429 train_time:255047ms step_avg:82.27ms +step:3200/20000 train_loss:2.1759 train_time:263298ms step_avg:82.28ms +step:3300/20000 train_loss:2.0749 train_time:271455ms step_avg:82.26ms +step:3400/20000 train_loss:2.0253 train_time:279700ms step_avg:82.26ms +step:3500/20000 train_loss:2.1843 train_time:287947ms step_avg:82.27ms +step:3500/20000 val_loss:2.0846 val_bpb:1.2346 train_time:287988ms step_avg:82.28ms +step:3600/20000 train_loss:2.0993 train_time:296190ms step_avg:82.27ms +step:3700/20000 train_loss:2.1007 train_time:304429ms step_avg:82.28ms +step:3800/20000 train_loss:2.0807 train_time:312587ms step_avg:82.26ms +step:3900/20000 train_loss:2.0852 train_time:320837ms step_avg:82.27ms +step:4000/20000 train_loss:1.9862 train_time:329070ms step_avg:82.27ms +step:4000/20000 val_loss:2.0758 val_bpb:1.2294 train_time:329110ms step_avg:82.28ms +step:4100/20000 train_loss:2.0275 train_time:337316ms step_avg:82.27ms +step:4200/20000 train_loss:2.1661 train_time:345568ms step_avg:82.28ms +step:4300/20000 train_loss:2.0737 train_time:353729ms step_avg:82.26ms +step:4400/20000 train_loss:2.0541 train_time:361969ms step_avg:82.27ms +step:4500/20000 train_loss:2.1399 train_time:370220ms step_avg:82.27ms +step:4500/20000 val_loss:2.0631 val_bpb:1.2219 train_time:370260ms step_avg:82.28ms +step:4600/20000 train_loss:1.8629 train_time:378475ms step_avg:82.28ms +step:4700/20000 train_loss:2.2502 train_time:386638ms step_avg:82.26ms +step:4800/20000 train_loss:2.4512 train_time:394887ms step_avg:82.27ms +step:4900/20000 train_loss:2.0697 train_time:403134ms step_avg:82.27ms +step:5000/20000 train_loss:2.1256 train_time:411392ms step_avg:82.28ms +step:5000/20000 val_loss:2.0453 val_bpb:1.2113 train_time:411432ms step_avg:82.29ms +step:5100/20000 train_loss:2.1480 train_time:419658ms step_avg:82.29ms +step:5200/20000 train_loss:2.0629 train_time:427825ms step_avg:82.27ms +step:5300/20000 train_loss:2.0328 train_time:436077ms step_avg:82.28ms +step:5400/20000 train_loss:2.0721 train_time:444316ms step_avg:82.28ms +step:5500/20000 train_loss:2.0458 train_time:452563ms step_avg:82.28ms +step:5500/20000 val_loss:2.0278 val_bpb:1.2010 train_time:452604ms step_avg:82.29ms +step:5600/20000 train_loss:1.9793 train_time:460824ms step_avg:82.29ms +step:5700/20000 train_loss:2.0407 train_time:468992ms step_avg:82.28ms +swa:start step:5800 +step:5800/20000 train_loss:2.0306 train_time:477245ms step_avg:82.28ms +step:5900/20000 train_loss:1.9359 train_time:485567ms step_avg:82.30ms +step:6000/20000 train_loss:1.9711 train_time:493825ms step_avg:82.30ms +step:6000/20000 val_loss:2.0116 val_bpb:1.1914 train_time:493893ms step_avg:82.32ms +step:6100/20000 train_loss:1.9517 train_time:502017ms step_avg:82.30ms +step:6200/20000 train_loss:1.9840 train_time:510277ms step_avg:82.30ms +step:6300/20000 train_loss:1.9807 train_time:518563ms step_avg:82.31ms +step:6400/20000 train_loss:2.0365 train_time:526811ms step_avg:82.31ms +step:6500/20000 train_loss:2.1197 train_time:535089ms step_avg:82.32ms +step:6500/20000 val_loss:1.9923 val_bpb:1.1799 train_time:535128ms step_avg:82.33ms +step:6600/20000 train_loss:1.8847 train_time:543248ms step_avg:82.31ms +step:6700/20000 train_loss:1.9862 train_time:551543ms step_avg:82.32ms +step:6800/20000 train_loss:2.0680 train_time:559801ms step_avg:82.32ms +step:6900/20000 train_loss:1.8727 train_time:568088ms step_avg:82.33ms +step:7000/20000 train_loss:1.8382 train_time:576342ms step_avg:82.33ms +step:7000/20000 val_loss:1.9754 val_bpb:1.1700 train_time:576432ms step_avg:82.35ms +step:7100/20000 train_loss:1.9748 train_time:584560ms step_avg:82.33ms +step:7200/20000 train_loss:1.9276 train_time:595455ms step_avg:82.70ms +step:7248/20000 val_loss:1.9685 val_bpb:1.1659 train_time:602053ms step_avg:83.06ms +stopping_early: wallclock_cap train_time:602053ms step:7248/20000 +peak memory allocated: 16938 MiB reserved: 17152 MiB +swa:applying averaged 8 checkpoints +Serialized model: 86495963 bytes +Code size: 56156 bytes +Total submission size: 86552119 bytes +Serialized model int6+zstd: 15745387 bytes +Total submission size int8+zlib: 15801543 bytes +ttt:start lr=0.002 momentum=0.9 epochs=2 +ttt_epoch:1/2 loss:1.9779 time:13.1s +ttt_epoch:2/2 loss:1.9775 time:26.0s +ttt:done elapsed=26.0s +Compiling forward_logits for sliding window eval... +final_eval_mode:sliding_window stride:32 batch_seqs:32 + sliding_eval [ 0.0%] 32/242272 windows running_bpb=1.164756 + sliding_eval [ 0.7%] 1632/242272 windows running_bpb=1.238775 + sliding_eval [ 1.3%] 3232/242272 windows running_bpb=1.157406 + sliding_eval [ 2.0%] 4832/242272 windows running_bpb=1.172642 + sliding_eval [ 2.7%] 6432/242272 windows running_bpb=1.159355 + sliding_eval [ 3.3%] 8032/242272 windows running_bpb=1.157739 + sliding_eval [ 4.0%] 9632/242272 windows running_bpb=1.153729 + sliding_eval [ 4.6%] 11232/242272 windows running_bpb=1.155908 + sliding_eval [ 5.3%] 12832/242272 windows running_bpb=1.164919 + sliding_eval [ 6.0%] 14432/242272 windows running_bpb=1.165937 + sliding_eval [ 6.6%] 16032/242272 windows running_bpb=1.166939 + sliding_eval [ 7.3%] 17632/242272 windows running_bpb=1.172181 + sliding_eval [ 7.9%] 19232/242272 windows running_bpb=1.169326 + sliding_eval [ 8.6%] 20832/242272 windows running_bpb=1.166459 + sliding_eval [ 9.3%] 22432/242272 windows running_bpb=1.165351 + sliding_eval [ 9.9%] 24032/242272 windows running_bpb=1.163643 + sliding_eval [ 10.6%] 25632/242272 windows running_bpb=1.162153 + sliding_eval [ 11.2%] 27232/242272 windows running_bpb=1.160324 + sliding_eval [ 11.9%] 28832/242272 windows running_bpb=1.163890 + sliding_eval [ 12.6%] 30432/242272 windows running_bpb=1.174930 + sliding_eval [ 13.2%] 32032/242272 windows running_bpb=1.172418 + sliding_eval [ 13.9%] 33632/242272 windows running_bpb=1.171312 + sliding_eval [ 14.5%] 35232/242272 windows running_bpb=1.170661 + sliding_eval [ 15.2%] 36832/242272 windows running_bpb=1.170288 + sliding_eval [ 15.9%] 38432/242272 windows running_bpb=1.172214 + sliding_eval [ 16.5%] 40032/242272 windows running_bpb=1.170752 + sliding_eval [ 17.2%] 41632/242272 windows running_bpb=1.170572 + sliding_eval [ 17.8%] 43232/242272 windows running_bpb=1.169961 + sliding_eval [ 18.5%] 44832/242272 windows running_bpb=1.168936 + sliding_eval [ 19.2%] 46432/242272 windows running_bpb=1.166171 + sliding_eval [ 19.8%] 48032/242272 windows running_bpb=1.169570 + sliding_eval [ 20.5%] 49632/242272 windows running_bpb=1.169376 + sliding_eval [ 21.1%] 51232/242272 windows running_bpb=1.170604 + sliding_eval [ 21.8%] 52832/242272 windows running_bpb=1.171031 + sliding_eval [ 22.5%] 54432/242272 windows running_bpb=1.171326 + sliding_eval [ 23.1%] 56032/242272 windows running_bpb=1.175680 + sliding_eval [ 23.8%] 57632/242272 windows running_bpb=1.177427 + sliding_eval [ 24.4%] 59232/242272 windows running_bpb=1.177145 + sliding_eval [ 25.1%] 60832/242272 windows running_bpb=1.174707 + sliding_eval [ 25.8%] 62432/242272 windows running_bpb=1.175413 + sliding_eval [ 26.4%] 64032/242272 windows running_bpb=1.176010 + sliding_eval [ 27.1%] 65632/242272 windows running_bpb=1.175950 + sliding_eval [ 27.8%] 67232/242272 windows running_bpb=1.174988 + sliding_eval [ 28.4%] 68832/242272 windows running_bpb=1.174271 + sliding_eval [ 29.1%] 70432/242272 windows running_bpb=1.174319 + sliding_eval [ 29.7%] 72032/242272 windows running_bpb=1.174992 + sliding_eval [ 30.4%] 73632/242272 windows running_bpb=1.174091 + sliding_eval [ 31.1%] 75232/242272 windows running_bpb=1.174512 + sliding_eval [ 31.7%] 76832/242272 windows running_bpb=1.174667 + sliding_eval [ 32.4%] 78432/242272 windows running_bpb=1.173310 + sliding_eval [ 33.0%] 80032/242272 windows running_bpb=1.172882 + sliding_eval [ 33.7%] 81632/242272 windows running_bpb=1.171889 + sliding_eval [ 34.4%] 83232/242272 windows running_bpb=1.172029 + sliding_eval [ 35.0%] 84832/242272 windows running_bpb=1.172150 + sliding_eval [ 35.7%] 86432/242272 windows running_bpb=1.172241 + sliding_eval [ 36.3%] 88032/242272 windows running_bpb=1.171770 + sliding_eval [ 37.0%] 89632/242272 windows running_bpb=1.171362 + sliding_eval [ 37.7%] 91232/242272 windows running_bpb=1.170957 + sliding_eval [ 38.3%] 92832/242272 windows running_bpb=1.171183 + sliding_eval [ 39.0%] 94432/242272 windows running_bpb=1.170356 + sliding_eval [ 39.6%] 96032/242272 windows running_bpb=1.170399 + sliding_eval [ 40.3%] 97632/242272 windows running_bpb=1.170278 + sliding_eval [ 41.0%] 99232/242272 windows running_bpb=1.171513 + sliding_eval [ 41.6%] 100832/242272 windows running_bpb=1.172150 + sliding_eval [ 42.3%] 102432/242272 windows running_bpb=1.172520 + sliding_eval [ 42.9%] 104032/242272 windows running_bpb=1.172436 + sliding_eval [ 43.6%] 105632/242272 windows running_bpb=1.173034 + sliding_eval [ 44.3%] 107232/242272 windows running_bpb=1.172058 + sliding_eval [ 44.9%] 108832/242272 windows running_bpb=1.172436 + sliding_eval [ 45.6%] 110432/242272 windows running_bpb=1.172770 + sliding_eval [ 46.2%] 112032/242272 windows running_bpb=1.172761 + sliding_eval [ 46.9%] 113632/242272 windows running_bpb=1.172124 + sliding_eval [ 47.6%] 115232/242272 windows running_bpb=1.171821 + sliding_eval [ 48.2%] 116832/242272 windows running_bpb=1.168877 + sliding_eval [ 48.9%] 118432/242272 windows running_bpb=1.168582 + sliding_eval [ 49.5%] 120032/242272 windows running_bpb=1.168453 + sliding_eval [ 50.2%] 121632/242272 windows running_bpb=1.168384 + sliding_eval [ 50.9%] 123232/242272 windows running_bpb=1.168515 + sliding_eval [ 51.5%] 124832/242272 windows running_bpb=1.169341 + sliding_eval [ 52.2%] 126432/242272 windows running_bpb=1.169091 + sliding_eval [ 52.8%] 128032/242272 windows running_bpb=1.169319 + sliding_eval [ 53.5%] 129632/242272 windows running_bpb=1.169538 + sliding_eval [ 54.2%] 131232/242272 windows running_bpb=1.169113 + sliding_eval [ 54.8%] 132832/242272 windows running_bpb=1.168289 + sliding_eval [ 55.5%] 134432/242272 windows running_bpb=1.167954 + sliding_eval [ 56.1%] 136032/242272 windows running_bpb=1.167762 + sliding_eval [ 56.8%] 137632/242272 windows running_bpb=1.167478 + sliding_eval [ 57.5%] 139232/242272 windows running_bpb=1.167202 + sliding_eval [ 58.1%] 140832/242272 windows running_bpb=1.166881 + sliding_eval [ 58.8%] 142432/242272 windows running_bpb=1.166924 + sliding_eval [ 59.5%] 144032/242272 windows running_bpb=1.166827 + sliding_eval [ 60.1%] 145632/242272 windows running_bpb=1.166735 + sliding_eval [ 60.8%] 147232/242272 windows running_bpb=1.166734 + sliding_eval [ 61.4%] 148832/242272 windows running_bpb=1.167036 + sliding_eval [ 62.1%] 150432/242272 windows running_bpb=1.166913 + sliding_eval [ 62.8%] 152032/242272 windows running_bpb=1.166256 + sliding_eval [ 63.4%] 153632/242272 windows running_bpb=1.166524 + sliding_eval [ 64.1%] 155232/242272 windows running_bpb=1.166853 + sliding_eval [ 64.7%] 156832/242272 windows running_bpb=1.167282 + sliding_eval [ 65.4%] 158432/242272 windows running_bpb=1.166975 + sliding_eval [ 66.1%] 160032/242272 windows running_bpb=1.167568 + sliding_eval [ 66.7%] 161632/242272 windows running_bpb=1.167862 + sliding_eval [ 67.4%] 163232/242272 windows running_bpb=1.167323 + sliding_eval [ 68.0%] 164832/242272 windows running_bpb=1.167931 + sliding_eval [ 68.7%] 166432/242272 windows running_bpb=1.168328 + sliding_eval [ 69.4%] 168032/242272 windows running_bpb=1.170035 + sliding_eval [ 70.0%] 169632/242272 windows running_bpb=1.170154 + sliding_eval [ 70.7%] 171232/242272 windows running_bpb=1.169306 + sliding_eval [ 71.3%] 172832/242272 windows running_bpb=1.169557 + sliding_eval [ 72.0%] 174432/242272 windows running_bpb=1.169812 + sliding_eval [ 72.7%] 176032/242272 windows running_bpb=1.170499 + sliding_eval [ 73.3%] 177632/242272 windows running_bpb=1.170618 + sliding_eval [ 74.0%] 179232/242272 windows running_bpb=1.170792 + sliding_eval [ 74.6%] 180832/242272 windows running_bpb=1.171041 + sliding_eval [ 75.3%] 182432/242272 windows running_bpb=1.170925 + sliding_eval [ 76.0%] 184032/242272 windows running_bpb=1.170252 + sliding_eval [ 76.6%] 185632/242272 windows running_bpb=1.170430 + sliding_eval [ 77.3%] 187232/242272 windows running_bpb=1.170632 + sliding_eval [ 77.9%] 188832/242272 windows running_bpb=1.170735 + sliding_eval [ 78.6%] 190432/242272 windows running_bpb=1.170712 + sliding_eval [ 79.3%] 192032/242272 windows running_bpb=1.170166 + sliding_eval [ 79.9%] 193632/242272 windows running_bpb=1.173238 + sliding_eval [ 80.6%] 195232/242272 windows running_bpb=1.173000 + sliding_eval [ 81.2%] 196832/242272 windows running_bpb=1.172917 + sliding_eval [ 81.9%] 198432/242272 windows running_bpb=1.172843 + sliding_eval [ 82.6%] 200032/242272 windows running_bpb=1.172623 + sliding_eval [ 83.2%] 201632/242272 windows running_bpb=1.172851 + sliding_eval [ 83.9%] 203232/242272 windows running_bpb=1.173201 + sliding_eval [ 84.5%] 204832/242272 windows running_bpb=1.172562 + sliding_eval [ 85.2%] 206432/242272 windows running_bpb=1.172392 + sliding_eval [ 85.9%] 208032/242272 windows running_bpb=1.172050 + sliding_eval [ 86.5%] 209632/242272 windows running_bpb=1.171546 + sliding_eval [ 87.2%] 211232/242272 windows running_bpb=1.171254 + sliding_eval [ 87.8%] 212832/242272 windows running_bpb=1.171158 + sliding_eval [ 88.5%] 214432/242272 windows running_bpb=1.171114 + sliding_eval [ 89.2%] 216032/242272 windows running_bpb=1.171190 + sliding_eval [ 89.8%] 217632/242272 windows running_bpb=1.171703 + sliding_eval [ 90.5%] 219232/242272 windows running_bpb=1.171896 + sliding_eval [ 91.2%] 220832/242272 windows running_bpb=1.171713 + sliding_eval [ 91.8%] 222432/242272 windows running_bpb=1.171901 + sliding_eval [ 92.5%] 224032/242272 windows running_bpb=1.171656 + sliding_eval [ 93.1%] 225632/242272 windows running_bpb=1.172332 + sliding_eval [ 93.8%] 227232/242272 windows running_bpb=1.172146 + sliding_eval [ 94.5%] 228832/242272 windows running_bpb=1.172066 + sliding_eval [ 95.1%] 230432/242272 windows running_bpb=1.171923 + sliding_eval [ 95.8%] 232032/242272 windows running_bpb=1.171907 + sliding_eval [ 96.4%] 233632/242272 windows running_bpb=1.171597 + sliding_eval [ 97.1%] 235232/242272 windows running_bpb=1.172032 + sliding_eval [ 97.8%] 236832/242272 windows running_bpb=1.172057 + sliding_eval [ 98.4%] 238432/242272 windows running_bpb=1.172062 + sliding_eval [ 99.1%] 240032/242272 windows running_bpb=1.172060 + sliding_eval [ 99.7%] 241632/242272 windows running_bpb=1.172287 +final_int8_zlib_roundtrip val_loss:1.9673 val_bpb:1.1652 eval_time:314179ms +final_int8_zlib_roundtrip_exact val_loss:1.96732761 val_bpb:1.16516352 diff --git a/records/track_10min_16mb/2026-03-20_NuclearStack_FarnsworthTech/train_seed_2884431328.log b/records/track_10min_16mb/2026-03-20_NuclearStack_FarnsworthTech/train_seed_2884431328.log new file mode 100644 index 0000000000..03cdfb6538 --- /dev/null +++ b/records/track_10min_16mb/2026-03-20_NuclearStack_FarnsworthTech/train_seed_2884431328.log @@ -0,0 +1,299 @@ +W0320 05:59:11.314000 32395 torch/distributed/run.py:803] +W0320 05:59:11.314000 32395 torch/distributed/run.py:803] ***************************************** +W0320 05:59:11.314000 32395 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0320 05:59:11.314000 32395 torch/distributed/run.py:803] ***************************************** +logs/95bb7ee9-44a2-429f-8e98-ec24d678173e.txt +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:21942857 +world_size:8 grad_accum_steps:1 +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.03 matrix_lr:0.02 scalar_lr:0.02 +train_batch_tokens:786432 train_seq_len:2048 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 +seed:2884431328 +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:0/20000 val_loss:6.9309 val_bpb:4.1049 train_time:0ms step_avg:0.03ms +step:1/20000 train_loss:6.9326 train_time:133ms step_avg:132.92ms +step:2/20000 train_loss:8.0411 train_time:213ms step_avg:106.29ms +step:3/20000 train_loss:7.6045 train_time:298ms step_avg:99.42ms +step:4/20000 train_loss:6.9517 train_time:386ms step_avg:96.41ms +step:5/20000 train_loss:6.7561 train_time:473ms step_avg:94.59ms +step:6/20000 train_loss:6.6465 train_time:560ms step_avg:93.30ms +step:7/20000 train_loss:6.5238 train_time:646ms step_avg:92.27ms +step:8/20000 train_loss:6.5287 train_time:733ms step_avg:91.63ms +step:9/20000 train_loss:6.3486 train_time:819ms step_avg:91.00ms +step:10/20000 train_loss:6.0753 train_time:906ms step_avg:90.57ms +step:100/20000 train_loss:3.2345 train_time:8187ms step_avg:81.87ms +step:200/20000 train_loss:2.4317 train_time:16387ms step_avg:81.94ms +step:300/20000 train_loss:2.5852 train_time:24600ms step_avg:82.00ms +step:400/20000 train_loss:2.4409 train_time:32847ms step_avg:82.12ms +step:500/20000 train_loss:2.4206 train_time:41002ms step_avg:82.00ms +step:500/20000 val_loss:2.3827 val_bpb:1.4112 train_time:41042ms step_avg:82.08ms +step:600/20000 train_loss:2.3499 train_time:49265ms step_avg:82.11ms +step:700/20000 train_loss:2.3642 train_time:57520ms step_avg:82.17ms +step:800/20000 train_loss:2.2568 train_time:65769ms step_avg:82.21ms +step:900/20000 train_loss:2.1446 train_time:74028ms step_avg:82.25ms +step:1000/20000 train_loss:2.2885 train_time:82199ms step_avg:82.20ms +step:1000/20000 val_loss:2.2426 val_bpb:1.3282 train_time:82241ms step_avg:82.24ms +step:1100/20000 train_loss:2.3388 train_time:90449ms step_avg:82.23ms +step:1200/20000 train_loss:2.3633 train_time:98703ms step_avg:82.25ms +step:1300/20000 train_loss:2.1103 train_time:106954ms step_avg:82.27ms +step:1400/20000 train_loss:2.1947 train_time:115219ms step_avg:82.30ms +step:1500/20000 train_loss:2.2340 train_time:123376ms step_avg:82.25ms +step:1500/20000 val_loss:2.1973 val_bpb:1.3014 train_time:123418ms step_avg:82.28ms +step:1600/20000 train_loss:2.0918 train_time:131625ms step_avg:82.27ms +step:1700/20000 train_loss:2.1628 train_time:139866ms step_avg:82.27ms +step:1800/20000 train_loss:2.1812 train_time:148112ms step_avg:82.28ms +step:1900/20000 train_loss:2.1427 train_time:156264ms step_avg:82.24ms +step:2000/20000 train_loss:2.0826 train_time:164508ms step_avg:82.25ms +step:2000/20000 val_loss:2.1461 val_bpb:1.2710 train_time:164550ms step_avg:82.28ms +step:2100/20000 train_loss:2.0574 train_time:172756ms step_avg:82.26ms +step:2200/20000 train_loss:2.1493 train_time:181011ms step_avg:82.28ms +step:2300/20000 train_loss:2.1187 train_time:189257ms step_avg:82.29ms +step:2400/20000 train_loss:2.0785 train_time:197413ms step_avg:82.26ms +step:2500/20000 train_loss:2.1788 train_time:205659ms step_avg:82.26ms +step:2500/20000 val_loss:2.1152 val_bpb:1.2528 train_time:205699ms step_avg:82.28ms +step:2600/20000 train_loss:2.1141 train_time:213893ms step_avg:82.27ms +step:2700/20000 train_loss:2.1094 train_time:222129ms step_avg:82.27ms +step:2800/20000 train_loss:2.1639 train_time:230363ms step_avg:82.27ms +step:2900/20000 train_loss:2.0316 train_time:238514ms step_avg:82.25ms +step:3000/20000 train_loss:2.1680 train_time:246754ms step_avg:82.25ms +step:3000/20000 val_loss:2.0967 val_bpb:1.2418 train_time:246795ms step_avg:82.27ms +step:3100/20000 train_loss:2.0416 train_time:254993ms step_avg:82.26ms +step:3200/20000 train_loss:2.1792 train_time:263253ms step_avg:82.27ms +step:3300/20000 train_loss:2.0734 train_time:271429ms step_avg:82.25ms +step:3400/20000 train_loss:2.0258 train_time:279698ms step_avg:82.26ms +step:3500/20000 train_loss:2.1849 train_time:287962ms step_avg:82.27ms +step:3500/20000 val_loss:2.0853 val_bpb:1.2350 train_time:288003ms step_avg:82.29ms +step:3600/20000 train_loss:2.1023 train_time:296227ms step_avg:82.29ms +step:3700/20000 train_loss:2.1006 train_time:304489ms step_avg:82.29ms +step:3800/20000 train_loss:2.0814 train_time:312660ms step_avg:82.28ms +step:3900/20000 train_loss:2.0838 train_time:320917ms step_avg:82.29ms +step:4000/20000 train_loss:1.9834 train_time:329188ms step_avg:82.30ms +step:4000/20000 val_loss:2.0765 val_bpb:1.2298 train_time:329229ms step_avg:82.31ms +step:4100/20000 train_loss:2.0226 train_time:337433ms step_avg:82.30ms +step:4200/20000 train_loss:2.1688 train_time:345692ms step_avg:82.31ms +step:4300/20000 train_loss:2.0735 train_time:353854ms step_avg:82.29ms +step:4400/20000 train_loss:2.0543 train_time:362108ms step_avg:82.30ms +step:4500/20000 train_loss:2.1416 train_time:370359ms step_avg:82.30ms +step:4500/20000 val_loss:2.0641 val_bpb:1.2225 train_time:370401ms step_avg:82.31ms +step:4600/20000 train_loss:1.8617 train_time:378616ms step_avg:82.31ms +step:4700/20000 train_loss:2.2566 train_time:386773ms step_avg:82.29ms +step:4800/20000 train_loss:2.4506 train_time:395020ms step_avg:82.30ms +step:4900/20000 train_loss:2.0697 train_time:403250ms step_avg:82.30ms +step:5000/20000 train_loss:2.1267 train_time:411490ms step_avg:82.30ms +step:5000/20000 val_loss:2.0460 val_bpb:1.2118 train_time:411531ms step_avg:82.31ms +step:5100/20000 train_loss:2.1459 train_time:419743ms step_avg:82.30ms +step:5200/20000 train_loss:2.0672 train_time:427900ms step_avg:82.29ms +step:5300/20000 train_loss:2.0295 train_time:436156ms step_avg:82.29ms +step:5400/20000 train_loss:2.0741 train_time:444398ms step_avg:82.30ms +step:5500/20000 train_loss:2.0463 train_time:452659ms step_avg:82.30ms +step:5500/20000 val_loss:2.0284 val_bpb:1.2014 train_time:452699ms step_avg:82.31ms +step:5600/20000 train_loss:1.9849 train_time:460908ms step_avg:82.30ms +step:5700/20000 train_loss:2.0418 train_time:469062ms step_avg:82.29ms +swa:start step:5800 +step:5800/20000 train_loss:2.0299 train_time:477315ms step_avg:82.30ms +step:5900/20000 train_loss:1.9377 train_time:485648ms step_avg:82.31ms +step:6000/20000 train_loss:1.9725 train_time:496410ms step_avg:82.73ms +step:6000/20000 val_loss:2.0119 val_bpb:1.1916 train_time:496484ms step_avg:82.75ms +step:6100/20000 train_loss:1.9496 train_time:504607ms step_avg:82.72ms +step:6200/20000 train_loss:1.9827 train_time:515530ms step_avg:83.15ms +step:6300/20000 train_loss:1.9816 train_time:526566ms step_avg:83.58ms +step:6400/20000 train_loss:2.0341 train_time:537556ms step_avg:83.99ms +step:6500/20000 train_loss:2.1170 train_time:548322ms step_avg:84.36ms +step:6500/20000 val_loss:1.9895 val_bpb:1.1783 train_time:548363ms step_avg:84.36ms +step:6600/20000 train_loss:1.8776 train_time:556461ms step_avg:84.31ms +step:6700/20000 train_loss:1.9819 train_time:567286ms step_avg:84.67ms +step:6800/20000 train_loss:2.0634 train_time:577772ms step_avg:84.97ms +step:6900/20000 train_loss:1.8676 train_time:588467ms step_avg:85.29ms +step:7000/20000 train_loss:1.8360 train_time:599128ms step_avg:85.59ms +step:7000/20000 val_loss:1.9723 val_bpb:1.1681 train_time:599216ms step_avg:85.60ms +step:7009/20000 val_loss:1.9722 val_bpb:1.1681 train_time:599948ms step_avg:85.60ms +stopping_early: wallclock_cap train_time:599948ms step:7009/20000 +peak memory allocated: 16939 MiB reserved: 17158 MiB +swa:applying averaged 7 checkpoints +Serialized model: 86495963 bytes +Code size: 56156 bytes +Total submission size: 86552119 bytes +Serialized model int6+zstd: 15740371 bytes +Total submission size int8+zlib: 15796527 bytes +ttt:start lr=0.004 momentum=0.9 epochs=2 +ttt_epoch:1/2 loss:1.9807 time:13.1s +ttt_epoch:2/2 loss:1.9800 time:26.0s +ttt:done elapsed=26.0s +Compiling forward_logits for sliding window eval... +final_eval_mode:sliding_window stride:32 batch_seqs:32 + sliding_eval [ 0.0%] 32/242272 windows running_bpb=1.165941 + sliding_eval [ 0.7%] 1632/242272 windows running_bpb=1.239013 + sliding_eval [ 1.3%] 3232/242272 windows running_bpb=1.158698 + sliding_eval [ 2.0%] 4832/242272 windows running_bpb=1.173679 + sliding_eval [ 2.7%] 6432/242272 windows running_bpb=1.160414 + sliding_eval [ 3.3%] 8032/242272 windows running_bpb=1.159424 + sliding_eval [ 4.0%] 9632/242272 windows running_bpb=1.155700 + sliding_eval [ 4.6%] 11232/242272 windows running_bpb=1.157676 + sliding_eval [ 5.3%] 12832/242272 windows running_bpb=1.166152 + sliding_eval [ 6.0%] 14432/242272 windows running_bpb=1.167023 + sliding_eval [ 6.6%] 16032/242272 windows running_bpb=1.168082 + sliding_eval [ 7.3%] 17632/242272 windows running_bpb=1.173232 + sliding_eval [ 7.9%] 19232/242272 windows running_bpb=1.170518 + sliding_eval [ 8.6%] 20832/242272 windows running_bpb=1.167517 + sliding_eval [ 9.3%] 22432/242272 windows running_bpb=1.166436 + sliding_eval [ 9.9%] 24032/242272 windows running_bpb=1.164759 + sliding_eval [ 10.6%] 25632/242272 windows running_bpb=1.163344 + sliding_eval [ 11.2%] 27232/242272 windows running_bpb=1.161480 + sliding_eval [ 11.9%] 28832/242272 windows running_bpb=1.164952 + sliding_eval [ 12.6%] 30432/242272 windows running_bpb=1.176104 + sliding_eval [ 13.2%] 32032/242272 windows running_bpb=1.173734 + sliding_eval [ 13.9%] 33632/242272 windows running_bpb=1.172688 + sliding_eval [ 14.5%] 35232/242272 windows running_bpb=1.171818 + sliding_eval [ 15.2%] 36832/242272 windows running_bpb=1.171379 + sliding_eval [ 15.9%] 38432/242272 windows running_bpb=1.173439 + sliding_eval [ 16.5%] 40032/242272 windows running_bpb=1.172021 + sliding_eval [ 17.2%] 41632/242272 windows running_bpb=1.171849 + sliding_eval [ 17.8%] 43232/242272 windows running_bpb=1.171396 + sliding_eval [ 18.5%] 44832/242272 windows running_bpb=1.170485 + sliding_eval [ 19.2%] 46432/242272 windows running_bpb=1.167773 + sliding_eval [ 19.8%] 48032/242272 windows running_bpb=1.171144 + sliding_eval [ 20.5%] 49632/242272 windows running_bpb=1.171035 + sliding_eval [ 21.1%] 51232/242272 windows running_bpb=1.172261 + sliding_eval [ 21.8%] 52832/242272 windows running_bpb=1.172685 + sliding_eval [ 22.5%] 54432/242272 windows running_bpb=1.172957 + sliding_eval [ 23.1%] 56032/242272 windows running_bpb=1.177166 + sliding_eval [ 23.8%] 57632/242272 windows running_bpb=1.178906 + sliding_eval [ 24.4%] 59232/242272 windows running_bpb=1.178688 + sliding_eval [ 25.1%] 60832/242272 windows running_bpb=1.176208 + sliding_eval [ 25.8%] 62432/242272 windows running_bpb=1.176937 + sliding_eval [ 26.4%] 64032/242272 windows running_bpb=1.177466 + sliding_eval [ 27.1%] 65632/242272 windows running_bpb=1.177441 + sliding_eval [ 27.8%] 67232/242272 windows running_bpb=1.176566 + sliding_eval [ 28.4%] 68832/242272 windows running_bpb=1.175830 + sliding_eval [ 29.1%] 70432/242272 windows running_bpb=1.175902 + sliding_eval [ 29.7%] 72032/242272 windows running_bpb=1.176566 + sliding_eval [ 30.4%] 73632/242272 windows running_bpb=1.175662 + sliding_eval [ 31.1%] 75232/242272 windows running_bpb=1.176020 + sliding_eval [ 31.7%] 76832/242272 windows running_bpb=1.176147 + sliding_eval [ 32.4%] 78432/242272 windows running_bpb=1.174748 + sliding_eval [ 33.0%] 80032/242272 windows running_bpb=1.174351 + sliding_eval [ 33.7%] 81632/242272 windows running_bpb=1.173410 + sliding_eval [ 34.4%] 83232/242272 windows running_bpb=1.173555 + sliding_eval [ 35.0%] 84832/242272 windows running_bpb=1.173717 + sliding_eval [ 35.7%] 86432/242272 windows running_bpb=1.173785 + sliding_eval [ 36.3%] 88032/242272 windows running_bpb=1.173267 + sliding_eval [ 37.0%] 89632/242272 windows running_bpb=1.172855 + sliding_eval [ 37.7%] 91232/242272 windows running_bpb=1.172485 + sliding_eval [ 38.3%] 92832/242272 windows running_bpb=1.172748 + sliding_eval [ 39.0%] 94432/242272 windows running_bpb=1.171918 + sliding_eval [ 39.6%] 96032/242272 windows running_bpb=1.171946 + sliding_eval [ 40.3%] 97632/242272 windows running_bpb=1.171830 + sliding_eval [ 41.0%] 99232/242272 windows running_bpb=1.173021 + sliding_eval [ 41.6%] 100832/242272 windows running_bpb=1.173641 + sliding_eval [ 42.3%] 102432/242272 windows running_bpb=1.174059 + sliding_eval [ 42.9%] 104032/242272 windows running_bpb=1.174003 + sliding_eval [ 43.6%] 105632/242272 windows running_bpb=1.174590 + sliding_eval [ 44.3%] 107232/242272 windows running_bpb=1.173616 + sliding_eval [ 44.9%] 108832/242272 windows running_bpb=1.174016 + sliding_eval [ 45.6%] 110432/242272 windows running_bpb=1.174348 + sliding_eval [ 46.2%] 112032/242272 windows running_bpb=1.174325 + sliding_eval [ 46.9%] 113632/242272 windows running_bpb=1.173695 + sliding_eval [ 47.6%] 115232/242272 windows running_bpb=1.173352 + sliding_eval [ 48.2%] 116832/242272 windows running_bpb=1.170329 + sliding_eval [ 48.9%] 118432/242272 windows running_bpb=1.170004 + sliding_eval [ 49.5%] 120032/242272 windows running_bpb=1.169849 + sliding_eval [ 50.2%] 121632/242272 windows running_bpb=1.169803 + sliding_eval [ 50.9%] 123232/242272 windows running_bpb=1.169938 + sliding_eval [ 51.5%] 124832/242272 windows running_bpb=1.170753 + sliding_eval [ 52.2%] 126432/242272 windows running_bpb=1.170509 + sliding_eval [ 52.8%] 128032/242272 windows running_bpb=1.170745 + sliding_eval [ 53.5%] 129632/242272 windows running_bpb=1.170969 + sliding_eval [ 54.2%] 131232/242272 windows running_bpb=1.170551 + sliding_eval [ 54.8%] 132832/242272 windows running_bpb=1.169743 + sliding_eval [ 55.5%] 134432/242272 windows running_bpb=1.169371 + sliding_eval [ 56.1%] 136032/242272 windows running_bpb=1.169197 + sliding_eval [ 56.8%] 137632/242272 windows running_bpb=1.168926 + sliding_eval [ 57.5%] 139232/242272 windows running_bpb=1.168673 + sliding_eval [ 58.1%] 140832/242272 windows running_bpb=1.168319 + sliding_eval [ 58.8%] 142432/242272 windows running_bpb=1.168378 + sliding_eval [ 59.5%] 144032/242272 windows running_bpb=1.168265 + sliding_eval [ 60.1%] 145632/242272 windows running_bpb=1.168158 + sliding_eval [ 60.8%] 147232/242272 windows running_bpb=1.168167 + sliding_eval [ 61.4%] 148832/242272 windows running_bpb=1.168480 + sliding_eval [ 62.1%] 150432/242272 windows running_bpb=1.168343 + sliding_eval [ 62.8%] 152032/242272 windows running_bpb=1.167672 + sliding_eval [ 63.4%] 153632/242272 windows running_bpb=1.167960 + sliding_eval [ 64.1%] 155232/242272 windows running_bpb=1.168289 + sliding_eval [ 64.7%] 156832/242272 windows running_bpb=1.168720 + sliding_eval [ 65.4%] 158432/242272 windows running_bpb=1.168408 + sliding_eval [ 66.1%] 160032/242272 windows running_bpb=1.169040 + sliding_eval [ 66.7%] 161632/242272 windows running_bpb=1.169343 + sliding_eval [ 67.4%] 163232/242272 windows running_bpb=1.168833 + sliding_eval [ 68.0%] 164832/242272 windows running_bpb=1.169448 + sliding_eval [ 68.7%] 166432/242272 windows running_bpb=1.169848 + sliding_eval [ 69.4%] 168032/242272 windows running_bpb=1.171545 + sliding_eval [ 70.0%] 169632/242272 windows running_bpb=1.171684 + sliding_eval [ 70.7%] 171232/242272 windows running_bpb=1.170846 + sliding_eval [ 71.3%] 172832/242272 windows running_bpb=1.171106 + sliding_eval [ 72.0%] 174432/242272 windows running_bpb=1.171366 + sliding_eval [ 72.7%] 176032/242272 windows running_bpb=1.172046 + sliding_eval [ 73.3%] 177632/242272 windows running_bpb=1.172150 + sliding_eval [ 74.0%] 179232/242272 windows running_bpb=1.172336 + sliding_eval [ 74.6%] 180832/242272 windows running_bpb=1.172592 + sliding_eval [ 75.3%] 182432/242272 windows running_bpb=1.172475 + sliding_eval [ 76.0%] 184032/242272 windows running_bpb=1.171804 + sliding_eval [ 76.6%] 185632/242272 windows running_bpb=1.171967 + sliding_eval [ 77.3%] 187232/242272 windows running_bpb=1.172140 + sliding_eval [ 77.9%] 188832/242272 windows running_bpb=1.172259 + sliding_eval [ 78.6%] 190432/242272 windows running_bpb=1.172219 + sliding_eval [ 79.3%] 192032/242272 windows running_bpb=1.171658 + sliding_eval [ 79.9%] 193632/242272 windows running_bpb=1.174739 + sliding_eval [ 80.6%] 195232/242272 windows running_bpb=1.174510 + sliding_eval [ 81.2%] 196832/242272 windows running_bpb=1.174444 + sliding_eval [ 81.9%] 198432/242272 windows running_bpb=1.174384 + sliding_eval [ 82.6%] 200032/242272 windows running_bpb=1.174187 + sliding_eval [ 83.2%] 201632/242272 windows running_bpb=1.174397 + sliding_eval [ 83.9%] 203232/242272 windows running_bpb=1.174725 + sliding_eval [ 84.5%] 204832/242272 windows running_bpb=1.174084 + sliding_eval [ 85.2%] 206432/242272 windows running_bpb=1.173944 + sliding_eval [ 85.9%] 208032/242272 windows running_bpb=1.173618 + sliding_eval [ 86.5%] 209632/242272 windows running_bpb=1.173134 + sliding_eval [ 87.2%] 211232/242272 windows running_bpb=1.172861 + sliding_eval [ 87.8%] 212832/242272 windows running_bpb=1.172773 + sliding_eval [ 88.5%] 214432/242272 windows running_bpb=1.172748 + sliding_eval [ 89.2%] 216032/242272 windows running_bpb=1.172843 + sliding_eval [ 89.8%] 217632/242272 windows running_bpb=1.173333 + sliding_eval [ 90.5%] 219232/242272 windows running_bpb=1.173538 + sliding_eval [ 91.2%] 220832/242272 windows running_bpb=1.173348 + sliding_eval [ 91.8%] 222432/242272 windows running_bpb=1.173523 + sliding_eval [ 92.5%] 224032/242272 windows running_bpb=1.173288 + sliding_eval [ 93.1%] 225632/242272 windows running_bpb=1.173960 + sliding_eval [ 93.8%] 227232/242272 windows running_bpb=1.173765 + sliding_eval [ 94.5%] 228832/242272 windows running_bpb=1.173680 + sliding_eval [ 95.1%] 230432/242272 windows running_bpb=1.173542 + sliding_eval [ 95.8%] 232032/242272 windows running_bpb=1.173525 + sliding_eval [ 96.4%] 233632/242272 windows running_bpb=1.173240 + sliding_eval [ 97.1%] 235232/242272 windows running_bpb=1.173670 + sliding_eval [ 97.8%] 236832/242272 windows running_bpb=1.173673 + sliding_eval [ 98.4%] 238432/242272 windows running_bpb=1.173691 + sliding_eval [ 99.1%] 240032/242272 windows running_bpb=1.173674 + sliding_eval [ 99.7%] 241632/242272 windows running_bpb=1.173904 +final_int8_zlib_roundtrip val_loss:1.9699 val_bpb:1.1667 eval_time:314237ms +final_int8_zlib_roundtrip_exact val_loss:1.96988417 val_bpb:1.16667766 diff --git a/records/track_10min_16mb/2026-03-20_NuclearStack_FarnsworthTech/train_seed_7.log b/records/track_10min_16mb/2026-03-20_NuclearStack_FarnsworthTech/train_seed_7.log new file mode 100644 index 0000000000..8fe04f96b6 --- /dev/null +++ b/records/track_10min_16mb/2026-03-20_NuclearStack_FarnsworthTech/train_seed_7.log @@ -0,0 +1,291 @@ +W0320 13:00:26.714000 1191 torch/distributed/run.py:803] +W0320 13:00:26.714000 1191 torch/distributed/run.py:803] ***************************************** +W0320 13:00:26.714000 1191 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0320 13:00:26.714000 1191 torch/distributed/run.py:803] ***************************************** +logs/5a03b571-5a82-4abb-ba7a-a1e55760fcad.txt +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:21942857 +world_size:8 grad_accum_steps:1 +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.03 matrix_lr:0.02 scalar_lr:0.02 +train_batch_tokens:786432 train_seq_len:2048 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 +seed:7 +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:0/20000 val_loss:6.9302 val_bpb:4.1044 train_time:0ms step_avg:0.03ms +step:1/20000 train_loss:6.9322 train_time:128ms step_avg:127.69ms +step:2/20000 train_loss:7.9440 train_time:185ms step_avg:92.27ms +step:3/20000 train_loss:7.5319 train_time:269ms step_avg:89.81ms +step:4/20000 train_loss:6.9712 train_time:354ms step_avg:88.39ms +step:5/20000 train_loss:6.7973 train_time:438ms step_avg:87.64ms +step:6/20000 train_loss:6.7205 train_time:523ms step_avg:87.24ms +step:7/20000 train_loss:6.5706 train_time:608ms step_avg:86.85ms +step:8/20000 train_loss:6.5366 train_time:692ms step_avg:86.53ms +step:9/20000 train_loss:6.2650 train_time:777ms step_avg:86.32ms +step:10/20000 train_loss:6.0644 train_time:862ms step_avg:86.15ms +step:100/20000 train_loss:3.2447 train_time:7970ms step_avg:79.70ms +step:200/20000 train_loss:2.4353 train_time:17522ms step_avg:87.61ms +step:300/20000 train_loss:2.5827 train_time:27257ms step_avg:90.86ms +step:400/20000 train_loss:2.4475 train_time:36905ms step_avg:92.26ms +step:500/20000 train_loss:2.4214 train_time:44883ms step_avg:89.77ms +step:500/20000 val_loss:2.3840 val_bpb:1.4120 train_time:44922ms step_avg:89.84ms +step:600/20000 train_loss:2.3563 train_time:54448ms step_avg:90.75ms +step:700/20000 train_loss:2.3708 train_time:64138ms step_avg:91.63ms +step:800/20000 train_loss:2.2539 train_time:73642ms step_avg:92.05ms +step:900/20000 train_loss:2.1430 train_time:83306ms step_avg:92.56ms +step:1000/20000 train_loss:2.2876 train_time:91276ms step_avg:91.28ms +step:1000/20000 val_loss:2.2416 val_bpb:1.3276 train_time:91314ms step_avg:91.31ms +step:1100/20000 train_loss:2.3383 train_time:101002ms step_avg:91.82ms +step:1200/20000 train_loss:2.3675 train_time:110765ms step_avg:92.30ms +step:1300/20000 train_loss:2.1138 train_time:120527ms step_avg:92.71ms +step:1400/20000 train_loss:2.2001 train_time:130149ms step_avg:92.96ms +step:1500/20000 train_loss:2.2324 train_time:138101ms step_avg:92.07ms +step:1500/20000 val_loss:2.1982 val_bpb:1.3019 train_time:138139ms step_avg:92.09ms +step:1600/20000 train_loss:2.0919 train_time:147885ms step_avg:92.43ms +step:1700/20000 train_loss:2.1587 train_time:157445ms step_avg:92.61ms +step:1800/20000 train_loss:2.1806 train_time:167167ms step_avg:92.87ms +step:1900/20000 train_loss:2.1416 train_time:175151ms step_avg:92.18ms +step:2000/20000 train_loss:2.0828 train_time:184843ms step_avg:92.42ms +step:2000/20000 val_loss:2.1465 val_bpb:1.2713 train_time:184884ms step_avg:92.44ms +step:2100/20000 train_loss:2.0628 train_time:194607ms step_avg:92.67ms +step:2200/20000 train_loss:2.1489 train_time:204269ms step_avg:92.85ms +step:2300/20000 train_loss:2.1231 train_time:214012ms step_avg:93.05ms +step:2400/20000 train_loss:2.0785 train_time:221981ms step_avg:92.49ms +step:2500/20000 train_loss:2.1820 train_time:231598ms step_avg:92.64ms +step:2500/20000 val_loss:2.1170 val_bpb:1.2538 train_time:231637ms step_avg:92.65ms +step:2600/20000 train_loss:2.1162 train_time:241057ms step_avg:92.71ms +step:2700/20000 train_loss:2.1085 train_time:250573ms step_avg:92.80ms +step:2800/20000 train_loss:2.1641 train_time:260259ms step_avg:92.95ms +step:2900/20000 train_loss:2.0350 train_time:268214ms step_avg:92.49ms +step:3000/20000 train_loss:2.1671 train_time:277941ms step_avg:92.65ms +step:3000/20000 val_loss:2.0988 val_bpb:1.2430 train_time:277982ms step_avg:92.66ms +step:3100/20000 train_loss:2.0436 train_time:287649ms step_avg:92.79ms +step:3200/20000 train_loss:2.1798 train_time:297230ms step_avg:92.88ms +step:3300/20000 train_loss:2.0762 train_time:305187ms step_avg:92.48ms +step:3400/20000 train_loss:2.0243 train_time:314978ms step_avg:92.64ms +step:3500/20000 train_loss:2.1893 train_time:324584ms step_avg:92.74ms +step:3500/20000 val_loss:2.0871 val_bpb:1.2361 train_time:324623ms step_avg:92.75ms +step:3600/20000 train_loss:2.1042 train_time:334275ms step_avg:92.85ms +step:3700/20000 train_loss:2.1001 train_time:343887ms step_avg:92.94ms +step:3800/20000 train_loss:2.0762 train_time:351837ms step_avg:92.59ms +step:3900/20000 train_loss:2.0795 train_time:361449ms step_avg:92.68ms +step:4000/20000 train_loss:1.9770 train_time:371122ms step_avg:92.78ms +step:4000/20000 val_loss:2.0670 val_bpb:1.2242 train_time:371161ms step_avg:92.79ms +step:4100/20000 train_loss:2.0175 train_time:380797ms step_avg:92.88ms +step:4200/20000 train_loss:2.1527 train_time:390473ms step_avg:92.97ms +step:4300/20000 train_loss:2.0575 train_time:398434ms step_avg:92.66ms +step:4400/20000 train_loss:2.0370 train_time:407982ms step_avg:92.72ms +step:4500/20000 train_loss:2.1226 train_time:417612ms step_avg:92.80ms +step:4500/20000 val_loss:2.0470 val_bpb:1.2123 train_time:417653ms step_avg:92.81ms +step:4600/20000 train_loss:1.8447 train_time:427289ms step_avg:92.89ms +step:4700/20000 train_loss:2.2397 train_time:435249ms step_avg:92.61ms +step:4800/20000 train_loss:2.4303 train_time:444839ms step_avg:92.67ms +step:4900/20000 train_loss:2.0537 train_time:455104ms step_avg:92.88ms +swa:start step:5000 +step:5000/20000 train_loss:2.1068 train_time:464702ms step_avg:92.94ms +step:5000/20000 val_loss:2.0283 val_bpb:1.2013 train_time:464828ms step_avg:92.97ms +step:5100/20000 train_loss:2.1290 train_time:474216ms step_avg:92.98ms +step:5200/20000 train_loss:2.0447 train_time:482176ms step_avg:92.73ms +step:5300/20000 train_loss:2.0139 train_time:491761ms step_avg:92.79ms +step:5400/20000 train_loss:2.0530 train_time:501430ms step_avg:92.86ms +step:5500/20000 train_loss:2.0241 train_time:511073ms step_avg:92.92ms +step:5500/20000 val_loss:2.0095 val_bpb:1.1901 train_time:511111ms step_avg:92.93ms +step:5600/20000 train_loss:1.9654 train_time:520572ms step_avg:92.96ms +step:5700/20000 train_loss:2.0235 train_time:528584ms step_avg:92.73ms +step:5800/20000 train_loss:2.0064 train_time:538176ms step_avg:92.79ms +step:5900/20000 train_loss:1.9166 train_time:547857ms step_avg:92.86ms +step:6000/20000 train_loss:1.9513 train_time:557516ms step_avg:92.92ms +step:6000/20000 val_loss:1.9917 val_bpb:1.1796 train_time:557604ms step_avg:92.93ms +step:6100/20000 train_loss:1.9314 train_time:565521ms step_avg:92.71ms +step:6200/20000 train_loss:1.9599 train_time:575298ms step_avg:92.79ms +step:6300/20000 train_loss:1.9655 train_time:584949ms step_avg:92.85ms +step:6400/20000 train_loss:2.0205 train_time:594629ms step_avg:92.91ms +step:6466/20000 val_loss:1.9779 val_bpb:1.1714 train_time:599963ms step_avg:92.79ms +stopping_early: wallclock_cap train_time:599963ms step:6466/20000 +peak memory allocated: 16944 MiB reserved: 17156 MiB +swa:applying averaged 8 checkpoints +Serialized model: 86495963 bytes +Code size: 56156 bytes +Total submission size: 86552119 bytes +Serialized model int6+zstd: 15442627 bytes +Total submission size int8+zlib: 15498783 bytes +ttt:start lr=0.004 momentum=0.9 epochs=2 +ttt_epoch:1/2 loss:1.9880 time:13.1s +ttt_epoch:2/2 loss:1.9872 time:26.0s +ttt:done elapsed=26.0s +Compiling forward_logits for sliding window eval... +final_eval_mode:sliding_window stride:32 batch_seqs:32 + sliding_eval [ 0.0%] 32/242272 windows running_bpb=1.177116 + sliding_eval [ 0.7%] 1632/242272 windows running_bpb=1.242131 + sliding_eval [ 1.3%] 3232/242272 windows running_bpb=1.161657 + sliding_eval [ 2.0%] 4832/242272 windows running_bpb=1.177039 + sliding_eval [ 2.7%] 6432/242272 windows running_bpb=1.164188 + sliding_eval [ 3.3%] 8032/242272 windows running_bpb=1.162975 + sliding_eval [ 4.0%] 9632/242272 windows running_bpb=1.158889 + sliding_eval [ 4.6%] 11232/242272 windows running_bpb=1.161045 + sliding_eval [ 5.3%] 12832/242272 windows running_bpb=1.170114 + sliding_eval [ 6.0%] 14432/242272 windows running_bpb=1.170993 + sliding_eval [ 6.6%] 16032/242272 windows running_bpb=1.171949 + sliding_eval [ 7.3%] 17632/242272 windows running_bpb=1.176962 + sliding_eval [ 7.9%] 19232/242272 windows running_bpb=1.174203 + sliding_eval [ 8.6%] 20832/242272 windows running_bpb=1.171416 + sliding_eval [ 9.3%] 22432/242272 windows running_bpb=1.170398 + sliding_eval [ 9.9%] 24032/242272 windows running_bpb=1.168606 + sliding_eval [ 10.6%] 25632/242272 windows running_bpb=1.167092 + sliding_eval [ 11.2%] 27232/242272 windows running_bpb=1.165157 + sliding_eval [ 11.9%] 28832/242272 windows running_bpb=1.168718 + sliding_eval [ 12.6%] 30432/242272 windows running_bpb=1.179901 + sliding_eval [ 13.2%] 32032/242272 windows running_bpb=1.177417 + sliding_eval [ 13.9%] 33632/242272 windows running_bpb=1.176492 + sliding_eval [ 14.5%] 35232/242272 windows running_bpb=1.175988 + sliding_eval [ 15.2%] 36832/242272 windows running_bpb=1.175691 + sliding_eval [ 15.9%] 38432/242272 windows running_bpb=1.177589 + sliding_eval [ 16.5%] 40032/242272 windows running_bpb=1.176176 + sliding_eval [ 17.2%] 41632/242272 windows running_bpb=1.176015 + sliding_eval [ 17.8%] 43232/242272 windows running_bpb=1.175463 + sliding_eval [ 18.5%] 44832/242272 windows running_bpb=1.174562 + sliding_eval [ 19.2%] 46432/242272 windows running_bpb=1.171888 + sliding_eval [ 19.8%] 48032/242272 windows running_bpb=1.175235 + sliding_eval [ 20.5%] 49632/242272 windows running_bpb=1.175094 + sliding_eval [ 21.1%] 51232/242272 windows running_bpb=1.176383 + sliding_eval [ 21.8%] 52832/242272 windows running_bpb=1.176752 + sliding_eval [ 22.5%] 54432/242272 windows running_bpb=1.176986 + sliding_eval [ 23.1%] 56032/242272 windows running_bpb=1.181276 + sliding_eval [ 23.8%] 57632/242272 windows running_bpb=1.183010 + sliding_eval [ 24.4%] 59232/242272 windows running_bpb=1.182720 + sliding_eval [ 25.1%] 60832/242272 windows running_bpb=1.180237 + sliding_eval [ 25.8%] 62432/242272 windows running_bpb=1.180900 + sliding_eval [ 26.4%] 64032/242272 windows running_bpb=1.181530 + sliding_eval [ 27.1%] 65632/242272 windows running_bpb=1.181481 + sliding_eval [ 27.8%] 67232/242272 windows running_bpb=1.180553 + sliding_eval [ 28.4%] 68832/242272 windows running_bpb=1.179839 + sliding_eval [ 29.1%] 70432/242272 windows running_bpb=1.179903 + sliding_eval [ 29.7%] 72032/242272 windows running_bpb=1.180582 + sliding_eval [ 30.4%] 73632/242272 windows running_bpb=1.179635 + sliding_eval [ 31.1%] 75232/242272 windows running_bpb=1.179984 + sliding_eval [ 31.7%] 76832/242272 windows running_bpb=1.180144 + sliding_eval [ 32.4%] 78432/242272 windows running_bpb=1.178759 + sliding_eval [ 33.0%] 80032/242272 windows running_bpb=1.178377 + sliding_eval [ 33.7%] 81632/242272 windows running_bpb=1.177415 + sliding_eval [ 34.4%] 83232/242272 windows running_bpb=1.177562 + sliding_eval [ 35.0%] 84832/242272 windows running_bpb=1.177734 + sliding_eval [ 35.7%] 86432/242272 windows running_bpb=1.177847 + sliding_eval [ 36.3%] 88032/242272 windows running_bpb=1.177417 + sliding_eval [ 37.0%] 89632/242272 windows running_bpb=1.177001 + sliding_eval [ 37.7%] 91232/242272 windows running_bpb=1.176547 + sliding_eval [ 38.3%] 92832/242272 windows running_bpb=1.176781 + sliding_eval [ 39.0%] 94432/242272 windows running_bpb=1.175925 + sliding_eval [ 39.6%] 96032/242272 windows running_bpb=1.175925 + sliding_eval [ 40.3%] 97632/242272 windows running_bpb=1.175845 + sliding_eval [ 41.0%] 99232/242272 windows running_bpb=1.177034 + sliding_eval [ 41.6%] 100832/242272 windows running_bpb=1.177658 + sliding_eval [ 42.3%] 102432/242272 windows running_bpb=1.178020 + sliding_eval [ 42.9%] 104032/242272 windows running_bpb=1.177980 + sliding_eval [ 43.6%] 105632/242272 windows running_bpb=1.178632 + sliding_eval [ 44.3%] 107232/242272 windows running_bpb=1.177687 + sliding_eval [ 44.9%] 108832/242272 windows running_bpb=1.178057 + sliding_eval [ 45.6%] 110432/242272 windows running_bpb=1.178352 + sliding_eval [ 46.2%] 112032/242272 windows running_bpb=1.178338 + sliding_eval [ 46.9%] 113632/242272 windows running_bpb=1.177701 + sliding_eval [ 47.6%] 115232/242272 windows running_bpb=1.177368 + sliding_eval [ 48.2%] 116832/242272 windows running_bpb=1.174378 + sliding_eval [ 48.9%] 118432/242272 windows running_bpb=1.174092 + sliding_eval [ 49.5%] 120032/242272 windows running_bpb=1.173918 + sliding_eval [ 50.2%] 121632/242272 windows running_bpb=1.173851 + sliding_eval [ 50.9%] 123232/242272 windows running_bpb=1.174019 + sliding_eval [ 51.5%] 124832/242272 windows running_bpb=1.174827 + sliding_eval [ 52.2%] 126432/242272 windows running_bpb=1.174602 + sliding_eval [ 52.8%] 128032/242272 windows running_bpb=1.174815 + sliding_eval [ 53.5%] 129632/242272 windows running_bpb=1.175017 + sliding_eval [ 54.2%] 131232/242272 windows running_bpb=1.174614 + sliding_eval [ 54.8%] 132832/242272 windows running_bpb=1.173813 + sliding_eval [ 55.5%] 134432/242272 windows running_bpb=1.173461 + sliding_eval [ 56.1%] 136032/242272 windows running_bpb=1.173268 + sliding_eval [ 56.8%] 137632/242272 windows running_bpb=1.173000 + sliding_eval [ 57.5%] 139232/242272 windows running_bpb=1.172726 + sliding_eval [ 58.1%] 140832/242272 windows running_bpb=1.172391 + sliding_eval [ 58.8%] 142432/242272 windows running_bpb=1.172462 + sliding_eval [ 59.5%] 144032/242272 windows running_bpb=1.172369 + sliding_eval [ 60.1%] 145632/242272 windows running_bpb=1.172261 + sliding_eval [ 60.8%] 147232/242272 windows running_bpb=1.172265 + sliding_eval [ 61.4%] 148832/242272 windows running_bpb=1.172568 + sliding_eval [ 62.1%] 150432/242272 windows running_bpb=1.172452 + sliding_eval [ 62.8%] 152032/242272 windows running_bpb=1.171786 + sliding_eval [ 63.4%] 153632/242272 windows running_bpb=1.172063 + sliding_eval [ 64.1%] 155232/242272 windows running_bpb=1.172400 + sliding_eval [ 64.7%] 156832/242272 windows running_bpb=1.172808 + sliding_eval [ 65.4%] 158432/242272 windows running_bpb=1.172498 + sliding_eval [ 66.1%] 160032/242272 windows running_bpb=1.173118 + sliding_eval [ 66.7%] 161632/242272 windows running_bpb=1.173412 + sliding_eval [ 67.4%] 163232/242272 windows running_bpb=1.172877 + sliding_eval [ 68.0%] 164832/242272 windows running_bpb=1.173500 + sliding_eval [ 68.7%] 166432/242272 windows running_bpb=1.173890 + sliding_eval [ 69.4%] 168032/242272 windows running_bpb=1.175601 + sliding_eval [ 70.0%] 169632/242272 windows running_bpb=1.175732 + sliding_eval [ 70.7%] 171232/242272 windows running_bpb=1.174899 + sliding_eval [ 71.3%] 172832/242272 windows running_bpb=1.175155 + sliding_eval [ 72.0%] 174432/242272 windows running_bpb=1.175396 + sliding_eval [ 72.7%] 176032/242272 windows running_bpb=1.176067 + sliding_eval [ 73.3%] 177632/242272 windows running_bpb=1.176188 + sliding_eval [ 74.0%] 179232/242272 windows running_bpb=1.176370 + sliding_eval [ 74.6%] 180832/242272 windows running_bpb=1.176615 + sliding_eval [ 75.3%] 182432/242272 windows running_bpb=1.176492 + sliding_eval [ 76.0%] 184032/242272 windows running_bpb=1.175827 + sliding_eval [ 76.6%] 185632/242272 windows running_bpb=1.175994 + sliding_eval [ 77.3%] 187232/242272 windows running_bpb=1.176206 + sliding_eval [ 77.9%] 188832/242272 windows running_bpb=1.176310 + sliding_eval [ 78.6%] 190432/242272 windows running_bpb=1.176279 + sliding_eval [ 79.3%] 192032/242272 windows running_bpb=1.175735 + sliding_eval [ 79.9%] 193632/242272 windows running_bpb=1.178802 + sliding_eval [ 80.6%] 195232/242272 windows running_bpb=1.178572 + sliding_eval [ 81.2%] 196832/242272 windows running_bpb=1.178498 + sliding_eval [ 81.9%] 198432/242272 windows running_bpb=1.178449 + sliding_eval [ 82.6%] 200032/242272 windows running_bpb=1.178255 + sliding_eval [ 83.2%] 201632/242272 windows running_bpb=1.178472 + sliding_eval [ 83.9%] 203232/242272 windows running_bpb=1.178806 + sliding_eval [ 84.5%] 204832/242272 windows running_bpb=1.178162 + sliding_eval [ 85.2%] 206432/242272 windows running_bpb=1.177987 + sliding_eval [ 85.9%] 208032/242272 windows running_bpb=1.177657 + sliding_eval [ 86.5%] 209632/242272 windows running_bpb=1.177164 + sliding_eval [ 87.2%] 211232/242272 windows running_bpb=1.176886 + sliding_eval [ 87.8%] 212832/242272 windows running_bpb=1.176803 + sliding_eval [ 88.5%] 214432/242272 windows running_bpb=1.176779 + sliding_eval [ 89.2%] 216032/242272 windows running_bpb=1.176863 + sliding_eval [ 89.8%] 217632/242272 windows running_bpb=1.177357 + sliding_eval [ 90.5%] 219232/242272 windows running_bpb=1.177575 + sliding_eval [ 91.2%] 220832/242272 windows running_bpb=1.177385 + sliding_eval [ 91.8%] 222432/242272 windows running_bpb=1.177571 + sliding_eval [ 92.5%] 224032/242272 windows running_bpb=1.177323 + sliding_eval [ 93.1%] 225632/242272 windows running_bpb=1.178004 + sliding_eval [ 93.8%] 227232/242272 windows running_bpb=1.177806 + sliding_eval [ 94.5%] 228832/242272 windows running_bpb=1.177702 + sliding_eval [ 95.1%] 230432/242272 windows running_bpb=1.177568 + sliding_eval [ 95.8%] 232032/242272 windows running_bpb=1.177564 + sliding_eval [ 96.4%] 233632/242272 windows running_bpb=1.177262 + sliding_eval [ 97.1%] 235232/242272 windows running_bpb=1.177691 + sliding_eval [ 97.8%] 236832/242272 windows running_bpb=1.177699 + sliding_eval [ 98.4%] 238432/242272 windows running_bpb=1.177714 + sliding_eval [ 99.1%] 240032/242272 windows running_bpb=1.177715 + sliding_eval [ 99.7%] 241632/242272 windows running_bpb=1.177949 +final_int8_zlib_roundtrip val_loss:1.9770 val_bpb:1.1709 eval_time:311802ms +final_int8_zlib_roundtrip_exact val_loss:1.97703826 val_bpb:1.17091471