From 824a1d33f830c3687db59c4fe31b0c750cc8c4e7 Mon Sep 17 00:00:00 2001 From: Lei Zhang Date: Thu, 16 Apr 2026 22:59:16 -0700 Subject: [PATCH 1/3] Submit the faithful K_KVShare_Wider FLA package as one records-only branch This branch lifts the validated review package onto a clean upstream/main base so the official submission diff stays to one records folder and one commit. The package keeps the faithful multi-file surface because the packed single-file experiments drifted, while a direct smoke on the current multi-file surface matched the measured candidate within noise. Constraint: The submission branch must contain only records/ files and must keep the exact measured candidate surface. Rejected: Reuse the existing fork review branch as-is | it carries many exploratory commits and is noisier than a clean submit branch Rejected: Promote the packed single-file variant | it was not fidelity-cleared for this candidate Confidence: high Scope-risk: narrow Reversibility: clean Directive: If packaging changes again, rerun at least one packaged smoke before treating the branch as submission-ready Tested: py_compile on packaged Python files; exact folder-size audit (15,991,282 bytes total); packaged multi-file smoke on PR-head surface at 1.03971272 BPB Not-tested: Re-running the full 3-seed sweep on this rebased records-only branch (package contents unchanged) --- .../2026-04-16_KKVShareWider_FLA/README.md | 56 + .../architectures.py | 709 ++++++++++ .../2026-04-16_KKVShareWider_FLA/configs.py | 316 +++++ .../requirements.txt | 10 + .../submission.json | 44 + .../train_gdn_7k.py | 1172 +++++++++++++++++ .../2026-04-16_KKVShareWider_FLA/train_gpt.py | 87 ++ .../train_seed1337.log | 77 ++ .../train_seed2025.log | 76 ++ .../train_seed42.log | 77 ++ 10 files changed, 2624 insertions(+) create mode 100644 records/track_10min_16mb/2026-04-16_KKVShareWider_FLA/README.md create mode 100644 records/track_10min_16mb/2026-04-16_KKVShareWider_FLA/architectures.py create mode 100644 records/track_10min_16mb/2026-04-16_KKVShareWider_FLA/configs.py create mode 100644 records/track_10min_16mb/2026-04-16_KKVShareWider_FLA/requirements.txt create mode 100644 records/track_10min_16mb/2026-04-16_KKVShareWider_FLA/submission.json create mode 100644 records/track_10min_16mb/2026-04-16_KKVShareWider_FLA/train_gdn_7k.py create mode 100644 records/track_10min_16mb/2026-04-16_KKVShareWider_FLA/train_gpt.py create mode 100644 records/track_10min_16mb/2026-04-16_KKVShareWider_FLA/train_seed1337.log create mode 100644 records/track_10min_16mb/2026-04-16_KKVShareWider_FLA/train_seed2025.log create mode 100644 records/track_10min_16mb/2026-04-16_KKVShareWider_FLA/train_seed42.log diff --git a/records/track_10min_16mb/2026-04-16_KKVShareWider_FLA/README.md b/records/track_10min_16mb/2026-04-16_KKVShareWider_FLA/README.md new file mode 100644 index 0000000000..8b9529ac96 --- /dev/null +++ b/records/track_10min_16mb/2026-04-16_KKVShareWider_FLA/README.md @@ -0,0 +1,56 @@ +# Record: K_KVShare_Wider full-recipe FLA + +**val_bpb: 1.0409** (3-seed mean, std 0.0011) | **3.1648 nats** | **8xH100 SXM, 600s** | **No TTT** + +FLA / GatedDeltaNet candidate using `K_KVShare_Wider` on a fuller +upstream-style recipe. Nearest prior family reference: PR `#1370`. +The packaged script avoids runtime dependency downloads from `train_gpt.py`. + +## Results + +| Seed | Steps | Post-EMA BPB | **Quantized BPB** | val_loss (nats) | Artifact | +|------|------:|-------------:|------------------:|----------------:|---------:| +| 1337 | 1652 | 1.020660 | **1.03967403** | 3.16104735 | 15,762,406 | +| 42 | 1652 | 1.022042 | **1.04153708** | 3.16671180 | 15,870,797 | +| 2025 | 1583 | 1.023994 | **1.04148177** | 3.16654364 | 15,648,800 | +| **Mean** | **1629** | **1.022232** | **1.04089763** | **3.16476760** | **15,760,668** | + +## Technique + +- FLA / GatedDeltaNet family (`K_KVShare_Wider`) +- KV sharing is used to buy width rather than depth +- fuller upstream-style recipe +- EMA + SWA + late QAT + int6 artifact path +- final scored line in all logs is `final_int6_roundtrip_exact` + +Not used: +- no TTT +- no SLOT +- no n-gram overlay +- no SWA/XSA final scoring path (`K_KVShare_Wider` has `num_swa_layers = 0`) + +## Compliance Notes + +- train uses `train_files`; scoring uses `val_files` +- no eval-time adaptation +- `train_gpt.py` does not download dependencies during evaluation +- dependencies are installed beforehand via `requirements.txt` +- max artifact bytes across reported seeds: `15,870,797` +- full packaged-folder audit remains under `16,000,000` bytes + +## Reproducibility + +Install dependencies before evaluation: + +```bash +pip install -r requirements.txt +``` + +Prepare the SP8192 cached dataset/tokenizer as usual, then run one seed with: + +```bash +SEED=$SEED ARCH_MODE=K MAX_WALLCLOCK_SECONDS=600 VAL_LOSS_EVERY=0 EVAL_COMPILE_ENABLED=0 torchrun --standalone --nproc_per_node=8 train_gpt.py +``` + +`EVAL_COMPILE_ENABLED=0` is an operational stability setting for final-eval +robustness; it does not change the model family or scored path. diff --git a/records/track_10min_16mb/2026-04-16_KKVShareWider_FLA/architectures.py b/records/track_10min_16mb/2026-04-16_KKVShareWider_FLA/architectures.py new file mode 100644 index 0000000000..dff180156c --- /dev/null +++ b/records/track_10min_16mb/2026-04-16_KKVShareWider_FLA/architectures.py @@ -0,0 +1,709 @@ +"""GDN Hybrid Architecture — modular blocks using FLA native layers. + +Supports 8 model variants (A-H) for the Parameter Golf screening experiments. +Each model is a stack of mixed {GDN, DeltaProduct, RWKV-7, Mamba-2, SWA} blocks +with shared MLP, RMSNorm, and residual connections. + +Key design choices: +- FLA layers handle recurrent attention (GatedDeltaNet, GatedDeltaProduct, RWKV7, Mamba2) +- Sliding Window Attention (SWA) uses flash attention with a causal window mask +- All blocks follow the same pre-norm residual pattern for uniform gradient flow +- Weight sharing for SWA layers in Zamba/Hymba-style models +- Score-first eval: XSA-all only extends attention layers (no future context leakage) +""" +from __future__ import annotations +import math +import os +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor + +# ─── FLA backend selection ────────────────────────────────────────────────── +# Set FLA_USE_NAIVE=1 to force pure-PyTorch (naive) kernels instead of Triton. +# This is needed when: +# - Running on V100 (sm_70) which doesn't support FLA's Triton kernels well +# - Triton cache is corrupted (FileNotFoundError on .json files) +# - Debugging without Triton dependency +# +# On A100 (sm_80+), the Triton kernels are ~3-10x faster and should be used. +_USE_NAIVE = os.environ.get("FLA_USE_NAIVE", "0") == "1" + +if _USE_NAIVE: + # 1. Patch GatedDeltaNet's chunk op + import fla.ops.gated_delta_rule.chunk as _gdr_chunk + import fla.ops.gated_delta_rule.naive as _gdr_naive + + def _patched_chunk_gated_delta_rule( + q, k, v, g, beta, scale=None, initial_state=None, + output_final_state=False, use_qk_l2norm_in_kernel=False, **kwargs + ): + if use_qk_l2norm_in_kernel: + q = F.normalize(q, p=2, dim=-1) + k = F.normalize(k, p=2, dim=-1) + return _gdr_naive.naive_chunk_gated_delta_rule( + q, k, v, g, beta, + chunk_size=64, scale=scale, + initial_state=initial_state, + output_final_state=output_final_state, + ) + + _gdr_chunk.chunk_gated_delta_rule = _patched_chunk_gated_delta_rule + import fla.layers.gated_deltanet as _gdn_layer + _gdn_layer.chunk_gated_delta_rule = _patched_chunk_gated_delta_rule + + # 2. Patch GatedDeltaProduct's chunk op + import fla.ops.gated_delta_product.chunk as _gdp_chunk + import fla.ops.gated_delta_product.naive as _gdp_naive + + def _patched_chunk_gated_delta_product( + q, k, v, g, beta, num_householder=1, scale=None, initial_state=None, + output_final_state=False, use_qk_l2norm_in_kernel=False, **kwargs + ): + if use_qk_l2norm_in_kernel: + q = F.normalize(q, p=2, dim=-1) + k = F.normalize(k, p=2, dim=-1) + return _gdp_naive.naive_recurrent_gated_delta_product( + q, k, v, g, beta, + scale=scale, cu_seqlens=None, + initial_state=initial_state, + output_final_state=output_final_state, + num_householder=num_householder, + ) + + _gdp_chunk.chunk_gated_delta_product = _patched_chunk_gated_delta_product + import fla.layers.gated_deltaproduct as _gdp_layer + _gdp_layer.chunk_gated_delta_product = _patched_chunk_gated_delta_product + + print("[FLA] Using NAIVE (pure-PyTorch) kernels — set FLA_USE_NAIVE=0 for Triton", flush=True) + +# FLA imports +from fla.layers import GatedDeltaNet, GatedDeltaProduct, Mamba2 +try: + from fla.layers import RWKV7Attention +except Exception: + RWKV7Attention = None # type: ignore + +try: + from flash_attn_interface import flash_attn_func as flash_attn_3_func +except ImportError: + def flash_attn_3_func(q, k, v, causal=False, window_size=(-1, -1)): + q2 = q.transpose(1, 2) + k2 = k.transpose(1, 2) + v2 = v.transpose(1, 2) + if k2.size(1) != q2.size(1): + rep = q2.size(1) // k2.size(1) + k2 = k2.repeat_interleave(rep, dim=1) + v2 = v2.repeat_interleave(rep, dim=1) + out = torch.nn.functional.scaled_dot_product_attention(q2, k2, v2, is_causal=causal) + return out.transpose(1, 2) + + +class RMSNorm(nn.Module): + def __init__(self, dim: int | None = None, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.dim = dim + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + """Linear layer that casts input to weight dtype for mixed precision. + Supports late QAT (int6 STE) when _qat_enabled is set.""" + _qat_enabled: bool = False + + def forward(self, x: Tensor) -> Tensor: + w = self.weight.to(dtype=x.dtype) + if CastedLinear._qat_enabled and self.training and w.ndim == 2: + with torch.no_grad(): + w32 = self.weight.float() + row_max = w32.abs().amax(dim=1) + scale = (row_max / 31.0).clamp_min(1.0 / 31.0) + w_q = (torch.clamp(torch.round(w32 / scale[:, None]), -31, 31) * scale[:, None]).to(x.dtype) + w = w + (w_q - w).detach() # STE: forward uses quantized, backward uses full + bias = self.bias.to(dtype=x.dtype) if self.bias is not None else None + return F.linear(x, w, bias) + + +class Rotary(nn.Module): + """RoPE embeddings for sliding window attention.""" + def __init__(self, dim: int, base: float = 10000.0, max_len: int = 4096): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.max_len = max_len + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype): + t = torch.arange(seq_len, device=device, dtype=torch.float32) + freqs = torch.outer(t, self.inv_freq.to(device)) + cos = freqs.cos().to(dtype) + sin = freqs.sin().to(dtype) + return cos, sin + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + """Apply RoPE to the input tensor.""" + d = x.shape[-1] // 2 + x1, x2 = x[..., :d], x[..., d:] + out1 = x1 * cos[:x.shape[-2]] - x2 * sin[:x.shape[-2]] + out2 = x2 * cos[:x.shape[-2]] + x1 * sin[:x.shape[-2]] + return torch.cat([out1, out2], dim=-1) + + +class MLP(nn.Module): + """Feed-forward MLP with configurable activation.""" + def __init__(self, dim: int, mult: float = 3.0, act: str = "relu_sq", leaky_slope: float = 0.5): + super().__init__() + hidden = int(mult * dim) + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + nn.init.zeros_(self.proj.weight) # zero-init output for residual + self.act = act + self.leaky_slope = leaky_slope + + def forward(self, x: Tensor) -> Tensor: + x = self.fc(x) + if self.act == "leaky_relu_sq": + x = F.leaky_relu(x, negative_slope=self.leaky_slope) + else: + x = F.relu(x) + return self.proj(x.square()) + + +class SlidingWindowAttention(nn.Module): + """Sliding window causal attention for hybrid models. + + Supports XSA (cross-segment attention) at eval time for extending context + across eval chunks. Window is enforced during training but can be relaxed at eval. + KV can be shared across layers (Zamba-style) by reusing the same module. + """ + def __init__( + self, + dim: int, + num_heads: int = 8, + num_kv_heads: int = 4, + window_size: int = 512, + rope_base: float = 10000.0, + qk_gain_init: float = 1.5, + ): + super().__init__() + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + kv_dim = num_kv_heads * self.head_dim + self.window_size = window_size + + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + nn.init.zeros_(self.proj.weight) + + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + self.use_xsa = False # enabled at eval time for XSA-all + + def _xsa_efficient(self, y: Tensor, v: Tensor) -> Tensor: + """XSA: subtract self-value projection (GQA-aware).""" + B, T, H, D = y.shape + Hkv = v.size(-2) + group = H // Hkv + y_g = y.reshape(B, T, Hkv, group, D) + vn = F.normalize(v, dim=-1).unsqueeze(-2) + proj = (y_g * vn).sum(dim=-1, keepdim=True) * vn + return (y_g - proj).reshape(B, T, H, D) + + def forward(self, x: Tensor, v_embed: Tensor | None = None) -> Tensor: + B, T, D = x.shape + q = self.c_q(x).reshape(B, T, self.num_heads, self.head_dim) + k = self.c_k(x).reshape(B, T, self.num_kv_heads, self.head_dim) + v = self.c_v(x) + if v_embed is not None: + v = v + v_embed + v = v.reshape(B, T, self.num_kv_heads, self.head_dim) + + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(T, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, None, :, None] + + if q.is_cuda and q.dtype not in (torch.float16, torch.bfloat16): + q, k, v = q.to(torch.bfloat16), k.to(torch.bfloat16), v.to(torch.bfloat16) + + # Use window during training, full causal at eval if XSA enabled + y = flash_attn_3_func(q, k, v, causal=True) + + if self.use_xsa: + y = self._xsa_efficient(y, v) + + y = y.reshape(B, T, D) + return self.proj(y) + + +class RecurrentBlock(nn.Module): + """Wraps any FLA recurrent layer (GDN, DeltaProduct, RWKV-7, Mamba-2) with + pre-norm residual connection and MLP.""" + + def __init__( + self, + dim: int, + recurrent_layer: nn.Module, + mlp_mult: float = 3.0, + mlp_act: str = "relu_sq", + layer_idx: int = 0, + ): + super().__init__() + self.attn_norm = RMSNorm(dim) + self.mlp_norm = RMSNorm(dim) + self.recurrent = recurrent_layer + self.mlp = MLP(dim, mlp_mult, act=mlp_act) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + self.layer_idx = layer_idx + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x_in = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + + # FLA layers return (output, state) or just output depending on mode + recurrent_out = self.recurrent(self.attn_norm(x_in)) + if isinstance(recurrent_out, tuple): + recurrent_out = recurrent_out[0] + + x_out = x_in + self.attn_scale.to(dtype=x_in.dtype)[None, None, :] * recurrent_out + x_out = x_out + self.mlp_scale.to(dtype=x_out.dtype)[None, None, :] * self.mlp(self.mlp_norm(x_out)) + return x_out + + +class AttentionBlock(nn.Module): + """SWA block with pre-norm residual and MLP.""" + + def __init__( + self, + dim: int, + swa: SlidingWindowAttention, + mlp_mult: float = 3.0, + mlp_act: str = "relu_sq", + layer_idx: int = 0, + ): + super().__init__() + self.attn_norm = RMSNorm(dim) + self.mlp_norm = RMSNorm(dim) + self.attn = swa + self.mlp = MLP(dim, mlp_mult, act=mlp_act) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + self.layer_idx = layer_idx + + def forward(self, x: Tensor, x0: Tensor, v_embed: Tensor | None = None) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x_in = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x_in), v_embed=v_embed) + x_out = x_in + self.attn_scale.to(dtype=x_in.dtype)[None, None, :] * attn_out + x_out = x_out + self.mlp_scale.to(dtype=x_out.dtype)[None, None, :] * self.mlp(self.mlp_norm(x_out)) + return x_out + + +class SmearGate(nn.Module): + """Weighted average of current and previous token embeddings.""" + def __init__(self, dim: int): + super().__init__() + self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + + def forward(self, x: Tensor) -> Tensor: + g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] + x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) + return (1 - g) * x + g * x_prev + + +class BigramHashEmbedding(nn.Module): + """Hash-based bigram/trigram embedding for additional context.""" + def __init__(self, bigram_vocab_size: int, bigram_dim: int, model_dim: int, + trigram: bool = False): + super().__init__() + self.bigram_vocab_size = bigram_vocab_size + self._trigram = trigram + self.embed = nn.Embedding(bigram_vocab_size, bigram_dim) + nn.init.zeros_(self.embed.weight) + self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) + + def bigram_hash(self, tokens: Tensor) -> Tensor: + t = tokens.to(torch.int32) + mod = self.bigram_vocab_size - 1 + out = torch.empty_like(t) + out[..., 0] = mod + out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod + return out.long() + + def trigram_hash(self, tokens: Tensor) -> Tensor: + t = tokens.to(torch.int32) + mod = self.bigram_vocab_size - 1 + out = torch.empty_like(t) + out[..., :2] = mod + out[..., 2:] = (36313 * t[..., 2:] ^ 27191 * t[..., 1:-1] ^ 51497 * t[..., :-2]) % mod + return out.long() + + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(self.bigram_hash(token_ids)) + if self._trigram: + h = h + self.embed(self.trigram_hash(token_ids)) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) + + +def _parse_layout(layout_str: str) -> list[tuple[str, int]]: + """Parse a layout string into a sequence of (layer_type, count) pairs. + + Examples: + "gdn_only" -> [("gdn", 11)] (count filled in by caller) + "gdn5_swa_gdn5_swa_shared" -> [("gdn", 5), ("swa", 1), ("gdn", 5), ("swa_shared", 1)] + "gdn3_swa_gdn3_swa_shared_gdn3" -> [("gdn", 3), ("swa", 1), ("gdn", 3), ("swa_shared", 1), ("gdn", 3)] + "mamba_only" -> [("mamba", 11)] + "gdn3_mamba2_swa_gdn3_mamba2" -> [("gdn", 3), ("mamba", 2), ("swa", 1), ("gdn", 3), ("mamba", 2)] + """ + if layout_str == "gdn_only": + return [("gdn", -1)] # -1 = use num_gdn_layers + if layout_str == "mamba_only": + return [("mamba", -1)] # -1 = use num_mamba_layers + if layout_str == "swa_only": + return [("swa", -1)] # -1 = use num_swa_layers + + # Parse custom layouts like "gdn5_swa_gdn5_swa_shared" + parts = layout_str.split("_") + result = [] + i = 0 + while i < len(parts): + part = parts[i] + if part.startswith("gdn") and len(part) > 3: + count = int(part[3:]) + result.append(("gdn", count)) + elif part.startswith("mamba") and len(part) > 5: + count = int(part[5:]) + result.append(("mamba", count)) + elif part == "swa": + # Check if next token is "shared" + if i + 1 < len(parts) and parts[i + 1] == "shared": + result.append(("swa_shared", 1)) + i += 1 + else: + result.append(("swa", 1)) + elif part == "shared": + # Already consumed by swa check above + pass + i += 1 + return result + + +class HybridGDN(nn.Module): + """Hybrid GDN architecture supporting mixed recurrent/attention layers. + + Builds a stack of blocks according to the layer_layout specification: + - "gdn" blocks use GatedDeltaNet (or GatedDeltaProduct, or RWKV-7) + - "mamba" blocks use Mamba-2 + - "swa" blocks use SlidingWindowAttention + - "swa_shared" reuses the same SWA module (Zamba-style weight sharing) + + All models share: token embedding, bigram hash, smear gate, final norm, lm_head. + """ + def __init__(self, config: dict, vocab_size: int = 1024): + super().__init__() + dim = config["model_dim"] + num_heads = config["num_heads"] + mlp_mult = config["mlp_mult"] + self.arch_name = config["arch_name"] + self.model_dim = dim + self.vocab_size = vocab_size + self.logit_softcap = 30.0 + + # Embeddings + self.tok_emb = nn.Embedding(vocab_size, dim) + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=0.005) + self.bigram = BigramHashEmbedding( + config.get("bigram_vocab_size", 2048), + config.get("bigram_dim", 128), + dim, + trigram=config.get("trigram", False), + ) + self.smear = SmearGate(dim) + + # Meta tokens (Hymba-style, for Model E) + n_meta = config.get("meta_tokens", 0) + if n_meta > 0: + self.meta_tokens = nn.Parameter(torch.randn(1, n_meta, dim) * 0.02) + self.n_meta = n_meta + else: + self.meta_tokens = None + self.n_meta = 0 + + # Build layer stack + layout = _parse_layout(config["layer_layout"]) + self.blocks = nn.ModuleList() + self._block_types = [] # track type for XSA/diagnostics + self._shared_swa = None # shared SWA module for Zamba/Hymba models + + layer_idx = 0 + for layer_type, count in layout: + if count == -1: + # Fill with the specified layer type + if layer_type == "gdn": + count = config["num_gdn_layers"] + elif layer_type == "mamba": + count = config["num_mamba_layers"] + elif layer_type == "swa": + count = config["num_swa_layers"] + + for _ in range(count): + if layer_type == "gdn": + recurrent = self._make_recurrent_layer(config, layer_idx) + block = RecurrentBlock(dim, recurrent, mlp_mult, layer_idx=layer_idx) + self.blocks.append(block) + self._block_types.append("gdn") + + elif layer_type == "mamba": + mamba_expand = config.get("mamba_expand", 2) + mamba_head_dim = config.get("gdn_head_dim", 64) + mamba_num_heads = (dim * mamba_expand) // mamba_head_dim + mamba = Mamba2( + num_heads=mamba_num_heads, + head_dim=mamba_head_dim, + hidden_size=dim, + state_size=config.get("mamba_state_size", 64), + expand=mamba_expand, + layer_idx=layer_idx, + ) + block = RecurrentBlock(dim, mamba, mlp_mult, layer_idx=layer_idx) + self.blocks.append(block) + self._block_types.append("mamba") + + elif layer_type in ("swa", "swa_shared"): + if layer_type == "swa_shared" and self._shared_swa is not None: + swa = self._shared_swa # reuse same SWA module + else: + swa = SlidingWindowAttention( + dim=dim, + num_heads=num_heads, + num_kv_heads=config.get("swa_num_kv_heads", 4), + window_size=config.get("swa_window", 512), + ) + if config.get("swa_shared", False): + self._shared_swa = swa + + # Each SWA position gets its own MLP even if SWA weights are shared + block = AttentionBlock(dim, swa, mlp_mult, layer_idx=layer_idx) + self.blocks.append(block) + self._block_types.append("swa" if layer_type == "swa" else "swa_shared") + + layer_idx += 1 + + # KV sharing: share k/v projections between adjacent layers + kv_stride = config.get("kv_sharing_stride", 0) + if kv_stride > 0: + self._apply_kv_sharing(kv_stride) + + self.final_norm = RMSNorm(dim) + # Tied embeddings (standard for parameter golf) + self.lm_head = None # use tok_emb.weight + self._init_weights() + + def _make_recurrent_layer(self, config: dict, layer_idx: int) -> nn.Module: + """Create the appropriate recurrent layer based on config.""" + dim = config["model_dim"] + num_heads = config["num_heads"] + + if config.get("use_rwkv7", False): + total_layers = config.get("num_gdn_layers", 11) + return RWKV7Attention( + hidden_size=dim, + head_dim=config.get("gdn_head_dim", 64), + num_heads=num_heads, + layer_idx=layer_idx, + num_hidden_layers=total_layers, + mode="chunk", + ) + elif config.get("use_deltaproduct", False): + return GatedDeltaProduct( + hidden_size=dim, + head_dim=config.get("gdn_head_dim", 64), + num_heads=num_heads, + num_householder=config.get("dp_num_householder", 2), + allow_neg_eigval=config.get("dp_allow_neg_eigval", False), + use_short_conv=config.get("gdn_use_short_conv", True), + expand_v=config.get("gdn_expand_v", 1), + layer_idx=layer_idx, + mode="chunk", + ) + else: + # Default: GatedDeltaNet + return GatedDeltaNet( + hidden_size=dim, + head_dim=config.get("gdn_head_dim", 64), + num_heads=num_heads, + allow_neg_eigval=config.get("gdn_allow_neg_eigval", False), + use_short_conv=config.get("gdn_use_short_conv", True), + expand_v=config.get("gdn_expand_v", 1), + layer_idx=layer_idx, + mode="chunk", + ) + + def _apply_kv_sharing(self, stride: int) -> None: + """Share KV projection modules between adjacent layer groups. + + For GDN layers: shares k_proj, v_proj, k_conv1d, v_conv1d. + For SWA layers: shares c_k, c_v. + The first layer in each group is the anchor; subsequent layers in the + group become followers that reference the anchor's modules. + """ + # Collect indices by block type + gdn_indices = [i for i, t in enumerate(self._block_types) if t == "gdn"] + swa_indices = [i for i, t in enumerate(self._block_types) + if t in ("swa", "swa_shared")] + + # Share GDN KV projections within each stride-group + for group_start in range(0, len(gdn_indices), stride): + anchor_idx = gdn_indices[group_start] + anchor = self.blocks[anchor_idx].recurrent + for j in range(1, stride): + if group_start + j >= len(gdn_indices): + break + follower_idx = gdn_indices[group_start + j] + follower = self.blocks[follower_idx].recurrent + follower.k_proj = anchor.k_proj + follower.v_proj = anchor.v_proj + follower.k_conv1d = anchor.k_conv1d + follower.v_conv1d = anchor.v_conv1d + + # Share SWA KV projections within each stride-group + for group_start in range(0, len(swa_indices), stride): + anchor_idx = swa_indices[group_start] + anchor = self.blocks[anchor_idx].attn + for j in range(1, stride): + if group_start + j >= len(swa_indices): + break + follower_idx = swa_indices[group_start + j] + follower = self.blocks[follower_idx].attn + follower.c_k = anchor.c_k + follower.c_v = anchor.c_v + + def _init_weights(self) -> None: + """Weight initialization. + + Each sub-module handles its own init (MLP zeros proj, SWA zeros proj, + FLA layers do own init). We just do the residual scaling for output + projections on our own CastedLinear layers. + """ + total_layers = len(self.blocks) + for name, p in self.named_parameters(): + # Skip FLA-internal parameters + if ".recurrent." in name: + continue + # Scale down output projections for residual stream + if p.ndim == 2 and "proj" in name and "bigram" not in name: + with torch.no_grad(): + p.mul_(1.0 / math.sqrt(2 * total_layers)) + + def set_xsa(self, enable: bool = True) -> None: + """Enable/disable XSA on all attention blocks.""" + for block, btype in zip(self.blocks, self._block_types): + if btype in ("swa", "swa_shared"): + block.attn.use_xsa = enable + + def _compute_logits(self, x: Tensor) -> Tensor: + """Compute logits with tied embeddings and softcap.""" + logits = F.linear(x, self.tok_emb.weight) + return self.logit_softcap * torch.tanh(logits / self.logit_softcap) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + """Forward pass returning cross-entropy loss.""" + x = self.tok_emb(input_ids) + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + + # Prepend meta tokens if Hymba-style + if self.meta_tokens is not None: + B = x.shape[0] + meta = self.meta_tokens.expand(B, -1, -1).to(dtype=x.dtype) + x = torch.cat([meta, x], dim=1) + x0 = torch.cat([meta, x0], dim=1) + + for block, btype in zip(self.blocks, self._block_types): + if btype in ("swa", "swa_shared"): + x = block(x, x0) + else: + x = block(x, x0) + + # Remove meta tokens before computing logits + if self.meta_tokens is not None: + x = x[:, self.n_meta:] + + x = self.final_norm(x) + logits = self._compute_logits(x.reshape(-1, x.size(-1))) + targets = target_ids.reshape(-1) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + def forward_logits(self, input_ids: Tensor) -> Tensor: + """Forward pass returning logits (for evaluation).""" + x = self.tok_emb(input_ids) + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + + if self.meta_tokens is not None: + B = x.shape[0] + meta = self.meta_tokens.expand(B, -1, -1).to(dtype=x.dtype) + x = torch.cat([meta, x], dim=1) + x0 = torch.cat([meta, x0], dim=1) + + for block, btype in zip(self.blocks, self._block_types): + if btype in ("swa", "swa_shared"): + x = block(x, x0) + else: + x = block(x, x0) + + if self.meta_tokens is not None: + x = x[:, self.n_meta:] + + x = self.final_norm(x) + return self._compute_logits(x) + + def get_diagnostics(self) -> dict: + """Collect per-layer weight statistics for checkpoint diagnostics.""" + diag = {} + for i, (block, btype) in enumerate(zip(self.blocks, self._block_types)): + prefix = f"layer_{i}_{btype}" + for name, param in block.named_parameters(): + if param.ndim >= 2: + w = param.data.float() + diag[f"{prefix}/{name}/std"] = w.std().item() + diag[f"{prefix}/{name}/kurtosis"] = (((w - w.mean()) / (w.std() + 1e-8)) ** 4).mean().item() - 3.0 + return diag + + def count_params(self) -> dict: + """Count parameters by category.""" + cats = {"embedding": 0, "recurrent": 0, "attention": 0, "mlp": 0, "other": 0} + for name, p in self.named_parameters(): + n = p.numel() + if "tok_emb" in name or "bigram" in name: + cats["embedding"] += n + elif any(k in name for k in ["recurrent", "gdn", "mamba", "rwkv", "delta"]): + cats["recurrent"] += n + elif "attn" in name or "c_q" in name or "c_k" in name or "c_v" in name: + cats["attention"] += n + elif "mlp" in name or "fc" in name: + cats["mlp"] += n + else: + cats["other"] += n + cats["total"] = sum(cats.values()) + return cats diff --git a/records/track_10min_16mb/2026-04-16_KKVShareWider_FLA/configs.py b/records/track_10min_16mb/2026-04-16_KKVShareWider_FLA/configs.py new file mode 100644 index 0000000000..5bbdac3bd4 --- /dev/null +++ b/records/track_10min_16mb/2026-04-16_KKVShareWider_FLA/configs.py @@ -0,0 +1,316 @@ +"""Model architecture configurations for GDN Hybrid experiments. + +Each config returns a dict consumed by HybridGDN.__init__. +All models are sized to fit ~16MB at int6+zstd-22. + +Models A-H: baseline architecture sweeps. +Models I-K: KV sharing experiments (kv_sharing_stride=2). +""" +from __future__ import annotations + + +def model_a_pure_gdn() -> dict: + """Model A: Pure GDN (Baseline) — 10 layers Gated DeltaNet.""" + return dict( + arch_name="A_PureGDN", + num_gdn_layers=10, + num_mamba_layers=0, + num_swa_layers=0, + swa_shared=False, + model_dim=512, + num_heads=8, + mlp_mult=3.0, + gdn_expand_v=1, + gdn_head_dim=64, + gdn_allow_neg_eigval=False, + gdn_use_short_conv=True, + swa_window=512, + swa_num_kv_heads=4, + meta_tokens=0, + layer_layout="gdn_only", + bigram_vocab_size=3072, + bigram_dim=112, + trigram=True, + ) + + +def model_b_deltaproduct() -> dict: + """Model B: Gated DeltaProduct n_h=2 — rank-2 state transitions.""" + return dict( + arch_name="B_DeltaProduct", + num_gdn_layers=10, # 10 layers to fit param budget + num_mamba_layers=0, + num_swa_layers=0, + swa_shared=False, + model_dim=480, # slightly narrower to fit 16MB + num_heads=8, + mlp_mult=3.0, + use_deltaproduct=True, + dp_num_householder=2, + dp_allow_neg_eigval=False, + gdn_expand_v=1, + gdn_head_dim=64, + gdn_use_short_conv=True, + swa_window=512, + swa_num_kv_heads=4, + meta_tokens=0, + layer_layout="gdn_only", + ) + + +def model_b2_deltaproduct_neg() -> dict: + """Model B2: DeltaProduct + negative eigenvalues.""" + cfg = model_b_deltaproduct() + cfg["arch_name"] = "B2_DeltaProduct_NegEig" + cfg["dp_allow_neg_eigval"] = True + return cfg + + +def model_c_gdn_neg() -> dict: + """Model C: GDN with negative eigenvalues — richer state dynamics. + + (Originally RWKV-7, replaced because RWKV7 requires Triton kernels with + no pure-PyTorch fallback available.) + """ + return dict( + arch_name="C_GDN_NegEig", + num_gdn_layers=11, + num_mamba_layers=0, + num_swa_layers=0, + swa_shared=False, + model_dim=512, + num_heads=8, + mlp_mult=3.0, + gdn_expand_v=1, + gdn_head_dim=64, + gdn_allow_neg_eigval=True, # Key difference: negative eigenvalues + gdn_use_short_conv=True, + swa_window=512, + swa_num_kv_heads=4, + meta_tokens=0, + layer_layout="gdn_only", + ) + + +def model_d_gdn_1swa() -> dict: + """Model D: GDN + 1 Shared SWA (Zamba-style).""" + return dict( + arch_name="D_GDN_1SWA", + num_gdn_layers=10, + num_mamba_layers=0, + num_swa_layers=1, + swa_shared=True, + model_dim=512, + num_heads=8, + mlp_mult=3.0, + gdn_expand_v=1, + gdn_head_dim=64, + gdn_allow_neg_eigval=False, + gdn_use_short_conv=True, + swa_window=512, + swa_num_kv_heads=4, + meta_tokens=0, + # Layout: [GDN×5] → [SWA] → [GDN×5] → [SWA_shared] + layer_layout="gdn5_swa_gdn5_swa_shared", + ) + + +def model_e_gdn_2swa() -> dict: + """Model E: GDN + 2 Shared SWA (Hymba-inspired) with meta-tokens.""" + return dict( + arch_name="E_GDN_2SWA_Hymba", + num_gdn_layers=9, + num_mamba_layers=0, + num_swa_layers=1, # 1 unique, shared at 2 positions + swa_shared=True, + model_dim=512, + num_heads=8, + mlp_mult=3.0, + gdn_expand_v=1, + gdn_head_dim=64, + gdn_allow_neg_eigval=False, + gdn_use_short_conv=True, + swa_window=512, + swa_num_kv_heads=4, + meta_tokens=4, # Hymba-style prepended meta-tokens + # Layout: [GDN×3] → [SWA] → [GDN×3] → [SWA_shared] → [GDN×3] + layer_layout="gdn3_swa_gdn3_swa_shared_gdn3", + ) + + +def model_f_mamba2() -> dict: + """Model F: Mamba-2 Pure (Mamba-3 proxy with RoPE on B/C).""" + return dict( + arch_name="F_Mamba2", + num_gdn_layers=0, + num_mamba_layers=11, + num_swa_layers=0, + swa_shared=False, + model_dim=512, + num_heads=8, + mlp_mult=3.0, + mamba_state_size=64, + mamba_expand=2, + gdn_expand_v=1, + gdn_head_dim=64, + gdn_use_short_conv=True, + swa_window=512, + swa_num_kv_heads=4, + meta_tokens=0, + layer_layout="mamba_only", + ) + + +def model_g_hybrid() -> dict: + """Model G: GDN + Mamba-2 + SWA triple hybrid.""" + return dict( + arch_name="G_GDN_Mamba_SWA", + num_gdn_layers=6, + num_mamba_layers=4, + num_swa_layers=1, + swa_shared=False, + model_dim=512, + num_heads=8, + mlp_mult=3.0, + mamba_state_size=64, + mamba_expand=2, + gdn_expand_v=1, + gdn_head_dim=64, + gdn_allow_neg_eigval=False, + gdn_use_short_conv=True, + swa_window=512, + swa_num_kv_heads=4, + meta_tokens=0, + # Layout: [GDN×3] → [Mamba×2] → [SWA] → [GDN×3] → [Mamba×2] + layer_layout="gdn3_mamba2_swa_gdn3_mamba2", + ) + + +def model_h_pure_swa() -> dict: + """Model H: Pure Sliding Window Attention (standard softmax) — control baseline. + + All 10 layers use causal sliding-window softmax attention (no GDN). + Same MLP, embedding, and normalization as Model A for fair comparison. + """ + return dict( + arch_name="H_PureSWA", + num_gdn_layers=0, + num_mamba_layers=0, + num_swa_layers=10, + swa_shared=False, + model_dim=512, + num_heads=8, + mlp_mult=3.0, + gdn_expand_v=1, + gdn_head_dim=64, + gdn_allow_neg_eigval=False, + gdn_use_short_conv=True, + swa_window=512, + swa_num_kv_heads=4, + meta_tokens=0, + layer_layout="swa_only", + bigram_vocab_size=3072, + bigram_dim=112, + trigram=True, + ) + + +def model_i_kv_share() -> dict: + """Model I: GDN + KV Share — same as A but with kv_sharing_stride=2.""" + return dict( + arch_name="I_KVShare", + num_gdn_layers=10, + num_mamba_layers=0, + num_swa_layers=0, + swa_shared=False, + model_dim=512, + num_heads=8, + mlp_mult=3.0, + gdn_expand_v=1, + gdn_head_dim=64, + gdn_allow_neg_eigval=False, + gdn_use_short_conv=True, + swa_window=512, + swa_num_kv_heads=4, + meta_tokens=0, + layer_layout="gdn_only", + bigram_vocab_size=3072, + bigram_dim=112, + trigram=True, + kv_sharing_stride=2, + ) + + +def model_j_kv_share_deeper() -> dict: + """Model J: GDN + KV Share + Deeper — 12L dim=480, near iso-parameter to A.""" + return dict( + arch_name="J_KVShare_Deeper", + num_gdn_layers=12, + num_mamba_layers=0, + num_swa_layers=0, + swa_shared=False, + model_dim=480, + num_heads=8, + mlp_mult=3.0, + gdn_expand_v=1, + gdn_head_dim=64, + gdn_allow_neg_eigval=False, + gdn_use_short_conv=True, + swa_window=512, + swa_num_kv_heads=4, + meta_tokens=0, + layer_layout="gdn_only", + bigram_vocab_size=3072, + bigram_dim=112, + trigram=True, + kv_sharing_stride=2, + ) + + +def model_k_kv_share_wider() -> dict: + """Model K: GDN + KV Share + Wider — 10L dim=544, iso-parameter to A.""" + return dict( + arch_name="K_KVShare_Wider", + num_gdn_layers=10, + num_mamba_layers=0, + num_swa_layers=0, + swa_shared=False, + model_dim=544, + num_heads=8, + mlp_mult=3.0, + gdn_expand_v=1, + gdn_head_dim=64, + gdn_allow_neg_eigval=False, + gdn_use_short_conv=True, + swa_window=512, + swa_num_kv_heads=4, + meta_tokens=0, + layer_layout="gdn_only", + bigram_vocab_size=3072, + bigram_dim=112, + trigram=True, + kv_sharing_stride=2, + ) + + +ALL_CONFIGS = { + "A": model_a_pure_gdn, + "B": model_b_deltaproduct, + "B2": model_b2_deltaproduct_neg, + "C": model_c_gdn_neg, + "D": model_d_gdn_1swa, + "E": model_e_gdn_2swa, + "F": model_f_mamba2, + "G": model_g_hybrid, + "H": model_h_pure_swa, + "I": model_i_kv_share, + "J": model_j_kv_share_deeper, + "K": model_k_kv_share_wider, +} + + +def get_config(model_id: str) -> dict: + """Get config by model ID (A, B, B2, C, D, E, F, G, H, I, J, K).""" + if model_id not in ALL_CONFIGS: + raise ValueError(f"Unknown model ID '{model_id}'. Choose from {list(ALL_CONFIGS.keys())}") + return ALL_CONFIGS[model_id]() diff --git a/records/track_10min_16mb/2026-04-16_KKVShareWider_FLA/requirements.txt b/records/track_10min_16mb/2026-04-16_KKVShareWider_FLA/requirements.txt new file mode 100644 index 0000000000..3feaed4f64 --- /dev/null +++ b/records/track_10min_16mb/2026-04-16_KKVShareWider_FLA/requirements.txt @@ -0,0 +1,10 @@ +numpy +torch +sentencepiece +zstandard +flash-linear-attention==0.4.2 +fla-core==0.4.2 +triton==3.2.0 +transformers==5.5.4 +tokenizers==0.22.2 +safetensors==0.7.0 diff --git a/records/track_10min_16mb/2026-04-16_KKVShareWider_FLA/submission.json b/records/track_10min_16mb/2026-04-16_KKVShareWider_FLA/submission.json new file mode 100644 index 0000000000..646524704b --- /dev/null +++ b/records/track_10min_16mb/2026-04-16_KKVShareWider_FLA/submission.json @@ -0,0 +1,44 @@ +{ + "author": "resouer", + "github_id": "resouer", + "name": "K_KVShare_Wider full-recipe FLA", + "blurb": "FLA/GatedDeltaNet candidate on K_KVShare_Wider with the fuller upstream-style recipe and no runtime dependency downloads in train_gpt.py.", + "date": "2026-04-16", + "track": "10min_16mb", + "val_loss": 3.1647676, + "val_bpb": 1.04089763, + "val_loss_std": 0.00322293, + "val_bpb_std": 0.00106003, + "seeds": [ + 1337, + 42, + 2025 + ], + "seed_results": { + "1337": { + "val_loss": 3.16104735, + "val_bpb": 1.03967403, + "artifact_bytes": 15762406, + "steps": 1652 + }, + "42": { + "val_loss": 3.1667118, + "val_bpb": 1.04153708, + "artifact_bytes": 15870797, + "steps": 1652 + }, + "2025": { + "val_loss": 3.16654364, + "val_bpb": 1.04148177, + "artifact_bytes": 15648800, + "steps": 1583 + } + }, + "based_on": "PR #1370 family", + "artifact_bytes_mean": 15760667.666666666, + "artifact_bytes_max": 15870797, + "bytes_total": 15991282, + "code_bytes": 91312, + "hardware": "8xH100 80GB SXM", + "technique_summary": "K_KVShare_Wider + fuller upstream-style FLA recipe + EMA + SWA + late QAT + int6 artifact path" +} diff --git a/records/track_10min_16mb/2026-04-16_KKVShareWider_FLA/train_gdn_7k.py b/records/track_10min_16mb/2026-04-16_KKVShareWider_FLA/train_gdn_7k.py new file mode 100644 index 0000000000..eb030c6103 --- /dev/null +++ b/records/track_10min_16mb/2026-04-16_KKVShareWider_FLA/train_gdn_7k.py @@ -0,0 +1,1172 @@ +#!/usr/bin/env python3 +"""GDN Hybrid Full Training Script — 7000 steps with all production features. + +Features beyond Phase 1 screening (train_gdn.py): + - EMA weight averaging (decay 0.997) + - SWA (stochastic weight averaging) during late warmdown + - Late QAT (int6 STE in CastedLinear forward during warmdown) + - Mixed int6/int8 quantization with percentile search + - zstd-22 compression for artifact + - Coprime shard ordering via SHARD_ORDER_FILE + - XSA-all sliding window eval on quantized artifact + - Roundtrip validation (load quantized, eval, report exact BPB) + +Environment variables (key additions vs Phase 1): + EMA_DECAY: EMA decay rate (default 0.997) + SWA_ENABLED: 1|0 (default 1) + SWA_EVERY: SWA collection interval (default 50) + LATE_QAT_THRESHOLD: LR scale below which QAT activates (default 0.15) + SHARD_ORDER_FILE: path to file with one shard path per line (coprime ordering) + MUON_MOMENTUM_WARMUP_START: starting momentum for warmup (default 0.85) + MUON_MOMENTUM_WARMUP_STEPS: steps to ramp momentum (default 500) +""" +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import sys +import time +import uuid +import zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +import zstandard +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +sys.path.insert(0, str(Path(__file__).resolve().parent)) +from architectures import HybridGDN, CastedLinear +from configs import get_config + + +# ─── Hyperparameters ────────────────────────────────────────────────────────── + +class Hyperparameters: + arch_mode = os.environ.get("ARCH_MODE", "A") + data_path = os.environ.get("DATA_PATH", "../data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "../data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 42)) + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + + # Training length + iterations = int(os.environ.get("ITERATIONS", 7000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 2100)) # 30% of 7k + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 786_432)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + eval_seq_len = int(os.environ.get("EVAL_SEQ_LEN", 1024)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 14100.0)) # 3h55m safety + + # Validation + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 500)) + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 100)) + save_every = int(os.environ.get("SAVE_EVERY", 1000)) + + # Optimizer + matrix_lr = float(os.environ.get("MATRIX_LR", 0.02)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.02)) + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.035)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) + muon_wd = float(os.environ.get("MUON_WD", 0.04)) + adam_wd = float(os.environ.get("ADAM_WD", 0.04)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + + # Eval + eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) + xsa_eval = bool(int(os.environ.get("XSA_EVAL", "0"))) # during training + eval_compile_enabled = bool(int(os.environ.get("EVAL_COMPILE_ENABLED", "1"))) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Checkpoint + ckpt_dir = os.environ.get("CKPT_DIR", "checkpoints") + + # Compile + compile_enabled = bool(int(os.environ.get("COMPILE_ENABLED", "1"))) + + # Resume from checkpoint + resume_ckpt = os.environ.get("RESUME_CKPT", "") + + # EMA / SWA + ema_decay = float(os.environ.get("EMA_DECAY", 0.997)) + swa_enabled = bool(int(os.environ.get("SWA_ENABLED", "1"))) + swa_every = int(os.environ.get("SWA_EVERY", 50)) + + # Late QAT + late_qat_threshold = float(os.environ.get("LATE_QAT_THRESHOLD", 0.15)) + + # Chained job support + auto_save_seconds = float(os.environ.get("AUTO_SAVE_SECONDS", "0")) + total_iterations = int(os.environ.get("TOTAL_ITERATIONS", "0")) # 0 = same as iterations + + +# ─── Data Loading ───────────────────────────────────────────────────────────── + +def load_data_shard(file: Path) -> Tensor: + header = np.fromfile(file, dtype=np.uint32, count=256) + assert header[0] == 20240520, f"Bad magic: {header[0]}" + assert header[1] in (1, 7), f"Bad version: {header[1]}" + ntok = int(header[2]) + return torch.from_numpy(np.fromfile(file, dtype=np.uint16, offset=256 * 4)[:ntok].astype(np.int64)) + + +class TokenStream: + """Reads shards sequentially, supports coprime ordering via SHARD_ORDER_FILE.""" + def __init__(self, pattern: str): + shard_order_file = os.environ.get("SHARD_ORDER_FILE", "") + if shard_order_file and os.path.exists(shard_order_file): + with open(shard_order_file) as f: + self.files = [Path(line.strip()) for line in f if line.strip()] + else: + self.files = [Path(p) for p in sorted(glob.glob(pattern))] + assert self.files, f"No files matching {pattern}" + self.idx = 0 + self.buf = load_data_shard(self.files[self.idx]) + self.pos = 0 + + def _advance_file(self) -> None: + self.idx = (self.idx + 1) % len(self.files) + self.buf = load_data_shard(self.files[self.idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + parts = [] + remaining = n + while remaining > 0: + avail = self.buf.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + take_n = min(avail, remaining) + parts.append(self.buf[self.pos:self.pos + take_n]) + self.pos += take_n + remaining -= take_n + return torch.cat(parts) + + +class DistributedTokenLoader: + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.stream = TokenStream(pattern) + self.rank = rank + self.world_size = world_size + self.device = device + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + tokens_per_rank = global_tokens // self.world_size + seqs_per_rank = tokens_per_rank // seq_len + total_seqs = seqs_per_rank * self.world_size + total_needed = total_seqs * seq_len + 1 + all_tokens = self.stream.take(total_needed) + start = self.rank * seqs_per_rank * seq_len + chunk = all_tokens[start:start + seqs_per_rank * seq_len + 1] + x = chunk[:-1].reshape(seqs_per_rank, seq_len) + y = chunk[1:].reshape(seqs_per_rank, seq_len) + return x.to(self.device), y.to(self.device) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = sorted(glob.glob(pattern)) + parts = [load_data_shard(Path(f)) for f in files] + combined = torch.cat(parts) + return combined[:((combined.numel() - 1) // seq_len) * seq_len + 1] + + +def build_sentencepiece_luts(sp, vocab_size, device): + base_bytes = torch.zeros(vocab_size, dtype=torch.float32, device=device) + has_space = torch.zeros(vocab_size, dtype=torch.bool, device=device) + is_boundary = torch.zeros(vocab_size, dtype=torch.bool, device=device) + for i in range(vocab_size): + piece = sp.id_to_piece(i) + raw = piece.encode("utf-8") + base_bytes[i] = len(raw) + if piece.startswith("\u2581"): + has_space[i] = True + base_bytes[i] = len(piece[1:].encode("utf-8")) + 1 + if sp.is_control(i) or sp.is_unknown(i): + is_boundary[i] = True + return base_bytes, has_space, is_boundary + + +def generate_coprime_shard_order(shard_files: list, seed: int = 42) -> list: + """Generate a coprime-stepping shard order for better data mixing. + + Instead of sequential 0,1,2,...,79, uses stride = coprime(N) to visit + all shards in a maximally-spread order. This ensures each training epoch + sees data from diverse shards rather than correlated sequential ones. + """ + n = len(shard_files) + if n <= 1: + return shard_files + + # Find a coprime stride that's roughly n/phi (golden ratio) + target = max(1, int(n / 1.618)) + stride = target + while math.gcd(stride, n) != 1: + stride += 1 + + rng = random.Random(seed) + start = rng.randint(0, n - 1) + order = [] + pos = start + for _ in range(n): + order.append(shard_files[pos]) + pos = (pos + stride) % n + return order + + +# ─── Muon Optimizer ────────────────────────────────────────────────────────── + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + if transposed: + X = X.T + return X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, + nesterov: bool = True, weight_decay: float = 0.0): + defaults = dict(lr=lr, momentum=momentum, backend_steps=backend_steps, + nesterov=nesterov, weight_decay=weight_decay) + super().__init__(params, defaults) + + def step(self, closure=None): + for group in self.param_groups: + lr = group["lr"] + momentum = group["momentum"] + nesterov = group["nesterov"] + wd = group.get("weight_decay", 0.0) + for p in group["params"]: + if p.grad is None: + continue + g = p.grad + state = self.state[p] + if len(state) == 0: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g + momentum * buf + else: + g = buf + if g.ndim == 2 and min(g.shape) >= 2: + g = zeropower_via_newtonschulz5(g, steps=group["backend_steps"]) + if wd > 0: + p.data.mul_(1.0 - lr * wd) + p.data.add_(g, alpha=-lr) + + +# ─── Evaluation ────────────────────────────────────────────────────────────── + +def eval_val_sliding( + model: nn.Module, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + rank: int, + world_size: int, + device: torch.device, + seq_len: int = 1024, + stride: int = 64, + batch_seqs: int = 128, + xsa_eval: bool = False, + compile_enabled: bool = True, +) -> tuple[float, float]: + """Score-first sliding window evaluation.""" + total_tokens = val_tokens.numel() - 1 + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= 1] + total_windows = len(window_starts) + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + base_model = model.module if hasattr(model, 'module') else model + if xsa_eval and hasattr(base_model, 'set_xsa'): + base_model.set_xsa(True) + + forward_fn = base_model.forward_logits + compiled_logits = forward_fn + if compile_enabled: + try: + compiled_logits = torch.compile(forward_fn, dynamic=False) + except Exception: + compiled_logits = forward_fn + + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = compiled_logits(x_batch) + + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), + reduction="none", + ).reshape(bsz, seq_len) + + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt = y_batch[i, s:wlen] + prev = x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + + val_loss = (loss_sum / token_count).item() + bits_per_token = val_loss / math.log(2.0) + tokens_per_byte = token_count.item() / byte_count.item() + + if xsa_eval and hasattr(base_model, 'set_xsa'): + base_model.set_xsa(False) + + model.train() + return val_loss, bits_per_token * tokens_per_byte + + +# ─── Quantization ──────────────────────────────────────────────────────────── + +# Control tensor patterns — kept at full precision during quantization +CONTROL_PATTERNS = ( + "resid_mix", "q_gain", "smear", "skip_weight", "attn_scale", "mlp_scale", +) + + +def generate_autoregressive_calib(model, device, num_seqs=64, seq_len=2048, + vocab_size=1024, temperature=0.8, batch_size=8, seed=42): + """Generate sequences autoregressively from the model for GPTQ calibration. + No external data accessed — fully self-contained.""" + model.eval() + rng = torch.Generator(device=device) + rng.manual_seed(seed) + all_tokens = [] + with torch.inference_mode(), torch.autocast(device_type="cuda", dtype=torch.bfloat16): + for batch_start in range(0, num_seqs, batch_size): + bs = min(batch_size, num_seqs - batch_start) + tokens = torch.randint(0, vocab_size, (bs, 1), device=device, generator=rng) + for pos in range(seq_len - 1): + logits = model.forward_logits(tokens) + next_logit = logits[:, -1, :] + probs = torch.softmax(next_logit / temperature, dim=-1) + next_tok = torch.multinomial(probs, 1, generator=rng) + tokens = torch.cat([tokens, next_tok], dim=1) + for i in range(bs): + all_tokens.append(tokens[i:i+1]) + return all_tokens + + +def collect_hessians_from_tokens(hessian_model, token_seqs, device): + """Collect H = X^T X from pre-generated token sequences.""" + hessians = {} + hooks = [] + for name, module in hessian_model.named_modules(): + if isinstance(module, CastedLinear): + param_name = name + ".weight" + cols = module.weight.shape[1] + hessians[param_name] = torch.zeros(cols, cols, dtype=torch.float32, device='cpu') + def make_hook(pname): + def hook_fn(module, input, output): + x = input[0].detach().float() + if x.ndim == 3: + x = x.reshape(-1, x.shape[-1]) + hessians[pname] += (x.T @ x).cpu() + return hook_fn + h = module.register_forward_hook(make_hook(param_name)) + hooks.append(h) + hessian_model.eval() + with torch.inference_mode(), torch.autocast(device_type="cuda", dtype=torch.bfloat16): + for seq in token_seqs: + x = seq[:, :-1].to(device) + y = seq[:, 1:].to(device) + hessian_model(x, y) + for h in hooks: + h.remove() + num_batches = len(token_seqs) + for name in hessians: + hessians[name] /= num_batches + return hessians + + +def quantize_int6_gptq(weight, hessian=None, clip_range=31, block_size=128): + """Full GPTQ: Hessian-aware int6 quantization with Cholesky error compensation. + If hessian is None, falls back to percentile search.""" + t32 = weight.float() + if t32.ndim != 2 or hessian is None: + return quantize_int6_per_row(t32) + rows, cols = t32.shape + H = hessian.float().clone() + dead = torch.diag(H) == 0 + H[dead, dead] = 1 + damp = 0.01 * torch.mean(torch.diag(H)) + H[torch.arange(cols), torch.arange(cols)] += damp + perm = torch.argsort(torch.diag(H), descending=True) + inv_perm = torch.argsort(perm) + W = t32[:, perm].clone() + W[:, dead[perm]] = 0 + H = H[perm][:, perm] + Hinv = torch.linalg.cholesky(H) + Hinv = torch.cholesky_inverse(Hinv) + Hinv = torch.linalg.cholesky(Hinv, upper=True) + best_q = None; best_scale = None; best_err = float('inf') + for pct in [0.9990, 0.9995, 0.9999, 0.99999, 1.0]: + if pct < 1.0: + row_clip = torch.quantile(t32.abs(), pct, dim=1) + else: + row_clip = t32.abs().amax(dim=1) + s = (row_clip / clip_range).clamp_min(1.0 / clip_range).to(torch.float16) + sf = s.float() + Q = torch.zeros_like(W, dtype=torch.int8) + W_work = W.clone() + for i1 in range(0, cols, block_size): + i2 = min(i1 + block_size, cols) + count = i2 - i1 + W1 = W_work[:, i1:i2].clone() + Q1 = torch.zeros(rows, count, dtype=torch.int8) + Err1 = torch.zeros(rows, count) + Hinv1 = Hinv[i1:i2, i1:i2] + for i in range(count): + w = W1[:, i] + d = Hinv1[i, i] + q = torch.clamp(torch.round(w / sf), -clip_range, clip_range).to(torch.int8) + Q1[:, i] = q + err = (w - q.float() * sf) / d + W1[:, i:] -= err.unsqueeze(1) * Hinv1[i, i:].unsqueeze(0) + Err1[:, i] = err + Q[:, i1:i2] = Q1 + if i2 < cols: + W_work[:, i2:] -= Err1 @ Hinv[i1:i2, i2:] + recon = Q.float() * sf[:, None] + mse = (W - recon).pow(2).mean().item() + if mse < best_err: + best_q, best_scale, best_err = Q, s, mse + best_q = best_q[:, inv_perm] + return best_q, best_scale + + +def quantize_int6_per_row(t: Tensor, clip_range: int = 31) -> tuple[Tensor, Tensor]: + """Int6 quantization with percentile search for optimal clipping.""" + t32 = t.float() + if t32.ndim == 2: + best_q, best_s, best_err = None, None, float('inf') + for pct in [0.9990, 0.9995, 0.9999, 0.99999, 1.0]: + if pct < 1.0: + row_clip = torch.quantile(t32.abs(), pct, dim=1) + else: + row_clip = t32.abs().amax(dim=1) + s = (row_clip / clip_range).clamp_min(1.0 / clip_range).to(torch.float16) + q = torch.clamp(torch.round(t32 / s.float()[:, None]), -clip_range, clip_range).to(torch.int8) + recon = q.float() * s.float()[:, None] + err = (t32 - recon).pow(2).mean().item() + if err < best_err: + best_q, best_s, best_err = q, s, err + return best_q, best_s + # 1D: simple per-tensor quantization + amax = t32.abs().max().item() + scale = torch.tensor(amax / clip_range if amax > 0 else 1.0, dtype=torch.float16) + q = torch.clamp(torch.round(t32 / scale.float()), -clip_range, clip_range).to(torch.int8) + return q, scale + + +def quantize_int8_per_row(t: Tensor) -> tuple[Tensor, Tensor]: + """Int8 quantization with percentile clipping.""" + t32 = t.float() + clip_q = 0.9999984 + if t32.ndim == 2: + clip_abs = torch.quantile(t32.abs(), clip_q, dim=1) if t32.numel() else torch.empty((t32.shape[0],), dtype=torch.float32) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0).to(torch.float16) + q = torch.clamp(torch.round(clipped / scale.float()[:, None]), -127, 127).to(torch.int8) + return q, scale + clip_abs = float(torch.quantile(t32.abs().flatten(), clip_q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float16) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale.float()), -127, 127).to(torch.int8) + return q, scale + + +def mixed_quantize(state_dict: dict[str, Tensor], hessians: dict[str, Tensor] | None = None) -> tuple[dict[str, Tensor], dict[str, object]]: + """Mixed int6 (large weights) / int8 (small weights) quantization.""" + result: dict[str, Tensor] = {} + meta: dict[str, object] = {} + for name, tensor in state_dict.items(): + t = tensor.detach().cpu().contiguous() + # Control tensors: keep fp16 passthrough + if any(p in name for p in CONTROL_PATTERNS): + result[name] = t.to(torch.float16) if t.is_floating_point() else t + meta[name] = "passthrough" + continue + # Non-float: passthrough + if not t.is_floating_point(): + result[name] = t + meta[name] = "passthrough" + continue + # Small tensors: fp16 passthrough + if t.numel() <= 65536: + result[name] = t.to(torch.float16) + meta[name] = "passthrough" + continue + # Large 2D weights: int6 (6-bit quantization for better compression) + if t.ndim == 2 and t.numel() > 65536: + H = hessians.get(name) if hessians else None + q, s = quantize_int6_gptq(t, hessian=H) if H is not None else quantize_int6_per_row(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int6"} + else: + # Other float tensors: int8 + q, s = quantize_int8_per_row(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int8"} + return result, meta + + +def dequantize_mixed(result: dict[str, Tensor], meta: dict[str, object], + template_sd: dict[str, Tensor]) -> dict[str, Tensor]: + """Dequantize mixed int6/int8 back to float.""" + out: dict[str, Tensor] = {} + for name, orig in template_sd.items(): + info = meta.get(name) + if info is None: + continue + orig_dtype = orig.dtype + if info == "passthrough": + t = result[name] + if t.dtype == torch.float16 and orig_dtype in (torch.float32, torch.bfloat16): + t = t.to(orig_dtype) + out[name] = t + continue + q, s = result[name + ".q"], result[name + ".scale"] + if s.ndim > 0: + out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig_dtype) + else: + out[name] = (q.float() * float(s.item())).to(orig_dtype) + return out + + +# ─── Checkpoint Saving ─────────────────────────────────────────────────────── + +def save_checkpoint(model, step, val_bpb, ckpt_dir, arch_name, seed): + base = model.module if hasattr(model, 'module') else model + ckpt = { + "step": step, "val_bpb": val_bpb, + "arch_name": arch_name, "seed": seed, + "model_state_dict": base.state_dict(), + } + os.makedirs(ckpt_dir, exist_ok=True) + path = os.path.join(ckpt_dir, f"{arch_name}_step{step}_seed{seed}.pt") + torch.save(ckpt, path) + return path + + +def save_full_checkpoint(model, step, val_bpb, ckpt_dir, arch_name, seed, + muon_opt, adam_opt, ema_state, swa_state, swa_count, + qat_enabled, rng_states=None, stream_state=None): + """Save complete training state for chained job resume.""" + base = model.module if hasattr(model, 'module') else model + ckpt = { + "step": step, "val_bpb": val_bpb, + "arch_name": arch_name, "seed": seed, + "model_state_dict": {k: v.cpu() for k, v in base.state_dict().items()}, + "muon_opt_state": muon_opt.state_dict(), + "adam_opt_state": adam_opt.state_dict(), + "ema_state": {k: v.cpu() for k, v in ema_state.items()}, + "swa_state": {k: v.cpu() for k, v in swa_state.items()} if swa_state is not None else None, + "swa_count": swa_count, + "qat_enabled": qat_enabled, + } + if rng_states is not None: + ckpt["rng_states"] = rng_states + if stream_state is not None: + ckpt["stream_state"] = stream_state + os.makedirs(ckpt_dir, exist_ok=True) + path = os.path.join(ckpt_dir, f"full_ckpt_step{step}_seed{seed}.pt") + torch.save(ckpt, path) + return path + + +def _find_latest_full_ckpt(ckpt_dir): + """Find the latest full_ckpt_step*.pt file in ckpt_dir by step number.""" + pattern = os.path.join(ckpt_dir, "full_ckpt_step*_seed*.pt") + files = glob.glob(pattern) + if not files: + return None + import re + step_re = re.compile(r"full_ckpt_step(\d+)_seed") + best_step, best_path = -1, None + for f in files: + m = step_re.search(os.path.basename(f)) + if m: + s = int(m.group(1)) + if s > best_step: + best_step, best_path = s, f + return best_path + + +# ─── Main Training Loop ───────────────────────────────────────────────────── + +def main(): + global zeropower_via_newtonschulz5 + args = Hyperparameters() + config = get_config(args.arch_mode) + + # Distributed setup + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + grad_accum_steps = max(1, 8 // world_size) + master_process = rank == 0 + + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + + if args.compile_enabled: + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # Logging + os.makedirs("logs", exist_ok=True) + os.makedirs(args.ckpt_dir, exist_ok=True) + logfile = f"logs/{args.run_id}.txt" if master_process else None + + def log0(msg: str, console: bool = True): + if not master_process: + return + if console: + print(msg, flush=True) + if logfile: + with open(logfile, "a") as f: + print(msg, file=f) + + # Seeds + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + log0(f"=== GDN Hybrid 7k Full Training ===") + log0(f"Arch: {config['arch_name']} (ARCH_MODE={args.arch_mode})") + log0(f"Seed: {args.seed}, Steps: {args.iterations}, Warmdown: {args.warmdown_iters}") + log0(f"World size: {world_size}, Grad accum: {grad_accum_steps}") + log0(f"EMA decay: {args.ema_decay}, SWA: {args.swa_enabled} (every {args.swa_every})") + log0(f"Late QAT threshold: {args.late_qat_threshold}") + log0(f"Eval compile enabled: {args.eval_compile_enabled}") + + # Tokenizer + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + assert int(sp.vocab_size()) == args.vocab_size + + # Validation data + val_tokens = load_validation_tokens(args.val_files, args.eval_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"Validation tokens: {val_tokens.numel()-1:,}") + + # Build model + _t0 = time.time() + model = HybridGDN(config, args.vocab_size) + model = model.to(device).bfloat16() + log0(f"Model built in {time.time()-_t0:.1f}s") + + for module in model.modules(): + if isinstance(module, CastedLinear): + module.float() + for name, p in model.named_parameters(): + if p.ndim <= 1: + p.data = p.data.float() + + param_counts = model.count_params() + log0(f"Parameters: {param_counts}") + log0(f"Total params: {param_counts['total']:,}") + + # Resume from checkpoint if specified + start_step = 0 + resume_state = None # holds full checkpoint data for deferred restore + resume_ckpt_path = args.resume_ckpt + if resume_ckpt_path == "auto": + resume_ckpt_path = _find_latest_full_ckpt(args.ckpt_dir) or "" + if resume_ckpt_path: + log0(f"Auto-detected resume checkpoint: {resume_ckpt_path}") + else: + log0("Auto-resume: no full checkpoint found, starting fresh") + if resume_ckpt_path and os.path.exists(resume_ckpt_path): + log0(f"Resuming from checkpoint: {resume_ckpt_path}") + ckpt = torch.load(resume_ckpt_path, map_location="cpu", weights_only=False) + base_sd = ckpt["model_state_dict"] + model.load_state_dict({k: v.to(device) for k, v in base_sd.items()}, strict=True) + start_step = ckpt.get("step", 0) + log0(f"Resumed model at step {start_step}, val_bpb={ckpt.get('val_bpb', 'N/A')}") + # Keep full checkpoint for deferred optimizer/EMA/SWA restore + if "muon_opt_state" in ckpt: + resume_state = ckpt + log0(" Full checkpoint detected — will restore optimizers, EMA, SWA, RNG") + else: + log0(" Lightweight checkpoint — model only") + del ckpt + + # DDP + base_model = model # keep reference before wrapping + if distributed: + model = DDP(model, device_ids=[local_rank], find_unused_parameters=True) + + # Optimizer setup + matrix_params = [] + scalar_params = [] + embed_params = [] + for name, p in base_model.named_parameters(): + if not p.requires_grad: + continue + if "tok_emb" in name: + embed_params.append(p) + elif p.ndim == 2 and min(p.shape) >= 2: + matrix_params.append(p) + else: + scalar_params.append(p) + + log0(f"Matrix params: {sum(p.numel() for p in matrix_params):,}") + log0(f"Scalar params: {sum(p.numel() for p in scalar_params):,}") + log0(f"Embed params: {sum(p.numel() for p in embed_params):,}") + + muon_opt = Muon( + matrix_params, lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + weight_decay=args.muon_wd, + ) + adam_opt = torch.optim.AdamW( + [{"params": scalar_params, "lr": args.scalar_lr}, + {"params": embed_params, "lr": args.tied_embed_lr}], + betas=(args.beta1, args.beta2), + weight_decay=args.adam_wd, + fused=True, + ) + + # Deferred restore: optimizer states (must happen after optimizer creation) + if resume_state is not None: + muon_opt.load_state_dict(resume_state["muon_opt_state"]) + adam_opt.load_state_dict(resume_state["adam_opt_state"]) + log0(" Restored optimizer states (Muon + Adam)") + + # Data loader — coprime shard ordering + shard_order_file = os.environ.get("SHARD_ORDER_FILE", "") + if not shard_order_file: + # Generate coprime ordering on-the-fly + shard_files = sorted(glob.glob(args.train_files)) + if shard_files: + ordered = generate_coprime_shard_order(shard_files, seed=args.seed) + shard_order_path = f"/tmp/shard_order_{args.run_id}.txt" + with open(shard_order_path, "w") as f: + for sf in ordered: + f.write(str(sf) + "\n") + os.environ["SHARD_ORDER_FILE"] = shard_order_path + log0(f"Generated coprime shard order: stride across {len(shard_files)} shards") + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # LR schedule with warmdown (cosine) + def lr_schedule(step: int) -> float: + warmdown_start = args.iterations - args.warmdown_iters + if step < args.warmup_steps: + return step / max(1, args.warmup_steps) + elif step >= warmdown_start: + progress = (step - warmdown_start) / args.warmdown_iters + return 0.5 * (1.0 + math.cos(math.pi * progress)) + return 1.0 + + # ─── EMA + SWA state ───────────────────────────────────────────────── + ema_state = {name: t.detach().float().clone() for name, t in base_model.state_dict().items()} + swa_state: dict[str, Tensor] | None = None + swa_count = 0 + + # Deferred restore: EMA, SWA, QAT, RNG, stream state + if resume_state is not None: + # EMA + saved_ema = resume_state.get("ema_state") + if saved_ema is not None: + ema_state = {k: v.to(device).float() for k, v in saved_ema.items()} + log0(" Restored EMA state") + # SWA + saved_swa = resume_state.get("swa_state") + if saved_swa is not None: + swa_state = {k: v.cpu() for k, v in saved_swa.items()} + swa_count = resume_state.get("swa_count", 0) + log0(f" Restored SWA state (count={swa_count})") + else: + swa_count = resume_state.get("swa_count", 0) + # QAT + if resume_state.get("qat_enabled", False): + CastedLinear._qat_enabled = True + log0(" Restored QAT enabled state") + # RNG states + saved_rng = resume_state.get("rng_states") + if saved_rng is not None: + torch.set_rng_state(saved_rng["torch_cpu"]) + torch.cuda.set_rng_state(saved_rng["torch_cuda"]) + np.random.set_state(saved_rng["numpy"]) + random.setstate(saved_rng["python"]) + log0(" Restored RNG states") + # Stream state (data loader fast-forward) + saved_stream = resume_state.get("stream_state") + if saved_stream is not None: + s_idx, s_pos = saved_stream + stream = train_loader.stream + # Advance to the saved shard + while stream.idx != s_idx: + stream._advance_file() + stream.pos = s_pos + log0(f" Restored stream state (shard={s_idx}, pos={s_pos})") + else: + # No stream state saved — fast-forward by consuming tokens + if start_step > 0: + log0(f" Fast-forwarding data loader by {start_step} steps...") + for _ in range(start_step): + train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + log0(f" Data loader advanced to step {start_step}") + del resume_state + log0(" Full checkpoint restore complete") + + # ─── Training Loop ─────────────────────────────────────────────────── + # Clear stale chain marker from previous segment (if any) + stale_marker = os.path.join(args.ckpt_dir, f"CHAIN_RESUME_FROM_seed{args.seed}") + if os.path.exists(stale_marker): + os.remove(stale_marker) + + log0(f"\n{'='*80}") + log0(f"Starting training: {args.iterations} steps (from step {start_step})") + log0(f"{'='*80}\n") + + t0 = time.time() + running_loss = 0.0 + loss_count = 0 + stop_after_step = None + step = start_step # ensure step is defined even if loop doesn't execute + + for step in range(start_step + 1, args.iterations + 1): + # Check early stop + if stop_after_step is not None and step > stop_after_step: + log0(f"Stopping early at step {step} (wallclock limit)") + break + + lr_mul = lr_schedule(step) + + # Muon momentum warmup + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + current_muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in muon_opt.param_groups: + group["lr"] = args.matrix_lr * lr_mul + group["momentum"] = current_muon_momentum + for i, pg in enumerate(adam_opt.param_groups): + if i == 0: + pg["lr"] = args.scalar_lr * lr_mul + else: + pg["lr"] = args.tied_embed_lr * lr_mul + + # Late QAT: activate int6 STE only during warmdown (not warmup!) + warmdown_start = args.iterations - args.warmdown_iters + if (args.late_qat_threshold > 0 and step >= warmdown_start + and lr_mul < args.late_qat_threshold and not CastedLinear._qat_enabled): + CastedLinear._qat_enabled = True + log0(f"Late QAT enabled at step {step} (lr_mul={lr_mul:.4f})") + + # Gradient accumulation + model.train() + total_loss = 0.0 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + micro_batch = x.shape[0] // grad_accum_steps + for micro_step in range(grad_accum_steps): + x_micro = x[micro_step * micro_batch:(micro_step + 1) * micro_batch] + y_micro = y[micro_step * micro_batch:(micro_step + 1) * micro_batch] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = model(x_micro, y_micro) + loss = loss / grad_accum_steps + loss.backward() + total_loss += loss.item() + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip_norm) + + muon_opt.step() + adam_opt.step() + muon_opt.zero_grad(set_to_none=True) + adam_opt.zero_grad(set_to_none=True) + + # EMA update (every step) + with torch.no_grad(): + for name, t in base_model.state_dict().items(): + ema_state[name].mul_(args.ema_decay).add_(t.detach().float(), alpha=1.0 - args.ema_decay) + + # SWA: collect checkpoints during late warmdown + if args.swa_enabled and lr_mul < 0.2 and step % args.swa_every == 0: + if swa_state is None: + swa_state = {name: t.detach().cpu().clone() for name, t in base_model.state_dict().items()} + swa_count = 1 + log0(f"SWA started at step {step}") + else: + for name, t in base_model.state_dict().items(): + swa_state[name] += t.detach().cpu() + swa_count += 1 + + running_loss += total_loss + loss_count += 1 + + # Logging + if step % args.train_log_every == 0 or step <= 10: + avg_loss = running_loss / max(loss_count, 1) + elapsed = time.time() - t0 + steps_per_sec = step / elapsed + log0(f"step {step:5d}/{args.iterations} | loss {avg_loss:.4f} | lr_mul {lr_mul:.4f} | " + f"mom {current_muon_momentum:.3f} | {steps_per_sec:.2f} steps/s | {elapsed:.0f}s") + running_loss = 0.0 + loss_count = 0 + + # Validation + checkpoint + if (args.val_loss_every > 0 and step % args.val_loss_every == 0) or step == args.iterations: + val_loss, val_bpb = eval_val_sliding( + model, val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + rank, world_size, device, + seq_len=args.eval_seq_len, stride=args.eval_stride, + xsa_eval=args.xsa_eval, + compile_enabled=args.eval_compile_enabled, + ) + log0(f"step {step:5d} | val_loss {val_loss:.4f} | val_bpb {val_bpb:.4f}") + + if master_process and args.save_every > 0 and (step % args.save_every == 0 or step == args.iterations): + ckpt_path = save_checkpoint( + model, step, val_bpb, args.ckpt_dir, config["arch_name"], args.seed, + ) + log0(f" Saved: {ckpt_path}") + + # Wallclock limit + if args.max_wallclock_seconds > 0: + elapsed = time.time() - t0 + if elapsed > args.max_wallclock_seconds and stop_after_step is None: + stop_after_step = step + log0(f"Wallclock limit reached ({elapsed:.0f}s), will stop after this step") + + # Auto-save for chained job support + if args.auto_save_seconds > 0: + elapsed = time.time() - t0 + if elapsed > args.auto_save_seconds: + log0(f"Auto-save triggered at step {step} ({elapsed:.0f}s elapsed)") + if master_process: + rng_states = { + "torch_cpu": torch.get_rng_state(), + "torch_cuda": torch.cuda.get_rng_state(), + "numpy": np.random.get_state(), + "python": random.getstate(), + } + stream = train_loader.stream + stream_state = (stream.idx, stream.pos) + ckpt_path = save_full_checkpoint( + model, step, 0.0, args.ckpt_dir, config["arch_name"], args.seed, + muon_opt, adam_opt, ema_state, swa_state, swa_count, + CastedLinear._qat_enabled, + rng_states=rng_states, stream_state=stream_state, + ) + # Write chain resume marker + marker_path = os.path.join(args.ckpt_dir, f"CHAIN_RESUME_FROM_seed{args.seed}") + with open(marker_path, "w") as f: + f.write(ckpt_path + "\n") + log0(f" Full checkpoint saved: {ckpt_path}") + log0(f" Chain resume marker: {marker_path}") + break # exit training loop cleanly + + # ─── Check if we exited due to auto-save vs normal completion ──────── + chain_marker = os.path.join(args.ckpt_dir, f"CHAIN_RESUME_FROM_seed{args.seed}") + if os.path.exists(chain_marker): + log0("\nExiting for chained job resume (skipping post-training)") + if distributed: + dist.destroy_process_group() + return + + # Check for total_iterations completion + effective_total = args.total_iterations if args.total_iterations > 0 else args.iterations + if master_process and step >= effective_total: + complete_marker = os.path.join(args.ckpt_dir, f"TRAINING_COMPLETE_seed{args.seed}") + with open(complete_marker, "w") as f: + f.write(f"step={step}\n") + + # ─── Post-Training: Apply EMA ──────────────────────────────────────── + elapsed_total = time.time() - t0 + log0(f"\nTraining complete in {elapsed_total:.0f}s") + log0(f"Peak memory: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB") + + log0("\n=== Applying EMA weights ===") + avg_state = {name: t.to(dtype=base_model.state_dict()[name].dtype) for name, t in ema_state.items()} + base_model.load_state_dict(avg_state, strict=True) + + # Eval EMA weights + val_loss_ema, val_bpb_ema = eval_val_sliding( + model, val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + rank, world_size, device, + seq_len=args.eval_seq_len, stride=args.eval_stride, + xsa_eval=False, + compile_enabled=args.eval_compile_enabled, + ) + log0(f"EMA BPB (no XSA): {val_bpb_ema:.6f}") + + # Save raw EMA model + if master_process: + torch.save(base_model.state_dict(), os.path.join(args.ckpt_dir, f"final_model_{config['arch_name']}_seed{args.seed}.pt")) + log0("Saved raw EMA model") + + # ─── GPTQ Calibration (optional) ───────────────────────────────────── + gptq_enabled = bool(int(os.environ.get("GPTQ_ENABLED", "0"))) + hessians = None + if gptq_enabled: + log0("\n=== GPTQ: generating autoregressive calibration data ===") + calib_seqs = generate_autoregressive_calib( + base_model, device, num_seqs=64, seq_len=args.train_seq_len, + vocab_size=args.vocab_size, temperature=0.8, batch_size=8, seed=args.seed, + ) + log0(f"GPTQ: generated {len(calib_seqs)} sequences, collecting hessians...") + hessians = collect_hessians_from_tokens(base_model, calib_seqs, device) + log0(f"GPTQ: collected hessians for {len(hessians)} layers") + + # ─── Quantization + Artifact Creation ──────────────────────────────── + log0("\n=== Quantizing to int6 + zstd-22 ===") + sd_cpu = {k: v.detach().cpu() for k, v in base_model.state_dict().items()} + quant_result, quant_meta = mixed_quantize(sd_cpu, hessians=hessians) + + quant_buf = io.BytesIO() + torch.save({"w": quant_result, "m": quant_meta}, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = zstandard.ZstdCompressor(level=22).compress(quant_raw) + + artifact_path = os.path.join(args.ckpt_dir, f"final_model_{config['arch_name']}_seed{args.seed}.int6.ptz") + if master_process: + with open(artifact_path, "wb") as f: + f.write(quant_blob) + artifact_bytes = len(quant_blob) + log0(f"Artifact: {artifact_bytes:,} bytes ({artifact_bytes / 1024 / 1024:.2f} MB)") + if artifact_bytes > 16 * 1024 * 1024: + log0(f"WARNING: Artifact exceeds 16MB budget by {(artifact_bytes - 16*1024*1024) / 1024:.1f} KB") + + # ─── Roundtrip Validation ──────────────────────────────────────────── + log0("\n=== Roundtrip Validation (quantized model) ===") + if distributed: + dist.barrier() + + with open(artifact_path, "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load( + io.BytesIO(zstandard.ZstdDecompressor().decompress(quant_blob_disk)), + map_location="cpu", + ) + deq_state = dequantize_mixed(quant_state["w"], quant_state["m"], sd_cpu) + + # Build fresh eval model (no DDP wrapping needed) + eval_model = HybridGDN(config, args.vocab_size).to(device).bfloat16() + for m in eval_model.modules(): + if isinstance(m, CastedLinear): + m.float() + for name, p in eval_model.named_parameters(): + if p.ndim <= 1: + p.data = p.data.float() + eval_model.load_state_dict(deq_state, strict=True) + + # Eval quantized model without XSA + val_loss_q, val_bpb_q = eval_val_sliding( + eval_model, val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + rank, world_size, device, + seq_len=args.eval_seq_len, stride=args.eval_stride, + xsa_eval=False, + compile_enabled=args.eval_compile_enabled, + ) + log0(f"Quantized BPB (no XSA): {val_bpb_q:.6f}") + log0(f"Quantization degradation: {val_bpb_q - val_bpb_ema:+.6f}") + + # Eval quantized model WITH XSA (if model has SWA layers) + block_types = eval_model._block_types + if any(bt in ("swa", "swa_shared") for bt in block_types): + val_loss_qx, val_bpb_qx = eval_val_sliding( + eval_model, val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + rank, world_size, device, + seq_len=args.eval_seq_len, stride=args.eval_stride, + xsa_eval=True, + compile_enabled=args.eval_compile_enabled, + ) + log0(f"Quantized BPB (XSA-all): {val_bpb_qx:.6f}") + + # ─── Final Summary ─────────────────────────────────────────────────── + log0(f"\n{'='*80}") + log0(f"FINAL RESULTS — {config['arch_name']} seed={args.seed}") + log0(f" Training: {args.iterations} steps, {elapsed_total:.0f}s") + log0(f" EMA BPB (fp32): {val_bpb_ema:.6f}") + log0(f" Quantized BPB: {val_bpb_q:.6f}") + if any(bt in ("swa", "swa_shared") for bt in block_types): + log0(f" Quantized BPB+XSA: {val_bpb_qx:.6f}") + if master_process: + log0(f" Artifact size: {artifact_bytes:,} bytes") + log0(f"{'='*80}") + log0(f"final_int6_roundtrip_exact val_loss:{val_loss_q:.8f} val_bpb:{val_bpb_q:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() diff --git a/records/track_10min_16mb/2026-04-16_KKVShareWider_FLA/train_gpt.py b/records/track_10min_16mb/2026-04-16_KKVShareWider_FLA/train_gpt.py new file mode 100644 index 0000000000..c92492b069 --- /dev/null +++ b/records/track_10min_16mb/2026-04-16_KKVShareWider_FLA/train_gpt.py @@ -0,0 +1,87 @@ +#!/usr/bin/env python3 +"""FLA / GatedDeltaNet entrypoint wrapper. + +The actual training logic lives in `train_gdn_7k.py`. `evaluate.py` expects +`torchrun train_gpt.py`, so this wrapper preserves the standard repo entrypoint +while keeping the scored path in the records folder self-contained. +""" + +import os +import sys +import traceback +from pathlib import Path + +# These defaults keep the wrapper aligned with the intended SP8192 scored path. +VOCAB_SIZE = int(os.environ.get("VOCAB_SIZE", 8192)) +DATA_PATH = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp8192") +TOKENIZER_PATH = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_8192_bpe.model") +ARCH_MODE = os.environ.get("ARCH_MODE", "K") +os.environ.setdefault("VOCAB_SIZE", str(VOCAB_SIZE)) +os.environ.setdefault("DATA_PATH", DATA_PATH) +os.environ.setdefault("TOKENIZER_PATH", TOKENIZER_PATH) +os.environ.setdefault("ARCH_MODE", ARCH_MODE) +os.environ.setdefault("MAX_WALLCLOCK_SECONDS", "600") +os.environ.setdefault("VAL_LOSS_EVERY", "0") +os.environ.setdefault("EVAL_COMPILE_ENABLED", "0") +if ARCH_MODE in ("D", "G", "M"): + os.environ.setdefault("XSA_EVAL", "1") + + +_VENDOR_DIR = Path(__file__).resolve().parent / ".fla_vendor" +_VENDOR_PKGS = [ + "triton==3.2.0", + "flash-linear-attention==0.4.2", + "fla-core==0.4.2", + "transformers==5.5.4", + "tokenizers==0.22.2", + "safetensors==0.7.0", +] +if ARCH_MODE in ("F", "G"): + _VENDOR_PKGS.extend( + [ + "mamba-ssm==2.3.1", + "causal-conv1d==1.6.1", + ] + ) + + +def _ensure_vendor_on_path() -> None: + p = str(_VENDOR_DIR) + if p not in sys.path: + sys.path.insert(0, p) + + +def _ensure_fla_vendor_available() -> None: + _ensure_vendor_on_path() + try: + if ARCH_MODE in ("F", "G"): + from fla.layers.mamba2 import Mamba2 # noqa: F401 + from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined # noqa: F401 + from causal_conv1d import causal_conv1d_fn # noqa: F401 + else: + from fla.layers.gated_deltanet import GatedDeltaNet # noqa: F401 + print("wrapper: local vendored FLA imports already work", flush=True) + return + except Exception: + vendor_pkgs = ", ".join(_VENDOR_PKGS) + raise RuntimeError( + "wrapper: required FLA deps are missing from the local environment. " + f"Expected vendored packages under {_VENDOR_DIR}. " + f"Install them before evaluation (e.g. via launcher/requirements), packages: {vendor_pkgs}" + ) + + +def main(): + _ensure_fla_vendor_available() + print("wrapper: importing train_gdn_7k", flush=True) + try: + from train_gdn_7k import main as train_main + except Exception: + traceback.print_exc() + raise + print("wrapper: import ok, entering train_main", flush=True) + train_main() + + +if __name__ == "__main__": + main() diff --git a/records/track_10min_16mb/2026-04-16_KKVShareWider_FLA/train_seed1337.log b/records/track_10min_16mb/2026-04-16_KKVShareWider_FLA/train_seed1337.log new file mode 100644 index 0000000000..da3eb634d7 --- /dev/null +++ b/records/track_10min_16mb/2026-04-16_KKVShareWider_FLA/train_seed1337.log @@ -0,0 +1,77 @@ +data_setup: vocab=8192 shards=128 +Warning: You are sending unauthenticated requests to the HF Hub. Please set a HF_TOKEN to enable higher rate limits and faster downloads. +W0416 17:03:08.789000 305 torch/distributed/run.py:803] +W0416 17:03:08.789000 305 torch/distributed/run.py:803] ***************************************** +W0416 17:03:08.789000 305 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +wrapper: local vendored FLA imports already work +wrapper: importing train_gdn_7k +wrapper: import ok, entering train_main +=== GDN Hybrid 7k Full Training === +Arch: K_KVShare_Wider (ARCH_MODE=K) +Seed: 1337, Steps: 7000, Warmdown: 2100 +World size: 8, Grad accum: 1 +EMA decay: 0.997, SWA: True (every 50) +Late QAT threshold: 0.15 +Eval compile enabled: False +Validation tokens: 40,540,160 +Model built in 0.3s +Parameters: {'embedding': 4861441, 'recurrent': 11269920, 'attention': 5440, 'mlp': 17761600, 'other': 11424, 'total': 33909825} +Total params: 33,909,825 +Matrix params: 29,400,192 +Scalar params: 53,185 +Embed params: 4,456,448 +Generated coprime shard order: stride across 128 shards +================================================================================ +Starting training: 7000 steps (from step 0) +[rank0]:[W416 17:04:22.464538238 reducer.cpp:1431] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator()) +[rank4]:[W416 17:04:22.554721511 reducer.cpp:1431] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator()) +[rank5]:[W416 17:04:22.578729687 reducer.cpp:1431] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator()) +[rank6]:[W416 17:04:22.592133300 reducer.cpp:1431] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator()) +[rank7]:[W416 17:04:22.598708539 reducer.cpp:1431] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator()) +[rank2]:[W416 17:04:22.610173671 reducer.cpp:1431] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator()) +[rank3]:[W416 17:04:22.642753158 reducer.cpp:1431] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator()) +[rank1]:[W416 17:04:22.654023506 reducer.cpp:1431] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator()) +step 1/7000 | loss 9.0073 | lr_mul 0.0500 | mom 0.850 | 0.01 steps/s | 126s +step 2/7000 | loss 8.7870 | lr_mul 0.1000 | mom 0.850 | 0.02 steps/s | 126s +step 3/7000 | loss 8.2119 | lr_mul 0.1500 | mom 0.851 | 0.02 steps/s | 126s +step 4/7000 | loss 7.6929 | lr_mul 0.2000 | mom 0.851 | 0.03 steps/s | 127s +step 5/7000 | loss 7.3819 | lr_mul 0.2500 | mom 0.851 | 0.04 steps/s | 127s +step 6/7000 | loss 7.3836 | lr_mul 0.3000 | mom 0.851 | 0.05 steps/s | 127s +step 7/7000 | loss 7.3728 | lr_mul 0.3500 | mom 0.851 | 0.05 steps/s | 128s +step 8/7000 | loss 7.2195 | lr_mul 0.4000 | mom 0.852 | 0.06 steps/s | 128s +step 9/7000 | loss 7.0525 | lr_mul 0.4500 | mom 0.852 | 0.07 steps/s | 128s +step 10/7000 | loss 6.9312 | lr_mul 0.5000 | mom 0.852 | 0.08 steps/s | 128s +step 100/7000 | loss 5.0926 | lr_mul 1.0000 | mom 0.870 | 0.65 steps/s | 154s +step 200/7000 | loss 4.1688 | lr_mul 1.0000 | mom 0.890 | 1.09 steps/s | 183s +step 300/7000 | loss 3.7872 | lr_mul 1.0000 | mom 0.910 | 1.42 steps/s | 212s +step 400/7000 | loss 3.6149 | lr_mul 1.0000 | mom 0.930 | 1.66 steps/s | 240s +step 500/7000 | loss 3.5175 | lr_mul 1.0000 | mom 0.950 | 1.86 steps/s | 269s +step 600/7000 | loss 3.4531 | lr_mul 1.0000 | mom 0.950 | 2.01 steps/s | 298s +step 700/7000 | loss 3.4341 | lr_mul 1.0000 | mom 0.950 | 2.14 steps/s | 327s +step 800/7000 | loss 3.3717 | lr_mul 1.0000 | mom 0.950 | 2.25 steps/s | 355s +step 900/7000 | loss 3.3735 | lr_mul 1.0000 | mom 0.950 | 2.34 steps/s | 384s +step 1000/7000 | loss 3.3300 | lr_mul 1.0000 | mom 0.950 | 2.42 steps/s | 413s +step 1100/7000 | loss 3.3436 | lr_mul 1.0000 | mom 0.950 | 2.49 steps/s | 442s +step 1200/7000 | loss 3.2732 | lr_mul 1.0000 | mom 0.950 | 2.55 steps/s | 470s +step 1300/7000 | loss 3.2799 | lr_mul 1.0000 | mom 0.950 | 2.60 steps/s | 499s +step 1400/7000 | loss 3.2655 | lr_mul 1.0000 | mom 0.950 | 2.65 steps/s | 528s +step 1500/7000 | loss 3.2346 | lr_mul 1.0000 | mom 0.950 | 2.69 steps/s | 557s +step 1600/7000 | loss 3.2463 | lr_mul 1.0000 | mom 0.950 | 2.73 steps/s | 586s +Wallclock limit reached (600s), will stop after this step +Stopping early at step 1652 (wallclock limit) +Training complete in 600s +Peak memory: 41127 MiB +=== Applying EMA weights === +EMA BPB (no XSA): 1.020660 +Saved raw EMA model +=== Quantizing to int6 + zstd-22 === +Artifact: 15,762,406 bytes (15.03 MB) +=== Roundtrip Validation (quantized model) === +Quantized BPB (no XSA): 1.039674 +Quantization degradation: +0.019015 +FINAL RESULTS — K_KVShare_Wider seed=1337 + Training: 7000 steps, 600s + EMA BPB (fp32): 1.020660 + Quantized BPB: 1.039674 + Artifact size: 15,762,406 bytes +final_int6_roundtrip_exact val_loss:3.16104735 val_bpb:1.03967403 diff --git a/records/track_10min_16mb/2026-04-16_KKVShareWider_FLA/train_seed2025.log b/records/track_10min_16mb/2026-04-16_KKVShareWider_FLA/train_seed2025.log new file mode 100644 index 0000000000..4b95196353 --- /dev/null +++ b/records/track_10min_16mb/2026-04-16_KKVShareWider_FLA/train_seed2025.log @@ -0,0 +1,76 @@ +data_setup: vocab=8192 shards=128 +Warning: You are sending unauthenticated requests to the HF Hub. Please set a HF_TOKEN to enable higher rate limits and faster downloads. +W0416 17:46:47.160000 305 torch/distributed/run.py:803] +W0416 17:46:47.160000 305 torch/distributed/run.py:803] ***************************************** +W0416 17:46:47.160000 305 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +wrapper: local vendored FLA imports already work +wrapper: importing train_gdn_7k +wrapper: import ok, entering train_main +=== GDN Hybrid 7k Full Training === +Arch: K_KVShare_Wider (ARCH_MODE=K) +Seed: 2025, Steps: 7000, Warmdown: 2100 +World size: 8, Grad accum: 1 +EMA decay: 0.997, SWA: True (every 50) +Late QAT threshold: 0.15 +Eval compile enabled: False +Validation tokens: 40,540,160 +Model built in 0.2s +Parameters: {'embedding': 4861441, 'recurrent': 11269920, 'attention': 5440, 'mlp': 17761600, 'other': 11424, 'total': 33909825} +Total params: 33,909,825 +Matrix params: 29,400,192 +Scalar params: 53,185 +Embed params: 4,456,448 +Generated coprime shard order: stride across 128 shards +================================================================================ +Starting training: 7000 steps (from step 0) +[rank1]:[W416 17:48:20.663814217 reducer.cpp:1431] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator()) +[rank2]:[W416 17:48:20.668355002 reducer.cpp:1431] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator()) +[rank6]:[W416 17:48:20.672315856 reducer.cpp:1431] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator()) +[rank7]:[W416 17:48:20.675186367 reducer.cpp:1431] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator()) +[rank0]:[W416 17:48:20.675348220 reducer.cpp:1431] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator()) +[rank4]:[W416 17:48:20.679108359 reducer.cpp:1431] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator()) +[rank5]:[W416 17:48:20.681119312 reducer.cpp:1431] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator()) +[rank3]:[W416 17:48:20.705809763 reducer.cpp:1431] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator()) +step 1/7000 | loss 8.9997 | lr_mul 0.0500 | mom 0.850 | 0.01 steps/s | 145s +step 2/7000 | loss 8.7711 | lr_mul 0.1000 | mom 0.850 | 0.01 steps/s | 146s +step 3/7000 | loss 8.1605 | lr_mul 0.1500 | mom 0.851 | 0.02 steps/s | 146s +step 4/7000 | loss 7.6251 | lr_mul 0.2000 | mom 0.851 | 0.03 steps/s | 146s +step 5/7000 | loss 7.3888 | lr_mul 0.2500 | mom 0.851 | 0.03 steps/s | 147s +step 6/7000 | loss 7.3862 | lr_mul 0.3000 | mom 0.851 | 0.04 steps/s | 147s +step 7/7000 | loss 7.3875 | lr_mul 0.3500 | mom 0.851 | 0.05 steps/s | 147s +step 8/7000 | loss 7.1988 | lr_mul 0.4000 | mom 0.852 | 0.05 steps/s | 148s +step 9/7000 | loss 7.0120 | lr_mul 0.4500 | mom 0.852 | 0.06 steps/s | 148s +step 10/7000 | loss 6.8351 | lr_mul 0.5000 | mom 0.852 | 0.07 steps/s | 148s +step 100/7000 | loss 5.1111 | lr_mul 1.0000 | mom 0.870 | 0.57 steps/s | 174s +step 200/7000 | loss 4.1321 | lr_mul 1.0000 | mom 0.890 | 0.99 steps/s | 203s +step 300/7000 | loss 3.7850 | lr_mul 1.0000 | mom 0.910 | 1.30 steps/s | 231s +step 400/7000 | loss 3.6333 | lr_mul 1.0000 | mom 0.930 | 1.54 steps/s | 260s +step 500/7000 | loss 3.5216 | lr_mul 1.0000 | mom 0.950 | 1.73 steps/s | 289s +step 600/7000 | loss 3.4740 | lr_mul 1.0000 | mom 0.950 | 1.89 steps/s | 318s +step 700/7000 | loss 3.4165 | lr_mul 1.0000 | mom 0.950 | 2.02 steps/s | 347s +step 800/7000 | loss 3.3769 | lr_mul 1.0000 | mom 0.950 | 2.13 steps/s | 375s +step 900/7000 | loss 3.3826 | lr_mul 1.0000 | mom 0.950 | 2.23 steps/s | 404s +step 1000/7000 | loss 3.3284 | lr_mul 1.0000 | mom 0.950 | 2.31 steps/s | 433s +step 1100/7000 | loss 3.3195 | lr_mul 1.0000 | mom 0.950 | 2.38 steps/s | 462s +step 1200/7000 | loss 3.3091 | lr_mul 1.0000 | mom 0.950 | 2.45 steps/s | 490s +step 1300/7000 | loss 3.2936 | lr_mul 1.0000 | mom 0.950 | 2.50 steps/s | 519s +step 1400/7000 | loss 3.2604 | lr_mul 1.0000 | mom 0.950 | 2.56 steps/s | 548s +step 1500/7000 | loss 3.2757 | lr_mul 1.0000 | mom 0.950 | 2.60 steps/s | 576s +Wallclock limit reached (600s), will stop after this step +Stopping early at step 1583 (wallclock limit) +Training complete in 600s +Peak memory: 41127 MiB +=== Applying EMA weights === +EMA BPB (no XSA): 1.023994 +Saved raw EMA model +=== Quantizing to int6 + zstd-22 === +Artifact: 15,648,800 bytes (14.92 MB) +=== Roundtrip Validation (quantized model) === +Quantized BPB (no XSA): 1.041482 +Quantization degradation: +0.017488 +FINAL RESULTS — K_KVShare_Wider seed=2025 + Training: 7000 steps, 600s + EMA BPB (fp32): 1.023994 + Quantized BPB: 1.041482 + Artifact size: 15,648,800 bytes +final_int6_roundtrip_exact val_loss:3.16654364 val_bpb:1.04148177 diff --git a/records/track_10min_16mb/2026-04-16_KKVShareWider_FLA/train_seed42.log b/records/track_10min_16mb/2026-04-16_KKVShareWider_FLA/train_seed42.log new file mode 100644 index 0000000000..9af658c69f --- /dev/null +++ b/records/track_10min_16mb/2026-04-16_KKVShareWider_FLA/train_seed42.log @@ -0,0 +1,77 @@ +data_setup: vocab=8192 shards=128 +Warning: You are sending unauthenticated requests to the HF Hub. Please set a HF_TOKEN to enable higher rate limits and faster downloads. +W0416 17:20:40.745000 305 torch/distributed/run.py:803] +W0416 17:20:40.745000 305 torch/distributed/run.py:803] ***************************************** +W0416 17:20:40.745000 305 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +wrapper: local vendored FLA imports already work +wrapper: importing train_gdn_7k +wrapper: import ok, entering train_main +=== GDN Hybrid 7k Full Training === +Arch: K_KVShare_Wider (ARCH_MODE=K) +Seed: 42, Steps: 7000, Warmdown: 2100 +World size: 8, Grad accum: 1 +EMA decay: 0.997, SWA: True (every 50) +Late QAT threshold: 0.15 +Eval compile enabled: False +Validation tokens: 40,540,160 +Model built in 0.3s +Parameters: {'embedding': 4861441, 'recurrent': 11269920, 'attention': 5440, 'mlp': 17761600, 'other': 11424, 'total': 33909825} +Total params: 33,909,825 +Matrix params: 29,400,192 +Scalar params: 53,185 +Embed params: 4,456,448 +Generated coprime shard order: stride across 128 shards +================================================================================ +Starting training: 7000 steps (from step 0) +[rank0]:[W416 17:21:55.628946976 reducer.cpp:1431] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator()) +[rank4]:[W416 17:21:55.685417161 reducer.cpp:1431] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator()) +[rank3]:[W416 17:21:55.765165093 reducer.cpp:1431] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator()) +[rank5]:[W416 17:21:55.796641325 reducer.cpp:1431] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator()) +[rank1]:[W416 17:21:55.807222331 reducer.cpp:1431] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator()) +[rank6]:[W416 17:21:55.844149809 reducer.cpp:1431] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator()) +[rank2]:[W416 17:21:55.845894386 reducer.cpp:1431] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator()) +[rank7]:[W416 17:21:55.847535711 reducer.cpp:1431] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator()) +step 1/7000 | loss 8.9998 | lr_mul 0.0500 | mom 0.850 | 0.01 steps/s | 125s +step 2/7000 | loss 8.7867 | lr_mul 0.1000 | mom 0.850 | 0.02 steps/s | 126s +step 3/7000 | loss 8.0541 | lr_mul 0.1500 | mom 0.851 | 0.02 steps/s | 126s +step 4/7000 | loss 7.6256 | lr_mul 0.2000 | mom 0.851 | 0.03 steps/s | 127s +step 5/7000 | loss 7.3438 | lr_mul 0.2500 | mom 0.851 | 0.04 steps/s | 127s +step 6/7000 | loss 7.3183 | lr_mul 0.3000 | mom 0.851 | 0.05 steps/s | 127s +step 7/7000 | loss 7.2767 | lr_mul 0.3500 | mom 0.851 | 0.05 steps/s | 128s +step 8/7000 | loss 7.2274 | lr_mul 0.4000 | mom 0.852 | 0.06 steps/s | 128s +step 9/7000 | loss 6.9283 | lr_mul 0.4500 | mom 0.852 | 0.07 steps/s | 128s +step 10/7000 | loss 6.8781 | lr_mul 0.5000 | mom 0.852 | 0.08 steps/s | 128s +step 100/7000 | loss 5.1039 | lr_mul 1.0000 | mom 0.870 | 0.65 steps/s | 154s +step 200/7000 | loss 4.1381 | lr_mul 1.0000 | mom 0.890 | 1.09 steps/s | 183s +step 300/7000 | loss 3.7795 | lr_mul 1.0000 | mom 0.910 | 1.41 steps/s | 212s +step 400/7000 | loss 3.6277 | lr_mul 1.0000 | mom 0.930 | 1.66 steps/s | 241s +step 500/7000 | loss 3.5450 | lr_mul 1.0000 | mom 0.950 | 1.86 steps/s | 269s +step 600/7000 | loss 3.4821 | lr_mul 1.0000 | mom 0.950 | 2.01 steps/s | 298s +step 700/7000 | loss 3.4313 | lr_mul 1.0000 | mom 0.950 | 2.14 steps/s | 327s +step 800/7000 | loss 3.3928 | lr_mul 1.0000 | mom 0.950 | 2.25 steps/s | 356s +step 900/7000 | loss 3.3556 | lr_mul 1.0000 | mom 0.950 | 2.34 steps/s | 384s +step 1000/7000 | loss 3.3242 | lr_mul 1.0000 | mom 0.950 | 2.42 steps/s | 413s +step 1100/7000 | loss 3.3304 | lr_mul 1.0000 | mom 0.950 | 2.49 steps/s | 442s +step 1200/7000 | loss 3.3045 | lr_mul 1.0000 | mom 0.950 | 2.55 steps/s | 471s +step 1300/7000 | loss 3.2894 | lr_mul 1.0000 | mom 0.950 | 2.60 steps/s | 500s +step 1400/7000 | loss 3.2761 | lr_mul 1.0000 | mom 0.950 | 2.65 steps/s | 528s +step 1500/7000 | loss 3.2672 | lr_mul 1.0000 | mom 0.950 | 2.69 steps/s | 557s +step 1600/7000 | loss 3.2524 | lr_mul 1.0000 | mom 0.950 | 2.73 steps/s | 586s +Wallclock limit reached (600s), will stop after this step +Stopping early at step 1652 (wallclock limit) +Training complete in 600s +Peak memory: 41127 MiB +=== Applying EMA weights === +EMA BPB (no XSA): 1.022042 +Saved raw EMA model +=== Quantizing to int6 + zstd-22 === +Artifact: 15,870,797 bytes (15.14 MB) +=== Roundtrip Validation (quantized model) === +Quantized BPB (no XSA): 1.041537 +Quantization degradation: +0.019495 +FINAL RESULTS — K_KVShare_Wider seed=42 + Training: 7000 steps, 600s + EMA BPB (fp32): 1.022042 + Quantized BPB: 1.041537 + Artifact size: 15,870,797 bytes +final_int6_roundtrip_exact val_loss:3.16671180 val_bpb:1.04153708 From be976e32aa704c313f7ad3618d5af15e6b25fccb Mon Sep 17 00:00:00 2001 From: genji0306 Date: Sat, 18 Apr 2026 05:58:27 +0700 Subject: [PATCH 2/3] =?UTF-8?q?Record:=20K=5FKVShare=5FWider=20FLA=20?= =?UTF-8?q?=E2=80=94=20val=5Fbpb=201.0339=20(3-seed=20mean)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Independent 3-seed reproduction of GatedDeltaNet K_KVShare_Wider on 8xH100 SXM. Builds on PR #1687 (resouer). No TTT, no SLOT, no n-gram. Seeds: 42 (1.0353), 1337 (1.0333), 2025 (1.0330) Mean: 1.0339 ± 0.0012 | Artifact: 15.88 MB mean --- .../README.md | 50 + .../architectures.py | 709 ++++++++++ .../configs.py | 316 +++++ .../requirements.txt | 10 + .../submission.json | 38 + .../train_gdn_7k.py | 1172 +++++++++++++++++ .../train_gpt.py | 87 ++ 7 files changed, 2382 insertions(+) create mode 100644 records/track_10min_16mb/2026-04-17_KKVShareWider_FLA_Opensens/README.md create mode 100644 records/track_10min_16mb/2026-04-17_KKVShareWider_FLA_Opensens/architectures.py create mode 100644 records/track_10min_16mb/2026-04-17_KKVShareWider_FLA_Opensens/configs.py create mode 100644 records/track_10min_16mb/2026-04-17_KKVShareWider_FLA_Opensens/requirements.txt create mode 100644 records/track_10min_16mb/2026-04-17_KKVShareWider_FLA_Opensens/submission.json create mode 100644 records/track_10min_16mb/2026-04-17_KKVShareWider_FLA_Opensens/train_gdn_7k.py create mode 100644 records/track_10min_16mb/2026-04-17_KKVShareWider_FLA_Opensens/train_gpt.py diff --git a/records/track_10min_16mb/2026-04-17_KKVShareWider_FLA_Opensens/README.md b/records/track_10min_16mb/2026-04-17_KKVShareWider_FLA_Opensens/README.md new file mode 100644 index 0000000000..5f0307154c --- /dev/null +++ b/records/track_10min_16mb/2026-04-17_KKVShareWider_FLA_Opensens/README.md @@ -0,0 +1,50 @@ +# Record: K_KVShare_Wider FLA (Opensens reproduction) + +**val_bpb: 1.0339** (3-seed mean, std 0.0012) | **3.1434 nats** | **8xH100 SXM, 600s** | **No TTT** + +Independent 3-seed reproduction of GatedDeltaNet K_KVShare_Wider, building on +PR #1687 (@resouer). Improved results (1.0339 vs 1.0409) likely due to hardware +variance (RunPod secure cloud, IN region). + +## Results + +| Seed | Steps | EMA BPB | **Quantized BPB** | Artifact | +|------|------:|--------:|------------------:|---------:| +| 42 | 1881 | 1.016763 | **1.03527246** | 15,927,295 | +| 1337 | 1890 | 1.013801 | **1.03326043** | 15,830,641 | +| 2025 | 1884 | 1.014923 | **1.03303636** | 15,893,661 | +| **Mean** | **1885** | **1.015162** | **1.03385760** | **15,883,866** | + +## Technique + +- GatedDeltaNet / Flash Linear Attention (`K_KVShare_Wider` config) +- 10 GDN layers, model_dim=544, 8 heads, head_dim=64 +- KV sharing stride=2 (5 unique K/V sets for 10 layers) +- MLP mult=3.0, ReLU-squared, logit softcap=30 +- BigramHash(3072, 112) + trigram embeddings +- SP8192 tokenizer (from kevclark/parameter-golf HF dataset) +- Muon optimizer (momentum 0.95, WD 0.04) +- EMA decay=0.997 + SWA every 50 steps +- Late QAT (Int6 STE when LR < 15% of peak) +- Int6 + zstd-22 artifact compression + +Not used: no TTT, no SLOT, no n-gram overlay, no XSA eval. + +## Reproducibility + +```bash +pip install -r requirements.txt + +# Download SP8192 data +MATCHED_FINEWEB_REPO_ID=kevclark/parameter-golf python3 data/cached_challenge_fineweb.py --variant sp8192 +# Or use snapshot_download from huggingface_hub + +SEED=$SEED ARCH_MODE=K MAX_WALLCLOCK_SECONDS=600 VAL_LOSS_EVERY=0 EVAL_COMPILE_ENABLED=0 \ + torchrun --standalone --nproc_per_node=8 train_gpt.py +``` + +## Attribution + +This submission reproduces and validates the architecture from PR #1687 by @resouer. +The GatedDeltaNet architecture is from Yang, Kautz & Hatamizadeh (NVIDIA, ICLR 2025). +Flash Linear Attention library by @sustcsonglin and @yzhangcs. diff --git a/records/track_10min_16mb/2026-04-17_KKVShareWider_FLA_Opensens/architectures.py b/records/track_10min_16mb/2026-04-17_KKVShareWider_FLA_Opensens/architectures.py new file mode 100644 index 0000000000..dff180156c --- /dev/null +++ b/records/track_10min_16mb/2026-04-17_KKVShareWider_FLA_Opensens/architectures.py @@ -0,0 +1,709 @@ +"""GDN Hybrid Architecture — modular blocks using FLA native layers. + +Supports 8 model variants (A-H) for the Parameter Golf screening experiments. +Each model is a stack of mixed {GDN, DeltaProduct, RWKV-7, Mamba-2, SWA} blocks +with shared MLP, RMSNorm, and residual connections. + +Key design choices: +- FLA layers handle recurrent attention (GatedDeltaNet, GatedDeltaProduct, RWKV7, Mamba2) +- Sliding Window Attention (SWA) uses flash attention with a causal window mask +- All blocks follow the same pre-norm residual pattern for uniform gradient flow +- Weight sharing for SWA layers in Zamba/Hymba-style models +- Score-first eval: XSA-all only extends attention layers (no future context leakage) +""" +from __future__ import annotations +import math +import os +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor + +# ─── FLA backend selection ────────────────────────────────────────────────── +# Set FLA_USE_NAIVE=1 to force pure-PyTorch (naive) kernels instead of Triton. +# This is needed when: +# - Running on V100 (sm_70) which doesn't support FLA's Triton kernels well +# - Triton cache is corrupted (FileNotFoundError on .json files) +# - Debugging without Triton dependency +# +# On A100 (sm_80+), the Triton kernels are ~3-10x faster and should be used. +_USE_NAIVE = os.environ.get("FLA_USE_NAIVE", "0") == "1" + +if _USE_NAIVE: + # 1. Patch GatedDeltaNet's chunk op + import fla.ops.gated_delta_rule.chunk as _gdr_chunk + import fla.ops.gated_delta_rule.naive as _gdr_naive + + def _patched_chunk_gated_delta_rule( + q, k, v, g, beta, scale=None, initial_state=None, + output_final_state=False, use_qk_l2norm_in_kernel=False, **kwargs + ): + if use_qk_l2norm_in_kernel: + q = F.normalize(q, p=2, dim=-1) + k = F.normalize(k, p=2, dim=-1) + return _gdr_naive.naive_chunk_gated_delta_rule( + q, k, v, g, beta, + chunk_size=64, scale=scale, + initial_state=initial_state, + output_final_state=output_final_state, + ) + + _gdr_chunk.chunk_gated_delta_rule = _patched_chunk_gated_delta_rule + import fla.layers.gated_deltanet as _gdn_layer + _gdn_layer.chunk_gated_delta_rule = _patched_chunk_gated_delta_rule + + # 2. Patch GatedDeltaProduct's chunk op + import fla.ops.gated_delta_product.chunk as _gdp_chunk + import fla.ops.gated_delta_product.naive as _gdp_naive + + def _patched_chunk_gated_delta_product( + q, k, v, g, beta, num_householder=1, scale=None, initial_state=None, + output_final_state=False, use_qk_l2norm_in_kernel=False, **kwargs + ): + if use_qk_l2norm_in_kernel: + q = F.normalize(q, p=2, dim=-1) + k = F.normalize(k, p=2, dim=-1) + return _gdp_naive.naive_recurrent_gated_delta_product( + q, k, v, g, beta, + scale=scale, cu_seqlens=None, + initial_state=initial_state, + output_final_state=output_final_state, + num_householder=num_householder, + ) + + _gdp_chunk.chunk_gated_delta_product = _patched_chunk_gated_delta_product + import fla.layers.gated_deltaproduct as _gdp_layer + _gdp_layer.chunk_gated_delta_product = _patched_chunk_gated_delta_product + + print("[FLA] Using NAIVE (pure-PyTorch) kernels — set FLA_USE_NAIVE=0 for Triton", flush=True) + +# FLA imports +from fla.layers import GatedDeltaNet, GatedDeltaProduct, Mamba2 +try: + from fla.layers import RWKV7Attention +except Exception: + RWKV7Attention = None # type: ignore + +try: + from flash_attn_interface import flash_attn_func as flash_attn_3_func +except ImportError: + def flash_attn_3_func(q, k, v, causal=False, window_size=(-1, -1)): + q2 = q.transpose(1, 2) + k2 = k.transpose(1, 2) + v2 = v.transpose(1, 2) + if k2.size(1) != q2.size(1): + rep = q2.size(1) // k2.size(1) + k2 = k2.repeat_interleave(rep, dim=1) + v2 = v2.repeat_interleave(rep, dim=1) + out = torch.nn.functional.scaled_dot_product_attention(q2, k2, v2, is_causal=causal) + return out.transpose(1, 2) + + +class RMSNorm(nn.Module): + def __init__(self, dim: int | None = None, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.dim = dim + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + """Linear layer that casts input to weight dtype for mixed precision. + Supports late QAT (int6 STE) when _qat_enabled is set.""" + _qat_enabled: bool = False + + def forward(self, x: Tensor) -> Tensor: + w = self.weight.to(dtype=x.dtype) + if CastedLinear._qat_enabled and self.training and w.ndim == 2: + with torch.no_grad(): + w32 = self.weight.float() + row_max = w32.abs().amax(dim=1) + scale = (row_max / 31.0).clamp_min(1.0 / 31.0) + w_q = (torch.clamp(torch.round(w32 / scale[:, None]), -31, 31) * scale[:, None]).to(x.dtype) + w = w + (w_q - w).detach() # STE: forward uses quantized, backward uses full + bias = self.bias.to(dtype=x.dtype) if self.bias is not None else None + return F.linear(x, w, bias) + + +class Rotary(nn.Module): + """RoPE embeddings for sliding window attention.""" + def __init__(self, dim: int, base: float = 10000.0, max_len: int = 4096): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.max_len = max_len + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype): + t = torch.arange(seq_len, device=device, dtype=torch.float32) + freqs = torch.outer(t, self.inv_freq.to(device)) + cos = freqs.cos().to(dtype) + sin = freqs.sin().to(dtype) + return cos, sin + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + """Apply RoPE to the input tensor.""" + d = x.shape[-1] // 2 + x1, x2 = x[..., :d], x[..., d:] + out1 = x1 * cos[:x.shape[-2]] - x2 * sin[:x.shape[-2]] + out2 = x2 * cos[:x.shape[-2]] + x1 * sin[:x.shape[-2]] + return torch.cat([out1, out2], dim=-1) + + +class MLP(nn.Module): + """Feed-forward MLP with configurable activation.""" + def __init__(self, dim: int, mult: float = 3.0, act: str = "relu_sq", leaky_slope: float = 0.5): + super().__init__() + hidden = int(mult * dim) + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + nn.init.zeros_(self.proj.weight) # zero-init output for residual + self.act = act + self.leaky_slope = leaky_slope + + def forward(self, x: Tensor) -> Tensor: + x = self.fc(x) + if self.act == "leaky_relu_sq": + x = F.leaky_relu(x, negative_slope=self.leaky_slope) + else: + x = F.relu(x) + return self.proj(x.square()) + + +class SlidingWindowAttention(nn.Module): + """Sliding window causal attention for hybrid models. + + Supports XSA (cross-segment attention) at eval time for extending context + across eval chunks. Window is enforced during training but can be relaxed at eval. + KV can be shared across layers (Zamba-style) by reusing the same module. + """ + def __init__( + self, + dim: int, + num_heads: int = 8, + num_kv_heads: int = 4, + window_size: int = 512, + rope_base: float = 10000.0, + qk_gain_init: float = 1.5, + ): + super().__init__() + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + kv_dim = num_kv_heads * self.head_dim + self.window_size = window_size + + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + nn.init.zeros_(self.proj.weight) + + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + self.use_xsa = False # enabled at eval time for XSA-all + + def _xsa_efficient(self, y: Tensor, v: Tensor) -> Tensor: + """XSA: subtract self-value projection (GQA-aware).""" + B, T, H, D = y.shape + Hkv = v.size(-2) + group = H // Hkv + y_g = y.reshape(B, T, Hkv, group, D) + vn = F.normalize(v, dim=-1).unsqueeze(-2) + proj = (y_g * vn).sum(dim=-1, keepdim=True) * vn + return (y_g - proj).reshape(B, T, H, D) + + def forward(self, x: Tensor, v_embed: Tensor | None = None) -> Tensor: + B, T, D = x.shape + q = self.c_q(x).reshape(B, T, self.num_heads, self.head_dim) + k = self.c_k(x).reshape(B, T, self.num_kv_heads, self.head_dim) + v = self.c_v(x) + if v_embed is not None: + v = v + v_embed + v = v.reshape(B, T, self.num_kv_heads, self.head_dim) + + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(T, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, None, :, None] + + if q.is_cuda and q.dtype not in (torch.float16, torch.bfloat16): + q, k, v = q.to(torch.bfloat16), k.to(torch.bfloat16), v.to(torch.bfloat16) + + # Use window during training, full causal at eval if XSA enabled + y = flash_attn_3_func(q, k, v, causal=True) + + if self.use_xsa: + y = self._xsa_efficient(y, v) + + y = y.reshape(B, T, D) + return self.proj(y) + + +class RecurrentBlock(nn.Module): + """Wraps any FLA recurrent layer (GDN, DeltaProduct, RWKV-7, Mamba-2) with + pre-norm residual connection and MLP.""" + + def __init__( + self, + dim: int, + recurrent_layer: nn.Module, + mlp_mult: float = 3.0, + mlp_act: str = "relu_sq", + layer_idx: int = 0, + ): + super().__init__() + self.attn_norm = RMSNorm(dim) + self.mlp_norm = RMSNorm(dim) + self.recurrent = recurrent_layer + self.mlp = MLP(dim, mlp_mult, act=mlp_act) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + self.layer_idx = layer_idx + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x_in = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + + # FLA layers return (output, state) or just output depending on mode + recurrent_out = self.recurrent(self.attn_norm(x_in)) + if isinstance(recurrent_out, tuple): + recurrent_out = recurrent_out[0] + + x_out = x_in + self.attn_scale.to(dtype=x_in.dtype)[None, None, :] * recurrent_out + x_out = x_out + self.mlp_scale.to(dtype=x_out.dtype)[None, None, :] * self.mlp(self.mlp_norm(x_out)) + return x_out + + +class AttentionBlock(nn.Module): + """SWA block with pre-norm residual and MLP.""" + + def __init__( + self, + dim: int, + swa: SlidingWindowAttention, + mlp_mult: float = 3.0, + mlp_act: str = "relu_sq", + layer_idx: int = 0, + ): + super().__init__() + self.attn_norm = RMSNorm(dim) + self.mlp_norm = RMSNorm(dim) + self.attn = swa + self.mlp = MLP(dim, mlp_mult, act=mlp_act) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + self.layer_idx = layer_idx + + def forward(self, x: Tensor, x0: Tensor, v_embed: Tensor | None = None) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x_in = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x_in), v_embed=v_embed) + x_out = x_in + self.attn_scale.to(dtype=x_in.dtype)[None, None, :] * attn_out + x_out = x_out + self.mlp_scale.to(dtype=x_out.dtype)[None, None, :] * self.mlp(self.mlp_norm(x_out)) + return x_out + + +class SmearGate(nn.Module): + """Weighted average of current and previous token embeddings.""" + def __init__(self, dim: int): + super().__init__() + self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + + def forward(self, x: Tensor) -> Tensor: + g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] + x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) + return (1 - g) * x + g * x_prev + + +class BigramHashEmbedding(nn.Module): + """Hash-based bigram/trigram embedding for additional context.""" + def __init__(self, bigram_vocab_size: int, bigram_dim: int, model_dim: int, + trigram: bool = False): + super().__init__() + self.bigram_vocab_size = bigram_vocab_size + self._trigram = trigram + self.embed = nn.Embedding(bigram_vocab_size, bigram_dim) + nn.init.zeros_(self.embed.weight) + self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) + + def bigram_hash(self, tokens: Tensor) -> Tensor: + t = tokens.to(torch.int32) + mod = self.bigram_vocab_size - 1 + out = torch.empty_like(t) + out[..., 0] = mod + out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod + return out.long() + + def trigram_hash(self, tokens: Tensor) -> Tensor: + t = tokens.to(torch.int32) + mod = self.bigram_vocab_size - 1 + out = torch.empty_like(t) + out[..., :2] = mod + out[..., 2:] = (36313 * t[..., 2:] ^ 27191 * t[..., 1:-1] ^ 51497 * t[..., :-2]) % mod + return out.long() + + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(self.bigram_hash(token_ids)) + if self._trigram: + h = h + self.embed(self.trigram_hash(token_ids)) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) + + +def _parse_layout(layout_str: str) -> list[tuple[str, int]]: + """Parse a layout string into a sequence of (layer_type, count) pairs. + + Examples: + "gdn_only" -> [("gdn", 11)] (count filled in by caller) + "gdn5_swa_gdn5_swa_shared" -> [("gdn", 5), ("swa", 1), ("gdn", 5), ("swa_shared", 1)] + "gdn3_swa_gdn3_swa_shared_gdn3" -> [("gdn", 3), ("swa", 1), ("gdn", 3), ("swa_shared", 1), ("gdn", 3)] + "mamba_only" -> [("mamba", 11)] + "gdn3_mamba2_swa_gdn3_mamba2" -> [("gdn", 3), ("mamba", 2), ("swa", 1), ("gdn", 3), ("mamba", 2)] + """ + if layout_str == "gdn_only": + return [("gdn", -1)] # -1 = use num_gdn_layers + if layout_str == "mamba_only": + return [("mamba", -1)] # -1 = use num_mamba_layers + if layout_str == "swa_only": + return [("swa", -1)] # -1 = use num_swa_layers + + # Parse custom layouts like "gdn5_swa_gdn5_swa_shared" + parts = layout_str.split("_") + result = [] + i = 0 + while i < len(parts): + part = parts[i] + if part.startswith("gdn") and len(part) > 3: + count = int(part[3:]) + result.append(("gdn", count)) + elif part.startswith("mamba") and len(part) > 5: + count = int(part[5:]) + result.append(("mamba", count)) + elif part == "swa": + # Check if next token is "shared" + if i + 1 < len(parts) and parts[i + 1] == "shared": + result.append(("swa_shared", 1)) + i += 1 + else: + result.append(("swa", 1)) + elif part == "shared": + # Already consumed by swa check above + pass + i += 1 + return result + + +class HybridGDN(nn.Module): + """Hybrid GDN architecture supporting mixed recurrent/attention layers. + + Builds a stack of blocks according to the layer_layout specification: + - "gdn" blocks use GatedDeltaNet (or GatedDeltaProduct, or RWKV-7) + - "mamba" blocks use Mamba-2 + - "swa" blocks use SlidingWindowAttention + - "swa_shared" reuses the same SWA module (Zamba-style weight sharing) + + All models share: token embedding, bigram hash, smear gate, final norm, lm_head. + """ + def __init__(self, config: dict, vocab_size: int = 1024): + super().__init__() + dim = config["model_dim"] + num_heads = config["num_heads"] + mlp_mult = config["mlp_mult"] + self.arch_name = config["arch_name"] + self.model_dim = dim + self.vocab_size = vocab_size + self.logit_softcap = 30.0 + + # Embeddings + self.tok_emb = nn.Embedding(vocab_size, dim) + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=0.005) + self.bigram = BigramHashEmbedding( + config.get("bigram_vocab_size", 2048), + config.get("bigram_dim", 128), + dim, + trigram=config.get("trigram", False), + ) + self.smear = SmearGate(dim) + + # Meta tokens (Hymba-style, for Model E) + n_meta = config.get("meta_tokens", 0) + if n_meta > 0: + self.meta_tokens = nn.Parameter(torch.randn(1, n_meta, dim) * 0.02) + self.n_meta = n_meta + else: + self.meta_tokens = None + self.n_meta = 0 + + # Build layer stack + layout = _parse_layout(config["layer_layout"]) + self.blocks = nn.ModuleList() + self._block_types = [] # track type for XSA/diagnostics + self._shared_swa = None # shared SWA module for Zamba/Hymba models + + layer_idx = 0 + for layer_type, count in layout: + if count == -1: + # Fill with the specified layer type + if layer_type == "gdn": + count = config["num_gdn_layers"] + elif layer_type == "mamba": + count = config["num_mamba_layers"] + elif layer_type == "swa": + count = config["num_swa_layers"] + + for _ in range(count): + if layer_type == "gdn": + recurrent = self._make_recurrent_layer(config, layer_idx) + block = RecurrentBlock(dim, recurrent, mlp_mult, layer_idx=layer_idx) + self.blocks.append(block) + self._block_types.append("gdn") + + elif layer_type == "mamba": + mamba_expand = config.get("mamba_expand", 2) + mamba_head_dim = config.get("gdn_head_dim", 64) + mamba_num_heads = (dim * mamba_expand) // mamba_head_dim + mamba = Mamba2( + num_heads=mamba_num_heads, + head_dim=mamba_head_dim, + hidden_size=dim, + state_size=config.get("mamba_state_size", 64), + expand=mamba_expand, + layer_idx=layer_idx, + ) + block = RecurrentBlock(dim, mamba, mlp_mult, layer_idx=layer_idx) + self.blocks.append(block) + self._block_types.append("mamba") + + elif layer_type in ("swa", "swa_shared"): + if layer_type == "swa_shared" and self._shared_swa is not None: + swa = self._shared_swa # reuse same SWA module + else: + swa = SlidingWindowAttention( + dim=dim, + num_heads=num_heads, + num_kv_heads=config.get("swa_num_kv_heads", 4), + window_size=config.get("swa_window", 512), + ) + if config.get("swa_shared", False): + self._shared_swa = swa + + # Each SWA position gets its own MLP even if SWA weights are shared + block = AttentionBlock(dim, swa, mlp_mult, layer_idx=layer_idx) + self.blocks.append(block) + self._block_types.append("swa" if layer_type == "swa" else "swa_shared") + + layer_idx += 1 + + # KV sharing: share k/v projections between adjacent layers + kv_stride = config.get("kv_sharing_stride", 0) + if kv_stride > 0: + self._apply_kv_sharing(kv_stride) + + self.final_norm = RMSNorm(dim) + # Tied embeddings (standard for parameter golf) + self.lm_head = None # use tok_emb.weight + self._init_weights() + + def _make_recurrent_layer(self, config: dict, layer_idx: int) -> nn.Module: + """Create the appropriate recurrent layer based on config.""" + dim = config["model_dim"] + num_heads = config["num_heads"] + + if config.get("use_rwkv7", False): + total_layers = config.get("num_gdn_layers", 11) + return RWKV7Attention( + hidden_size=dim, + head_dim=config.get("gdn_head_dim", 64), + num_heads=num_heads, + layer_idx=layer_idx, + num_hidden_layers=total_layers, + mode="chunk", + ) + elif config.get("use_deltaproduct", False): + return GatedDeltaProduct( + hidden_size=dim, + head_dim=config.get("gdn_head_dim", 64), + num_heads=num_heads, + num_householder=config.get("dp_num_householder", 2), + allow_neg_eigval=config.get("dp_allow_neg_eigval", False), + use_short_conv=config.get("gdn_use_short_conv", True), + expand_v=config.get("gdn_expand_v", 1), + layer_idx=layer_idx, + mode="chunk", + ) + else: + # Default: GatedDeltaNet + return GatedDeltaNet( + hidden_size=dim, + head_dim=config.get("gdn_head_dim", 64), + num_heads=num_heads, + allow_neg_eigval=config.get("gdn_allow_neg_eigval", False), + use_short_conv=config.get("gdn_use_short_conv", True), + expand_v=config.get("gdn_expand_v", 1), + layer_idx=layer_idx, + mode="chunk", + ) + + def _apply_kv_sharing(self, stride: int) -> None: + """Share KV projection modules between adjacent layer groups. + + For GDN layers: shares k_proj, v_proj, k_conv1d, v_conv1d. + For SWA layers: shares c_k, c_v. + The first layer in each group is the anchor; subsequent layers in the + group become followers that reference the anchor's modules. + """ + # Collect indices by block type + gdn_indices = [i for i, t in enumerate(self._block_types) if t == "gdn"] + swa_indices = [i for i, t in enumerate(self._block_types) + if t in ("swa", "swa_shared")] + + # Share GDN KV projections within each stride-group + for group_start in range(0, len(gdn_indices), stride): + anchor_idx = gdn_indices[group_start] + anchor = self.blocks[anchor_idx].recurrent + for j in range(1, stride): + if group_start + j >= len(gdn_indices): + break + follower_idx = gdn_indices[group_start + j] + follower = self.blocks[follower_idx].recurrent + follower.k_proj = anchor.k_proj + follower.v_proj = anchor.v_proj + follower.k_conv1d = anchor.k_conv1d + follower.v_conv1d = anchor.v_conv1d + + # Share SWA KV projections within each stride-group + for group_start in range(0, len(swa_indices), stride): + anchor_idx = swa_indices[group_start] + anchor = self.blocks[anchor_idx].attn + for j in range(1, stride): + if group_start + j >= len(swa_indices): + break + follower_idx = swa_indices[group_start + j] + follower = self.blocks[follower_idx].attn + follower.c_k = anchor.c_k + follower.c_v = anchor.c_v + + def _init_weights(self) -> None: + """Weight initialization. + + Each sub-module handles its own init (MLP zeros proj, SWA zeros proj, + FLA layers do own init). We just do the residual scaling for output + projections on our own CastedLinear layers. + """ + total_layers = len(self.blocks) + for name, p in self.named_parameters(): + # Skip FLA-internal parameters + if ".recurrent." in name: + continue + # Scale down output projections for residual stream + if p.ndim == 2 and "proj" in name and "bigram" not in name: + with torch.no_grad(): + p.mul_(1.0 / math.sqrt(2 * total_layers)) + + def set_xsa(self, enable: bool = True) -> None: + """Enable/disable XSA on all attention blocks.""" + for block, btype in zip(self.blocks, self._block_types): + if btype in ("swa", "swa_shared"): + block.attn.use_xsa = enable + + def _compute_logits(self, x: Tensor) -> Tensor: + """Compute logits with tied embeddings and softcap.""" + logits = F.linear(x, self.tok_emb.weight) + return self.logit_softcap * torch.tanh(logits / self.logit_softcap) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + """Forward pass returning cross-entropy loss.""" + x = self.tok_emb(input_ids) + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + + # Prepend meta tokens if Hymba-style + if self.meta_tokens is not None: + B = x.shape[0] + meta = self.meta_tokens.expand(B, -1, -1).to(dtype=x.dtype) + x = torch.cat([meta, x], dim=1) + x0 = torch.cat([meta, x0], dim=1) + + for block, btype in zip(self.blocks, self._block_types): + if btype in ("swa", "swa_shared"): + x = block(x, x0) + else: + x = block(x, x0) + + # Remove meta tokens before computing logits + if self.meta_tokens is not None: + x = x[:, self.n_meta:] + + x = self.final_norm(x) + logits = self._compute_logits(x.reshape(-1, x.size(-1))) + targets = target_ids.reshape(-1) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + def forward_logits(self, input_ids: Tensor) -> Tensor: + """Forward pass returning logits (for evaluation).""" + x = self.tok_emb(input_ids) + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + + if self.meta_tokens is not None: + B = x.shape[0] + meta = self.meta_tokens.expand(B, -1, -1).to(dtype=x.dtype) + x = torch.cat([meta, x], dim=1) + x0 = torch.cat([meta, x0], dim=1) + + for block, btype in zip(self.blocks, self._block_types): + if btype in ("swa", "swa_shared"): + x = block(x, x0) + else: + x = block(x, x0) + + if self.meta_tokens is not None: + x = x[:, self.n_meta:] + + x = self.final_norm(x) + return self._compute_logits(x) + + def get_diagnostics(self) -> dict: + """Collect per-layer weight statistics for checkpoint diagnostics.""" + diag = {} + for i, (block, btype) in enumerate(zip(self.blocks, self._block_types)): + prefix = f"layer_{i}_{btype}" + for name, param in block.named_parameters(): + if param.ndim >= 2: + w = param.data.float() + diag[f"{prefix}/{name}/std"] = w.std().item() + diag[f"{prefix}/{name}/kurtosis"] = (((w - w.mean()) / (w.std() + 1e-8)) ** 4).mean().item() - 3.0 + return diag + + def count_params(self) -> dict: + """Count parameters by category.""" + cats = {"embedding": 0, "recurrent": 0, "attention": 0, "mlp": 0, "other": 0} + for name, p in self.named_parameters(): + n = p.numel() + if "tok_emb" in name or "bigram" in name: + cats["embedding"] += n + elif any(k in name for k in ["recurrent", "gdn", "mamba", "rwkv", "delta"]): + cats["recurrent"] += n + elif "attn" in name or "c_q" in name or "c_k" in name or "c_v" in name: + cats["attention"] += n + elif "mlp" in name or "fc" in name: + cats["mlp"] += n + else: + cats["other"] += n + cats["total"] = sum(cats.values()) + return cats diff --git a/records/track_10min_16mb/2026-04-17_KKVShareWider_FLA_Opensens/configs.py b/records/track_10min_16mb/2026-04-17_KKVShareWider_FLA_Opensens/configs.py new file mode 100644 index 0000000000..5bbdac3bd4 --- /dev/null +++ b/records/track_10min_16mb/2026-04-17_KKVShareWider_FLA_Opensens/configs.py @@ -0,0 +1,316 @@ +"""Model architecture configurations for GDN Hybrid experiments. + +Each config returns a dict consumed by HybridGDN.__init__. +All models are sized to fit ~16MB at int6+zstd-22. + +Models A-H: baseline architecture sweeps. +Models I-K: KV sharing experiments (kv_sharing_stride=2). +""" +from __future__ import annotations + + +def model_a_pure_gdn() -> dict: + """Model A: Pure GDN (Baseline) — 10 layers Gated DeltaNet.""" + return dict( + arch_name="A_PureGDN", + num_gdn_layers=10, + num_mamba_layers=0, + num_swa_layers=0, + swa_shared=False, + model_dim=512, + num_heads=8, + mlp_mult=3.0, + gdn_expand_v=1, + gdn_head_dim=64, + gdn_allow_neg_eigval=False, + gdn_use_short_conv=True, + swa_window=512, + swa_num_kv_heads=4, + meta_tokens=0, + layer_layout="gdn_only", + bigram_vocab_size=3072, + bigram_dim=112, + trigram=True, + ) + + +def model_b_deltaproduct() -> dict: + """Model B: Gated DeltaProduct n_h=2 — rank-2 state transitions.""" + return dict( + arch_name="B_DeltaProduct", + num_gdn_layers=10, # 10 layers to fit param budget + num_mamba_layers=0, + num_swa_layers=0, + swa_shared=False, + model_dim=480, # slightly narrower to fit 16MB + num_heads=8, + mlp_mult=3.0, + use_deltaproduct=True, + dp_num_householder=2, + dp_allow_neg_eigval=False, + gdn_expand_v=1, + gdn_head_dim=64, + gdn_use_short_conv=True, + swa_window=512, + swa_num_kv_heads=4, + meta_tokens=0, + layer_layout="gdn_only", + ) + + +def model_b2_deltaproduct_neg() -> dict: + """Model B2: DeltaProduct + negative eigenvalues.""" + cfg = model_b_deltaproduct() + cfg["arch_name"] = "B2_DeltaProduct_NegEig" + cfg["dp_allow_neg_eigval"] = True + return cfg + + +def model_c_gdn_neg() -> dict: + """Model C: GDN with negative eigenvalues — richer state dynamics. + + (Originally RWKV-7, replaced because RWKV7 requires Triton kernels with + no pure-PyTorch fallback available.) + """ + return dict( + arch_name="C_GDN_NegEig", + num_gdn_layers=11, + num_mamba_layers=0, + num_swa_layers=0, + swa_shared=False, + model_dim=512, + num_heads=8, + mlp_mult=3.0, + gdn_expand_v=1, + gdn_head_dim=64, + gdn_allow_neg_eigval=True, # Key difference: negative eigenvalues + gdn_use_short_conv=True, + swa_window=512, + swa_num_kv_heads=4, + meta_tokens=0, + layer_layout="gdn_only", + ) + + +def model_d_gdn_1swa() -> dict: + """Model D: GDN + 1 Shared SWA (Zamba-style).""" + return dict( + arch_name="D_GDN_1SWA", + num_gdn_layers=10, + num_mamba_layers=0, + num_swa_layers=1, + swa_shared=True, + model_dim=512, + num_heads=8, + mlp_mult=3.0, + gdn_expand_v=1, + gdn_head_dim=64, + gdn_allow_neg_eigval=False, + gdn_use_short_conv=True, + swa_window=512, + swa_num_kv_heads=4, + meta_tokens=0, + # Layout: [GDN×5] → [SWA] → [GDN×5] → [SWA_shared] + layer_layout="gdn5_swa_gdn5_swa_shared", + ) + + +def model_e_gdn_2swa() -> dict: + """Model E: GDN + 2 Shared SWA (Hymba-inspired) with meta-tokens.""" + return dict( + arch_name="E_GDN_2SWA_Hymba", + num_gdn_layers=9, + num_mamba_layers=0, + num_swa_layers=1, # 1 unique, shared at 2 positions + swa_shared=True, + model_dim=512, + num_heads=8, + mlp_mult=3.0, + gdn_expand_v=1, + gdn_head_dim=64, + gdn_allow_neg_eigval=False, + gdn_use_short_conv=True, + swa_window=512, + swa_num_kv_heads=4, + meta_tokens=4, # Hymba-style prepended meta-tokens + # Layout: [GDN×3] → [SWA] → [GDN×3] → [SWA_shared] → [GDN×3] + layer_layout="gdn3_swa_gdn3_swa_shared_gdn3", + ) + + +def model_f_mamba2() -> dict: + """Model F: Mamba-2 Pure (Mamba-3 proxy with RoPE on B/C).""" + return dict( + arch_name="F_Mamba2", + num_gdn_layers=0, + num_mamba_layers=11, + num_swa_layers=0, + swa_shared=False, + model_dim=512, + num_heads=8, + mlp_mult=3.0, + mamba_state_size=64, + mamba_expand=2, + gdn_expand_v=1, + gdn_head_dim=64, + gdn_use_short_conv=True, + swa_window=512, + swa_num_kv_heads=4, + meta_tokens=0, + layer_layout="mamba_only", + ) + + +def model_g_hybrid() -> dict: + """Model G: GDN + Mamba-2 + SWA triple hybrid.""" + return dict( + arch_name="G_GDN_Mamba_SWA", + num_gdn_layers=6, + num_mamba_layers=4, + num_swa_layers=1, + swa_shared=False, + model_dim=512, + num_heads=8, + mlp_mult=3.0, + mamba_state_size=64, + mamba_expand=2, + gdn_expand_v=1, + gdn_head_dim=64, + gdn_allow_neg_eigval=False, + gdn_use_short_conv=True, + swa_window=512, + swa_num_kv_heads=4, + meta_tokens=0, + # Layout: [GDN×3] → [Mamba×2] → [SWA] → [GDN×3] → [Mamba×2] + layer_layout="gdn3_mamba2_swa_gdn3_mamba2", + ) + + +def model_h_pure_swa() -> dict: + """Model H: Pure Sliding Window Attention (standard softmax) — control baseline. + + All 10 layers use causal sliding-window softmax attention (no GDN). + Same MLP, embedding, and normalization as Model A for fair comparison. + """ + return dict( + arch_name="H_PureSWA", + num_gdn_layers=0, + num_mamba_layers=0, + num_swa_layers=10, + swa_shared=False, + model_dim=512, + num_heads=8, + mlp_mult=3.0, + gdn_expand_v=1, + gdn_head_dim=64, + gdn_allow_neg_eigval=False, + gdn_use_short_conv=True, + swa_window=512, + swa_num_kv_heads=4, + meta_tokens=0, + layer_layout="swa_only", + bigram_vocab_size=3072, + bigram_dim=112, + trigram=True, + ) + + +def model_i_kv_share() -> dict: + """Model I: GDN + KV Share — same as A but with kv_sharing_stride=2.""" + return dict( + arch_name="I_KVShare", + num_gdn_layers=10, + num_mamba_layers=0, + num_swa_layers=0, + swa_shared=False, + model_dim=512, + num_heads=8, + mlp_mult=3.0, + gdn_expand_v=1, + gdn_head_dim=64, + gdn_allow_neg_eigval=False, + gdn_use_short_conv=True, + swa_window=512, + swa_num_kv_heads=4, + meta_tokens=0, + layer_layout="gdn_only", + bigram_vocab_size=3072, + bigram_dim=112, + trigram=True, + kv_sharing_stride=2, + ) + + +def model_j_kv_share_deeper() -> dict: + """Model J: GDN + KV Share + Deeper — 12L dim=480, near iso-parameter to A.""" + return dict( + arch_name="J_KVShare_Deeper", + num_gdn_layers=12, + num_mamba_layers=0, + num_swa_layers=0, + swa_shared=False, + model_dim=480, + num_heads=8, + mlp_mult=3.0, + gdn_expand_v=1, + gdn_head_dim=64, + gdn_allow_neg_eigval=False, + gdn_use_short_conv=True, + swa_window=512, + swa_num_kv_heads=4, + meta_tokens=0, + layer_layout="gdn_only", + bigram_vocab_size=3072, + bigram_dim=112, + trigram=True, + kv_sharing_stride=2, + ) + + +def model_k_kv_share_wider() -> dict: + """Model K: GDN + KV Share + Wider — 10L dim=544, iso-parameter to A.""" + return dict( + arch_name="K_KVShare_Wider", + num_gdn_layers=10, + num_mamba_layers=0, + num_swa_layers=0, + swa_shared=False, + model_dim=544, + num_heads=8, + mlp_mult=3.0, + gdn_expand_v=1, + gdn_head_dim=64, + gdn_allow_neg_eigval=False, + gdn_use_short_conv=True, + swa_window=512, + swa_num_kv_heads=4, + meta_tokens=0, + layer_layout="gdn_only", + bigram_vocab_size=3072, + bigram_dim=112, + trigram=True, + kv_sharing_stride=2, + ) + + +ALL_CONFIGS = { + "A": model_a_pure_gdn, + "B": model_b_deltaproduct, + "B2": model_b2_deltaproduct_neg, + "C": model_c_gdn_neg, + "D": model_d_gdn_1swa, + "E": model_e_gdn_2swa, + "F": model_f_mamba2, + "G": model_g_hybrid, + "H": model_h_pure_swa, + "I": model_i_kv_share, + "J": model_j_kv_share_deeper, + "K": model_k_kv_share_wider, +} + + +def get_config(model_id: str) -> dict: + """Get config by model ID (A, B, B2, C, D, E, F, G, H, I, J, K).""" + if model_id not in ALL_CONFIGS: + raise ValueError(f"Unknown model ID '{model_id}'. Choose from {list(ALL_CONFIGS.keys())}") + return ALL_CONFIGS[model_id]() diff --git a/records/track_10min_16mb/2026-04-17_KKVShareWider_FLA_Opensens/requirements.txt b/records/track_10min_16mb/2026-04-17_KKVShareWider_FLA_Opensens/requirements.txt new file mode 100644 index 0000000000..3feaed4f64 --- /dev/null +++ b/records/track_10min_16mb/2026-04-17_KKVShareWider_FLA_Opensens/requirements.txt @@ -0,0 +1,10 @@ +numpy +torch +sentencepiece +zstandard +flash-linear-attention==0.4.2 +fla-core==0.4.2 +triton==3.2.0 +transformers==5.5.4 +tokenizers==0.22.2 +safetensors==0.7.0 diff --git a/records/track_10min_16mb/2026-04-17_KKVShareWider_FLA_Opensens/submission.json b/records/track_10min_16mb/2026-04-17_KKVShareWider_FLA_Opensens/submission.json new file mode 100644 index 0000000000..4d08e1cd3c --- /dev/null +++ b/records/track_10min_16mb/2026-04-17_KKVShareWider_FLA_Opensens/submission.json @@ -0,0 +1,38 @@ +{ + "author": "genji0306", + "github_id": "genji0306", + "name": "K_KVShare_Wider FLA (Opensens reproduction)", + "blurb": "Independent 3-seed reproduction of GatedDeltaNet K_KVShare_Wider on 8xH100 SXM (secure cloud). Builds on PR #1687 (resouer). No TTT, no SLOT, no n-gram overlay.", + "date": "2026-04-17", + "track": "10min_16mb", + "val_loss": 3.14335601, + "val_bpb": 1.03385760, + "val_loss_std": 0.00376, + "val_bpb_std": 0.00124, + "seeds": [42, 1337, 2025], + "seed_results": { + "42": { + "val_loss": 3.14766471, + "val_bpb": 1.03527246, + "artifact_bytes": 15927295, + "steps": 1881 + }, + "1337": { + "val_loss": 3.14154729, + "val_bpb": 1.03326043, + "artifact_bytes": 15830641, + "steps": 1890 + }, + "2025": { + "val_loss": 3.14086602, + "val_bpb": 1.03303636, + "artifact_bytes": 15893661, + "steps": 1884 + } + }, + "based_on": "PR #1687 (resouer) — K_KVShare_Wider FLA family", + "artifact_bytes_mean": 15883866, + "artifact_bytes_max": 15927295, + "hardware": "8xH100 80GB SXM (RunPod secure cloud, IN region)", + "technique_summary": "GatedDeltaNet K_KVShare_Wider + EMA + SWA + late QAT + int6 + zstd-22. Independent reproduction with improved results likely due to hardware variance." +} diff --git a/records/track_10min_16mb/2026-04-17_KKVShareWider_FLA_Opensens/train_gdn_7k.py b/records/track_10min_16mb/2026-04-17_KKVShareWider_FLA_Opensens/train_gdn_7k.py new file mode 100644 index 0000000000..eb030c6103 --- /dev/null +++ b/records/track_10min_16mb/2026-04-17_KKVShareWider_FLA_Opensens/train_gdn_7k.py @@ -0,0 +1,1172 @@ +#!/usr/bin/env python3 +"""GDN Hybrid Full Training Script — 7000 steps with all production features. + +Features beyond Phase 1 screening (train_gdn.py): + - EMA weight averaging (decay 0.997) + - SWA (stochastic weight averaging) during late warmdown + - Late QAT (int6 STE in CastedLinear forward during warmdown) + - Mixed int6/int8 quantization with percentile search + - zstd-22 compression for artifact + - Coprime shard ordering via SHARD_ORDER_FILE + - XSA-all sliding window eval on quantized artifact + - Roundtrip validation (load quantized, eval, report exact BPB) + +Environment variables (key additions vs Phase 1): + EMA_DECAY: EMA decay rate (default 0.997) + SWA_ENABLED: 1|0 (default 1) + SWA_EVERY: SWA collection interval (default 50) + LATE_QAT_THRESHOLD: LR scale below which QAT activates (default 0.15) + SHARD_ORDER_FILE: path to file with one shard path per line (coprime ordering) + MUON_MOMENTUM_WARMUP_START: starting momentum for warmup (default 0.85) + MUON_MOMENTUM_WARMUP_STEPS: steps to ramp momentum (default 500) +""" +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import sys +import time +import uuid +import zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +import zstandard +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +sys.path.insert(0, str(Path(__file__).resolve().parent)) +from architectures import HybridGDN, CastedLinear +from configs import get_config + + +# ─── Hyperparameters ────────────────────────────────────────────────────────── + +class Hyperparameters: + arch_mode = os.environ.get("ARCH_MODE", "A") + data_path = os.environ.get("DATA_PATH", "../data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "../data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 42)) + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + + # Training length + iterations = int(os.environ.get("ITERATIONS", 7000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 2100)) # 30% of 7k + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 786_432)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + eval_seq_len = int(os.environ.get("EVAL_SEQ_LEN", 1024)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 14100.0)) # 3h55m safety + + # Validation + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 500)) + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 100)) + save_every = int(os.environ.get("SAVE_EVERY", 1000)) + + # Optimizer + matrix_lr = float(os.environ.get("MATRIX_LR", 0.02)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.02)) + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.035)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) + muon_wd = float(os.environ.get("MUON_WD", 0.04)) + adam_wd = float(os.environ.get("ADAM_WD", 0.04)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + + # Eval + eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) + xsa_eval = bool(int(os.environ.get("XSA_EVAL", "0"))) # during training + eval_compile_enabled = bool(int(os.environ.get("EVAL_COMPILE_ENABLED", "1"))) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Checkpoint + ckpt_dir = os.environ.get("CKPT_DIR", "checkpoints") + + # Compile + compile_enabled = bool(int(os.environ.get("COMPILE_ENABLED", "1"))) + + # Resume from checkpoint + resume_ckpt = os.environ.get("RESUME_CKPT", "") + + # EMA / SWA + ema_decay = float(os.environ.get("EMA_DECAY", 0.997)) + swa_enabled = bool(int(os.environ.get("SWA_ENABLED", "1"))) + swa_every = int(os.environ.get("SWA_EVERY", 50)) + + # Late QAT + late_qat_threshold = float(os.environ.get("LATE_QAT_THRESHOLD", 0.15)) + + # Chained job support + auto_save_seconds = float(os.environ.get("AUTO_SAVE_SECONDS", "0")) + total_iterations = int(os.environ.get("TOTAL_ITERATIONS", "0")) # 0 = same as iterations + + +# ─── Data Loading ───────────────────────────────────────────────────────────── + +def load_data_shard(file: Path) -> Tensor: + header = np.fromfile(file, dtype=np.uint32, count=256) + assert header[0] == 20240520, f"Bad magic: {header[0]}" + assert header[1] in (1, 7), f"Bad version: {header[1]}" + ntok = int(header[2]) + return torch.from_numpy(np.fromfile(file, dtype=np.uint16, offset=256 * 4)[:ntok].astype(np.int64)) + + +class TokenStream: + """Reads shards sequentially, supports coprime ordering via SHARD_ORDER_FILE.""" + def __init__(self, pattern: str): + shard_order_file = os.environ.get("SHARD_ORDER_FILE", "") + if shard_order_file and os.path.exists(shard_order_file): + with open(shard_order_file) as f: + self.files = [Path(line.strip()) for line in f if line.strip()] + else: + self.files = [Path(p) for p in sorted(glob.glob(pattern))] + assert self.files, f"No files matching {pattern}" + self.idx = 0 + self.buf = load_data_shard(self.files[self.idx]) + self.pos = 0 + + def _advance_file(self) -> None: + self.idx = (self.idx + 1) % len(self.files) + self.buf = load_data_shard(self.files[self.idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + parts = [] + remaining = n + while remaining > 0: + avail = self.buf.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + take_n = min(avail, remaining) + parts.append(self.buf[self.pos:self.pos + take_n]) + self.pos += take_n + remaining -= take_n + return torch.cat(parts) + + +class DistributedTokenLoader: + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.stream = TokenStream(pattern) + self.rank = rank + self.world_size = world_size + self.device = device + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + tokens_per_rank = global_tokens // self.world_size + seqs_per_rank = tokens_per_rank // seq_len + total_seqs = seqs_per_rank * self.world_size + total_needed = total_seqs * seq_len + 1 + all_tokens = self.stream.take(total_needed) + start = self.rank * seqs_per_rank * seq_len + chunk = all_tokens[start:start + seqs_per_rank * seq_len + 1] + x = chunk[:-1].reshape(seqs_per_rank, seq_len) + y = chunk[1:].reshape(seqs_per_rank, seq_len) + return x.to(self.device), y.to(self.device) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = sorted(glob.glob(pattern)) + parts = [load_data_shard(Path(f)) for f in files] + combined = torch.cat(parts) + return combined[:((combined.numel() - 1) // seq_len) * seq_len + 1] + + +def build_sentencepiece_luts(sp, vocab_size, device): + base_bytes = torch.zeros(vocab_size, dtype=torch.float32, device=device) + has_space = torch.zeros(vocab_size, dtype=torch.bool, device=device) + is_boundary = torch.zeros(vocab_size, dtype=torch.bool, device=device) + for i in range(vocab_size): + piece = sp.id_to_piece(i) + raw = piece.encode("utf-8") + base_bytes[i] = len(raw) + if piece.startswith("\u2581"): + has_space[i] = True + base_bytes[i] = len(piece[1:].encode("utf-8")) + 1 + if sp.is_control(i) or sp.is_unknown(i): + is_boundary[i] = True + return base_bytes, has_space, is_boundary + + +def generate_coprime_shard_order(shard_files: list, seed: int = 42) -> list: + """Generate a coprime-stepping shard order for better data mixing. + + Instead of sequential 0,1,2,...,79, uses stride = coprime(N) to visit + all shards in a maximally-spread order. This ensures each training epoch + sees data from diverse shards rather than correlated sequential ones. + """ + n = len(shard_files) + if n <= 1: + return shard_files + + # Find a coprime stride that's roughly n/phi (golden ratio) + target = max(1, int(n / 1.618)) + stride = target + while math.gcd(stride, n) != 1: + stride += 1 + + rng = random.Random(seed) + start = rng.randint(0, n - 1) + order = [] + pos = start + for _ in range(n): + order.append(shard_files[pos]) + pos = (pos + stride) % n + return order + + +# ─── Muon Optimizer ────────────────────────────────────────────────────────── + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + if transposed: + X = X.T + return X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, + nesterov: bool = True, weight_decay: float = 0.0): + defaults = dict(lr=lr, momentum=momentum, backend_steps=backend_steps, + nesterov=nesterov, weight_decay=weight_decay) + super().__init__(params, defaults) + + def step(self, closure=None): + for group in self.param_groups: + lr = group["lr"] + momentum = group["momentum"] + nesterov = group["nesterov"] + wd = group.get("weight_decay", 0.0) + for p in group["params"]: + if p.grad is None: + continue + g = p.grad + state = self.state[p] + if len(state) == 0: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g + momentum * buf + else: + g = buf + if g.ndim == 2 and min(g.shape) >= 2: + g = zeropower_via_newtonschulz5(g, steps=group["backend_steps"]) + if wd > 0: + p.data.mul_(1.0 - lr * wd) + p.data.add_(g, alpha=-lr) + + +# ─── Evaluation ────────────────────────────────────────────────────────────── + +def eval_val_sliding( + model: nn.Module, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + rank: int, + world_size: int, + device: torch.device, + seq_len: int = 1024, + stride: int = 64, + batch_seqs: int = 128, + xsa_eval: bool = False, + compile_enabled: bool = True, +) -> tuple[float, float]: + """Score-first sliding window evaluation.""" + total_tokens = val_tokens.numel() - 1 + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= 1] + total_windows = len(window_starts) + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + base_model = model.module if hasattr(model, 'module') else model + if xsa_eval and hasattr(base_model, 'set_xsa'): + base_model.set_xsa(True) + + forward_fn = base_model.forward_logits + compiled_logits = forward_fn + if compile_enabled: + try: + compiled_logits = torch.compile(forward_fn, dynamic=False) + except Exception: + compiled_logits = forward_fn + + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = compiled_logits(x_batch) + + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), + reduction="none", + ).reshape(bsz, seq_len) + + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt = y_batch[i, s:wlen] + prev = x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + + val_loss = (loss_sum / token_count).item() + bits_per_token = val_loss / math.log(2.0) + tokens_per_byte = token_count.item() / byte_count.item() + + if xsa_eval and hasattr(base_model, 'set_xsa'): + base_model.set_xsa(False) + + model.train() + return val_loss, bits_per_token * tokens_per_byte + + +# ─── Quantization ──────────────────────────────────────────────────────────── + +# Control tensor patterns — kept at full precision during quantization +CONTROL_PATTERNS = ( + "resid_mix", "q_gain", "smear", "skip_weight", "attn_scale", "mlp_scale", +) + + +def generate_autoregressive_calib(model, device, num_seqs=64, seq_len=2048, + vocab_size=1024, temperature=0.8, batch_size=8, seed=42): + """Generate sequences autoregressively from the model for GPTQ calibration. + No external data accessed — fully self-contained.""" + model.eval() + rng = torch.Generator(device=device) + rng.manual_seed(seed) + all_tokens = [] + with torch.inference_mode(), torch.autocast(device_type="cuda", dtype=torch.bfloat16): + for batch_start in range(0, num_seqs, batch_size): + bs = min(batch_size, num_seqs - batch_start) + tokens = torch.randint(0, vocab_size, (bs, 1), device=device, generator=rng) + for pos in range(seq_len - 1): + logits = model.forward_logits(tokens) + next_logit = logits[:, -1, :] + probs = torch.softmax(next_logit / temperature, dim=-1) + next_tok = torch.multinomial(probs, 1, generator=rng) + tokens = torch.cat([tokens, next_tok], dim=1) + for i in range(bs): + all_tokens.append(tokens[i:i+1]) + return all_tokens + + +def collect_hessians_from_tokens(hessian_model, token_seqs, device): + """Collect H = X^T X from pre-generated token sequences.""" + hessians = {} + hooks = [] + for name, module in hessian_model.named_modules(): + if isinstance(module, CastedLinear): + param_name = name + ".weight" + cols = module.weight.shape[1] + hessians[param_name] = torch.zeros(cols, cols, dtype=torch.float32, device='cpu') + def make_hook(pname): + def hook_fn(module, input, output): + x = input[0].detach().float() + if x.ndim == 3: + x = x.reshape(-1, x.shape[-1]) + hessians[pname] += (x.T @ x).cpu() + return hook_fn + h = module.register_forward_hook(make_hook(param_name)) + hooks.append(h) + hessian_model.eval() + with torch.inference_mode(), torch.autocast(device_type="cuda", dtype=torch.bfloat16): + for seq in token_seqs: + x = seq[:, :-1].to(device) + y = seq[:, 1:].to(device) + hessian_model(x, y) + for h in hooks: + h.remove() + num_batches = len(token_seqs) + for name in hessians: + hessians[name] /= num_batches + return hessians + + +def quantize_int6_gptq(weight, hessian=None, clip_range=31, block_size=128): + """Full GPTQ: Hessian-aware int6 quantization with Cholesky error compensation. + If hessian is None, falls back to percentile search.""" + t32 = weight.float() + if t32.ndim != 2 or hessian is None: + return quantize_int6_per_row(t32) + rows, cols = t32.shape + H = hessian.float().clone() + dead = torch.diag(H) == 0 + H[dead, dead] = 1 + damp = 0.01 * torch.mean(torch.diag(H)) + H[torch.arange(cols), torch.arange(cols)] += damp + perm = torch.argsort(torch.diag(H), descending=True) + inv_perm = torch.argsort(perm) + W = t32[:, perm].clone() + W[:, dead[perm]] = 0 + H = H[perm][:, perm] + Hinv = torch.linalg.cholesky(H) + Hinv = torch.cholesky_inverse(Hinv) + Hinv = torch.linalg.cholesky(Hinv, upper=True) + best_q = None; best_scale = None; best_err = float('inf') + for pct in [0.9990, 0.9995, 0.9999, 0.99999, 1.0]: + if pct < 1.0: + row_clip = torch.quantile(t32.abs(), pct, dim=1) + else: + row_clip = t32.abs().amax(dim=1) + s = (row_clip / clip_range).clamp_min(1.0 / clip_range).to(torch.float16) + sf = s.float() + Q = torch.zeros_like(W, dtype=torch.int8) + W_work = W.clone() + for i1 in range(0, cols, block_size): + i2 = min(i1 + block_size, cols) + count = i2 - i1 + W1 = W_work[:, i1:i2].clone() + Q1 = torch.zeros(rows, count, dtype=torch.int8) + Err1 = torch.zeros(rows, count) + Hinv1 = Hinv[i1:i2, i1:i2] + for i in range(count): + w = W1[:, i] + d = Hinv1[i, i] + q = torch.clamp(torch.round(w / sf), -clip_range, clip_range).to(torch.int8) + Q1[:, i] = q + err = (w - q.float() * sf) / d + W1[:, i:] -= err.unsqueeze(1) * Hinv1[i, i:].unsqueeze(0) + Err1[:, i] = err + Q[:, i1:i2] = Q1 + if i2 < cols: + W_work[:, i2:] -= Err1 @ Hinv[i1:i2, i2:] + recon = Q.float() * sf[:, None] + mse = (W - recon).pow(2).mean().item() + if mse < best_err: + best_q, best_scale, best_err = Q, s, mse + best_q = best_q[:, inv_perm] + return best_q, best_scale + + +def quantize_int6_per_row(t: Tensor, clip_range: int = 31) -> tuple[Tensor, Tensor]: + """Int6 quantization with percentile search for optimal clipping.""" + t32 = t.float() + if t32.ndim == 2: + best_q, best_s, best_err = None, None, float('inf') + for pct in [0.9990, 0.9995, 0.9999, 0.99999, 1.0]: + if pct < 1.0: + row_clip = torch.quantile(t32.abs(), pct, dim=1) + else: + row_clip = t32.abs().amax(dim=1) + s = (row_clip / clip_range).clamp_min(1.0 / clip_range).to(torch.float16) + q = torch.clamp(torch.round(t32 / s.float()[:, None]), -clip_range, clip_range).to(torch.int8) + recon = q.float() * s.float()[:, None] + err = (t32 - recon).pow(2).mean().item() + if err < best_err: + best_q, best_s, best_err = q, s, err + return best_q, best_s + # 1D: simple per-tensor quantization + amax = t32.abs().max().item() + scale = torch.tensor(amax / clip_range if amax > 0 else 1.0, dtype=torch.float16) + q = torch.clamp(torch.round(t32 / scale.float()), -clip_range, clip_range).to(torch.int8) + return q, scale + + +def quantize_int8_per_row(t: Tensor) -> tuple[Tensor, Tensor]: + """Int8 quantization with percentile clipping.""" + t32 = t.float() + clip_q = 0.9999984 + if t32.ndim == 2: + clip_abs = torch.quantile(t32.abs(), clip_q, dim=1) if t32.numel() else torch.empty((t32.shape[0],), dtype=torch.float32) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0).to(torch.float16) + q = torch.clamp(torch.round(clipped / scale.float()[:, None]), -127, 127).to(torch.int8) + return q, scale + clip_abs = float(torch.quantile(t32.abs().flatten(), clip_q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float16) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale.float()), -127, 127).to(torch.int8) + return q, scale + + +def mixed_quantize(state_dict: dict[str, Tensor], hessians: dict[str, Tensor] | None = None) -> tuple[dict[str, Tensor], dict[str, object]]: + """Mixed int6 (large weights) / int8 (small weights) quantization.""" + result: dict[str, Tensor] = {} + meta: dict[str, object] = {} + for name, tensor in state_dict.items(): + t = tensor.detach().cpu().contiguous() + # Control tensors: keep fp16 passthrough + if any(p in name for p in CONTROL_PATTERNS): + result[name] = t.to(torch.float16) if t.is_floating_point() else t + meta[name] = "passthrough" + continue + # Non-float: passthrough + if not t.is_floating_point(): + result[name] = t + meta[name] = "passthrough" + continue + # Small tensors: fp16 passthrough + if t.numel() <= 65536: + result[name] = t.to(torch.float16) + meta[name] = "passthrough" + continue + # Large 2D weights: int6 (6-bit quantization for better compression) + if t.ndim == 2 and t.numel() > 65536: + H = hessians.get(name) if hessians else None + q, s = quantize_int6_gptq(t, hessian=H) if H is not None else quantize_int6_per_row(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int6"} + else: + # Other float tensors: int8 + q, s = quantize_int8_per_row(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int8"} + return result, meta + + +def dequantize_mixed(result: dict[str, Tensor], meta: dict[str, object], + template_sd: dict[str, Tensor]) -> dict[str, Tensor]: + """Dequantize mixed int6/int8 back to float.""" + out: dict[str, Tensor] = {} + for name, orig in template_sd.items(): + info = meta.get(name) + if info is None: + continue + orig_dtype = orig.dtype + if info == "passthrough": + t = result[name] + if t.dtype == torch.float16 and orig_dtype in (torch.float32, torch.bfloat16): + t = t.to(orig_dtype) + out[name] = t + continue + q, s = result[name + ".q"], result[name + ".scale"] + if s.ndim > 0: + out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig_dtype) + else: + out[name] = (q.float() * float(s.item())).to(orig_dtype) + return out + + +# ─── Checkpoint Saving ─────────────────────────────────────────────────────── + +def save_checkpoint(model, step, val_bpb, ckpt_dir, arch_name, seed): + base = model.module if hasattr(model, 'module') else model + ckpt = { + "step": step, "val_bpb": val_bpb, + "arch_name": arch_name, "seed": seed, + "model_state_dict": base.state_dict(), + } + os.makedirs(ckpt_dir, exist_ok=True) + path = os.path.join(ckpt_dir, f"{arch_name}_step{step}_seed{seed}.pt") + torch.save(ckpt, path) + return path + + +def save_full_checkpoint(model, step, val_bpb, ckpt_dir, arch_name, seed, + muon_opt, adam_opt, ema_state, swa_state, swa_count, + qat_enabled, rng_states=None, stream_state=None): + """Save complete training state for chained job resume.""" + base = model.module if hasattr(model, 'module') else model + ckpt = { + "step": step, "val_bpb": val_bpb, + "arch_name": arch_name, "seed": seed, + "model_state_dict": {k: v.cpu() for k, v in base.state_dict().items()}, + "muon_opt_state": muon_opt.state_dict(), + "adam_opt_state": adam_opt.state_dict(), + "ema_state": {k: v.cpu() for k, v in ema_state.items()}, + "swa_state": {k: v.cpu() for k, v in swa_state.items()} if swa_state is not None else None, + "swa_count": swa_count, + "qat_enabled": qat_enabled, + } + if rng_states is not None: + ckpt["rng_states"] = rng_states + if stream_state is not None: + ckpt["stream_state"] = stream_state + os.makedirs(ckpt_dir, exist_ok=True) + path = os.path.join(ckpt_dir, f"full_ckpt_step{step}_seed{seed}.pt") + torch.save(ckpt, path) + return path + + +def _find_latest_full_ckpt(ckpt_dir): + """Find the latest full_ckpt_step*.pt file in ckpt_dir by step number.""" + pattern = os.path.join(ckpt_dir, "full_ckpt_step*_seed*.pt") + files = glob.glob(pattern) + if not files: + return None + import re + step_re = re.compile(r"full_ckpt_step(\d+)_seed") + best_step, best_path = -1, None + for f in files: + m = step_re.search(os.path.basename(f)) + if m: + s = int(m.group(1)) + if s > best_step: + best_step, best_path = s, f + return best_path + + +# ─── Main Training Loop ───────────────────────────────────────────────────── + +def main(): + global zeropower_via_newtonschulz5 + args = Hyperparameters() + config = get_config(args.arch_mode) + + # Distributed setup + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + grad_accum_steps = max(1, 8 // world_size) + master_process = rank == 0 + + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + + if args.compile_enabled: + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # Logging + os.makedirs("logs", exist_ok=True) + os.makedirs(args.ckpt_dir, exist_ok=True) + logfile = f"logs/{args.run_id}.txt" if master_process else None + + def log0(msg: str, console: bool = True): + if not master_process: + return + if console: + print(msg, flush=True) + if logfile: + with open(logfile, "a") as f: + print(msg, file=f) + + # Seeds + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + log0(f"=== GDN Hybrid 7k Full Training ===") + log0(f"Arch: {config['arch_name']} (ARCH_MODE={args.arch_mode})") + log0(f"Seed: {args.seed}, Steps: {args.iterations}, Warmdown: {args.warmdown_iters}") + log0(f"World size: {world_size}, Grad accum: {grad_accum_steps}") + log0(f"EMA decay: {args.ema_decay}, SWA: {args.swa_enabled} (every {args.swa_every})") + log0(f"Late QAT threshold: {args.late_qat_threshold}") + log0(f"Eval compile enabled: {args.eval_compile_enabled}") + + # Tokenizer + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + assert int(sp.vocab_size()) == args.vocab_size + + # Validation data + val_tokens = load_validation_tokens(args.val_files, args.eval_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"Validation tokens: {val_tokens.numel()-1:,}") + + # Build model + _t0 = time.time() + model = HybridGDN(config, args.vocab_size) + model = model.to(device).bfloat16() + log0(f"Model built in {time.time()-_t0:.1f}s") + + for module in model.modules(): + if isinstance(module, CastedLinear): + module.float() + for name, p in model.named_parameters(): + if p.ndim <= 1: + p.data = p.data.float() + + param_counts = model.count_params() + log0(f"Parameters: {param_counts}") + log0(f"Total params: {param_counts['total']:,}") + + # Resume from checkpoint if specified + start_step = 0 + resume_state = None # holds full checkpoint data for deferred restore + resume_ckpt_path = args.resume_ckpt + if resume_ckpt_path == "auto": + resume_ckpt_path = _find_latest_full_ckpt(args.ckpt_dir) or "" + if resume_ckpt_path: + log0(f"Auto-detected resume checkpoint: {resume_ckpt_path}") + else: + log0("Auto-resume: no full checkpoint found, starting fresh") + if resume_ckpt_path and os.path.exists(resume_ckpt_path): + log0(f"Resuming from checkpoint: {resume_ckpt_path}") + ckpt = torch.load(resume_ckpt_path, map_location="cpu", weights_only=False) + base_sd = ckpt["model_state_dict"] + model.load_state_dict({k: v.to(device) for k, v in base_sd.items()}, strict=True) + start_step = ckpt.get("step", 0) + log0(f"Resumed model at step {start_step}, val_bpb={ckpt.get('val_bpb', 'N/A')}") + # Keep full checkpoint for deferred optimizer/EMA/SWA restore + if "muon_opt_state" in ckpt: + resume_state = ckpt + log0(" Full checkpoint detected — will restore optimizers, EMA, SWA, RNG") + else: + log0(" Lightweight checkpoint — model only") + del ckpt + + # DDP + base_model = model # keep reference before wrapping + if distributed: + model = DDP(model, device_ids=[local_rank], find_unused_parameters=True) + + # Optimizer setup + matrix_params = [] + scalar_params = [] + embed_params = [] + for name, p in base_model.named_parameters(): + if not p.requires_grad: + continue + if "tok_emb" in name: + embed_params.append(p) + elif p.ndim == 2 and min(p.shape) >= 2: + matrix_params.append(p) + else: + scalar_params.append(p) + + log0(f"Matrix params: {sum(p.numel() for p in matrix_params):,}") + log0(f"Scalar params: {sum(p.numel() for p in scalar_params):,}") + log0(f"Embed params: {sum(p.numel() for p in embed_params):,}") + + muon_opt = Muon( + matrix_params, lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + weight_decay=args.muon_wd, + ) + adam_opt = torch.optim.AdamW( + [{"params": scalar_params, "lr": args.scalar_lr}, + {"params": embed_params, "lr": args.tied_embed_lr}], + betas=(args.beta1, args.beta2), + weight_decay=args.adam_wd, + fused=True, + ) + + # Deferred restore: optimizer states (must happen after optimizer creation) + if resume_state is not None: + muon_opt.load_state_dict(resume_state["muon_opt_state"]) + adam_opt.load_state_dict(resume_state["adam_opt_state"]) + log0(" Restored optimizer states (Muon + Adam)") + + # Data loader — coprime shard ordering + shard_order_file = os.environ.get("SHARD_ORDER_FILE", "") + if not shard_order_file: + # Generate coprime ordering on-the-fly + shard_files = sorted(glob.glob(args.train_files)) + if shard_files: + ordered = generate_coprime_shard_order(shard_files, seed=args.seed) + shard_order_path = f"/tmp/shard_order_{args.run_id}.txt" + with open(shard_order_path, "w") as f: + for sf in ordered: + f.write(str(sf) + "\n") + os.environ["SHARD_ORDER_FILE"] = shard_order_path + log0(f"Generated coprime shard order: stride across {len(shard_files)} shards") + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # LR schedule with warmdown (cosine) + def lr_schedule(step: int) -> float: + warmdown_start = args.iterations - args.warmdown_iters + if step < args.warmup_steps: + return step / max(1, args.warmup_steps) + elif step >= warmdown_start: + progress = (step - warmdown_start) / args.warmdown_iters + return 0.5 * (1.0 + math.cos(math.pi * progress)) + return 1.0 + + # ─── EMA + SWA state ───────────────────────────────────────────────── + ema_state = {name: t.detach().float().clone() for name, t in base_model.state_dict().items()} + swa_state: dict[str, Tensor] | None = None + swa_count = 0 + + # Deferred restore: EMA, SWA, QAT, RNG, stream state + if resume_state is not None: + # EMA + saved_ema = resume_state.get("ema_state") + if saved_ema is not None: + ema_state = {k: v.to(device).float() for k, v in saved_ema.items()} + log0(" Restored EMA state") + # SWA + saved_swa = resume_state.get("swa_state") + if saved_swa is not None: + swa_state = {k: v.cpu() for k, v in saved_swa.items()} + swa_count = resume_state.get("swa_count", 0) + log0(f" Restored SWA state (count={swa_count})") + else: + swa_count = resume_state.get("swa_count", 0) + # QAT + if resume_state.get("qat_enabled", False): + CastedLinear._qat_enabled = True + log0(" Restored QAT enabled state") + # RNG states + saved_rng = resume_state.get("rng_states") + if saved_rng is not None: + torch.set_rng_state(saved_rng["torch_cpu"]) + torch.cuda.set_rng_state(saved_rng["torch_cuda"]) + np.random.set_state(saved_rng["numpy"]) + random.setstate(saved_rng["python"]) + log0(" Restored RNG states") + # Stream state (data loader fast-forward) + saved_stream = resume_state.get("stream_state") + if saved_stream is not None: + s_idx, s_pos = saved_stream + stream = train_loader.stream + # Advance to the saved shard + while stream.idx != s_idx: + stream._advance_file() + stream.pos = s_pos + log0(f" Restored stream state (shard={s_idx}, pos={s_pos})") + else: + # No stream state saved — fast-forward by consuming tokens + if start_step > 0: + log0(f" Fast-forwarding data loader by {start_step} steps...") + for _ in range(start_step): + train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + log0(f" Data loader advanced to step {start_step}") + del resume_state + log0(" Full checkpoint restore complete") + + # ─── Training Loop ─────────────────────────────────────────────────── + # Clear stale chain marker from previous segment (if any) + stale_marker = os.path.join(args.ckpt_dir, f"CHAIN_RESUME_FROM_seed{args.seed}") + if os.path.exists(stale_marker): + os.remove(stale_marker) + + log0(f"\n{'='*80}") + log0(f"Starting training: {args.iterations} steps (from step {start_step})") + log0(f"{'='*80}\n") + + t0 = time.time() + running_loss = 0.0 + loss_count = 0 + stop_after_step = None + step = start_step # ensure step is defined even if loop doesn't execute + + for step in range(start_step + 1, args.iterations + 1): + # Check early stop + if stop_after_step is not None and step > stop_after_step: + log0(f"Stopping early at step {step} (wallclock limit)") + break + + lr_mul = lr_schedule(step) + + # Muon momentum warmup + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + current_muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in muon_opt.param_groups: + group["lr"] = args.matrix_lr * lr_mul + group["momentum"] = current_muon_momentum + for i, pg in enumerate(adam_opt.param_groups): + if i == 0: + pg["lr"] = args.scalar_lr * lr_mul + else: + pg["lr"] = args.tied_embed_lr * lr_mul + + # Late QAT: activate int6 STE only during warmdown (not warmup!) + warmdown_start = args.iterations - args.warmdown_iters + if (args.late_qat_threshold > 0 and step >= warmdown_start + and lr_mul < args.late_qat_threshold and not CastedLinear._qat_enabled): + CastedLinear._qat_enabled = True + log0(f"Late QAT enabled at step {step} (lr_mul={lr_mul:.4f})") + + # Gradient accumulation + model.train() + total_loss = 0.0 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + micro_batch = x.shape[0] // grad_accum_steps + for micro_step in range(grad_accum_steps): + x_micro = x[micro_step * micro_batch:(micro_step + 1) * micro_batch] + y_micro = y[micro_step * micro_batch:(micro_step + 1) * micro_batch] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = model(x_micro, y_micro) + loss = loss / grad_accum_steps + loss.backward() + total_loss += loss.item() + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip_norm) + + muon_opt.step() + adam_opt.step() + muon_opt.zero_grad(set_to_none=True) + adam_opt.zero_grad(set_to_none=True) + + # EMA update (every step) + with torch.no_grad(): + for name, t in base_model.state_dict().items(): + ema_state[name].mul_(args.ema_decay).add_(t.detach().float(), alpha=1.0 - args.ema_decay) + + # SWA: collect checkpoints during late warmdown + if args.swa_enabled and lr_mul < 0.2 and step % args.swa_every == 0: + if swa_state is None: + swa_state = {name: t.detach().cpu().clone() for name, t in base_model.state_dict().items()} + swa_count = 1 + log0(f"SWA started at step {step}") + else: + for name, t in base_model.state_dict().items(): + swa_state[name] += t.detach().cpu() + swa_count += 1 + + running_loss += total_loss + loss_count += 1 + + # Logging + if step % args.train_log_every == 0 or step <= 10: + avg_loss = running_loss / max(loss_count, 1) + elapsed = time.time() - t0 + steps_per_sec = step / elapsed + log0(f"step {step:5d}/{args.iterations} | loss {avg_loss:.4f} | lr_mul {lr_mul:.4f} | " + f"mom {current_muon_momentum:.3f} | {steps_per_sec:.2f} steps/s | {elapsed:.0f}s") + running_loss = 0.0 + loss_count = 0 + + # Validation + checkpoint + if (args.val_loss_every > 0 and step % args.val_loss_every == 0) or step == args.iterations: + val_loss, val_bpb = eval_val_sliding( + model, val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + rank, world_size, device, + seq_len=args.eval_seq_len, stride=args.eval_stride, + xsa_eval=args.xsa_eval, + compile_enabled=args.eval_compile_enabled, + ) + log0(f"step {step:5d} | val_loss {val_loss:.4f} | val_bpb {val_bpb:.4f}") + + if master_process and args.save_every > 0 and (step % args.save_every == 0 or step == args.iterations): + ckpt_path = save_checkpoint( + model, step, val_bpb, args.ckpt_dir, config["arch_name"], args.seed, + ) + log0(f" Saved: {ckpt_path}") + + # Wallclock limit + if args.max_wallclock_seconds > 0: + elapsed = time.time() - t0 + if elapsed > args.max_wallclock_seconds and stop_after_step is None: + stop_after_step = step + log0(f"Wallclock limit reached ({elapsed:.0f}s), will stop after this step") + + # Auto-save for chained job support + if args.auto_save_seconds > 0: + elapsed = time.time() - t0 + if elapsed > args.auto_save_seconds: + log0(f"Auto-save triggered at step {step} ({elapsed:.0f}s elapsed)") + if master_process: + rng_states = { + "torch_cpu": torch.get_rng_state(), + "torch_cuda": torch.cuda.get_rng_state(), + "numpy": np.random.get_state(), + "python": random.getstate(), + } + stream = train_loader.stream + stream_state = (stream.idx, stream.pos) + ckpt_path = save_full_checkpoint( + model, step, 0.0, args.ckpt_dir, config["arch_name"], args.seed, + muon_opt, adam_opt, ema_state, swa_state, swa_count, + CastedLinear._qat_enabled, + rng_states=rng_states, stream_state=stream_state, + ) + # Write chain resume marker + marker_path = os.path.join(args.ckpt_dir, f"CHAIN_RESUME_FROM_seed{args.seed}") + with open(marker_path, "w") as f: + f.write(ckpt_path + "\n") + log0(f" Full checkpoint saved: {ckpt_path}") + log0(f" Chain resume marker: {marker_path}") + break # exit training loop cleanly + + # ─── Check if we exited due to auto-save vs normal completion ──────── + chain_marker = os.path.join(args.ckpt_dir, f"CHAIN_RESUME_FROM_seed{args.seed}") + if os.path.exists(chain_marker): + log0("\nExiting for chained job resume (skipping post-training)") + if distributed: + dist.destroy_process_group() + return + + # Check for total_iterations completion + effective_total = args.total_iterations if args.total_iterations > 0 else args.iterations + if master_process and step >= effective_total: + complete_marker = os.path.join(args.ckpt_dir, f"TRAINING_COMPLETE_seed{args.seed}") + with open(complete_marker, "w") as f: + f.write(f"step={step}\n") + + # ─── Post-Training: Apply EMA ──────────────────────────────────────── + elapsed_total = time.time() - t0 + log0(f"\nTraining complete in {elapsed_total:.0f}s") + log0(f"Peak memory: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB") + + log0("\n=== Applying EMA weights ===") + avg_state = {name: t.to(dtype=base_model.state_dict()[name].dtype) for name, t in ema_state.items()} + base_model.load_state_dict(avg_state, strict=True) + + # Eval EMA weights + val_loss_ema, val_bpb_ema = eval_val_sliding( + model, val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + rank, world_size, device, + seq_len=args.eval_seq_len, stride=args.eval_stride, + xsa_eval=False, + compile_enabled=args.eval_compile_enabled, + ) + log0(f"EMA BPB (no XSA): {val_bpb_ema:.6f}") + + # Save raw EMA model + if master_process: + torch.save(base_model.state_dict(), os.path.join(args.ckpt_dir, f"final_model_{config['arch_name']}_seed{args.seed}.pt")) + log0("Saved raw EMA model") + + # ─── GPTQ Calibration (optional) ───────────────────────────────────── + gptq_enabled = bool(int(os.environ.get("GPTQ_ENABLED", "0"))) + hessians = None + if gptq_enabled: + log0("\n=== GPTQ: generating autoregressive calibration data ===") + calib_seqs = generate_autoregressive_calib( + base_model, device, num_seqs=64, seq_len=args.train_seq_len, + vocab_size=args.vocab_size, temperature=0.8, batch_size=8, seed=args.seed, + ) + log0(f"GPTQ: generated {len(calib_seqs)} sequences, collecting hessians...") + hessians = collect_hessians_from_tokens(base_model, calib_seqs, device) + log0(f"GPTQ: collected hessians for {len(hessians)} layers") + + # ─── Quantization + Artifact Creation ──────────────────────────────── + log0("\n=== Quantizing to int6 + zstd-22 ===") + sd_cpu = {k: v.detach().cpu() for k, v in base_model.state_dict().items()} + quant_result, quant_meta = mixed_quantize(sd_cpu, hessians=hessians) + + quant_buf = io.BytesIO() + torch.save({"w": quant_result, "m": quant_meta}, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = zstandard.ZstdCompressor(level=22).compress(quant_raw) + + artifact_path = os.path.join(args.ckpt_dir, f"final_model_{config['arch_name']}_seed{args.seed}.int6.ptz") + if master_process: + with open(artifact_path, "wb") as f: + f.write(quant_blob) + artifact_bytes = len(quant_blob) + log0(f"Artifact: {artifact_bytes:,} bytes ({artifact_bytes / 1024 / 1024:.2f} MB)") + if artifact_bytes > 16 * 1024 * 1024: + log0(f"WARNING: Artifact exceeds 16MB budget by {(artifact_bytes - 16*1024*1024) / 1024:.1f} KB") + + # ─── Roundtrip Validation ──────────────────────────────────────────── + log0("\n=== Roundtrip Validation (quantized model) ===") + if distributed: + dist.barrier() + + with open(artifact_path, "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load( + io.BytesIO(zstandard.ZstdDecompressor().decompress(quant_blob_disk)), + map_location="cpu", + ) + deq_state = dequantize_mixed(quant_state["w"], quant_state["m"], sd_cpu) + + # Build fresh eval model (no DDP wrapping needed) + eval_model = HybridGDN(config, args.vocab_size).to(device).bfloat16() + for m in eval_model.modules(): + if isinstance(m, CastedLinear): + m.float() + for name, p in eval_model.named_parameters(): + if p.ndim <= 1: + p.data = p.data.float() + eval_model.load_state_dict(deq_state, strict=True) + + # Eval quantized model without XSA + val_loss_q, val_bpb_q = eval_val_sliding( + eval_model, val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + rank, world_size, device, + seq_len=args.eval_seq_len, stride=args.eval_stride, + xsa_eval=False, + compile_enabled=args.eval_compile_enabled, + ) + log0(f"Quantized BPB (no XSA): {val_bpb_q:.6f}") + log0(f"Quantization degradation: {val_bpb_q - val_bpb_ema:+.6f}") + + # Eval quantized model WITH XSA (if model has SWA layers) + block_types = eval_model._block_types + if any(bt in ("swa", "swa_shared") for bt in block_types): + val_loss_qx, val_bpb_qx = eval_val_sliding( + eval_model, val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + rank, world_size, device, + seq_len=args.eval_seq_len, stride=args.eval_stride, + xsa_eval=True, + compile_enabled=args.eval_compile_enabled, + ) + log0(f"Quantized BPB (XSA-all): {val_bpb_qx:.6f}") + + # ─── Final Summary ─────────────────────────────────────────────────── + log0(f"\n{'='*80}") + log0(f"FINAL RESULTS — {config['arch_name']} seed={args.seed}") + log0(f" Training: {args.iterations} steps, {elapsed_total:.0f}s") + log0(f" EMA BPB (fp32): {val_bpb_ema:.6f}") + log0(f" Quantized BPB: {val_bpb_q:.6f}") + if any(bt in ("swa", "swa_shared") for bt in block_types): + log0(f" Quantized BPB+XSA: {val_bpb_qx:.6f}") + if master_process: + log0(f" Artifact size: {artifact_bytes:,} bytes") + log0(f"{'='*80}") + log0(f"final_int6_roundtrip_exact val_loss:{val_loss_q:.8f} val_bpb:{val_bpb_q:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() diff --git a/records/track_10min_16mb/2026-04-17_KKVShareWider_FLA_Opensens/train_gpt.py b/records/track_10min_16mb/2026-04-17_KKVShareWider_FLA_Opensens/train_gpt.py new file mode 100644 index 0000000000..c92492b069 --- /dev/null +++ b/records/track_10min_16mb/2026-04-17_KKVShareWider_FLA_Opensens/train_gpt.py @@ -0,0 +1,87 @@ +#!/usr/bin/env python3 +"""FLA / GatedDeltaNet entrypoint wrapper. + +The actual training logic lives in `train_gdn_7k.py`. `evaluate.py` expects +`torchrun train_gpt.py`, so this wrapper preserves the standard repo entrypoint +while keeping the scored path in the records folder self-contained. +""" + +import os +import sys +import traceback +from pathlib import Path + +# These defaults keep the wrapper aligned with the intended SP8192 scored path. +VOCAB_SIZE = int(os.environ.get("VOCAB_SIZE", 8192)) +DATA_PATH = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp8192") +TOKENIZER_PATH = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_8192_bpe.model") +ARCH_MODE = os.environ.get("ARCH_MODE", "K") +os.environ.setdefault("VOCAB_SIZE", str(VOCAB_SIZE)) +os.environ.setdefault("DATA_PATH", DATA_PATH) +os.environ.setdefault("TOKENIZER_PATH", TOKENIZER_PATH) +os.environ.setdefault("ARCH_MODE", ARCH_MODE) +os.environ.setdefault("MAX_WALLCLOCK_SECONDS", "600") +os.environ.setdefault("VAL_LOSS_EVERY", "0") +os.environ.setdefault("EVAL_COMPILE_ENABLED", "0") +if ARCH_MODE in ("D", "G", "M"): + os.environ.setdefault("XSA_EVAL", "1") + + +_VENDOR_DIR = Path(__file__).resolve().parent / ".fla_vendor" +_VENDOR_PKGS = [ + "triton==3.2.0", + "flash-linear-attention==0.4.2", + "fla-core==0.4.2", + "transformers==5.5.4", + "tokenizers==0.22.2", + "safetensors==0.7.0", +] +if ARCH_MODE in ("F", "G"): + _VENDOR_PKGS.extend( + [ + "mamba-ssm==2.3.1", + "causal-conv1d==1.6.1", + ] + ) + + +def _ensure_vendor_on_path() -> None: + p = str(_VENDOR_DIR) + if p not in sys.path: + sys.path.insert(0, p) + + +def _ensure_fla_vendor_available() -> None: + _ensure_vendor_on_path() + try: + if ARCH_MODE in ("F", "G"): + from fla.layers.mamba2 import Mamba2 # noqa: F401 + from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined # noqa: F401 + from causal_conv1d import causal_conv1d_fn # noqa: F401 + else: + from fla.layers.gated_deltanet import GatedDeltaNet # noqa: F401 + print("wrapper: local vendored FLA imports already work", flush=True) + return + except Exception: + vendor_pkgs = ", ".join(_VENDOR_PKGS) + raise RuntimeError( + "wrapper: required FLA deps are missing from the local environment. " + f"Expected vendored packages under {_VENDOR_DIR}. " + f"Install them before evaluation (e.g. via launcher/requirements), packages: {vendor_pkgs}" + ) + + +def main(): + _ensure_fla_vendor_available() + print("wrapper: importing train_gdn_7k", flush=True) + try: + from train_gdn_7k import main as train_main + except Exception: + traceback.print_exc() + raise + print("wrapper: import ok, entering train_main", flush=True) + train_main() + + +if __name__ == "__main__": + main() From 14eae0a144abe305f6f8143daf86770688e7b6b4 Mon Sep 17 00:00:00 2001 From: genji0306 Date: Sat, 18 Apr 2026 06:31:41 +0700 Subject: [PATCH 3/3] fix: correct SentencePiece byte-accounting LUT to match base repo - is_boundary defaults to True (was zeros) - skip control/unknown/unused tokens early - handle byte tokens as 1 byte explicitly - strip sentencepiece space marker before UTF-8 encoding - use int16 for base_bytes (was float32) Same bug that closed PR #1687. --- .../train_gdn_7k.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/records/track_10min_16mb/2026-04-17_KKVShareWider_FLA_Opensens/train_gdn_7k.py b/records/track_10min_16mb/2026-04-17_KKVShareWider_FLA_Opensens/train_gdn_7k.py index eb030c6103..847f082ab8 100644 --- a/records/track_10min_16mb/2026-04-17_KKVShareWider_FLA_Opensens/train_gdn_7k.py +++ b/records/track_10min_16mb/2026-04-17_KKVShareWider_FLA_Opensens/train_gdn_7k.py @@ -190,18 +190,21 @@ def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: def build_sentencepiece_luts(sp, vocab_size, device): - base_bytes = torch.zeros(vocab_size, dtype=torch.float32, device=device) + base_bytes = torch.zeros(vocab_size, dtype=torch.int16, device=device) has_space = torch.zeros(vocab_size, dtype=torch.bool, device=device) - is_boundary = torch.zeros(vocab_size, dtype=torch.bool, device=device) + is_boundary = torch.ones(vocab_size, dtype=torch.bool, device=device) for i in range(vocab_size): + if sp.is_control(i) or sp.is_unknown(i) or sp.is_unused(i): + continue + is_boundary[i] = False + if sp.is_byte(i): + base_bytes[i] = 1 + continue piece = sp.id_to_piece(i) - raw = piece.encode("utf-8") - base_bytes[i] = len(raw) if piece.startswith("\u2581"): has_space[i] = True - base_bytes[i] = len(piece[1:].encode("utf-8")) + 1 - if sp.is_control(i) or sp.is_unknown(i): - is_boundary[i] = True + piece = piece[1:] + base_bytes[i] = len(piece.encode("utf-8")) return base_bytes, has_space, is_boundary