diff --git a/records/track_10min_16mb/2026-04-19_GatedDeltaNet_MacroPhase_Brotli_LegalTTT/README.md b/records/track_10min_16mb/2026-04-19_GatedDeltaNet_MacroPhase_Brotli_LegalTTT/README.md new file mode 100644 index 0000000000..cf1913389c --- /dev/null +++ b/records/track_10min_16mb/2026-04-19_GatedDeltaNet_MacroPhase_Brotli_LegalTTT/README.md @@ -0,0 +1,80 @@ +# GatedDeltaNet (FLA) + Legal Score-First TTT + Brotli-11 Compression + +**val_bpb: 1.01080** (3-seed mean, std 0.00115) | **~15.53 MB** (VALID; PR #1698 is 16.5-16.6 MB → INVALID) | 8×H100 80GB SXM + +## Summary + +This submission is built directly on @arsenis-cmd's PR #1698 (GatedDeltaNet + Legal Score-First TTT, 1.00995 BPB). PR #1698 is currently **invalid** because all 3 of its artifacts exceed the 16,000,000-byte decimal cap (16.47–16.60 MB); it cannot be merged as a record until that is fixed. + +Two changes vs PR #1698: + +1. **Compression: zstandard (level 22) → brotli (quality 11)**. This is the primary fix. Brotli compresses int6-GPTQ byte streams 5-8% better than zstandard on this model (verified: same bits, same weights → 15.54 MB vs 16.60 MB on seed 42). This brings the full-quality `clip_range=31` artifact comfortably under the 16,000,000-byte cap. + +2. **Optional macro-phase SGD TTT**: multi-phase consolidation layered on top of PR #1698's per-chunk SGD TTT, inspired by PR #1700's Multi-Phase Global SGD TTT. Disabled in the scored run (`TTT_MACRO_PHASES=0`) — on this base it was within noise (seed 42: -0.00999 with macro vs -0.01012 without, indistinguishable), but the infrastructure is left in place for future tuning. + +No other changes: architecture (K_KVShare_Wider, 10-layer GDN, 544d, 8H, KV-share stride=2), training (7000-step budget, Muon + Adam, EMA 0.997, SWA, Late QAT), and TTT (score-first SGD lr=0.005, 3 epochs/chunk, freeze first 2 blocks) are identical to PR #1698. + +## Results (8xH100 80GB SXM, torch 2.9.1+cu128) + +| Seed | EMA BPB | Pre-TTT BPB | **Post-TTT BPB** | TTT Gain | Artifact | +|------|---------|-------------|------------------|----------|----------| +| 42 | 1.00257 | 1.02189 | **1.01205** | -0.00984 | 15,543,829 B | +| 314 | 1.00033 | 1.01903 | **1.00978** | -0.00925 | 15,527,172 B | +| 999 | 1.00146 | 1.01986 | **1.01056** | -0.00930 | 15,524,066 B | +| **Mean** | **1.00146** | **1.02026** | **1.01080 (std 0.00115)** | **-0.00946** | 15,531,689 B | + +Beats merged SOTA (1.0810, PR #1493) by **-0.07020 BPB / ~-0.04867 nats**, clearing the 0.005-nat (~0.0072 BPB) threshold by a 10x margin. Seed 314 alone (1.00978) is lower than PR #1698's entire 3-seed mean of 1.00995. + +## Why this is the first valid sub-1.02 submission + +PR #1698's three artifacts: +- seed 42: 16,600,916 B (600,916 over cap) +- seed 314: 16,548,775 B (548,775 over cap) +- seed 999: 16,474,250 B (474,250 over cap) + +All violate the 16,000,000-byte decimal cap (Rules: "The cap is decimal 16MB, i.e. 16,000,000 total bytes, not 16 MiB / 16,777,216 bytes"). The author (@arsenis-cmd) acknowledged this in the PR comments and proposed reducing `clip_range` from 31 to 24. That fix works but introduces ~+0.015 BPB quantization penalty because more weights get clipped. + +This PR takes a different fix: keep `clip_range=31` (no extra quantization penalty) and replace zstandard-22 with brotli-11 for artifact compression. Brotli saves ~6% on this byte distribution, bringing all three artifacts well under 16,000,000 bytes with zero quality loss. + +## Compliance (Issue #1017 Track A) + +- **Condition 1 (Causality)**: Sliding-window eval is strictly causal (same as PR #1698) +- **Condition 2 (Normalized)**: Standard softmax over full vocab +- **Condition 3 (Score-before-update)**: Each 32K-token chunk is fully scored under `torch.inference_mode()` BEFORE any SGD update (same as PR #1698) +- **Condition 4 (Single pass)**: Each token scored exactly once; no rescoring across passes + +The `ttt_epochs=3` multi-epoch SGD on already-scored tokens is the same pattern used in PR #1698, PR #1700, and merged SOTA PR #1493 (`TTT_EPOCHS=3`). The interpretation of Condition 4 vs multi-epoch post-score training is pending organizer clarification — see discussion on PR #1698. + +## Reproduction + +```bash +pip install torch==2.9.1 --index-url https://download.pytorch.org/whl/cu128 +pip install numpy sentencepiece zstandard brotli triton==3.5.1 +pip install flash-linear-attention==0.4.2 fla-core==0.4.2 transformers==5.5.4 tokenizers==0.22.2 safetensors==0.7.0 + +MATCHED_FINEWEB_REPO_ID=kevclark/parameter-golf \ + python3 data/cached_challenge_fineweb.py --variant sp8192 + +for seed in 42 314 999; do + SEED=$seed \ + ARCH_MODE=K VOCAB_SIZE=8192 \ + DATA_PATH=./data/datasets/fineweb10B_sp8192 \ + TOKENIZER_PATH=./data/tokenizers/fineweb_8192_bpe.model \ + MAX_WALLCLOCK_SECONDS=600 \ + INT6_CLIP_RANGE=31 \ + COMPRESSOR=brotli \ + TTT_ENABLED=1 TTT_LR=0.005 TTT_EPOCHS=3 \ + TTT_CHUNK_TOKENS=32768 TTT_FREEZE_BLOCKS=2 TTT_MOMENTUM=0.9 \ + TTT_BATCH_SEQS=32 TTT_GRAD_CLIP=1.0 \ + TTT_MACRO_PHASES=0 \ + torchrun --standalone --nproc_per_node=8 train_gdn_7k.py +done +``` + +## Credits + +- **@arsenis-cmd** (PR #1698) — full base: GatedDeltaNet integration, K_KVShare_Wider config, all training and score-first TTT infrastructure. This submission changes only compression + adds optional (disabled-in-scored-run) macro-phase hook. +- **@resouer** (PR #1687) — K_KVShare_Wider architecture and FLA integration, consumed by PR #1698. +- **Flash Linear Attention** by @sustcsonglin — GatedDeltaNet Triton kernel (`fla-core==0.4.2`). +- **@Christopher-Lee-McClendon** (PR #461) — legal score-first TTT framework. +- **@jorge-asenjo** (PR #1700) / **@dexhunter** (PR #1626) — Multi-Phase Global SGD TTT concept (provides the macro-phase hook design; not used in scored run). diff --git a/records/track_10min_16mb/2026-04-19_GatedDeltaNet_MacroPhase_Brotli_LegalTTT/architectures.py b/records/track_10min_16mb/2026-04-19_GatedDeltaNet_MacroPhase_Brotli_LegalTTT/architectures.py new file mode 100644 index 0000000000..dff180156c --- /dev/null +++ b/records/track_10min_16mb/2026-04-19_GatedDeltaNet_MacroPhase_Brotli_LegalTTT/architectures.py @@ -0,0 +1,709 @@ +"""GDN Hybrid Architecture — modular blocks using FLA native layers. + +Supports 8 model variants (A-H) for the Parameter Golf screening experiments. +Each model is a stack of mixed {GDN, DeltaProduct, RWKV-7, Mamba-2, SWA} blocks +with shared MLP, RMSNorm, and residual connections. + +Key design choices: +- FLA layers handle recurrent attention (GatedDeltaNet, GatedDeltaProduct, RWKV7, Mamba2) +- Sliding Window Attention (SWA) uses flash attention with a causal window mask +- All blocks follow the same pre-norm residual pattern for uniform gradient flow +- Weight sharing for SWA layers in Zamba/Hymba-style models +- Score-first eval: XSA-all only extends attention layers (no future context leakage) +""" +from __future__ import annotations +import math +import os +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor + +# ─── FLA backend selection ────────────────────────────────────────────────── +# Set FLA_USE_NAIVE=1 to force pure-PyTorch (naive) kernels instead of Triton. +# This is needed when: +# - Running on V100 (sm_70) which doesn't support FLA's Triton kernels well +# - Triton cache is corrupted (FileNotFoundError on .json files) +# - Debugging without Triton dependency +# +# On A100 (sm_80+), the Triton kernels are ~3-10x faster and should be used. +_USE_NAIVE = os.environ.get("FLA_USE_NAIVE", "0") == "1" + +if _USE_NAIVE: + # 1. Patch GatedDeltaNet's chunk op + import fla.ops.gated_delta_rule.chunk as _gdr_chunk + import fla.ops.gated_delta_rule.naive as _gdr_naive + + def _patched_chunk_gated_delta_rule( + q, k, v, g, beta, scale=None, initial_state=None, + output_final_state=False, use_qk_l2norm_in_kernel=False, **kwargs + ): + if use_qk_l2norm_in_kernel: + q = F.normalize(q, p=2, dim=-1) + k = F.normalize(k, p=2, dim=-1) + return _gdr_naive.naive_chunk_gated_delta_rule( + q, k, v, g, beta, + chunk_size=64, scale=scale, + initial_state=initial_state, + output_final_state=output_final_state, + ) + + _gdr_chunk.chunk_gated_delta_rule = _patched_chunk_gated_delta_rule + import fla.layers.gated_deltanet as _gdn_layer + _gdn_layer.chunk_gated_delta_rule = _patched_chunk_gated_delta_rule + + # 2. Patch GatedDeltaProduct's chunk op + import fla.ops.gated_delta_product.chunk as _gdp_chunk + import fla.ops.gated_delta_product.naive as _gdp_naive + + def _patched_chunk_gated_delta_product( + q, k, v, g, beta, num_householder=1, scale=None, initial_state=None, + output_final_state=False, use_qk_l2norm_in_kernel=False, **kwargs + ): + if use_qk_l2norm_in_kernel: + q = F.normalize(q, p=2, dim=-1) + k = F.normalize(k, p=2, dim=-1) + return _gdp_naive.naive_recurrent_gated_delta_product( + q, k, v, g, beta, + scale=scale, cu_seqlens=None, + initial_state=initial_state, + output_final_state=output_final_state, + num_householder=num_householder, + ) + + _gdp_chunk.chunk_gated_delta_product = _patched_chunk_gated_delta_product + import fla.layers.gated_deltaproduct as _gdp_layer + _gdp_layer.chunk_gated_delta_product = _patched_chunk_gated_delta_product + + print("[FLA] Using NAIVE (pure-PyTorch) kernels — set FLA_USE_NAIVE=0 for Triton", flush=True) + +# FLA imports +from fla.layers import GatedDeltaNet, GatedDeltaProduct, Mamba2 +try: + from fla.layers import RWKV7Attention +except Exception: + RWKV7Attention = None # type: ignore + +try: + from flash_attn_interface import flash_attn_func as flash_attn_3_func +except ImportError: + def flash_attn_3_func(q, k, v, causal=False, window_size=(-1, -1)): + q2 = q.transpose(1, 2) + k2 = k.transpose(1, 2) + v2 = v.transpose(1, 2) + if k2.size(1) != q2.size(1): + rep = q2.size(1) // k2.size(1) + k2 = k2.repeat_interleave(rep, dim=1) + v2 = v2.repeat_interleave(rep, dim=1) + out = torch.nn.functional.scaled_dot_product_attention(q2, k2, v2, is_causal=causal) + return out.transpose(1, 2) + + +class RMSNorm(nn.Module): + def __init__(self, dim: int | None = None, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.dim = dim + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + """Linear layer that casts input to weight dtype for mixed precision. + Supports late QAT (int6 STE) when _qat_enabled is set.""" + _qat_enabled: bool = False + + def forward(self, x: Tensor) -> Tensor: + w = self.weight.to(dtype=x.dtype) + if CastedLinear._qat_enabled and self.training and w.ndim == 2: + with torch.no_grad(): + w32 = self.weight.float() + row_max = w32.abs().amax(dim=1) + scale = (row_max / 31.0).clamp_min(1.0 / 31.0) + w_q = (torch.clamp(torch.round(w32 / scale[:, None]), -31, 31) * scale[:, None]).to(x.dtype) + w = w + (w_q - w).detach() # STE: forward uses quantized, backward uses full + bias = self.bias.to(dtype=x.dtype) if self.bias is not None else None + return F.linear(x, w, bias) + + +class Rotary(nn.Module): + """RoPE embeddings for sliding window attention.""" + def __init__(self, dim: int, base: float = 10000.0, max_len: int = 4096): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.max_len = max_len + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype): + t = torch.arange(seq_len, device=device, dtype=torch.float32) + freqs = torch.outer(t, self.inv_freq.to(device)) + cos = freqs.cos().to(dtype) + sin = freqs.sin().to(dtype) + return cos, sin + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + """Apply RoPE to the input tensor.""" + d = x.shape[-1] // 2 + x1, x2 = x[..., :d], x[..., d:] + out1 = x1 * cos[:x.shape[-2]] - x2 * sin[:x.shape[-2]] + out2 = x2 * cos[:x.shape[-2]] + x1 * sin[:x.shape[-2]] + return torch.cat([out1, out2], dim=-1) + + +class MLP(nn.Module): + """Feed-forward MLP with configurable activation.""" + def __init__(self, dim: int, mult: float = 3.0, act: str = "relu_sq", leaky_slope: float = 0.5): + super().__init__() + hidden = int(mult * dim) + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + nn.init.zeros_(self.proj.weight) # zero-init output for residual + self.act = act + self.leaky_slope = leaky_slope + + def forward(self, x: Tensor) -> Tensor: + x = self.fc(x) + if self.act == "leaky_relu_sq": + x = F.leaky_relu(x, negative_slope=self.leaky_slope) + else: + x = F.relu(x) + return self.proj(x.square()) + + +class SlidingWindowAttention(nn.Module): + """Sliding window causal attention for hybrid models. + + Supports XSA (cross-segment attention) at eval time for extending context + across eval chunks. Window is enforced during training but can be relaxed at eval. + KV can be shared across layers (Zamba-style) by reusing the same module. + """ + def __init__( + self, + dim: int, + num_heads: int = 8, + num_kv_heads: int = 4, + window_size: int = 512, + rope_base: float = 10000.0, + qk_gain_init: float = 1.5, + ): + super().__init__() + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + kv_dim = num_kv_heads * self.head_dim + self.window_size = window_size + + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + nn.init.zeros_(self.proj.weight) + + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + self.use_xsa = False # enabled at eval time for XSA-all + + def _xsa_efficient(self, y: Tensor, v: Tensor) -> Tensor: + """XSA: subtract self-value projection (GQA-aware).""" + B, T, H, D = y.shape + Hkv = v.size(-2) + group = H // Hkv + y_g = y.reshape(B, T, Hkv, group, D) + vn = F.normalize(v, dim=-1).unsqueeze(-2) + proj = (y_g * vn).sum(dim=-1, keepdim=True) * vn + return (y_g - proj).reshape(B, T, H, D) + + def forward(self, x: Tensor, v_embed: Tensor | None = None) -> Tensor: + B, T, D = x.shape + q = self.c_q(x).reshape(B, T, self.num_heads, self.head_dim) + k = self.c_k(x).reshape(B, T, self.num_kv_heads, self.head_dim) + v = self.c_v(x) + if v_embed is not None: + v = v + v_embed + v = v.reshape(B, T, self.num_kv_heads, self.head_dim) + + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(T, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, None, :, None] + + if q.is_cuda and q.dtype not in (torch.float16, torch.bfloat16): + q, k, v = q.to(torch.bfloat16), k.to(torch.bfloat16), v.to(torch.bfloat16) + + # Use window during training, full causal at eval if XSA enabled + y = flash_attn_3_func(q, k, v, causal=True) + + if self.use_xsa: + y = self._xsa_efficient(y, v) + + y = y.reshape(B, T, D) + return self.proj(y) + + +class RecurrentBlock(nn.Module): + """Wraps any FLA recurrent layer (GDN, DeltaProduct, RWKV-7, Mamba-2) with + pre-norm residual connection and MLP.""" + + def __init__( + self, + dim: int, + recurrent_layer: nn.Module, + mlp_mult: float = 3.0, + mlp_act: str = "relu_sq", + layer_idx: int = 0, + ): + super().__init__() + self.attn_norm = RMSNorm(dim) + self.mlp_norm = RMSNorm(dim) + self.recurrent = recurrent_layer + self.mlp = MLP(dim, mlp_mult, act=mlp_act) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + self.layer_idx = layer_idx + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x_in = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + + # FLA layers return (output, state) or just output depending on mode + recurrent_out = self.recurrent(self.attn_norm(x_in)) + if isinstance(recurrent_out, tuple): + recurrent_out = recurrent_out[0] + + x_out = x_in + self.attn_scale.to(dtype=x_in.dtype)[None, None, :] * recurrent_out + x_out = x_out + self.mlp_scale.to(dtype=x_out.dtype)[None, None, :] * self.mlp(self.mlp_norm(x_out)) + return x_out + + +class AttentionBlock(nn.Module): + """SWA block with pre-norm residual and MLP.""" + + def __init__( + self, + dim: int, + swa: SlidingWindowAttention, + mlp_mult: float = 3.0, + mlp_act: str = "relu_sq", + layer_idx: int = 0, + ): + super().__init__() + self.attn_norm = RMSNorm(dim) + self.mlp_norm = RMSNorm(dim) + self.attn = swa + self.mlp = MLP(dim, mlp_mult, act=mlp_act) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + self.layer_idx = layer_idx + + def forward(self, x: Tensor, x0: Tensor, v_embed: Tensor | None = None) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x_in = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x_in), v_embed=v_embed) + x_out = x_in + self.attn_scale.to(dtype=x_in.dtype)[None, None, :] * attn_out + x_out = x_out + self.mlp_scale.to(dtype=x_out.dtype)[None, None, :] * self.mlp(self.mlp_norm(x_out)) + return x_out + + +class SmearGate(nn.Module): + """Weighted average of current and previous token embeddings.""" + def __init__(self, dim: int): + super().__init__() + self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + + def forward(self, x: Tensor) -> Tensor: + g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] + x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) + return (1 - g) * x + g * x_prev + + +class BigramHashEmbedding(nn.Module): + """Hash-based bigram/trigram embedding for additional context.""" + def __init__(self, bigram_vocab_size: int, bigram_dim: int, model_dim: int, + trigram: bool = False): + super().__init__() + self.bigram_vocab_size = bigram_vocab_size + self._trigram = trigram + self.embed = nn.Embedding(bigram_vocab_size, bigram_dim) + nn.init.zeros_(self.embed.weight) + self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) + + def bigram_hash(self, tokens: Tensor) -> Tensor: + t = tokens.to(torch.int32) + mod = self.bigram_vocab_size - 1 + out = torch.empty_like(t) + out[..., 0] = mod + out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod + return out.long() + + def trigram_hash(self, tokens: Tensor) -> Tensor: + t = tokens.to(torch.int32) + mod = self.bigram_vocab_size - 1 + out = torch.empty_like(t) + out[..., :2] = mod + out[..., 2:] = (36313 * t[..., 2:] ^ 27191 * t[..., 1:-1] ^ 51497 * t[..., :-2]) % mod + return out.long() + + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(self.bigram_hash(token_ids)) + if self._trigram: + h = h + self.embed(self.trigram_hash(token_ids)) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) + + +def _parse_layout(layout_str: str) -> list[tuple[str, int]]: + """Parse a layout string into a sequence of (layer_type, count) pairs. + + Examples: + "gdn_only" -> [("gdn", 11)] (count filled in by caller) + "gdn5_swa_gdn5_swa_shared" -> [("gdn", 5), ("swa", 1), ("gdn", 5), ("swa_shared", 1)] + "gdn3_swa_gdn3_swa_shared_gdn3" -> [("gdn", 3), ("swa", 1), ("gdn", 3), ("swa_shared", 1), ("gdn", 3)] + "mamba_only" -> [("mamba", 11)] + "gdn3_mamba2_swa_gdn3_mamba2" -> [("gdn", 3), ("mamba", 2), ("swa", 1), ("gdn", 3), ("mamba", 2)] + """ + if layout_str == "gdn_only": + return [("gdn", -1)] # -1 = use num_gdn_layers + if layout_str == "mamba_only": + return [("mamba", -1)] # -1 = use num_mamba_layers + if layout_str == "swa_only": + return [("swa", -1)] # -1 = use num_swa_layers + + # Parse custom layouts like "gdn5_swa_gdn5_swa_shared" + parts = layout_str.split("_") + result = [] + i = 0 + while i < len(parts): + part = parts[i] + if part.startswith("gdn") and len(part) > 3: + count = int(part[3:]) + result.append(("gdn", count)) + elif part.startswith("mamba") and len(part) > 5: + count = int(part[5:]) + result.append(("mamba", count)) + elif part == "swa": + # Check if next token is "shared" + if i + 1 < len(parts) and parts[i + 1] == "shared": + result.append(("swa_shared", 1)) + i += 1 + else: + result.append(("swa", 1)) + elif part == "shared": + # Already consumed by swa check above + pass + i += 1 + return result + + +class HybridGDN(nn.Module): + """Hybrid GDN architecture supporting mixed recurrent/attention layers. + + Builds a stack of blocks according to the layer_layout specification: + - "gdn" blocks use GatedDeltaNet (or GatedDeltaProduct, or RWKV-7) + - "mamba" blocks use Mamba-2 + - "swa" blocks use SlidingWindowAttention + - "swa_shared" reuses the same SWA module (Zamba-style weight sharing) + + All models share: token embedding, bigram hash, smear gate, final norm, lm_head. + """ + def __init__(self, config: dict, vocab_size: int = 1024): + super().__init__() + dim = config["model_dim"] + num_heads = config["num_heads"] + mlp_mult = config["mlp_mult"] + self.arch_name = config["arch_name"] + self.model_dim = dim + self.vocab_size = vocab_size + self.logit_softcap = 30.0 + + # Embeddings + self.tok_emb = nn.Embedding(vocab_size, dim) + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=0.005) + self.bigram = BigramHashEmbedding( + config.get("bigram_vocab_size", 2048), + config.get("bigram_dim", 128), + dim, + trigram=config.get("trigram", False), + ) + self.smear = SmearGate(dim) + + # Meta tokens (Hymba-style, for Model E) + n_meta = config.get("meta_tokens", 0) + if n_meta > 0: + self.meta_tokens = nn.Parameter(torch.randn(1, n_meta, dim) * 0.02) + self.n_meta = n_meta + else: + self.meta_tokens = None + self.n_meta = 0 + + # Build layer stack + layout = _parse_layout(config["layer_layout"]) + self.blocks = nn.ModuleList() + self._block_types = [] # track type for XSA/diagnostics + self._shared_swa = None # shared SWA module for Zamba/Hymba models + + layer_idx = 0 + for layer_type, count in layout: + if count == -1: + # Fill with the specified layer type + if layer_type == "gdn": + count = config["num_gdn_layers"] + elif layer_type == "mamba": + count = config["num_mamba_layers"] + elif layer_type == "swa": + count = config["num_swa_layers"] + + for _ in range(count): + if layer_type == "gdn": + recurrent = self._make_recurrent_layer(config, layer_idx) + block = RecurrentBlock(dim, recurrent, mlp_mult, layer_idx=layer_idx) + self.blocks.append(block) + self._block_types.append("gdn") + + elif layer_type == "mamba": + mamba_expand = config.get("mamba_expand", 2) + mamba_head_dim = config.get("gdn_head_dim", 64) + mamba_num_heads = (dim * mamba_expand) // mamba_head_dim + mamba = Mamba2( + num_heads=mamba_num_heads, + head_dim=mamba_head_dim, + hidden_size=dim, + state_size=config.get("mamba_state_size", 64), + expand=mamba_expand, + layer_idx=layer_idx, + ) + block = RecurrentBlock(dim, mamba, mlp_mult, layer_idx=layer_idx) + self.blocks.append(block) + self._block_types.append("mamba") + + elif layer_type in ("swa", "swa_shared"): + if layer_type == "swa_shared" and self._shared_swa is not None: + swa = self._shared_swa # reuse same SWA module + else: + swa = SlidingWindowAttention( + dim=dim, + num_heads=num_heads, + num_kv_heads=config.get("swa_num_kv_heads", 4), + window_size=config.get("swa_window", 512), + ) + if config.get("swa_shared", False): + self._shared_swa = swa + + # Each SWA position gets its own MLP even if SWA weights are shared + block = AttentionBlock(dim, swa, mlp_mult, layer_idx=layer_idx) + self.blocks.append(block) + self._block_types.append("swa" if layer_type == "swa" else "swa_shared") + + layer_idx += 1 + + # KV sharing: share k/v projections between adjacent layers + kv_stride = config.get("kv_sharing_stride", 0) + if kv_stride > 0: + self._apply_kv_sharing(kv_stride) + + self.final_norm = RMSNorm(dim) + # Tied embeddings (standard for parameter golf) + self.lm_head = None # use tok_emb.weight + self._init_weights() + + def _make_recurrent_layer(self, config: dict, layer_idx: int) -> nn.Module: + """Create the appropriate recurrent layer based on config.""" + dim = config["model_dim"] + num_heads = config["num_heads"] + + if config.get("use_rwkv7", False): + total_layers = config.get("num_gdn_layers", 11) + return RWKV7Attention( + hidden_size=dim, + head_dim=config.get("gdn_head_dim", 64), + num_heads=num_heads, + layer_idx=layer_idx, + num_hidden_layers=total_layers, + mode="chunk", + ) + elif config.get("use_deltaproduct", False): + return GatedDeltaProduct( + hidden_size=dim, + head_dim=config.get("gdn_head_dim", 64), + num_heads=num_heads, + num_householder=config.get("dp_num_householder", 2), + allow_neg_eigval=config.get("dp_allow_neg_eigval", False), + use_short_conv=config.get("gdn_use_short_conv", True), + expand_v=config.get("gdn_expand_v", 1), + layer_idx=layer_idx, + mode="chunk", + ) + else: + # Default: GatedDeltaNet + return GatedDeltaNet( + hidden_size=dim, + head_dim=config.get("gdn_head_dim", 64), + num_heads=num_heads, + allow_neg_eigval=config.get("gdn_allow_neg_eigval", False), + use_short_conv=config.get("gdn_use_short_conv", True), + expand_v=config.get("gdn_expand_v", 1), + layer_idx=layer_idx, + mode="chunk", + ) + + def _apply_kv_sharing(self, stride: int) -> None: + """Share KV projection modules between adjacent layer groups. + + For GDN layers: shares k_proj, v_proj, k_conv1d, v_conv1d. + For SWA layers: shares c_k, c_v. + The first layer in each group is the anchor; subsequent layers in the + group become followers that reference the anchor's modules. + """ + # Collect indices by block type + gdn_indices = [i for i, t in enumerate(self._block_types) if t == "gdn"] + swa_indices = [i for i, t in enumerate(self._block_types) + if t in ("swa", "swa_shared")] + + # Share GDN KV projections within each stride-group + for group_start in range(0, len(gdn_indices), stride): + anchor_idx = gdn_indices[group_start] + anchor = self.blocks[anchor_idx].recurrent + for j in range(1, stride): + if group_start + j >= len(gdn_indices): + break + follower_idx = gdn_indices[group_start + j] + follower = self.blocks[follower_idx].recurrent + follower.k_proj = anchor.k_proj + follower.v_proj = anchor.v_proj + follower.k_conv1d = anchor.k_conv1d + follower.v_conv1d = anchor.v_conv1d + + # Share SWA KV projections within each stride-group + for group_start in range(0, len(swa_indices), stride): + anchor_idx = swa_indices[group_start] + anchor = self.blocks[anchor_idx].attn + for j in range(1, stride): + if group_start + j >= len(swa_indices): + break + follower_idx = swa_indices[group_start + j] + follower = self.blocks[follower_idx].attn + follower.c_k = anchor.c_k + follower.c_v = anchor.c_v + + def _init_weights(self) -> None: + """Weight initialization. + + Each sub-module handles its own init (MLP zeros proj, SWA zeros proj, + FLA layers do own init). We just do the residual scaling for output + projections on our own CastedLinear layers. + """ + total_layers = len(self.blocks) + for name, p in self.named_parameters(): + # Skip FLA-internal parameters + if ".recurrent." in name: + continue + # Scale down output projections for residual stream + if p.ndim == 2 and "proj" in name and "bigram" not in name: + with torch.no_grad(): + p.mul_(1.0 / math.sqrt(2 * total_layers)) + + def set_xsa(self, enable: bool = True) -> None: + """Enable/disable XSA on all attention blocks.""" + for block, btype in zip(self.blocks, self._block_types): + if btype in ("swa", "swa_shared"): + block.attn.use_xsa = enable + + def _compute_logits(self, x: Tensor) -> Tensor: + """Compute logits with tied embeddings and softcap.""" + logits = F.linear(x, self.tok_emb.weight) + return self.logit_softcap * torch.tanh(logits / self.logit_softcap) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + """Forward pass returning cross-entropy loss.""" + x = self.tok_emb(input_ids) + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + + # Prepend meta tokens if Hymba-style + if self.meta_tokens is not None: + B = x.shape[0] + meta = self.meta_tokens.expand(B, -1, -1).to(dtype=x.dtype) + x = torch.cat([meta, x], dim=1) + x0 = torch.cat([meta, x0], dim=1) + + for block, btype in zip(self.blocks, self._block_types): + if btype in ("swa", "swa_shared"): + x = block(x, x0) + else: + x = block(x, x0) + + # Remove meta tokens before computing logits + if self.meta_tokens is not None: + x = x[:, self.n_meta:] + + x = self.final_norm(x) + logits = self._compute_logits(x.reshape(-1, x.size(-1))) + targets = target_ids.reshape(-1) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + def forward_logits(self, input_ids: Tensor) -> Tensor: + """Forward pass returning logits (for evaluation).""" + x = self.tok_emb(input_ids) + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + + if self.meta_tokens is not None: + B = x.shape[0] + meta = self.meta_tokens.expand(B, -1, -1).to(dtype=x.dtype) + x = torch.cat([meta, x], dim=1) + x0 = torch.cat([meta, x0], dim=1) + + for block, btype in zip(self.blocks, self._block_types): + if btype in ("swa", "swa_shared"): + x = block(x, x0) + else: + x = block(x, x0) + + if self.meta_tokens is not None: + x = x[:, self.n_meta:] + + x = self.final_norm(x) + return self._compute_logits(x) + + def get_diagnostics(self) -> dict: + """Collect per-layer weight statistics for checkpoint diagnostics.""" + diag = {} + for i, (block, btype) in enumerate(zip(self.blocks, self._block_types)): + prefix = f"layer_{i}_{btype}" + for name, param in block.named_parameters(): + if param.ndim >= 2: + w = param.data.float() + diag[f"{prefix}/{name}/std"] = w.std().item() + diag[f"{prefix}/{name}/kurtosis"] = (((w - w.mean()) / (w.std() + 1e-8)) ** 4).mean().item() - 3.0 + return diag + + def count_params(self) -> dict: + """Count parameters by category.""" + cats = {"embedding": 0, "recurrent": 0, "attention": 0, "mlp": 0, "other": 0} + for name, p in self.named_parameters(): + n = p.numel() + if "tok_emb" in name or "bigram" in name: + cats["embedding"] += n + elif any(k in name for k in ["recurrent", "gdn", "mamba", "rwkv", "delta"]): + cats["recurrent"] += n + elif "attn" in name or "c_q" in name or "c_k" in name or "c_v" in name: + cats["attention"] += n + elif "mlp" in name or "fc" in name: + cats["mlp"] += n + else: + cats["other"] += n + cats["total"] = sum(cats.values()) + return cats diff --git a/records/track_10min_16mb/2026-04-19_GatedDeltaNet_MacroPhase_Brotli_LegalTTT/configs.py b/records/track_10min_16mb/2026-04-19_GatedDeltaNet_MacroPhase_Brotli_LegalTTT/configs.py new file mode 100644 index 0000000000..5bbdac3bd4 --- /dev/null +++ b/records/track_10min_16mb/2026-04-19_GatedDeltaNet_MacroPhase_Brotli_LegalTTT/configs.py @@ -0,0 +1,316 @@ +"""Model architecture configurations for GDN Hybrid experiments. + +Each config returns a dict consumed by HybridGDN.__init__. +All models are sized to fit ~16MB at int6+zstd-22. + +Models A-H: baseline architecture sweeps. +Models I-K: KV sharing experiments (kv_sharing_stride=2). +""" +from __future__ import annotations + + +def model_a_pure_gdn() -> dict: + """Model A: Pure GDN (Baseline) — 10 layers Gated DeltaNet.""" + return dict( + arch_name="A_PureGDN", + num_gdn_layers=10, + num_mamba_layers=0, + num_swa_layers=0, + swa_shared=False, + model_dim=512, + num_heads=8, + mlp_mult=3.0, + gdn_expand_v=1, + gdn_head_dim=64, + gdn_allow_neg_eigval=False, + gdn_use_short_conv=True, + swa_window=512, + swa_num_kv_heads=4, + meta_tokens=0, + layer_layout="gdn_only", + bigram_vocab_size=3072, + bigram_dim=112, + trigram=True, + ) + + +def model_b_deltaproduct() -> dict: + """Model B: Gated DeltaProduct n_h=2 — rank-2 state transitions.""" + return dict( + arch_name="B_DeltaProduct", + num_gdn_layers=10, # 10 layers to fit param budget + num_mamba_layers=0, + num_swa_layers=0, + swa_shared=False, + model_dim=480, # slightly narrower to fit 16MB + num_heads=8, + mlp_mult=3.0, + use_deltaproduct=True, + dp_num_householder=2, + dp_allow_neg_eigval=False, + gdn_expand_v=1, + gdn_head_dim=64, + gdn_use_short_conv=True, + swa_window=512, + swa_num_kv_heads=4, + meta_tokens=0, + layer_layout="gdn_only", + ) + + +def model_b2_deltaproduct_neg() -> dict: + """Model B2: DeltaProduct + negative eigenvalues.""" + cfg = model_b_deltaproduct() + cfg["arch_name"] = "B2_DeltaProduct_NegEig" + cfg["dp_allow_neg_eigval"] = True + return cfg + + +def model_c_gdn_neg() -> dict: + """Model C: GDN with negative eigenvalues — richer state dynamics. + + (Originally RWKV-7, replaced because RWKV7 requires Triton kernels with + no pure-PyTorch fallback available.) + """ + return dict( + arch_name="C_GDN_NegEig", + num_gdn_layers=11, + num_mamba_layers=0, + num_swa_layers=0, + swa_shared=False, + model_dim=512, + num_heads=8, + mlp_mult=3.0, + gdn_expand_v=1, + gdn_head_dim=64, + gdn_allow_neg_eigval=True, # Key difference: negative eigenvalues + gdn_use_short_conv=True, + swa_window=512, + swa_num_kv_heads=4, + meta_tokens=0, + layer_layout="gdn_only", + ) + + +def model_d_gdn_1swa() -> dict: + """Model D: GDN + 1 Shared SWA (Zamba-style).""" + return dict( + arch_name="D_GDN_1SWA", + num_gdn_layers=10, + num_mamba_layers=0, + num_swa_layers=1, + swa_shared=True, + model_dim=512, + num_heads=8, + mlp_mult=3.0, + gdn_expand_v=1, + gdn_head_dim=64, + gdn_allow_neg_eigval=False, + gdn_use_short_conv=True, + swa_window=512, + swa_num_kv_heads=4, + meta_tokens=0, + # Layout: [GDN×5] → [SWA] → [GDN×5] → [SWA_shared] + layer_layout="gdn5_swa_gdn5_swa_shared", + ) + + +def model_e_gdn_2swa() -> dict: + """Model E: GDN + 2 Shared SWA (Hymba-inspired) with meta-tokens.""" + return dict( + arch_name="E_GDN_2SWA_Hymba", + num_gdn_layers=9, + num_mamba_layers=0, + num_swa_layers=1, # 1 unique, shared at 2 positions + swa_shared=True, + model_dim=512, + num_heads=8, + mlp_mult=3.0, + gdn_expand_v=1, + gdn_head_dim=64, + gdn_allow_neg_eigval=False, + gdn_use_short_conv=True, + swa_window=512, + swa_num_kv_heads=4, + meta_tokens=4, # Hymba-style prepended meta-tokens + # Layout: [GDN×3] → [SWA] → [GDN×3] → [SWA_shared] → [GDN×3] + layer_layout="gdn3_swa_gdn3_swa_shared_gdn3", + ) + + +def model_f_mamba2() -> dict: + """Model F: Mamba-2 Pure (Mamba-3 proxy with RoPE on B/C).""" + return dict( + arch_name="F_Mamba2", + num_gdn_layers=0, + num_mamba_layers=11, + num_swa_layers=0, + swa_shared=False, + model_dim=512, + num_heads=8, + mlp_mult=3.0, + mamba_state_size=64, + mamba_expand=2, + gdn_expand_v=1, + gdn_head_dim=64, + gdn_use_short_conv=True, + swa_window=512, + swa_num_kv_heads=4, + meta_tokens=0, + layer_layout="mamba_only", + ) + + +def model_g_hybrid() -> dict: + """Model G: GDN + Mamba-2 + SWA triple hybrid.""" + return dict( + arch_name="G_GDN_Mamba_SWA", + num_gdn_layers=6, + num_mamba_layers=4, + num_swa_layers=1, + swa_shared=False, + model_dim=512, + num_heads=8, + mlp_mult=3.0, + mamba_state_size=64, + mamba_expand=2, + gdn_expand_v=1, + gdn_head_dim=64, + gdn_allow_neg_eigval=False, + gdn_use_short_conv=True, + swa_window=512, + swa_num_kv_heads=4, + meta_tokens=0, + # Layout: [GDN×3] → [Mamba×2] → [SWA] → [GDN×3] → [Mamba×2] + layer_layout="gdn3_mamba2_swa_gdn3_mamba2", + ) + + +def model_h_pure_swa() -> dict: + """Model H: Pure Sliding Window Attention (standard softmax) — control baseline. + + All 10 layers use causal sliding-window softmax attention (no GDN). + Same MLP, embedding, and normalization as Model A for fair comparison. + """ + return dict( + arch_name="H_PureSWA", + num_gdn_layers=0, + num_mamba_layers=0, + num_swa_layers=10, + swa_shared=False, + model_dim=512, + num_heads=8, + mlp_mult=3.0, + gdn_expand_v=1, + gdn_head_dim=64, + gdn_allow_neg_eigval=False, + gdn_use_short_conv=True, + swa_window=512, + swa_num_kv_heads=4, + meta_tokens=0, + layer_layout="swa_only", + bigram_vocab_size=3072, + bigram_dim=112, + trigram=True, + ) + + +def model_i_kv_share() -> dict: + """Model I: GDN + KV Share — same as A but with kv_sharing_stride=2.""" + return dict( + arch_name="I_KVShare", + num_gdn_layers=10, + num_mamba_layers=0, + num_swa_layers=0, + swa_shared=False, + model_dim=512, + num_heads=8, + mlp_mult=3.0, + gdn_expand_v=1, + gdn_head_dim=64, + gdn_allow_neg_eigval=False, + gdn_use_short_conv=True, + swa_window=512, + swa_num_kv_heads=4, + meta_tokens=0, + layer_layout="gdn_only", + bigram_vocab_size=3072, + bigram_dim=112, + trigram=True, + kv_sharing_stride=2, + ) + + +def model_j_kv_share_deeper() -> dict: + """Model J: GDN + KV Share + Deeper — 12L dim=480, near iso-parameter to A.""" + return dict( + arch_name="J_KVShare_Deeper", + num_gdn_layers=12, + num_mamba_layers=0, + num_swa_layers=0, + swa_shared=False, + model_dim=480, + num_heads=8, + mlp_mult=3.0, + gdn_expand_v=1, + gdn_head_dim=64, + gdn_allow_neg_eigval=False, + gdn_use_short_conv=True, + swa_window=512, + swa_num_kv_heads=4, + meta_tokens=0, + layer_layout="gdn_only", + bigram_vocab_size=3072, + bigram_dim=112, + trigram=True, + kv_sharing_stride=2, + ) + + +def model_k_kv_share_wider() -> dict: + """Model K: GDN + KV Share + Wider — 10L dim=544, iso-parameter to A.""" + return dict( + arch_name="K_KVShare_Wider", + num_gdn_layers=10, + num_mamba_layers=0, + num_swa_layers=0, + swa_shared=False, + model_dim=544, + num_heads=8, + mlp_mult=3.0, + gdn_expand_v=1, + gdn_head_dim=64, + gdn_allow_neg_eigval=False, + gdn_use_short_conv=True, + swa_window=512, + swa_num_kv_heads=4, + meta_tokens=0, + layer_layout="gdn_only", + bigram_vocab_size=3072, + bigram_dim=112, + trigram=True, + kv_sharing_stride=2, + ) + + +ALL_CONFIGS = { + "A": model_a_pure_gdn, + "B": model_b_deltaproduct, + "B2": model_b2_deltaproduct_neg, + "C": model_c_gdn_neg, + "D": model_d_gdn_1swa, + "E": model_e_gdn_2swa, + "F": model_f_mamba2, + "G": model_g_hybrid, + "H": model_h_pure_swa, + "I": model_i_kv_share, + "J": model_j_kv_share_deeper, + "K": model_k_kv_share_wider, +} + + +def get_config(model_id: str) -> dict: + """Get config by model ID (A, B, B2, C, D, E, F, G, H, I, J, K).""" + if model_id not in ALL_CONFIGS: + raise ValueError(f"Unknown model ID '{model_id}'. Choose from {list(ALL_CONFIGS.keys())}") + return ALL_CONFIGS[model_id]() diff --git a/records/track_10min_16mb/2026-04-19_GatedDeltaNet_MacroPhase_Brotli_LegalTTT/requirements.txt b/records/track_10min_16mb/2026-04-19_GatedDeltaNet_MacroPhase_Brotli_LegalTTT/requirements.txt new file mode 100644 index 0000000000..4e433194cb --- /dev/null +++ b/records/track_10min_16mb/2026-04-19_GatedDeltaNet_MacroPhase_Brotli_LegalTTT/requirements.txt @@ -0,0 +1,11 @@ +numpy +torch==2.9.1 +sentencepiece +zstandard +brotli +flash-linear-attention==0.4.2 +fla-core==0.4.2 +triton +transformers==5.5.4 +tokenizers==0.22.2 +safetensors==0.7.0 diff --git a/records/track_10min_16mb/2026-04-19_GatedDeltaNet_MacroPhase_Brotli_LegalTTT/submission.json b/records/track_10min_16mb/2026-04-19_GatedDeltaNet_MacroPhase_Brotli_LegalTTT/submission.json new file mode 100644 index 0000000000..4b04f6d8ad --- /dev/null +++ b/records/track_10min_16mb/2026-04-19_GatedDeltaNet_MacroPhase_Brotli_LegalTTT/submission.json @@ -0,0 +1,37 @@ +{ + "author": "yahya010", + "github_id": "yahya010", + "name": "GatedDeltaNet (FLA) + Legal Score-First TTT + Brotli-11 Compression", + "blurb": "First VALID sub-1.02 BPB submission. PR #1698's GatedDeltaNet + Legal Score-First TTT stack with zstandard-22 replaced by brotli-11 compression. Brotli saves ~6% over zstd on the int6-GPTQ byte stream, bringing all 3 artifacts comfortably under the 16,000,000-byte cap while keeping clip_range=31 (no extra quant penalty). PR #1698 itself is invalid (artifacts 16.47-16.60 MB, over the 16,000,000-byte decimal cap). Macro-phase SGD TTT hook added (from PR #1700's Multi-Phase design) but disabled in the scored run (ttt_macro_phases=0) because on seed 42 it was indistinguishable from vanilla per-chunk SGD (-0.00999 vs -0.01012).", + "date": "2026-04-19", + "track": "10min_16mb", + "val_bpb": 1.01080, + "val_bpb_std": 0.00115, + "seeds": [42, 314, 999], + "seed_results": { + "42": {"val_bpb": 1.01205, "artifact_bytes": 15543829}, + "314": {"val_bpb": 1.00978, "artifact_bytes": 15527172}, + "999": {"val_bpb": 1.01056, "artifact_bytes": 15524066} + }, + "hardware": "8xH100 80GB SXM", + "pytorch_version": "2.9.1+cu128", + "technique_summary": "SP8192 + GatedDeltaNet (FLA linear attention, K_KVShare_Wider from PR #1687) + Legal Score-First SGD TTT (3 epochs per 32K-token chunk, freeze first 2 blocks) + int6 GPTQ (clip_range=31) + int8 embeddings + brotli-11 compression. Optional macro-phase SGD hook from PR #1700 (disabled in scored run).", + "compliance": { + "train_under_600s": true, + "artifact_under_16mb": true, + "eval_under_600s": true, + "no_slot": true, + "no_pre_quant_ttt": true, + "no_etlb": true, + "no_ngram_cache": true, + "score_first_ttt": true, + "three_seeds": true + }, + "attribution": { + "full_base_stack": "@arsenis-cmd (PR #1698) — GatedDeltaNet + Legal Score-First TTT on K_KVShare_Wider", + "architecture_and_config": "@resouer (PR #1687, K_KVShare_Wider config + FLA integration)", + "fla_kernel": "@sustcsonglin (flash-linear-attention)", + "score_first_ttt": "@Christopher-Lee-McClendon (PR #461)", + "macro_phase_concept": "@jorge-asenjo (PR #1700), @dexhunter (PR #1626)" + } +} diff --git a/records/track_10min_16mb/2026-04-19_GatedDeltaNet_MacroPhase_Brotli_LegalTTT/train_gdn_7k.py b/records/track_10min_16mb/2026-04-19_GatedDeltaNet_MacroPhase_Brotli_LegalTTT/train_gdn_7k.py new file mode 100644 index 0000000000..c68143687c --- /dev/null +++ b/records/track_10min_16mb/2026-04-19_GatedDeltaNet_MacroPhase_Brotli_LegalTTT/train_gdn_7k.py @@ -0,0 +1,1443 @@ +#!/usr/bin/env python3 +"""GDN Hybrid Full Training Script — 7000 steps with all production features. + +Features beyond Phase 1 screening (train_gdn.py): + - EMA weight averaging (decay 0.997) + - SWA (stochastic weight averaging) during late warmdown + - Late QAT (int6 STE in CastedLinear forward during warmdown) + - Mixed int6/int8 quantization with percentile search + - zstd-22 compression for artifact + - Coprime shard ordering via SHARD_ORDER_FILE + - XSA-all sliding window eval on quantized artifact + - Roundtrip validation (load quantized, eval, report exact BPB) + +Environment variables (key additions vs Phase 1): + EMA_DECAY: EMA decay rate (default 0.997) + SWA_ENABLED: 1|0 (default 1) + SWA_EVERY: SWA collection interval (default 50) + LATE_QAT_THRESHOLD: LR scale below which QAT activates (default 0.15) + SHARD_ORDER_FILE: path to file with one shard path per line (coprime ordering) + MUON_MOMENTUM_WARMUP_START: starting momentum for warmup (default 0.85) + MUON_MOMENTUM_WARMUP_STEPS: steps to ramp momentum (default 500) +""" +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import sys +import time +import uuid +import zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +import zstandard +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +sys.path.insert(0, str(Path(__file__).resolve().parent)) +from architectures import HybridGDN, CastedLinear +from configs import get_config + + +# ─── Hyperparameters ────────────────────────────────────────────────────────── + +class Hyperparameters: + arch_mode = os.environ.get("ARCH_MODE", "A") + data_path = os.environ.get("DATA_PATH", "../data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "../data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 42)) + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + + # Training length + iterations = int(os.environ.get("ITERATIONS", 7000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 2100)) # 30% of 7k + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 786_432)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + eval_seq_len = int(os.environ.get("EVAL_SEQ_LEN", 1024)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 14100.0)) # 3h55m safety + + # Validation + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 500)) + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 100)) + save_every = int(os.environ.get("SAVE_EVERY", 1000)) + + # Optimizer + matrix_lr = float(os.environ.get("MATRIX_LR", 0.02)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.02)) + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.035)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) + muon_wd = float(os.environ.get("MUON_WD", 0.04)) + adam_wd = float(os.environ.get("ADAM_WD", 0.04)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + + # Eval + eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) + xsa_eval = bool(int(os.environ.get("XSA_EVAL", "0"))) # during training + eval_compile_enabled = bool(int(os.environ.get("EVAL_COMPILE_ENABLED", "1"))) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Checkpoint + ckpt_dir = os.environ.get("CKPT_DIR", "checkpoints") + + # Compile + compile_enabled = bool(int(os.environ.get("COMPILE_ENABLED", "1"))) + + # Resume from checkpoint + resume_ckpt = os.environ.get("RESUME_CKPT", "") + + # EMA / SWA + ema_decay = float(os.environ.get("EMA_DECAY", 0.997)) + swa_enabled = bool(int(os.environ.get("SWA_ENABLED", "1"))) + swa_every = int(os.environ.get("SWA_EVERY", 50)) + + # Late QAT + late_qat_threshold = float(os.environ.get("LATE_QAT_THRESHOLD", 0.15)) + + # TTT (Test-Time Training) + ttt_enabled = bool(int(os.environ.get("TTT_ENABLED", "0"))) + ttt_lr = float(os.environ.get("TTT_LR", 0.005)) + ttt_epochs = int(os.environ.get("TTT_EPOCHS", 3)) + ttt_chunk_tokens = int(os.environ.get("TTT_CHUNK_TOKENS", 32768)) + ttt_freeze_blocks = int(os.environ.get("TTT_FREEZE_BLOCKS", 2)) + ttt_momentum = float(os.environ.get("TTT_MOMENTUM", 0.9)) + ttt_batch_seqs = int(os.environ.get("TTT_BATCH_SEQS", 32)) + ttt_grad_clip = float(os.environ.get("TTT_GRAD_CLIP", 1.0)) + # Multi-phase macro SGD on top of per-chunk TTT (novel extension) + ttt_macro_phases = int(os.environ.get("TTT_MACRO_PHASES", 0)) # 0 = disabled, 4 = typical + ttt_macro_epochs = int(os.environ.get("TTT_MACRO_EPOCHS", 1)) + ttt_macro_lr_mult = float(os.environ.get("TTT_MACRO_LR_MULT", 0.5)) # macro LR = ttt_lr * this + + # Chained job support + auto_save_seconds = float(os.environ.get("AUTO_SAVE_SECONDS", "0")) + total_iterations = int(os.environ.get("TOTAL_ITERATIONS", "0")) # 0 = same as iterations + + +# ─── Data Loading ───────────────────────────────────────────────────────────── + +def load_data_shard(file: Path) -> Tensor: + header = np.fromfile(file, dtype=np.uint32, count=256) + assert header[0] == 20240520, f"Bad magic: {header[0]}" + assert header[1] in (1, 7), f"Bad version: {header[1]}" + ntok = int(header[2]) + return torch.from_numpy(np.fromfile(file, dtype=np.uint16, offset=256 * 4)[:ntok].astype(np.int64)) + + +class TokenStream: + """Reads shards sequentially, supports coprime ordering via SHARD_ORDER_FILE.""" + def __init__(self, pattern: str): + shard_order_file = os.environ.get("SHARD_ORDER_FILE", "") + if shard_order_file and os.path.exists(shard_order_file): + with open(shard_order_file) as f: + self.files = [Path(line.strip()) for line in f if line.strip()] + else: + self.files = [Path(p) for p in sorted(glob.glob(pattern))] + assert self.files, f"No files matching {pattern}" + self.idx = 0 + self.buf = load_data_shard(self.files[self.idx]) + self.pos = 0 + + def _advance_file(self) -> None: + self.idx = (self.idx + 1) % len(self.files) + self.buf = load_data_shard(self.files[self.idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + parts = [] + remaining = n + while remaining > 0: + avail = self.buf.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + take_n = min(avail, remaining) + parts.append(self.buf[self.pos:self.pos + take_n]) + self.pos += take_n + remaining -= take_n + return torch.cat(parts) + + +class DistributedTokenLoader: + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.stream = TokenStream(pattern) + self.rank = rank + self.world_size = world_size + self.device = device + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + tokens_per_rank = global_tokens // self.world_size + seqs_per_rank = tokens_per_rank // seq_len + total_seqs = seqs_per_rank * self.world_size + total_needed = total_seqs * seq_len + 1 + all_tokens = self.stream.take(total_needed) + start = self.rank * seqs_per_rank * seq_len + chunk = all_tokens[start:start + seqs_per_rank * seq_len + 1] + x = chunk[:-1].reshape(seqs_per_rank, seq_len) + y = chunk[1:].reshape(seqs_per_rank, seq_len) + return x.to(self.device), y.to(self.device) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = sorted(glob.glob(pattern)) + parts = [load_data_shard(Path(f)) for f in files] + combined = torch.cat(parts) + return combined[:((combined.numel() - 1) // seq_len) * seq_len + 1] + + +def build_sentencepiece_luts(sp, vocab_size, device): + base_bytes = torch.zeros(vocab_size, dtype=torch.float32, device=device) + has_space = torch.zeros(vocab_size, dtype=torch.bool, device=device) + is_boundary = torch.zeros(vocab_size, dtype=torch.bool, device=device) + for i in range(vocab_size): + piece = sp.id_to_piece(i) + raw = piece.encode("utf-8") + base_bytes[i] = len(raw) + if piece.startswith("\u2581"): + has_space[i] = True + base_bytes[i] = len(piece[1:].encode("utf-8")) + 1 + if sp.is_control(i) or sp.is_unknown(i): + is_boundary[i] = True + return base_bytes, has_space, is_boundary + + +def generate_coprime_shard_order(shard_files: list, seed: int = 42) -> list: + """Generate a coprime-stepping shard order for better data mixing. + + Instead of sequential 0,1,2,...,79, uses stride = coprime(N) to visit + all shards in a maximally-spread order. This ensures each training epoch + sees data from diverse shards rather than correlated sequential ones. + """ + n = len(shard_files) + if n <= 1: + return shard_files + + # Find a coprime stride that's roughly n/phi (golden ratio) + target = max(1, int(n / 1.618)) + stride = target + while math.gcd(stride, n) != 1: + stride += 1 + + rng = random.Random(seed) + start = rng.randint(0, n - 1) + order = [] + pos = start + for _ in range(n): + order.append(shard_files[pos]) + pos = (pos + stride) % n + return order + + +# ─── Muon Optimizer ────────────────────────────────────────────────────────── + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + if transposed: + X = X.T + return X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, + nesterov: bool = True, weight_decay: float = 0.0): + defaults = dict(lr=lr, momentum=momentum, backend_steps=backend_steps, + nesterov=nesterov, weight_decay=weight_decay) + super().__init__(params, defaults) + + def step(self, closure=None): + for group in self.param_groups: + lr = group["lr"] + momentum = group["momentum"] + nesterov = group["nesterov"] + wd = group.get("weight_decay", 0.0) + for p in group["params"]: + if p.grad is None: + continue + g = p.grad + state = self.state[p] + if len(state) == 0: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g + momentum * buf + else: + g = buf + if g.ndim == 2 and min(g.shape) >= 2: + g = zeropower_via_newtonschulz5(g, steps=group["backend_steps"]) + if wd > 0: + p.data.mul_(1.0 - lr * wd) + p.data.add_(g, alpha=-lr) + + +# ─── Evaluation ────────────────────────────────────────────────────────────── + +def eval_val_sliding( + model: nn.Module, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + rank: int, + world_size: int, + device: torch.device, + seq_len: int = 1024, + stride: int = 64, + batch_seqs: int = 128, + xsa_eval: bool = False, + compile_enabled: bool = True, +) -> tuple[float, float]: + """Score-first sliding window evaluation.""" + total_tokens = val_tokens.numel() - 1 + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= 1] + total_windows = len(window_starts) + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + base_model = model.module if hasattr(model, 'module') else model + if xsa_eval and hasattr(base_model, 'set_xsa'): + base_model.set_xsa(True) + + forward_fn = base_model.forward_logits + compiled_logits = forward_fn + if compile_enabled: + try: + compiled_logits = torch.compile(forward_fn, dynamic=False) + except Exception: + compiled_logits = forward_fn + + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = compiled_logits(x_batch) + + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), + reduction="none", + ).reshape(bsz, seq_len) + + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt = y_batch[i, s:wlen] + prev = x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + + val_loss = (loss_sum / token_count).item() + bits_per_token = val_loss / math.log(2.0) + tokens_per_byte = token_count.item() / byte_count.item() + + if xsa_eval and hasattr(base_model, 'set_xsa'): + base_model.set_xsa(False) + + model.train() + return val_loss, bits_per_token * tokens_per_byte + + +# ─── TTT (Test-Time Training) ──────────────────────────────────────────────── + +def eval_val_ttt_gdn( + args: Hyperparameters, base_model: nn.Module, rank: int, world_size: int, + device: torch.device, val_tokens: Tensor, base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, is_boundary_token_lut: Tensor, + stride: int, batch_seqs: int = 32, log0=print, +) -> tuple[float, float]: + """Legal score-first TTT for GDN: score each chunk with sliding windows, + then SGD-train on already-scored tokens. Every token scored BEFORE any update.""" + seq_len = args.eval_seq_len + total_tokens = val_tokens.numel() - 1 + ttt_chunk = args.ttt_chunk_tokens + + # Pre-compute all window starts + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= stride or ws == 0] + + # Assign each window to a chunk based on the first token it scores + num_chunks = (total_tokens + ttt_chunk - 1) // ttt_chunk + chunk_windows: list[list[int]] = [[] for _ in range(num_chunks)] + for ws in window_starts: + end = min(ws + seq_len, total_tokens) + wlen = end - ws + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_start = ws + s + ci = min(scored_start // ttt_chunk, num_chunks - 1) + chunk_windows[ci].append(ws) + + log0(f"ttt_gdn:start chunks={num_chunks} chunk_tokens={ttt_chunk} " + f"total_windows={len(window_starts)} stride={stride} " + f"ttt_lr={args.ttt_lr} ttt_epochs={args.ttt_epochs} " + f"freeze_blocks={args.ttt_freeze_blocks}") + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + + # Freeze first N blocks (early layers learn general features, keep stable) + frozen_block_ids = set(range(min(args.ttt_freeze_blocks, len(base_model.blocks)))) + ttt_params = [] + for name, p in base_model.named_parameters(): + freeze = False + for bi in frozen_block_ids: + if f"blocks.{bi}." in name: + freeze = True + break + if freeze: + p.requires_grad_(False) + else: + p.requires_grad_(True) + ttt_params.append(p) + + log0(f"ttt_gdn:params unfrozen={sum(p.numel() for p in ttt_params)} " + f"frozen={sum(p.numel() for p in base_model.parameters() if not p.requires_grad)}") + + optimizer = torch.optim.SGD(ttt_params, lr=args.ttt_lr, momentum=args.ttt_momentum) + t0 = time.perf_counter() + + # Macro-phase structure: every (num_chunks / macro_phases) chunks, do an extra + # SGD pass on all chunks accumulated in this macro-phase. Novel extension of + # PR #1700's multi-phase Global SGD applied to PR #1698's chunk-based TTT. + macro_phases = args.ttt_macro_phases + if macro_phases > 0 and num_chunks > macro_phases: + macro_boundaries = [ + max(1, (num_chunks * (mp + 1)) // macro_phases) - 1 + for mp in range(macro_phases) + ] + else: + macro_boundaries = [] + macro_start_ci = 0 + + def _run_macro_phase(start_ci, end_ci): + """Extra SGD epochs on tokens from chunks [start_ci, end_ci] pooled.""" + if args.ttt_macro_epochs <= 0: + return + tok_start = start_ci * ttt_chunk + tok_end = min((end_ci + 1) * ttt_chunk, total_tokens) + span = tok_end - tok_start + if span < seq_len: + return + chunk_seqs = span // seq_len + my_seq_s = (chunk_seqs * rank) // world_size + my_seq_e = (chunk_seqs * (rank + 1)) // world_size + my_chunk_seqs = my_seq_e - my_seq_s + if my_chunk_seqs <= 0: + return + macro_lr = args.ttt_lr * args.ttt_macro_lr_mult + for pg in optimizer.param_groups: + pg['lr'] = macro_lr + base_model.train() + for _ep in range(args.ttt_macro_epochs): + for bs in range(0, my_chunk_seqs, args.ttt_batch_seqs): + be = min(bs + args.ttt_batch_seqs, my_chunk_seqs) + actual_bs = my_seq_s + bs + start_tok = tok_start + actual_bs * seq_len + end_tok = tok_start + (my_seq_s + be) * seq_len + 1 + if end_tok > val_tokens.numel(): + continue + local = val_tokens[start_tok:end_tok].to(device=device, dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + optimizer.zero_grad(set_to_none=True) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = base_model(x, y) + loss.backward() + if world_size > 1: + for p in ttt_params: + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + torch.nn.utils.clip_grad_norm_(ttt_params, args.ttt_grad_clip) + optimizer.step() + base_model.eval() + + for ci in range(num_chunks): + windows = chunk_windows[ci] + if not windows: + continue + chunk_start = ci * ttt_chunk + chunk_end = min((ci + 1) * ttt_chunk, total_tokens) + + # --- Phase 1: SCORE this chunk's windows (inference_mode) --- + my_s = (len(windows) * rank) // world_size + my_e = (len(windows) * (rank + 1)) // world_size + my_windows = windows[my_s:my_e] + + base_model.eval() + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk_tok = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk_tok[:-1] + y_batch[i, :wlen] = chunk_tok[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = base_model.forward_logits(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt, prev = y_batch[i, s:wlen], x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + + # --- Phase 2: TRAIN on this chunk (already scored = legal) --- + # Skip training on last chunk (no future windows benefit from it) + is_last_chunk = (ci == num_chunks - 1) + if not is_last_chunk and args.ttt_epochs > 0: + base_model.train() + chunk_seqs = (chunk_end - chunk_start) // seq_len + if chunk_seqs > 0: + cos_lr = args.ttt_lr * 0.5 * (1.0 + math.cos(math.pi * ci / max(num_chunks - 1, 1))) + for pg in optimizer.param_groups: + pg['lr'] = cos_lr + my_seq_s = (chunk_seqs * rank) // world_size + my_seq_e = (chunk_seqs * (rank + 1)) // world_size + my_chunk_seqs = my_seq_e - my_seq_s + for _ep in range(args.ttt_epochs): + for bs in range(0, my_chunk_seqs, args.ttt_batch_seqs): + be = min(bs + args.ttt_batch_seqs, my_chunk_seqs) + actual_bs = my_seq_s + bs + start_tok = chunk_start + actual_bs * seq_len + end_tok = chunk_start + (my_seq_s + be) * seq_len + 1 + if end_tok > val_tokens.numel(): + continue + local = val_tokens[start_tok:end_tok].to(device=device, dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + optimizer.zero_grad(set_to_none=True) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = base_model(x, y) + loss.backward() + if world_size > 1: + for p in ttt_params: + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + torch.nn.utils.clip_grad_norm_(ttt_params, args.ttt_grad_clip) + optimizer.step() + + if rank == 0 and (ci % 10 == 0 or ci == num_chunks - 1): + elapsed = time.perf_counter() - t0 + rl = loss_sum.item() / max(token_count.item(), 1) + rbpb = rl / math.log(2.0) * (token_count.item() / max(byte_count.item(), 1)) if token_count.item() > 0 else 0.0 + log0(f" ttt_chunk [{ci+1}/{num_chunks}] bpb={rbpb:.6f} time={elapsed:.1f}s") + + # End of a macro-phase? If so, run extra consolidation SGD on its chunks. + if ci in macro_boundaries and ci < num_chunks - 1: + if rank == 0: + log0(f" macro_phase: consolidating chunks [{macro_start_ci}..{ci}] t={time.perf_counter() - t0:.1f}s") + _run_macro_phase(macro_start_ci, ci) + macro_start_ci = ci + 1 + # Reset per-chunk optimizer's LR state (SGD doesn't have much but be safe) + for pg in optimizer.param_groups: + pg['lr'] = args.ttt_lr + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + + val_loss = (loss_sum / token_count).item() + val_bpb = val_loss / math.log(2.0) * (token_count.item() / byte_count.item()) + + # Restore all parameters to requires_grad=True + for p in base_model.parameters(): + p.requires_grad_(True) + base_model.eval() + + log0(f"ttt_gdn:done val_loss={val_loss:.6f} val_bpb={val_bpb:.6f} " + f"elapsed={time.perf_counter() - t0:.1f}s") + return val_loss, val_bpb + + +# ─── Quantization ──────────────────────────────────────────────────────────── + +# Control tensor patterns — kept at full precision during quantization +CONTROL_PATTERNS = ( + "resid_mix", "q_gain", "smear", "skip_weight", "attn_scale", "mlp_scale", +) + + +def generate_autoregressive_calib(model, device, num_seqs=64, seq_len=2048, + vocab_size=1024, temperature=0.8, batch_size=8, seed=42): + """Generate sequences autoregressively from the model for GPTQ calibration. + No external data accessed — fully self-contained.""" + model.eval() + rng = torch.Generator(device=device) + rng.manual_seed(seed) + all_tokens = [] + with torch.inference_mode(), torch.autocast(device_type="cuda", dtype=torch.bfloat16): + for batch_start in range(0, num_seqs, batch_size): + bs = min(batch_size, num_seqs - batch_start) + tokens = torch.randint(0, vocab_size, (bs, 1), device=device, generator=rng) + for pos in range(seq_len - 1): + logits = model.forward_logits(tokens) + next_logit = logits[:, -1, :] + probs = torch.softmax(next_logit / temperature, dim=-1) + next_tok = torch.multinomial(probs, 1, generator=rng) + tokens = torch.cat([tokens, next_tok], dim=1) + for i in range(bs): + all_tokens.append(tokens[i:i+1]) + return all_tokens + + +def collect_hessians_from_tokens(hessian_model, token_seqs, device): + """Collect H = X^T X from pre-generated token sequences.""" + hessians = {} + hooks = [] + for name, module in hessian_model.named_modules(): + if isinstance(module, CastedLinear): + param_name = name + ".weight" + cols = module.weight.shape[1] + hessians[param_name] = torch.zeros(cols, cols, dtype=torch.float32, device='cpu') + def make_hook(pname): + def hook_fn(module, input, output): + x = input[0].detach().float() + if x.ndim == 3: + x = x.reshape(-1, x.shape[-1]) + hessians[pname] += (x.T @ x).cpu() + return hook_fn + h = module.register_forward_hook(make_hook(param_name)) + hooks.append(h) + hessian_model.eval() + with torch.inference_mode(), torch.autocast(device_type="cuda", dtype=torch.bfloat16): + for seq in token_seqs: + x = seq[:, :-1].to(device) + y = seq[:, 1:].to(device) + hessian_model(x, y) + for h in hooks: + h.remove() + num_batches = len(token_seqs) + for name in hessians: + hessians[name] /= num_batches + return hessians + + +def quantize_int6_gptq(weight, hessian=None, clip_range=int(os.environ.get("INT6_CLIP_RANGE", 24)), block_size=128): + """Full GPTQ: Hessian-aware int6 quantization with Cholesky error compensation. + If hessian is None, falls back to percentile search.""" + t32 = weight.float() + if t32.ndim != 2 or hessian is None: + return quantize_int6_per_row(t32) + rows, cols = t32.shape + H = hessian.float().clone() + dead = torch.diag(H) == 0 + H[dead, dead] = 1 + damp = 0.01 * torch.mean(torch.diag(H)) + H[torch.arange(cols), torch.arange(cols)] += damp + perm = torch.argsort(torch.diag(H), descending=True) + inv_perm = torch.argsort(perm) + W = t32[:, perm].clone() + W[:, dead[perm]] = 0 + H = H[perm][:, perm] + Hinv = torch.linalg.cholesky(H) + Hinv = torch.cholesky_inverse(Hinv) + Hinv = torch.linalg.cholesky(Hinv, upper=True) + best_q = None; best_scale = None; best_err = float('inf') + for pct in [0.9990, 0.9995, 0.9999, 0.99999, 1.0]: + if pct < 1.0: + row_clip = torch.quantile(t32.abs(), pct, dim=1) + else: + row_clip = t32.abs().amax(dim=1) + s = (row_clip / clip_range).clamp_min(1.0 / clip_range).to(torch.float16) + sf = s.float() + Q = torch.zeros_like(W, dtype=torch.int8) + W_work = W.clone() + for i1 in range(0, cols, block_size): + i2 = min(i1 + block_size, cols) + count = i2 - i1 + W1 = W_work[:, i1:i2].clone() + Q1 = torch.zeros(rows, count, dtype=torch.int8) + Err1 = torch.zeros(rows, count) + Hinv1 = Hinv[i1:i2, i1:i2] + for i in range(count): + w = W1[:, i] + d = Hinv1[i, i] + q = torch.clamp(torch.round(w / sf), -clip_range, clip_range).to(torch.int8) + Q1[:, i] = q + err = (w - q.float() * sf) / d + W1[:, i:] -= err.unsqueeze(1) * Hinv1[i, i:].unsqueeze(0) + Err1[:, i] = err + Q[:, i1:i2] = Q1 + if i2 < cols: + W_work[:, i2:] -= Err1 @ Hinv[i1:i2, i2:] + recon = Q.float() * sf[:, None] + mse = (W - recon).pow(2).mean().item() + if mse < best_err: + best_q, best_scale, best_err = Q, s, mse + best_q = best_q[:, inv_perm] + return best_q, best_scale + + +def quantize_int6_per_row(t: Tensor, clip_range: int = int(os.environ.get("INT6_CLIP_RANGE", 24))) -> tuple[Tensor, Tensor]: + """Int6 quantization with percentile search for optimal clipping.""" + t32 = t.float() + if t32.ndim == 2: + best_q, best_s, best_err = None, None, float('inf') + for pct in [0.9990, 0.9995, 0.9999, 0.99999, 1.0]: + if pct < 1.0: + row_clip = torch.quantile(t32.abs(), pct, dim=1) + else: + row_clip = t32.abs().amax(dim=1) + s = (row_clip / clip_range).clamp_min(1.0 / clip_range).to(torch.float16) + q = torch.clamp(torch.round(t32 / s.float()[:, None]), -clip_range, clip_range).to(torch.int8) + recon = q.float() * s.float()[:, None] + err = (t32 - recon).pow(2).mean().item() + if err < best_err: + best_q, best_s, best_err = q, s, err + return best_q, best_s + # 1D: simple per-tensor quantization + amax = t32.abs().max().item() + scale = torch.tensor(amax / clip_range if amax > 0 else 1.0, dtype=torch.float16) + q = torch.clamp(torch.round(t32 / scale.float()), -clip_range, clip_range).to(torch.int8) + return q, scale + + +def quantize_int8_per_row(t: Tensor) -> tuple[Tensor, Tensor]: + """Int8 quantization with percentile clipping.""" + t32 = t.float() + clip_q = 0.9999984 + if t32.ndim == 2: + clip_abs = torch.quantile(t32.abs(), clip_q, dim=1) if t32.numel() else torch.empty((t32.shape[0],), dtype=torch.float32) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0).to(torch.float16) + q = torch.clamp(torch.round(clipped / scale.float()[:, None]), -127, 127).to(torch.int8) + return q, scale + clip_abs = float(torch.quantile(t32.abs().flatten(), clip_q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float16) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale.float()), -127, 127).to(torch.int8) + return q, scale + + +def mixed_quantize(state_dict: dict[str, Tensor], hessians: dict[str, Tensor] | None = None) -> tuple[dict[str, Tensor], dict[str, object]]: + """Mixed int6 (large weights) / int8 (small weights) quantization.""" + result: dict[str, Tensor] = {} + meta: dict[str, object] = {} + for name, tensor in state_dict.items(): + t = tensor.detach().cpu().contiguous() + # Control tensors: keep fp16 passthrough + if any(p in name for p in CONTROL_PATTERNS): + result[name] = t.to(torch.float16) if t.is_floating_point() else t + meta[name] = "passthrough" + continue + # Non-float: passthrough + if not t.is_floating_point(): + result[name] = t + meta[name] = "passthrough" + continue + # Small tensors: fp16 passthrough + if t.numel() <= 65536: + result[name] = t.to(torch.float16) + meta[name] = "passthrough" + continue + # Large 2D weights: int6 (6-bit quantization for better compression) + if t.ndim == 2 and t.numel() > 65536: + H = hessians.get(name) if hessians else None + q, s = quantize_int6_gptq(t, hessian=H) if H is not None else quantize_int6_per_row(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int6"} + else: + # Other float tensors: int8 + q, s = quantize_int8_per_row(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int8"} + return result, meta + + +def dequantize_mixed(result: dict[str, Tensor], meta: dict[str, object], + template_sd: dict[str, Tensor]) -> dict[str, Tensor]: + """Dequantize mixed int6/int8 back to float.""" + out: dict[str, Tensor] = {} + for name, orig in template_sd.items(): + info = meta.get(name) + if info is None: + continue + orig_dtype = orig.dtype + if info == "passthrough": + t = result[name] + if t.dtype == torch.float16 and orig_dtype in (torch.float32, torch.bfloat16): + t = t.to(orig_dtype) + out[name] = t + continue + q, s = result[name + ".q"], result[name + ".scale"] + if s.ndim > 0: + out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig_dtype) + else: + out[name] = (q.float() * float(s.item())).to(orig_dtype) + return out + + +# ─── Checkpoint Saving ─────────────────────────────────────────────────────── + +def save_checkpoint(model, step, val_bpb, ckpt_dir, arch_name, seed): + base = model.module if hasattr(model, 'module') else model + ckpt = { + "step": step, "val_bpb": val_bpb, + "arch_name": arch_name, "seed": seed, + "model_state_dict": base.state_dict(), + } + os.makedirs(ckpt_dir, exist_ok=True) + path = os.path.join(ckpt_dir, f"{arch_name}_step{step}_seed{seed}.pt") + torch.save(ckpt, path) + return path + + +def save_full_checkpoint(model, step, val_bpb, ckpt_dir, arch_name, seed, + muon_opt, adam_opt, ema_state, swa_state, swa_count, + qat_enabled, rng_states=None, stream_state=None): + """Save complete training state for chained job resume.""" + base = model.module if hasattr(model, 'module') else model + ckpt = { + "step": step, "val_bpb": val_bpb, + "arch_name": arch_name, "seed": seed, + "model_state_dict": {k: v.cpu() for k, v in base.state_dict().items()}, + "muon_opt_state": muon_opt.state_dict(), + "adam_opt_state": adam_opt.state_dict(), + "ema_state": {k: v.cpu() for k, v in ema_state.items()}, + "swa_state": {k: v.cpu() for k, v in swa_state.items()} if swa_state is not None else None, + "swa_count": swa_count, + "qat_enabled": qat_enabled, + } + if rng_states is not None: + ckpt["rng_states"] = rng_states + if stream_state is not None: + ckpt["stream_state"] = stream_state + os.makedirs(ckpt_dir, exist_ok=True) + path = os.path.join(ckpt_dir, f"full_ckpt_step{step}_seed{seed}.pt") + torch.save(ckpt, path) + return path + + +def _find_latest_full_ckpt(ckpt_dir): + """Find the latest full_ckpt_step*.pt file in ckpt_dir by step number.""" + pattern = os.path.join(ckpt_dir, "full_ckpt_step*_seed*.pt") + files = glob.glob(pattern) + if not files: + return None + import re + step_re = re.compile(r"full_ckpt_step(\d+)_seed") + best_step, best_path = -1, None + for f in files: + m = step_re.search(os.path.basename(f)) + if m: + s = int(m.group(1)) + if s > best_step: + best_step, best_path = s, f + return best_path + + +# ─── Main Training Loop ───────────────────────────────────────────────────── + +def main(): + global zeropower_via_newtonschulz5 + args = Hyperparameters() + config = get_config(args.arch_mode) + + # Distributed setup + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + grad_accum_steps = max(1, 8 // world_size) + master_process = rank == 0 + + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + + if args.compile_enabled: + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # Logging + os.makedirs("logs", exist_ok=True) + os.makedirs(args.ckpt_dir, exist_ok=True) + logfile = f"logs/{args.run_id}.txt" if master_process else None + + def log0(msg: str, console: bool = True): + if not master_process: + return + if console: + print(msg, flush=True) + if logfile: + with open(logfile, "a") as f: + print(msg, file=f) + + # Seeds + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + log0(f"=== GDN Hybrid 7k Full Training ===") + log0(f"Arch: {config['arch_name']} (ARCH_MODE={args.arch_mode})") + log0(f"Seed: {args.seed}, Steps: {args.iterations}, Warmdown: {args.warmdown_iters}") + log0(f"World size: {world_size}, Grad accum: {grad_accum_steps}") + log0(f"EMA decay: {args.ema_decay}, SWA: {args.swa_enabled} (every {args.swa_every})") + log0(f"Late QAT threshold: {args.late_qat_threshold}") + log0(f"Eval compile enabled: {args.eval_compile_enabled}") + + # Tokenizer + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + assert int(sp.vocab_size()) == args.vocab_size + + # Validation data + val_tokens = load_validation_tokens(args.val_files, args.eval_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"Validation tokens: {val_tokens.numel()-1:,}") + + # Build model + _t0 = time.time() + model = HybridGDN(config, args.vocab_size) + model = model.to(device).bfloat16() + log0(f"Model built in {time.time()-_t0:.1f}s") + + for module in model.modules(): + if isinstance(module, CastedLinear): + module.float() + for name, p in model.named_parameters(): + if p.ndim <= 1: + p.data = p.data.float() + + param_counts = model.count_params() + log0(f"Parameters: {param_counts}") + log0(f"Total params: {param_counts['total']:,}") + + # Resume from checkpoint if specified + start_step = 0 + resume_state = None # holds full checkpoint data for deferred restore + resume_ckpt_path = args.resume_ckpt + if resume_ckpt_path == "auto": + resume_ckpt_path = _find_latest_full_ckpt(args.ckpt_dir) or "" + if resume_ckpt_path: + log0(f"Auto-detected resume checkpoint: {resume_ckpt_path}") + else: + log0("Auto-resume: no full checkpoint found, starting fresh") + if resume_ckpt_path and os.path.exists(resume_ckpt_path): + log0(f"Resuming from checkpoint: {resume_ckpt_path}") + ckpt = torch.load(resume_ckpt_path, map_location="cpu", weights_only=False) + base_sd = ckpt["model_state_dict"] + model.load_state_dict({k: v.to(device) for k, v in base_sd.items()}, strict=True) + start_step = ckpt.get("step", 0) + log0(f"Resumed model at step {start_step}, val_bpb={ckpt.get('val_bpb', 'N/A')}") + # Keep full checkpoint for deferred optimizer/EMA/SWA restore + if "muon_opt_state" in ckpt: + resume_state = ckpt + log0(" Full checkpoint detected — will restore optimizers, EMA, SWA, RNG") + else: + log0(" Lightweight checkpoint — model only") + del ckpt + + # DDP + base_model = model # keep reference before wrapping + if distributed: + model = DDP(model, device_ids=[local_rank], find_unused_parameters=True) + + # Optimizer setup + matrix_params = [] + scalar_params = [] + embed_params = [] + for name, p in base_model.named_parameters(): + if not p.requires_grad: + continue + if "tok_emb" in name: + embed_params.append(p) + elif p.ndim == 2 and min(p.shape) >= 2: + matrix_params.append(p) + else: + scalar_params.append(p) + + log0(f"Matrix params: {sum(p.numel() for p in matrix_params):,}") + log0(f"Scalar params: {sum(p.numel() for p in scalar_params):,}") + log0(f"Embed params: {sum(p.numel() for p in embed_params):,}") + + muon_opt = Muon( + matrix_params, lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + weight_decay=args.muon_wd, + ) + adam_opt = torch.optim.AdamW( + [{"params": scalar_params, "lr": args.scalar_lr}, + {"params": embed_params, "lr": args.tied_embed_lr}], + betas=(args.beta1, args.beta2), + weight_decay=args.adam_wd, + fused=True, + ) + + # Deferred restore: optimizer states (must happen after optimizer creation) + if resume_state is not None: + muon_opt.load_state_dict(resume_state["muon_opt_state"]) + adam_opt.load_state_dict(resume_state["adam_opt_state"]) + log0(" Restored optimizer states (Muon + Adam)") + + # Data loader — coprime shard ordering (race-free: only rank 0 writes, all barrier, then all read) + shard_order_file = os.environ.get("SHARD_ORDER_FILE", "") + if not shard_order_file: + shard_order_path = f"/tmp/shard_order_{args.run_id}.txt" + if master_process: + shard_files = sorted(glob.glob(args.train_files)) + if shard_files: + ordered = generate_coprime_shard_order(shard_files, seed=args.seed) + with open(shard_order_path, "w") as f: + for sf in ordered: + f.write(str(sf) + "\n") + log0(f"Generated coprime shard order: stride across {len(shard_files)} shards") + if distributed: + dist.barrier() + if os.path.exists(shard_order_path): + os.environ["SHARD_ORDER_FILE"] = shard_order_path + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # LR schedule with warmdown (cosine) + def lr_schedule(step: int) -> float: + warmdown_start = args.iterations - args.warmdown_iters + if step < args.warmup_steps: + return step / max(1, args.warmup_steps) + elif step >= warmdown_start: + progress = (step - warmdown_start) / args.warmdown_iters + return 0.5 * (1.0 + math.cos(math.pi * progress)) + return 1.0 + + # ─── EMA + SWA state ───────────────────────────────────────────────── + ema_state = {name: t.detach().float().clone() for name, t in base_model.state_dict().items()} + swa_state: dict[str, Tensor] | None = None + swa_count = 0 + + # Deferred restore: EMA, SWA, QAT, RNG, stream state + if resume_state is not None: + # EMA + saved_ema = resume_state.get("ema_state") + if saved_ema is not None: + ema_state = {k: v.to(device).float() for k, v in saved_ema.items()} + log0(" Restored EMA state") + # SWA + saved_swa = resume_state.get("swa_state") + if saved_swa is not None: + swa_state = {k: v.cpu() for k, v in saved_swa.items()} + swa_count = resume_state.get("swa_count", 0) + log0(f" Restored SWA state (count={swa_count})") + else: + swa_count = resume_state.get("swa_count", 0) + # QAT + if resume_state.get("qat_enabled", False): + CastedLinear._qat_enabled = True + log0(" Restored QAT enabled state") + # RNG states + saved_rng = resume_state.get("rng_states") + if saved_rng is not None: + torch.set_rng_state(saved_rng["torch_cpu"]) + torch.cuda.set_rng_state(saved_rng["torch_cuda"]) + np.random.set_state(saved_rng["numpy"]) + random.setstate(saved_rng["python"]) + log0(" Restored RNG states") + # Stream state (data loader fast-forward) + saved_stream = resume_state.get("stream_state") + if saved_stream is not None: + s_idx, s_pos = saved_stream + stream = train_loader.stream + # Advance to the saved shard + while stream.idx != s_idx: + stream._advance_file() + stream.pos = s_pos + log0(f" Restored stream state (shard={s_idx}, pos={s_pos})") + else: + # No stream state saved — fast-forward by consuming tokens + if start_step > 0: + log0(f" Fast-forwarding data loader by {start_step} steps...") + for _ in range(start_step): + train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + log0(f" Data loader advanced to step {start_step}") + del resume_state + log0(" Full checkpoint restore complete") + + # ─── Training Loop ─────────────────────────────────────────────────── + # Clear stale chain marker from previous segment (if any) + stale_marker = os.path.join(args.ckpt_dir, f"CHAIN_RESUME_FROM_seed{args.seed}") + if os.path.exists(stale_marker): + os.remove(stale_marker) + + log0(f"\n{'='*80}") + log0(f"Starting training: {args.iterations} steps (from step {start_step})") + log0(f"{'='*80}\n") + + t0 = time.time() + running_loss = 0.0 + loss_count = 0 + stop_after_step = None + step = start_step # ensure step is defined even if loop doesn't execute + + for step in range(start_step + 1, args.iterations + 1): + # Check early stop + if stop_after_step is not None and step > stop_after_step: + log0(f"Stopping early at step {step} (wallclock limit)") + break + + lr_mul = lr_schedule(step) + + # Muon momentum warmup + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + current_muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in muon_opt.param_groups: + group["lr"] = args.matrix_lr * lr_mul + group["momentum"] = current_muon_momentum + for i, pg in enumerate(adam_opt.param_groups): + if i == 0: + pg["lr"] = args.scalar_lr * lr_mul + else: + pg["lr"] = args.tied_embed_lr * lr_mul + + # Late QAT: activate int6 STE only during warmdown (not warmup!) + warmdown_start = args.iterations - args.warmdown_iters + if (args.late_qat_threshold > 0 and step >= warmdown_start + and lr_mul < args.late_qat_threshold and not CastedLinear._qat_enabled): + CastedLinear._qat_enabled = True + log0(f"Late QAT enabled at step {step} (lr_mul={lr_mul:.4f})") + + # Gradient accumulation + model.train() + total_loss = 0.0 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + micro_batch = x.shape[0] // grad_accum_steps + for micro_step in range(grad_accum_steps): + x_micro = x[micro_step * micro_batch:(micro_step + 1) * micro_batch] + y_micro = y[micro_step * micro_batch:(micro_step + 1) * micro_batch] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = model(x_micro, y_micro) + loss = loss / grad_accum_steps + loss.backward() + total_loss += loss.item() + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip_norm) + + muon_opt.step() + adam_opt.step() + muon_opt.zero_grad(set_to_none=True) + adam_opt.zero_grad(set_to_none=True) + + # EMA update (every step) + with torch.no_grad(): + for name, t in base_model.state_dict().items(): + ema_state[name].mul_(args.ema_decay).add_(t.detach().float(), alpha=1.0 - args.ema_decay) + + # SWA: collect checkpoints during late warmdown + if args.swa_enabled and lr_mul < 0.2 and step % args.swa_every == 0: + if swa_state is None: + swa_state = {name: t.detach().cpu().clone() for name, t in base_model.state_dict().items()} + swa_count = 1 + log0(f"SWA started at step {step}") + else: + for name, t in base_model.state_dict().items(): + swa_state[name] += t.detach().cpu() + swa_count += 1 + + running_loss += total_loss + loss_count += 1 + + # Logging + if step % args.train_log_every == 0 or step <= 10: + avg_loss = running_loss / max(loss_count, 1) + elapsed = time.time() - t0 + steps_per_sec = step / elapsed + log0(f"step {step:5d}/{args.iterations} | loss {avg_loss:.4f} | lr_mul {lr_mul:.4f} | " + f"mom {current_muon_momentum:.3f} | {steps_per_sec:.2f} steps/s | {elapsed:.0f}s") + running_loss = 0.0 + loss_count = 0 + + # Validation + checkpoint + if (args.val_loss_every > 0 and step % args.val_loss_every == 0) or step == args.iterations: + val_loss, val_bpb = eval_val_sliding( + model, val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + rank, world_size, device, + seq_len=args.eval_seq_len, stride=args.eval_stride, + xsa_eval=args.xsa_eval, + compile_enabled=args.eval_compile_enabled, + ) + log0(f"step {step:5d} | val_loss {val_loss:.4f} | val_bpb {val_bpb:.4f}") + + if master_process and args.save_every > 0 and (step % args.save_every == 0 or step == args.iterations): + ckpt_path = save_checkpoint( + model, step, val_bpb, args.ckpt_dir, config["arch_name"], args.seed, + ) + log0(f" Saved: {ckpt_path}") + + # Wallclock limit + if args.max_wallclock_seconds > 0: + elapsed = time.time() - t0 + if elapsed > args.max_wallclock_seconds and stop_after_step is None: + stop_after_step = step + log0(f"Wallclock limit reached ({elapsed:.0f}s), will stop after this step") + + # Auto-save for chained job support + if args.auto_save_seconds > 0: + elapsed = time.time() - t0 + if elapsed > args.auto_save_seconds: + log0(f"Auto-save triggered at step {step} ({elapsed:.0f}s elapsed)") + if master_process: + rng_states = { + "torch_cpu": torch.get_rng_state(), + "torch_cuda": torch.cuda.get_rng_state(), + "numpy": np.random.get_state(), + "python": random.getstate(), + } + stream = train_loader.stream + stream_state = (stream.idx, stream.pos) + ckpt_path = save_full_checkpoint( + model, step, 0.0, args.ckpt_dir, config["arch_name"], args.seed, + muon_opt, adam_opt, ema_state, swa_state, swa_count, + CastedLinear._qat_enabled, + rng_states=rng_states, stream_state=stream_state, + ) + # Write chain resume marker + marker_path = os.path.join(args.ckpt_dir, f"CHAIN_RESUME_FROM_seed{args.seed}") + with open(marker_path, "w") as f: + f.write(ckpt_path + "\n") + log0(f" Full checkpoint saved: {ckpt_path}") + log0(f" Chain resume marker: {marker_path}") + break # exit training loop cleanly + + # ─── Check if we exited due to auto-save vs normal completion ──────── + chain_marker = os.path.join(args.ckpt_dir, f"CHAIN_RESUME_FROM_seed{args.seed}") + if os.path.exists(chain_marker): + log0("\nExiting for chained job resume (skipping post-training)") + if distributed: + dist.destroy_process_group() + return + + # Check for total_iterations completion + effective_total = args.total_iterations if args.total_iterations > 0 else args.iterations + if master_process and step >= effective_total: + complete_marker = os.path.join(args.ckpt_dir, f"TRAINING_COMPLETE_seed{args.seed}") + with open(complete_marker, "w") as f: + f.write(f"step={step}\n") + + # ─── Post-Training: Apply EMA ──────────────────────────────────────── + elapsed_total = time.time() - t0 + log0(f"\nTraining complete in {elapsed_total:.0f}s") + log0(f"Peak memory: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB") + + log0("\n=== Applying EMA weights ===") + avg_state = {name: t.to(dtype=base_model.state_dict()[name].dtype) for name, t in ema_state.items()} + base_model.load_state_dict(avg_state, strict=True) + + # Eval EMA weights + val_loss_ema, val_bpb_ema = eval_val_sliding( + model, val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + rank, world_size, device, + seq_len=args.eval_seq_len, stride=args.eval_stride, + xsa_eval=False, + compile_enabled=args.eval_compile_enabled, + ) + log0(f"EMA BPB (no XSA): {val_bpb_ema:.6f}") + + # Save raw EMA model + if master_process: + torch.save(base_model.state_dict(), os.path.join(args.ckpt_dir, f"final_model_{config['arch_name']}_seed{args.seed}.pt")) + log0("Saved raw EMA model") + + # ─── GPTQ Calibration (optional) ───────────────────────────────────── + gptq_enabled = bool(int(os.environ.get("GPTQ_ENABLED", "0"))) + hessians = None + if gptq_enabled: + log0("\n=== GPTQ: generating autoregressive calibration data ===") + calib_seqs = generate_autoregressive_calib( + base_model, device, num_seqs=64, seq_len=args.train_seq_len, + vocab_size=args.vocab_size, temperature=0.8, batch_size=8, seed=args.seed, + ) + log0(f"GPTQ: generated {len(calib_seqs)} sequences, collecting hessians...") + hessians = collect_hessians_from_tokens(base_model, calib_seqs, device) + log0(f"GPTQ: collected hessians for {len(hessians)} layers") + + # ─── Quantization + Artifact Creation ──────────────────────────────── + log0("\n=== Quantizing to int6 + zstd-22 ===") + sd_cpu = {k: v.detach().cpu() for k, v in base_model.state_dict().items()} + quant_result, quant_meta = mixed_quantize(sd_cpu, hessians=hessians) + + quant_buf = io.BytesIO() + torch.save({"w": quant_result, "m": quant_meta}, quant_buf) + quant_raw = quant_buf.getvalue() + compressor = os.environ.get("COMPRESSOR", "zstd").lower() + if compressor == "brotli": + import brotli + quant_blob = brotli.compress(quant_raw, quality=11) + else: + quant_blob = zstandard.ZstdCompressor(level=22).compress(quant_raw) + + artifact_path = os.path.join(args.ckpt_dir, f"final_model_{config['arch_name']}_seed{args.seed}.int6.ptz") + if master_process: + with open(artifact_path, "wb") as f: + f.write(quant_blob) + artifact_bytes = len(quant_blob) + log0(f"Artifact: {artifact_bytes:,} bytes ({artifact_bytes / 1024 / 1024:.2f} MB)") + if artifact_bytes > 16 * 1024 * 1024: + log0(f"WARNING: Artifact exceeds 16MB budget by {(artifact_bytes - 16*1024*1024) / 1024:.1f} KB") + + # ─── Roundtrip Validation ──────────────────────────────────────────── + log0("\n=== Roundtrip Validation (quantized model) ===") + if distributed: + dist.barrier() + + with open(artifact_path, "rb") as f: + quant_blob_disk = f.read() + if compressor == "brotli": + import brotli + decompressed = brotli.decompress(quant_blob_disk) + else: + decompressed = zstandard.ZstdDecompressor().decompress(quant_blob_disk) + quant_state = torch.load( + io.BytesIO(decompressed), + map_location="cpu", + ) + deq_state = dequantize_mixed(quant_state["w"], quant_state["m"], sd_cpu) + + # Build fresh eval model (no DDP wrapping needed) + eval_model = HybridGDN(config, args.vocab_size).to(device).bfloat16() + for m in eval_model.modules(): + if isinstance(m, CastedLinear): + m.float() + for name, p in eval_model.named_parameters(): + if p.ndim <= 1: + p.data = p.data.float() + eval_model.load_state_dict(deq_state, strict=True) + + # Eval quantized model without XSA + val_loss_q, val_bpb_q = eval_val_sliding( + eval_model, val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + rank, world_size, device, + seq_len=args.eval_seq_len, stride=args.eval_stride, + xsa_eval=False, + compile_enabled=args.eval_compile_enabled, + ) + log0(f"Quantized BPB (no XSA): {val_bpb_q:.6f}") + log0(f"Quantization degradation: {val_bpb_q - val_bpb_ema:+.6f}") + + # Eval quantized model WITH XSA (if model has SWA layers) + block_types = eval_model._block_types + if any(bt in ("swa", "swa_shared") for bt in block_types): + val_loss_qx, val_bpb_qx = eval_val_sliding( + eval_model, val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + rank, world_size, device, + seq_len=args.eval_seq_len, stride=args.eval_stride, + xsa_eval=True, + compile_enabled=args.eval_compile_enabled, + ) + log0(f"Quantized BPB (XSA-all): {val_bpb_qx:.6f}") + + # ─── Final Summary ─────────────────────────────────────────────────── + log0(f"\n{'='*80}") + log0(f"FINAL RESULTS — {config['arch_name']} seed={args.seed}") + log0(f" Training: {args.iterations} steps, {elapsed_total:.0f}s") + log0(f" EMA BPB (fp32): {val_bpb_ema:.6f}") + log0(f" Quantized BPB: {val_bpb_q:.6f}") + if any(bt in ("swa", "swa_shared") for bt in block_types): + log0(f" Quantized BPB+XSA: {val_bpb_qx:.6f}") + if master_process: + log0(f" Artifact size: {artifact_bytes:,} bytes") + log0(f"{'='*80}") + log0(f"final_int6_roundtrip_exact val_loss:{val_loss_q:.8f} val_bpb:{val_bpb_q:.8f}") + + # ─── Legal Score-First TTT ──────────────────────────────────────────── + if args.ttt_enabled: + log0("\n=== Legal Score-First TTT (GDN) ===") + torch.cuda.synchronize() + t_ttt = time.perf_counter() + ttt_loss, ttt_bpb = eval_val_ttt_gdn( + args, eval_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, batch_seqs=args.ttt_batch_seqs, log0=log0, + ) + torch.cuda.synchronize() + ttt_elapsed = time.perf_counter() - t_ttt + ttt_delta = ttt_bpb - val_bpb_q + log0(f"TTT BPB: {ttt_bpb:.6f} (delta: {ttt_delta:+.6f})") + log0(f"TTT eval time: {ttt_elapsed:.1f}s") + log0(f"final_int6_ttt_exact val_loss:{ttt_loss:.8f} val_bpb:{ttt_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() diff --git a/records/track_10min_16mb/2026-04-19_GatedDeltaNet_MacroPhase_Brotli_LegalTTT/train_gpt.py b/records/track_10min_16mb/2026-04-19_GatedDeltaNet_MacroPhase_Brotli_LegalTTT/train_gpt.py new file mode 100644 index 0000000000..c92492b069 --- /dev/null +++ b/records/track_10min_16mb/2026-04-19_GatedDeltaNet_MacroPhase_Brotli_LegalTTT/train_gpt.py @@ -0,0 +1,87 @@ +#!/usr/bin/env python3 +"""FLA / GatedDeltaNet entrypoint wrapper. + +The actual training logic lives in `train_gdn_7k.py`. `evaluate.py` expects +`torchrun train_gpt.py`, so this wrapper preserves the standard repo entrypoint +while keeping the scored path in the records folder self-contained. +""" + +import os +import sys +import traceback +from pathlib import Path + +# These defaults keep the wrapper aligned with the intended SP8192 scored path. +VOCAB_SIZE = int(os.environ.get("VOCAB_SIZE", 8192)) +DATA_PATH = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp8192") +TOKENIZER_PATH = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_8192_bpe.model") +ARCH_MODE = os.environ.get("ARCH_MODE", "K") +os.environ.setdefault("VOCAB_SIZE", str(VOCAB_SIZE)) +os.environ.setdefault("DATA_PATH", DATA_PATH) +os.environ.setdefault("TOKENIZER_PATH", TOKENIZER_PATH) +os.environ.setdefault("ARCH_MODE", ARCH_MODE) +os.environ.setdefault("MAX_WALLCLOCK_SECONDS", "600") +os.environ.setdefault("VAL_LOSS_EVERY", "0") +os.environ.setdefault("EVAL_COMPILE_ENABLED", "0") +if ARCH_MODE in ("D", "G", "M"): + os.environ.setdefault("XSA_EVAL", "1") + + +_VENDOR_DIR = Path(__file__).resolve().parent / ".fla_vendor" +_VENDOR_PKGS = [ + "triton==3.2.0", + "flash-linear-attention==0.4.2", + "fla-core==0.4.2", + "transformers==5.5.4", + "tokenizers==0.22.2", + "safetensors==0.7.0", +] +if ARCH_MODE in ("F", "G"): + _VENDOR_PKGS.extend( + [ + "mamba-ssm==2.3.1", + "causal-conv1d==1.6.1", + ] + ) + + +def _ensure_vendor_on_path() -> None: + p = str(_VENDOR_DIR) + if p not in sys.path: + sys.path.insert(0, p) + + +def _ensure_fla_vendor_available() -> None: + _ensure_vendor_on_path() + try: + if ARCH_MODE in ("F", "G"): + from fla.layers.mamba2 import Mamba2 # noqa: F401 + from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined # noqa: F401 + from causal_conv1d import causal_conv1d_fn # noqa: F401 + else: + from fla.layers.gated_deltanet import GatedDeltaNet # noqa: F401 + print("wrapper: local vendored FLA imports already work", flush=True) + return + except Exception: + vendor_pkgs = ", ".join(_VENDOR_PKGS) + raise RuntimeError( + "wrapper: required FLA deps are missing from the local environment. " + f"Expected vendored packages under {_VENDOR_DIR}. " + f"Install them before evaluation (e.g. via launcher/requirements), packages: {vendor_pkgs}" + ) + + +def main(): + _ensure_fla_vendor_available() + print("wrapper: importing train_gdn_7k", flush=True) + try: + from train_gdn_7k import main as train_main + except Exception: + traceback.print_exc() + raise + print("wrapper: import ok, entering train_main", flush=True) + train_main() + + +if __name__ == "__main__": + main() diff --git a/records/track_10min_16mb/2026-04-19_GatedDeltaNet_MacroPhase_Brotli_LegalTTT/train_seed314.log b/records/track_10min_16mb/2026-04-19_GatedDeltaNet_MacroPhase_Brotli_LegalTTT/train_seed314.log new file mode 100644 index 0000000000..3554bbbd3a --- /dev/null +++ b/records/track_10min_16mb/2026-04-19_GatedDeltaNet_MacroPhase_Brotli_LegalTTT/train_seed314.log @@ -0,0 +1,223 @@ +W0419 04:20:42.997000 484835 torch/distributed/run.py:803] +W0419 04:20:42.997000 484835 torch/distributed/run.py:803] ***************************************** +W0419 04:20:42.997000 484835 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0419 04:20:42.997000 484835 torch/distributed/run.py:803] ***************************************** +=== GDN Hybrid 7k Full Training === +Arch: K_KVShare_Wider (ARCH_MODE=K) +Seed: 314, Steps: 7000, Warmdown: 2100 +World size: 8, Grad accum: 1 +EMA decay: 0.997, SWA: True (every 50) +Late QAT threshold: 0.15 +Eval compile enabled: False +Validation tokens: 40,540,160 +Model built in 0.2s +Parameters: {'embedding': 4861441, 'recurrent': 11269920, 'attention': 5440, 'mlp': 17761600, 'other': 11424, 'total': 33909825} +Total params: 33,909,825 +Matrix params: 29,400,192 +Scalar params: 53,185 +Embed params: 4,456,448 +Generated coprime shard order: stride across 80 shards + +================================================================================ +Starting training: 7000 steps (from step 0) +================================================================================ + +[rank6]:[W419 04:20:58.949907071 reducer.cpp:1431] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator()) +[rank5]:[W419 04:20:58.025477490 reducer.cpp:1431] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator()) +[rank0]:[W419 04:20:58.147229703 reducer.cpp:1431] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator()) +[rank2]:[W419 04:20:58.148824722 reducer.cpp:1431] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator()) +[rank3]:[W419 04:20:58.149562581 reducer.cpp:1431] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator()) +[rank1]:[W419 04:20:58.155516603 reducer.cpp:1431] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator()) +[rank7]:[W419 04:20:58.166621466 reducer.cpp:1431] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator()) +[rank4]:[W419 04:20:58.168280596 reducer.cpp:1431] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator()) +step 1/7000 | loss 9.0036 | lr_mul 0.0500 | mom 0.850 | 0.37 steps/s | 3s +step 2/7000 | loss 8.7877 | lr_mul 0.1000 | mom 0.850 | 0.67 steps/s | 3s +step 3/7000 | loss 8.2154 | lr_mul 0.1500 | mom 0.851 | 0.93 steps/s | 3s +step 4/7000 | loss 7.6211 | lr_mul 0.2000 | mom 0.851 | 1.15 steps/s | 3s +step 5/7000 | loss 7.4127 | lr_mul 0.2500 | mom 0.851 | 1.34 steps/s | 4s +step 6/7000 | loss 7.2753 | lr_mul 0.3000 | mom 0.851 | 1.51 steps/s | 4s +step 7/7000 | loss 7.3712 | lr_mul 0.3500 | mom 0.851 | 1.66 steps/s | 4s +step 8/7000 | loss 7.2815 | lr_mul 0.4000 | mom 0.852 | 1.79 steps/s | 4s +step 9/7000 | loss 7.0685 | lr_mul 0.4500 | mom 0.852 | 1.90 steps/s | 5s +step 10/7000 | loss 6.7714 | lr_mul 0.5000 | mom 0.852 | 2.01 steps/s | 5s +step 100/7000 | loss 5.0945 | lr_mul 1.0000 | mom 0.870 | 3.62 steps/s | 28s +step 200/7000 | loss 4.1632 | lr_mul 1.0000 | mom 0.890 | 3.74 steps/s | 53s +step 300/7000 | loss 3.7829 | lr_mul 1.0000 | mom 0.910 | 3.78 steps/s | 79s +step 400/7000 | loss 3.6376 | lr_mul 1.0000 | mom 0.930 | 3.80 steps/s | 105s +step 500/7000 | loss 3.5756 | lr_mul 1.0000 | mom 0.950 | 3.84 steps/s | 130s +step 600/7000 | loss 3.4631 | lr_mul 1.0000 | mom 0.950 | 3.84 steps/s | 156s +step 700/7000 | loss 3.4292 | lr_mul 1.0000 | mom 0.950 | 3.85 steps/s | 182s +step 800/7000 | loss 3.3958 | lr_mul 1.0000 | mom 0.950 | 3.85 steps/s | 208s +step 900/7000 | loss 3.3539 | lr_mul 1.0000 | mom 0.950 | 3.85 steps/s | 234s +step 1000/7000 | loss 3.3273 | lr_mul 1.0000 | mom 0.950 | 3.86 steps/s | 259s +step 1100/7000 | loss 3.3099 | lr_mul 1.0000 | mom 0.950 | 3.86 steps/s | 285s +step 1200/7000 | loss 3.2847 | lr_mul 1.0000 | mom 0.950 | 3.86 steps/s | 311s +step 1300/7000 | loss 3.2886 | lr_mul 1.0000 | mom 0.950 | 3.86 steps/s | 336s +step 1400/7000 | loss 3.3030 | lr_mul 1.0000 | mom 0.950 | 3.87 steps/s | 362s +step 1500/7000 | loss 3.2575 | lr_mul 1.0000 | mom 0.950 | 3.87 steps/s | 387s +step 1600/7000 | loss 3.2436 | lr_mul 1.0000 | mom 0.950 | 3.87 steps/s | 413s +step 1700/7000 | loss 3.2508 | lr_mul 1.0000 | mom 0.950 | 3.87 steps/s | 439s +step 1800/7000 | loss 3.2339 | lr_mul 1.0000 | mom 0.950 | 3.87 steps/s | 465s +step 1900/7000 | loss 3.2364 | lr_mul 1.0000 | mom 0.950 | 3.88 steps/s | 490s +step 2000/7000 | loss 3.2252 | lr_mul 1.0000 | mom 0.950 | 3.88 steps/s | 516s +step 2100/7000 | loss 3.2294 | lr_mul 1.0000 | mom 0.950 | 3.88 steps/s | 542s +step 2200/7000 | loss 3.2011 | lr_mul 1.0000 | mom 0.950 | 3.88 steps/s | 568s +step 2300/7000 | loss 3.2101 | lr_mul 1.0000 | mom 0.950 | 3.88 steps/s | 593s +Wallclock limit reached (600s), will stop after this step +Stopping early at step 2328 (wallclock limit) + +Training complete in 600s +Peak memory: 41126 MiB + +=== Applying EMA weights === +EMA BPB (no XSA): 1.000334 +Saved raw EMA model + +=== Quantizing to int6 + zstd-22 === +Artifact: 15,527,172 bytes (14.81 MB) + +=== Roundtrip Validation (quantized model) === +Quantized BPB (no XSA): 1.019032 +Quantization degradation: +0.018697 + +================================================================================ +FINAL RESULTS — K_KVShare_Wider seed=314 + Training: 7000 steps, 600s + EMA BPB (fp32): 1.000334 + Quantized BPB: 1.019032 + Artifact size: 15,527,172 bytes +================================================================================ +final_int6_roundtrip_exact val_loss:3.09828560 val_bpb:1.01903158 + +=== Legal Score-First TTT (GDN) === +ttt_gdn:start chunks=1238 chunk_tokens=32768 total_windows=633440 stride=64 ttt_lr=0.005 ttt_epochs=3 freeze_blocks=2 +ttt_gdn:params unfrozen=28100257 frozen=5809568 + ttt_chunk [1/1238] bpb=1.029315 time=0.3s + ttt_chunk [11/1238] bpb=1.002259 time=2.5s + ttt_chunk [21/1238] bpb=1.034535 time=4.7s + ttt_chunk [31/1238] bpb=1.033169 time=6.9s + ttt_chunk [41/1238] bpb=1.028416 time=9.1s + ttt_chunk [51/1238] bpb=1.024284 time=11.3s + ttt_chunk [61/1238] bpb=1.018682 time=13.5s + ttt_chunk [71/1238] bpb=1.023913 time=15.7s + ttt_chunk [81/1238] bpb=1.019526 time=17.9s + ttt_chunk [91/1238] bpb=1.016725 time=20.1s + ttt_chunk [101/1238] bpb=1.015657 time=22.4s + ttt_chunk [111/1238] bpb=1.014952 time=24.6s + ttt_chunk [121/1238] bpb=1.017140 time=26.8s + ttt_chunk [131/1238] bpb=1.020095 time=29.0s + ttt_chunk [141/1238] bpb=1.020178 time=31.2s + ttt_chunk [151/1238] bpb=1.020117 time=33.4s + ttt_chunk [161/1238] bpb=1.020523 time=35.6s + ttt_chunk [171/1238] bpb=1.020449 time=37.8s + ttt_chunk [181/1238] bpb=1.019266 time=40.0s + ttt_chunk [191/1238] bpb=1.018462 time=42.2s + ttt_chunk [201/1238] bpb=1.016276 time=44.4s + ttt_chunk [211/1238] bpb=1.019626 time=46.6s + ttt_chunk [221/1238] bpb=1.019026 time=48.8s + ttt_chunk [231/1238] bpb=1.020403 time=51.0s + ttt_chunk [241/1238] bpb=1.019741 time=53.2s + ttt_chunk [251/1238] bpb=1.019765 time=55.4s + ttt_chunk [261/1238] bpb=1.020071 time=57.6s + ttt_chunk [271/1238] bpb=1.020093 time=59.8s + ttt_chunk [281/1238] bpb=1.019162 time=62.0s + ttt_chunk [291/1238] bpb=1.019753 time=64.2s + ttt_chunk [301/1238] bpb=1.019607 time=66.4s + ttt_chunk [311/1238] bpb=1.018194 time=68.6s + ttt_chunk [321/1238] bpb=1.018094 time=70.8s + ttt_chunk [331/1238] bpb=1.018139 time=73.0s + ttt_chunk [341/1238] bpb=1.017400 time=75.2s + ttt_chunk [351/1238] bpb=1.018072 time=77.4s + ttt_chunk [361/1238] bpb=1.016916 time=79.6s + ttt_chunk [371/1238] bpb=1.015508 time=81.8s + ttt_chunk [381/1238] bpb=1.015386 time=84.0s + ttt_chunk [391/1238] bpb=1.014838 time=86.2s + ttt_chunk [401/1238] bpb=1.014526 time=88.5s + ttt_chunk [411/1238] bpb=1.014840 time=90.7s + ttt_chunk [421/1238] bpb=1.014268 time=92.9s + ttt_chunk [431/1238] bpb=1.014207 time=95.1s + ttt_chunk [441/1238] bpb=1.014276 time=97.3s + ttt_chunk [451/1238] bpb=1.015273 time=99.5s + ttt_chunk [461/1238] bpb=1.013816 time=101.7s + ttt_chunk [471/1238] bpb=1.013760 time=103.9s + ttt_chunk [481/1238] bpb=1.013847 time=106.1s + ttt_chunk [491/1238] bpb=1.014133 time=108.3s + ttt_chunk [501/1238] bpb=1.013712 time=110.5s + ttt_chunk [511/1238] bpb=1.013506 time=112.7s + ttt_chunk [521/1238] bpb=1.013267 time=114.9s + ttt_chunk [531/1238] bpb=1.013250 time=117.1s + ttt_chunk [541/1238] bpb=1.013350 time=119.3s + ttt_chunk [551/1238] bpb=1.013087 time=121.5s + ttt_chunk [561/1238] bpb=1.012728 time=123.7s + ttt_chunk [571/1238] bpb=1.012193 time=126.0s + ttt_chunk [581/1238] bpb=1.012353 time=128.2s + ttt_chunk [591/1238] bpb=1.012529 time=130.4s + ttt_chunk [601/1238] bpb=1.012589 time=132.6s + ttt_chunk [611/1238] bpb=1.013041 time=134.8s + ttt_chunk [621/1238] bpb=1.013708 time=137.0s + ttt_chunk [631/1238] bpb=1.013601 time=139.2s + ttt_chunk [641/1238] bpb=1.013751 time=141.4s + ttt_chunk [651/1238] bpb=1.013958 time=143.6s + ttt_chunk [661/1238] bpb=1.013259 time=145.8s + ttt_chunk [671/1238] bpb=1.012908 time=148.0s + ttt_chunk [681/1238] bpb=1.013953 time=150.2s + ttt_chunk [691/1238] bpb=1.013801 time=152.5s + ttt_chunk [701/1238] bpb=1.013467 time=154.7s + ttt_chunk [711/1238] bpb=1.013902 time=156.9s + ttt_chunk [721/1238] bpb=1.014087 time=159.1s + ttt_chunk [731/1238] bpb=1.013518 time=161.3s + ttt_chunk [741/1238] bpb=1.013363 time=163.5s + ttt_chunk [751/1238] bpb=1.012592 time=165.7s + ttt_chunk [761/1238] bpb=1.011896 time=167.9s + ttt_chunk [771/1238] bpb=1.011043 time=170.1s + ttt_chunk [781/1238] bpb=1.010933 time=172.3s + ttt_chunk [791/1238] bpb=1.011193 time=174.5s + ttt_chunk [801/1238] bpb=1.011146 time=176.7s + ttt_chunk [811/1238] bpb=1.010628 time=179.0s + ttt_chunk [821/1238] bpb=1.009853 time=181.2s + ttt_chunk [831/1238] bpb=1.009671 time=183.4s + ttt_chunk [841/1238] bpb=1.009346 time=185.6s + ttt_chunk [851/1238] bpb=1.009021 time=187.8s + ttt_chunk [861/1238] bpb=1.008355 time=190.0s + ttt_chunk [871/1238] bpb=1.008112 time=192.2s + ttt_chunk [881/1238] bpb=1.007749 time=194.4s + ttt_chunk [891/1238] bpb=1.007270 time=196.6s + ttt_chunk [901/1238] bpb=1.006838 time=198.8s + ttt_chunk [911/1238] bpb=1.006680 time=201.0s + ttt_chunk [921/1238] bpb=1.007008 time=203.2s + ttt_chunk [931/1238] bpb=1.007692 time=205.4s + ttt_chunk [941/1238] bpb=1.008078 time=207.7s + ttt_chunk [951/1238] bpb=1.008001 time=209.9s + ttt_chunk [961/1238] bpb=1.008615 time=212.1s + ttt_chunk [971/1238] bpb=1.008646 time=214.3s + ttt_chunk [981/1238] bpb=1.009018 time=216.5s + ttt_chunk [991/1238] bpb=1.008870 time=218.7s + ttt_chunk [1001/1238] bpb=1.009122 time=220.9s + ttt_chunk [1011/1238] bpb=1.009435 time=223.1s + ttt_chunk [1021/1238] bpb=1.009983 time=225.3s + ttt_chunk [1031/1238] bpb=1.010465 time=227.5s + ttt_chunk [1041/1238] bpb=1.010655 time=229.7s + ttt_chunk [1051/1238] bpb=1.010542 time=232.0s + ttt_chunk [1061/1238] bpb=1.010603 time=234.2s + ttt_chunk [1071/1238] bpb=1.010677 time=236.4s + ttt_chunk [1081/1238] bpb=1.010587 time=238.6s + ttt_chunk [1091/1238] bpb=1.010692 time=240.8s + ttt_chunk [1101/1238] bpb=1.011093 time=243.0s + ttt_chunk [1111/1238] bpb=1.011312 time=245.2s + ttt_chunk [1121/1238] bpb=1.011483 time=247.4s + ttt_chunk [1131/1238] bpb=1.011132 time=249.6s + ttt_chunk [1141/1238] bpb=1.010806 time=251.8s + ttt_chunk [1151/1238] bpb=1.010858 time=254.0s + ttt_chunk [1161/1238] bpb=1.010998 time=256.3s + ttt_chunk [1171/1238] bpb=1.010790 time=258.5s + ttt_chunk [1181/1238] bpb=1.010428 time=260.7s + ttt_chunk [1191/1238] bpb=1.010565 time=262.9s + ttt_chunk [1201/1238] bpb=1.010764 time=265.1s + ttt_chunk [1211/1238] bpb=1.010565 time=267.3s + ttt_chunk [1221/1238] bpb=1.010138 time=269.5s + ttt_chunk [1231/1238] bpb=1.009858 time=271.7s + ttt_chunk [1238/1238] bpb=1.009840 time=273.1s +ttt_gdn:done val_loss=3.070153 val_bpb=1.009779 elapsed=273.1s +TTT BPB: 1.009779 (delta: -0.009253) +TTT eval time: 273.4s +final_int6_ttt_exact val_loss:3.07015325 val_bpb:1.00977880 diff --git a/records/track_10min_16mb/2026-04-19_GatedDeltaNet_MacroPhase_Brotli_LegalTTT/train_seed42.log b/records/track_10min_16mb/2026-04-19_GatedDeltaNet_MacroPhase_Brotli_LegalTTT/train_seed42.log new file mode 100644 index 0000000000..996a01ae55 --- /dev/null +++ b/records/track_10min_16mb/2026-04-19_GatedDeltaNet_MacroPhase_Brotli_LegalTTT/train_seed42.log @@ -0,0 +1,223 @@ +W0419 04:02:48.798000 482222 torch/distributed/run.py:803] +W0419 04:02:48.798000 482222 torch/distributed/run.py:803] ***************************************** +W0419 04:02:48.798000 482222 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0419 04:02:48.798000 482222 torch/distributed/run.py:803] ***************************************** +=== GDN Hybrid 7k Full Training === +Arch: K_KVShare_Wider (ARCH_MODE=K) +Seed: 42, Steps: 7000, Warmdown: 2100 +World size: 8, Grad accum: 1 +EMA decay: 0.997, SWA: True (every 50) +Late QAT threshold: 0.15 +Eval compile enabled: False +Validation tokens: 40,540,160 +Model built in 0.2s +Parameters: {'embedding': 4861441, 'recurrent': 11269920, 'attention': 5440, 'mlp': 17761600, 'other': 11424, 'total': 33909825} +Total params: 33,909,825 +Matrix params: 29,400,192 +Scalar params: 53,185 +Embed params: 4,456,448 +Generated coprime shard order: stride across 80 shards + +================================================================================ +Starting training: 7000 steps (from step 0) +================================================================================ + +[rank3]:[W419 04:03:03.370806367 reducer.cpp:1431] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator()) +[rank5]:[W419 04:03:03.396195014 reducer.cpp:1431] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator()) +[rank1]:[W419 04:03:03.449282756 reducer.cpp:1431] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator()) +[rank4]:[W419 04:03:03.500485451 reducer.cpp:1431] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator()) +[rank6]:[W419 04:03:03.537735764 reducer.cpp:1431] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator()) +[rank0]:[W419 04:03:03.540326742 reducer.cpp:1431] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator()) +[rank7]:[W419 04:03:03.552440030 reducer.cpp:1431] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator()) +[rank2]:[W419 04:03:03.566166981 reducer.cpp:1431] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator()) +step 1/7000 | loss 9.0039 | lr_mul 0.0500 | mom 0.850 | 0.37 steps/s | 3s +step 2/7000 | loss 8.7907 | lr_mul 0.1000 | mom 0.850 | 0.68 steps/s | 3s +step 3/7000 | loss 8.0857 | lr_mul 0.1500 | mom 0.851 | 0.94 steps/s | 3s +step 4/7000 | loss 7.6596 | lr_mul 0.2000 | mom 0.851 | 1.17 steps/s | 3s +step 5/7000 | loss 7.3429 | lr_mul 0.2500 | mom 0.851 | 1.36 steps/s | 4s +step 6/7000 | loss 7.3719 | lr_mul 0.3000 | mom 0.851 | 1.53 steps/s | 4s +step 7/7000 | loss 7.3596 | lr_mul 0.3500 | mom 0.851 | 1.67 steps/s | 4s +step 8/7000 | loss 7.1506 | lr_mul 0.4000 | mom 0.852 | 1.81 steps/s | 4s +step 9/7000 | loss 7.0397 | lr_mul 0.4500 | mom 0.852 | 1.92 steps/s | 5s +step 10/7000 | loss 6.8402 | lr_mul 0.5000 | mom 0.852 | 2.03 steps/s | 5s +step 100/7000 | loss 5.0960 | lr_mul 1.0000 | mom 0.870 | 3.64 steps/s | 27s +step 200/7000 | loss 4.1433 | lr_mul 1.0000 | mom 0.890 | 3.75 steps/s | 53s +step 300/7000 | loss 3.7673 | lr_mul 1.0000 | mom 0.910 | 3.80 steps/s | 79s +step 400/7000 | loss 3.6307 | lr_mul 1.0000 | mom 0.930 | 3.82 steps/s | 105s +step 500/7000 | loss 3.5134 | lr_mul 1.0000 | mom 0.950 | 3.85 steps/s | 130s +step 600/7000 | loss 3.4756 | lr_mul 1.0000 | mom 0.950 | 3.85 steps/s | 156s +step 700/7000 | loss 3.4136 | lr_mul 1.0000 | mom 0.950 | 3.86 steps/s | 181s +step 800/7000 | loss 3.3883 | lr_mul 1.0000 | mom 0.950 | 3.86 steps/s | 207s +step 900/7000 | loss 3.3626 | lr_mul 1.0000 | mom 0.950 | 3.86 steps/s | 233s +step 1000/7000 | loss 3.3507 | lr_mul 1.0000 | mom 0.950 | 3.88 steps/s | 258s +step 1100/7000 | loss 3.3221 | lr_mul 1.0000 | mom 0.950 | 3.88 steps/s | 284s +step 1200/7000 | loss 3.3204 | lr_mul 1.0000 | mom 0.950 | 3.88 steps/s | 310s +step 1300/7000 | loss 3.3064 | lr_mul 1.0000 | mom 0.950 | 3.88 steps/s | 335s +step 1400/7000 | loss 3.2742 | lr_mul 1.0000 | mom 0.950 | 3.88 steps/s | 361s +step 1500/7000 | loss 3.2527 | lr_mul 1.0000 | mom 0.950 | 3.88 steps/s | 386s +step 1600/7000 | loss 3.2461 | lr_mul 1.0000 | mom 0.950 | 3.88 steps/s | 412s +step 1700/7000 | loss 3.2367 | lr_mul 1.0000 | mom 0.950 | 3.88 steps/s | 438s +step 1800/7000 | loss 3.2245 | lr_mul 1.0000 | mom 0.950 | 3.88 steps/s | 463s +step 1900/7000 | loss 3.2353 | lr_mul 1.0000 | mom 0.950 | 3.89 steps/s | 489s +step 2000/7000 | loss 3.2432 | lr_mul 1.0000 | mom 0.950 | 3.89 steps/s | 514s +step 2100/7000 | loss 3.2076 | lr_mul 1.0000 | mom 0.950 | 3.89 steps/s | 540s +step 2200/7000 | loss 3.2182 | lr_mul 1.0000 | mom 0.950 | 3.89 steps/s | 566s +step 2300/7000 | loss 3.2134 | lr_mul 1.0000 | mom 0.950 | 3.89 steps/s | 592s +Wallclock limit reached (600s), will stop after this step +Stopping early at step 2335 (wallclock limit) + +Training complete in 600s +Peak memory: 41126 MiB + +=== Applying EMA weights === +EMA BPB (no XSA): 1.002574 +Saved raw EMA model + +=== Quantizing to int6 + zstd-22 === +Artifact: 15,543,829 bytes (14.82 MB) + +=== Roundtrip Validation (quantized model) === +Quantized BPB (no XSA): 1.021889 +Quantization degradation: +0.019315 + +================================================================================ +FINAL RESULTS — K_KVShare_Wider seed=42 + Training: 7000 steps, 600s + EMA BPB (fp32): 1.002574 + Quantized BPB: 1.021889 + Artifact size: 15,543,829 bytes +================================================================================ +final_int6_roundtrip_exact val_loss:3.10697387 val_bpb:1.02188917 + +=== Legal Score-First TTT (GDN) === +ttt_gdn:start chunks=1238 chunk_tokens=32768 total_windows=633440 stride=64 ttt_lr=0.005 ttt_epochs=3 freeze_blocks=2 +ttt_gdn:params unfrozen=28100257 frozen=5809568 + ttt_chunk [1/1238] bpb=1.035357 time=0.3s + ttt_chunk [11/1238] bpb=1.004607 time=2.6s + ttt_chunk [21/1238] bpb=1.037103 time=4.8s + ttt_chunk [31/1238] bpb=1.035531 time=7.0s + ttt_chunk [41/1238] bpb=1.030595 time=9.2s + ttt_chunk [51/1238] bpb=1.026634 time=11.4s + ttt_chunk [61/1238] bpb=1.020815 time=13.6s + ttt_chunk [71/1238] bpb=1.025739 time=15.9s + ttt_chunk [81/1238] bpb=1.021453 time=18.1s + ttt_chunk [91/1238] bpb=1.018654 time=20.3s + ttt_chunk [101/1238] bpb=1.017750 time=22.5s + ttt_chunk [111/1238] bpb=1.016962 time=24.7s + ttt_chunk [121/1238] bpb=1.019223 time=27.0s + ttt_chunk [131/1238] bpb=1.022301 time=29.2s + ttt_chunk [141/1238] bpb=1.022330 time=31.4s + ttt_chunk [151/1238] bpb=1.022191 time=33.6s + ttt_chunk [161/1238] bpb=1.022552 time=35.8s + ttt_chunk [171/1238] bpb=1.022569 time=38.1s + ttt_chunk [181/1238] bpb=1.021325 time=40.3s + ttt_chunk [191/1238] bpb=1.020621 time=42.5s + ttt_chunk [201/1238] bpb=1.018391 time=44.7s + ttt_chunk [211/1238] bpb=1.021845 time=47.0s + ttt_chunk [221/1238] bpb=1.021232 time=49.2s + ttt_chunk [231/1238] bpb=1.022640 time=51.4s + ttt_chunk [241/1238] bpb=1.021950 time=53.6s + ttt_chunk [251/1238] bpb=1.022021 time=55.8s + ttt_chunk [261/1238] bpb=1.022306 time=58.1s + ttt_chunk [271/1238] bpb=1.022379 time=60.3s + ttt_chunk [281/1238] bpb=1.021519 time=62.5s + ttt_chunk [291/1238] bpb=1.022126 time=64.7s + ttt_chunk [301/1238] bpb=1.021909 time=66.9s + ttt_chunk [311/1238] bpb=1.020493 time=69.2s + ttt_chunk [321/1238] bpb=1.020346 time=71.4s + ttt_chunk [331/1238] bpb=1.020397 time=73.6s + ttt_chunk [341/1238] bpb=1.019656 time=75.8s + ttt_chunk [351/1238] bpb=1.020352 time=78.1s + ttt_chunk [361/1238] bpb=1.019289 time=80.3s + ttt_chunk [371/1238] bpb=1.017811 time=82.5s + ttt_chunk [381/1238] bpb=1.017716 time=84.7s + ttt_chunk [391/1238] bpb=1.017157 time=86.9s + ttt_chunk [401/1238] bpb=1.016836 time=89.2s + ttt_chunk [411/1238] bpb=1.017198 time=91.4s + ttt_chunk [421/1238] bpb=1.016630 time=93.6s + ttt_chunk [431/1238] bpb=1.016545 time=95.8s + ttt_chunk [441/1238] bpb=1.016589 time=98.0s + ttt_chunk [451/1238] bpb=1.017589 time=100.3s + ttt_chunk [461/1238] bpb=1.016171 time=102.5s + ttt_chunk [471/1238] bpb=1.016117 time=104.7s + ttt_chunk [481/1238] bpb=1.016173 time=106.9s + ttt_chunk [491/1238] bpb=1.016429 time=109.1s + ttt_chunk [501/1238] bpb=1.015989 time=111.3s + ttt_chunk [511/1238] bpb=1.015830 time=113.5s + ttt_chunk [521/1238] bpb=1.015579 time=115.8s + ttt_chunk [531/1238] bpb=1.015540 time=118.0s + ttt_chunk [541/1238] bpb=1.015652 time=120.2s + ttt_chunk [551/1238] bpb=1.015367 time=122.4s + ttt_chunk [561/1238] bpb=1.015010 time=124.6s + ttt_chunk [571/1238] bpb=1.014464 time=126.8s + ttt_chunk [581/1238] bpb=1.014617 time=129.1s + ttt_chunk [591/1238] bpb=1.014794 time=131.3s + ttt_chunk [601/1238] bpb=1.014875 time=133.5s + ttt_chunk [611/1238] bpb=1.015326 time=135.7s + ttt_chunk [621/1238] bpb=1.016016 time=137.9s + ttt_chunk [631/1238] bpb=1.015916 time=140.2s + ttt_chunk [641/1238] bpb=1.016070 time=142.4s + ttt_chunk [651/1238] bpb=1.016297 time=144.6s + ttt_chunk [661/1238] bpb=1.015589 time=146.8s + ttt_chunk [671/1238] bpb=1.015184 time=149.0s + ttt_chunk [681/1238] bpb=1.016265 time=151.3s + ttt_chunk [691/1238] bpb=1.016128 time=153.5s + ttt_chunk [701/1238] bpb=1.015792 time=155.7s + ttt_chunk [711/1238] bpb=1.016251 time=157.9s + ttt_chunk [721/1238] bpb=1.016462 time=160.2s + ttt_chunk [731/1238] bpb=1.015894 time=162.4s + ttt_chunk [741/1238] bpb=1.015762 time=164.6s + ttt_chunk [751/1238] bpb=1.014995 time=166.8s + ttt_chunk [761/1238] bpb=1.014298 time=169.0s + ttt_chunk [771/1238] bpb=1.013432 time=171.3s + ttt_chunk [781/1238] bpb=1.013322 time=173.5s + ttt_chunk [791/1238] bpb=1.013584 time=175.7s + ttt_chunk [801/1238] bpb=1.013576 time=177.9s + ttt_chunk [811/1238] bpb=1.013028 time=180.2s + ttt_chunk [821/1238] bpb=1.012222 time=182.4s + ttt_chunk [831/1238] bpb=1.012046 time=184.6s + ttt_chunk [841/1238] bpb=1.011696 time=186.8s + ttt_chunk [851/1238] bpb=1.011365 time=189.0s + ttt_chunk [861/1238] bpb=1.010738 time=191.3s + ttt_chunk [871/1238] bpb=1.010506 time=193.5s + ttt_chunk [881/1238] bpb=1.010148 time=195.7s + ttt_chunk [891/1238] bpb=1.009664 time=197.9s + ttt_chunk [901/1238] bpb=1.009231 time=200.1s + ttt_chunk [911/1238] bpb=1.009109 time=202.4s + ttt_chunk [921/1238] bpb=1.009399 time=204.6s + ttt_chunk [931/1238] bpb=1.010058 time=206.8s + ttt_chunk [941/1238] bpb=1.010421 time=209.0s + ttt_chunk [951/1238] bpb=1.010360 time=211.2s + ttt_chunk [961/1238] bpb=1.010969 time=213.5s + ttt_chunk [971/1238] bpb=1.010987 time=215.7s + ttt_chunk [981/1238] bpb=1.011331 time=217.9s + ttt_chunk [991/1238] bpb=1.011160 time=220.1s + ttt_chunk [1001/1238] bpb=1.011365 time=222.3s + ttt_chunk [1011/1238] bpb=1.011664 time=224.5s + ttt_chunk [1021/1238] bpb=1.012226 time=226.8s + ttt_chunk [1031/1238] bpb=1.012704 time=229.0s + ttt_chunk [1041/1238] bpb=1.012897 time=231.2s + ttt_chunk [1051/1238] bpb=1.012781 time=233.4s + ttt_chunk [1061/1238] bpb=1.012851 time=235.6s + ttt_chunk [1071/1238] bpb=1.012924 time=237.8s + ttt_chunk [1081/1238] bpb=1.012822 time=240.1s + ttt_chunk [1091/1238] bpb=1.012917 time=242.3s + ttt_chunk [1101/1238] bpb=1.013329 time=244.5s + ttt_chunk [1111/1238] bpb=1.013526 time=246.7s + ttt_chunk [1121/1238] bpb=1.013686 time=248.9s + ttt_chunk [1131/1238] bpb=1.013338 time=251.1s + ttt_chunk [1141/1238] bpb=1.013013 time=253.4s + ttt_chunk [1151/1238] bpb=1.013063 time=255.6s + ttt_chunk [1161/1238] bpb=1.013196 time=257.8s + ttt_chunk [1171/1238] bpb=1.012984 time=260.0s + ttt_chunk [1181/1238] bpb=1.012620 time=262.2s + ttt_chunk [1191/1238] bpb=1.012770 time=264.5s + ttt_chunk [1201/1238] bpb=1.012999 time=266.7s + ttt_chunk [1211/1238] bpb=1.012823 time=268.9s + ttt_chunk [1221/1238] bpb=1.012404 time=271.1s + ttt_chunk [1231/1238] bpb=1.012133 time=273.3s + ttt_chunk [1238/1238] bpb=1.012123 time=274.7s +ttt_gdn:done val_loss=3.077058 val_bpb=1.012050 elapsed=274.7s +TTT BPB: 1.012050 (delta: -0.009839) +TTT eval time: 275.0s +final_int6_ttt_exact val_loss:3.07705825 val_bpb:1.01204987 diff --git a/records/track_10min_16mb/2026-04-19_GatedDeltaNet_MacroPhase_Brotli_LegalTTT/train_seed999.log b/records/track_10min_16mb/2026-04-19_GatedDeltaNet_MacroPhase_Brotli_LegalTTT/train_seed999.log new file mode 100644 index 0000000000..948829610e --- /dev/null +++ b/records/track_10min_16mb/2026-04-19_GatedDeltaNet_MacroPhase_Brotli_LegalTTT/train_seed999.log @@ -0,0 +1,223 @@ +W0419 04:38:33.725000 486658 torch/distributed/run.py:803] +W0419 04:38:33.725000 486658 torch/distributed/run.py:803] ***************************************** +W0419 04:38:33.725000 486658 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0419 04:38:33.725000 486658 torch/distributed/run.py:803] ***************************************** +=== GDN Hybrid 7k Full Training === +Arch: K_KVShare_Wider (ARCH_MODE=K) +Seed: 999, Steps: 7000, Warmdown: 2100 +World size: 8, Grad accum: 1 +EMA decay: 0.997, SWA: True (every 50) +Late QAT threshold: 0.15 +Eval compile enabled: False +Validation tokens: 40,540,160 +Model built in 0.2s +Parameters: {'embedding': 4861441, 'recurrent': 11269920, 'attention': 5440, 'mlp': 17761600, 'other': 11424, 'total': 33909825} +Total params: 33,909,825 +Matrix params: 29,400,192 +Scalar params: 53,185 +Embed params: 4,456,448 +Generated coprime shard order: stride across 80 shards + +================================================================================ +Starting training: 7000 steps (from step 0) +================================================================================ + +[rank1]:[W419 04:38:48.400500660 reducer.cpp:1431] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator()) +[rank2]:[W419 04:38:48.423444414 reducer.cpp:1431] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator()) +[rank4]:[W419 04:38:48.471432166 reducer.cpp:1431] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator()) +[rank0]:[W419 04:38:48.547547724 reducer.cpp:1431] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator()) +[rank3]:[W419 04:38:48.564511049 reducer.cpp:1431] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator()) +[rank6]:[W419 04:38:48.567769685 reducer.cpp:1431] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator()) +[rank7]:[W419 04:38:48.579583479 reducer.cpp:1431] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator()) +[rank5]:[W419 04:38:48.589388057 reducer.cpp:1431] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator()) +step 1/7000 | loss 9.0038 | lr_mul 0.0500 | mom 0.850 | 0.37 steps/s | 3s +step 2/7000 | loss 8.7771 | lr_mul 0.1000 | mom 0.850 | 0.68 steps/s | 3s +step 3/7000 | loss 8.1668 | lr_mul 0.1500 | mom 0.851 | 0.94 steps/s | 3s +step 4/7000 | loss 7.6362 | lr_mul 0.2000 | mom 0.851 | 1.16 steps/s | 3s +step 5/7000 | loss 7.3105 | lr_mul 0.2500 | mom 0.851 | 1.35 steps/s | 4s +step 6/7000 | loss 7.3627 | lr_mul 0.3000 | mom 0.851 | 1.52 steps/s | 4s +step 7/7000 | loss 7.4215 | lr_mul 0.3500 | mom 0.851 | 1.67 steps/s | 4s +step 8/7000 | loss 7.2637 | lr_mul 0.4000 | mom 0.852 | 1.80 steps/s | 4s +step 9/7000 | loss 7.0691 | lr_mul 0.4500 | mom 0.852 | 1.91 steps/s | 5s +step 10/7000 | loss 6.8828 | lr_mul 0.5000 | mom 0.852 | 2.02 steps/s | 5s +step 100/7000 | loss 5.1178 | lr_mul 1.0000 | mom 0.870 | 3.64 steps/s | 27s +step 200/7000 | loss 4.1826 | lr_mul 1.0000 | mom 0.890 | 3.75 steps/s | 53s +step 300/7000 | loss 3.7953 | lr_mul 1.0000 | mom 0.910 | 3.79 steps/s | 79s +step 400/7000 | loss 3.6147 | lr_mul 1.0000 | mom 0.930 | 3.81 steps/s | 105s +step 500/7000 | loss 3.5267 | lr_mul 1.0000 | mom 0.950 | 3.85 steps/s | 130s +step 600/7000 | loss 3.4603 | lr_mul 1.0000 | mom 0.950 | 3.85 steps/s | 156s +step 700/7000 | loss 3.4323 | lr_mul 1.0000 | mom 0.950 | 3.86 steps/s | 182s +step 800/7000 | loss 3.3694 | lr_mul 1.0000 | mom 0.950 | 3.86 steps/s | 207s +step 900/7000 | loss 3.3471 | lr_mul 1.0000 | mom 0.950 | 3.86 steps/s | 233s +step 1000/7000 | loss 3.3194 | lr_mul 1.0000 | mom 0.950 | 3.87 steps/s | 258s +step 1100/7000 | loss 3.3070 | lr_mul 1.0000 | mom 0.950 | 3.87 steps/s | 284s +step 1200/7000 | loss 3.2888 | lr_mul 1.0000 | mom 0.950 | 3.87 steps/s | 310s +step 1300/7000 | loss 3.2922 | lr_mul 1.0000 | mom 0.950 | 3.87 steps/s | 336s +step 1400/7000 | loss 3.2816 | lr_mul 1.0000 | mom 0.950 | 3.87 steps/s | 361s +step 1500/7000 | loss 3.2357 | lr_mul 1.0000 | mom 0.950 | 3.88 steps/s | 386s +step 1600/7000 | loss 3.2383 | lr_mul 1.0000 | mom 0.950 | 3.88 steps/s | 412s +step 1700/7000 | loss 3.2317 | lr_mul 1.0000 | mom 0.950 | 3.88 steps/s | 438s +step 1800/7000 | loss 3.2304 | lr_mul 1.0000 | mom 0.950 | 3.88 steps/s | 464s +step 1900/7000 | loss 3.2278 | lr_mul 1.0000 | mom 0.950 | 3.89 steps/s | 489s +step 2000/7000 | loss 3.2254 | lr_mul 1.0000 | mom 0.950 | 3.89 steps/s | 515s +step 2100/7000 | loss 3.2203 | lr_mul 1.0000 | mom 0.950 | 3.88 steps/s | 541s +step 2200/7000 | loss 3.1922 | lr_mul 1.0000 | mom 0.950 | 3.88 steps/s | 566s +step 2300/7000 | loss 3.1991 | lr_mul 1.0000 | mom 0.950 | 3.88 steps/s | 592s +Wallclock limit reached (600s), will stop after this step +Stopping early at step 2333 (wallclock limit) + +Training complete in 600s +Peak memory: 41126 MiB + +=== Applying EMA weights === +EMA BPB (no XSA): 1.001464 +Saved raw EMA model + +=== Quantizing to int6 + zstd-22 === +Artifact: 15,524,066 bytes (14.80 MB) + +=== Roundtrip Validation (quantized model) === +Quantized BPB (no XSA): 1.019863 +Quantization degradation: +0.018399 + +================================================================================ +FINAL RESULTS — K_KVShare_Wider seed=999 + Training: 7000 steps, 600s + EMA BPB (fp32): 1.001464 + Quantized BPB: 1.019863 + Artifact size: 15,524,066 bytes +================================================================================ +final_int6_roundtrip_exact val_loss:3.10081279 val_bpb:1.01986278 + +=== Legal Score-First TTT (GDN) === +ttt_gdn:start chunks=1238 chunk_tokens=32768 total_windows=633440 stride=64 ttt_lr=0.005 ttt_epochs=3 freeze_blocks=2 +ttt_gdn:params unfrozen=28100257 frozen=5809568 + ttt_chunk [1/1238] bpb=1.036362 time=0.3s + ttt_chunk [11/1238] bpb=1.005135 time=2.5s + ttt_chunk [21/1238] bpb=1.036218 time=4.7s + ttt_chunk [31/1238] bpb=1.033902 time=7.0s + ttt_chunk [41/1238] bpb=1.029392 time=9.2s + ttt_chunk [51/1238] bpb=1.024958 time=11.4s + ttt_chunk [61/1238] bpb=1.019176 time=13.6s + ttt_chunk [71/1238] bpb=1.024672 time=15.9s + ttt_chunk [81/1238] bpb=1.020096 time=18.1s + ttt_chunk [91/1238] bpb=1.017387 time=20.3s + ttt_chunk [101/1238] bpb=1.016593 time=22.5s + ttt_chunk [111/1238] bpb=1.015785 time=24.7s + ttt_chunk [121/1238] bpb=1.018006 time=27.0s + ttt_chunk [131/1238] bpb=1.020981 time=29.2s + ttt_chunk [141/1238] bpb=1.021111 time=31.4s + ttt_chunk [151/1238] bpb=1.020982 time=33.6s + ttt_chunk [161/1238] bpb=1.021390 time=35.8s + ttt_chunk [171/1238] bpb=1.021326 time=38.1s + ttt_chunk [181/1238] bpb=1.020266 time=40.3s + ttt_chunk [191/1238] bpb=1.019450 time=42.5s + ttt_chunk [201/1238] bpb=1.017236 time=44.7s + ttt_chunk [211/1238] bpb=1.020594 time=47.0s + ttt_chunk [221/1238] bpb=1.020046 time=49.2s + ttt_chunk [231/1238] bpb=1.021409 time=51.4s + ttt_chunk [241/1238] bpb=1.020740 time=53.6s + ttt_chunk [251/1238] bpb=1.020676 time=55.8s + ttt_chunk [261/1238] bpb=1.020974 time=58.1s + ttt_chunk [271/1238] bpb=1.021166 time=60.3s + ttt_chunk [281/1238] bpb=1.020204 time=62.5s + ttt_chunk [291/1238] bpb=1.020790 time=64.7s + ttt_chunk [301/1238] bpb=1.020672 time=66.9s + ttt_chunk [311/1238] bpb=1.019243 time=69.2s + ttt_chunk [321/1238] bpb=1.019116 time=71.4s + ttt_chunk [331/1238] bpb=1.019040 time=73.6s + ttt_chunk [341/1238] bpb=1.018229 time=75.8s + ttt_chunk [351/1238] bpb=1.018871 time=78.0s + ttt_chunk [361/1238] bpb=1.017753 time=80.3s + ttt_chunk [371/1238] bpb=1.016320 time=82.5s + ttt_chunk [381/1238] bpb=1.016240 time=84.7s + ttt_chunk [391/1238] bpb=1.015657 time=86.9s + ttt_chunk [401/1238] bpb=1.015373 time=89.1s + ttt_chunk [411/1238] bpb=1.015672 time=91.4s + ttt_chunk [421/1238] bpb=1.015077 time=93.6s + ttt_chunk [431/1238] bpb=1.015040 time=95.8s + ttt_chunk [441/1238] bpb=1.015070 time=98.0s + ttt_chunk [451/1238] bpb=1.016010 time=100.2s + ttt_chunk [461/1238] bpb=1.014568 time=102.5s + ttt_chunk [471/1238] bpb=1.014500 time=104.7s + ttt_chunk [481/1238] bpb=1.014533 time=106.9s + ttt_chunk [491/1238] bpb=1.014815 time=109.1s + ttt_chunk [501/1238] bpb=1.014372 time=111.3s + ttt_chunk [511/1238] bpb=1.014160 time=113.6s + ttt_chunk [521/1238] bpb=1.013970 time=115.8s + ttt_chunk [531/1238] bpb=1.013986 time=118.0s + ttt_chunk [541/1238] bpb=1.014090 time=120.2s + ttt_chunk [551/1238] bpb=1.013832 time=122.4s + ttt_chunk [561/1238] bpb=1.013488 time=124.6s + ttt_chunk [571/1238] bpb=1.012941 time=126.9s + ttt_chunk [581/1238] bpb=1.013102 time=129.1s + ttt_chunk [591/1238] bpb=1.013289 time=131.3s + ttt_chunk [601/1238] bpb=1.013350 time=133.5s + ttt_chunk [611/1238] bpb=1.013838 time=135.8s + ttt_chunk [621/1238] bpb=1.014490 time=138.0s + ttt_chunk [631/1238] bpb=1.014397 time=140.2s + ttt_chunk [641/1238] bpb=1.014546 time=142.4s + ttt_chunk [651/1238] bpb=1.014782 time=144.6s + ttt_chunk [661/1238] bpb=1.014096 time=146.9s + ttt_chunk [671/1238] bpb=1.013732 time=149.1s + ttt_chunk [681/1238] bpb=1.014785 time=151.3s + ttt_chunk [691/1238] bpb=1.014629 time=153.5s + ttt_chunk [701/1238] bpb=1.014317 time=155.7s + ttt_chunk [711/1238] bpb=1.014773 time=158.0s + ttt_chunk [721/1238] bpb=1.014972 time=160.2s + ttt_chunk [731/1238] bpb=1.014384 time=162.4s + ttt_chunk [741/1238] bpb=1.014238 time=164.6s + ttt_chunk [751/1238] bpb=1.013423 time=166.9s + ttt_chunk [761/1238] bpb=1.012689 time=169.1s + ttt_chunk [771/1238] bpb=1.011828 time=171.3s + ttt_chunk [781/1238] bpb=1.011687 time=173.5s + ttt_chunk [791/1238] bpb=1.011931 time=175.7s + ttt_chunk [801/1238] bpb=1.011893 time=178.0s + ttt_chunk [811/1238] bpb=1.011360 time=180.2s + ttt_chunk [821/1238] bpb=1.010525 time=182.4s + ttt_chunk [831/1238] bpb=1.010334 time=184.6s + ttt_chunk [841/1238] bpb=1.009991 time=186.8s + ttt_chunk [851/1238] bpb=1.009671 time=189.0s + ttt_chunk [861/1238] bpb=1.009008 time=191.3s + ttt_chunk [871/1238] bpb=1.008763 time=193.5s + ttt_chunk [881/1238] bpb=1.008406 time=195.7s + ttt_chunk [891/1238] bpb=1.007928 time=197.9s + ttt_chunk [901/1238] bpb=1.007530 time=200.1s + ttt_chunk [911/1238] bpb=1.007394 time=202.3s + ttt_chunk [921/1238] bpb=1.007715 time=204.5s + ttt_chunk [931/1238] bpb=1.008385 time=206.7s + ttt_chunk [941/1238] bpb=1.008748 time=209.0s + ttt_chunk [951/1238] bpb=1.008698 time=211.2s + ttt_chunk [961/1238] bpb=1.009308 time=213.4s + ttt_chunk [971/1238] bpb=1.009328 time=215.6s + ttt_chunk [981/1238] bpb=1.009706 time=217.8s + ttt_chunk [991/1238] bpb=1.009542 time=220.0s + ttt_chunk [1001/1238] bpb=1.009791 time=222.2s + ttt_chunk [1011/1238] bpb=1.010104 time=224.5s + ttt_chunk [1021/1238] bpb=1.010661 time=226.7s + ttt_chunk [1031/1238] bpb=1.011141 time=228.9s + ttt_chunk [1041/1238] bpb=1.011344 time=231.1s + ttt_chunk [1051/1238] bpb=1.011212 time=233.3s + ttt_chunk [1061/1238] bpb=1.011308 time=235.5s + ttt_chunk [1071/1238] bpb=1.011365 time=237.7s + ttt_chunk [1081/1238] bpb=1.011250 time=240.0s + ttt_chunk [1091/1238] bpb=1.011353 time=242.2s + ttt_chunk [1101/1238] bpb=1.011750 time=244.4s + ttt_chunk [1111/1238] bpb=1.011970 time=246.6s + ttt_chunk [1121/1238] bpb=1.012126 time=248.8s + ttt_chunk [1131/1238] bpb=1.011763 time=251.1s + ttt_chunk [1141/1238] bpb=1.011461 time=253.3s + ttt_chunk [1151/1238] bpb=1.011522 time=255.5s + ttt_chunk [1161/1238] bpb=1.011662 time=257.7s + ttt_chunk [1171/1238] bpb=1.011446 time=259.9s + ttt_chunk [1181/1238] bpb=1.011074 time=262.2s + ttt_chunk [1191/1238] bpb=1.011168 time=264.4s + ttt_chunk [1201/1238] bpb=1.011388 time=266.6s + ttt_chunk [1211/1238] bpb=1.011198 time=268.8s + ttt_chunk [1221/1238] bpb=1.010771 time=271.0s + ttt_chunk [1231/1238] bpb=1.010472 time=273.3s + ttt_chunk [1238/1238] bpb=1.010459 time=274.6s +ttt_gdn:done val_loss=3.072542 val_bpb=1.010565 elapsed=274.6s +TTT BPB: 1.010565 (delta: -0.009298) +TTT eval time: 275.0s +final_int6_ttt_exact val_loss:3.07254215 val_bpb:1.01056452