diff --git a/records/track_non_record_16mb/2026-04-30_TTSM_TernarySSM/README.md b/records/track_non_record_16mb/2026-04-30_TTSM_TernarySSM/README.md new file mode 100644 index 0000000000..bb63f1b3fb --- /dev/null +++ b/records/track_non_record_16mb/2026-04-30_TTSM_TernarySSM/README.md @@ -0,0 +1,125 @@ +# TTSM: Typical Ternary State-Space Model + +**Author:** Ambivalence (dd_dent) +**Track:** 10min/16MB +**val_bpb:** 2.0032 (seed 42) +**Comparison baseline:** PR #1644 (mradassaad, Mamba-3 Hybrid, 1.1473 bpb) +**Artifact:** 12,039,626 bytes | **Params:** 11M (7.8M ternary at 1.6 bits/param, 3.3M fp16/fp32 dynamics) +**Date:** April 30, 2026 + +--- + +## State Is Protected + +In selective SSMs, the ternary quantization boundary falls at the B and C projections, not in the hidden state. The state vector h_t never sees the ternary constraint. This is why ternary Mamba works. + +In the selective SSM recurrence: + +``` +h_t = exp(Δ_t ⊙ A) ⊙ h_{t-1} + Δ_t ⊙ B_t ⊙ x_t +y_t = C_t ⊙ h_t + D ⊙ x_t +``` + +B_t controls *whether* each input is written into state — write at full scale, write nothing, or write negated. C_t controls *which* state channels are read out. Both are gates operating on a continuous fp16 state. Errors in B degrade write selectivity; they do not corrupt existing state. Errors in C degrade read selectivity; they do not touch what is being read. + +The contrast with DeltaNet: k_t appears in both the state update and the readout simultaneously. An error in k_t propagates bidirectionally. Ternary B/C is structurally easier than ternary DeltaNet k_t. We haven't tested this. + +Δ_t (the discretization step) is *not* the easy case either. Δ enters exp(ΔA), where small errors compound multiplicatively through the recurrence. We kept dt_proj in fp16 and A_log in fp32. mradassaad (PR #1890, same author as our comparison baseline) independently reached the same boundary — their Mamba-3 hybrid collapsed under INT6 until they promoted A/dt to INT8. The ternary boundary is at B and C. + +The model confirms this structure empirically. B activations stayed stable across training (std ≈ 0.009 at convergence); C activations were highly variable (std ≈ 27.1). The write gate locked; the readout explored. We didn't engineer it. + +--- + +## Notes on Engineering + +**Reversed-scan backward.** The Triton forward kernel for chunk-wise parallel scan runs in ~1ms per step. The naive backward (PyTorch compiled autograd) took 31 seconds — 1,024 sequential Python→CUDA sync points. The backward recurrence `Δh[t] = do[t] + exp(g[t+1]) * Δh[t+1]` is the same operation as the forward scan, reversed in time. Flip the inputs, run the forward kernel, flip the output. 15 lines. 31s → 1.2s per step (26×). + +This generalizes: any recurrence whose backward is the same recurrence reversed in time can reuse its forward kernel. + +**NS=5 outperforms NS=10.** DeepSeek-V4 uses 10 Muon Newton-Schulz iterations. We confirmed NS=10 gives 52× better orthogonality than NS=5. It also gives worse val_bpb. The less-orthogonal Muon step acts as a diversity regularizer for the ternary weight competition — under STE, imprecision in the optimizer step prevents premature commitment to ternary assignments. Finite-resource optima require finite imprecision. + +**Overtraining degrades quality.** Running beyond the 600s cutoff (1×H100 long-burn), val_bpb rises past 2.08 by step 5000, compared to ~1.93 at the competition-equivalent step count. A phase transition around step ~3000 marks over-crystallization: flip rate collapses from ~12% to <1%, and the ternary assignments lock into a suboptimal configuration. The 600s budget is coincidentally near-optimal. + +**Frozen conv outperforms trained conv.** Short conv weights frozen at kaiming initialization outperform trained conv by 0.07 bpb on FineWeb. Random initialization provides an unbiased local smoother that task-specific training degrades. The submitted artifact uses trained conv — this was discovered after the submission run, and GPU availability precluded a rerun. + +**Trit packing at entropy.** Ternary weights pack at 5 trits per byte (1.6 bits/param). The artifact is compressed with zlib; zstd achieves equivalent ~0% compression — the encoding is already at entropy. Artifact budget is deterministic: `ternary_params = bytes × 5`. At 1.6 bits/param, 16 MB buys ~80M ternary parameters vs ~25M at int5. Whether the capacity gain exceeds the precision loss is the experiment this submission runs. + +**Constrained-optimum architecture.** Under the 10min/16MB budget, optimal model size shrinks relative to unconstrained scaling. Steps-per-second dominates parameters-per-bit. We swept d ∈ {384, 512, 576}, blocks ∈ {5, 7, 9}, D_STATE ∈ {32, 64}, and MATRIX_LR over five points. The architecture below is the result of this search. + +**Z-loss for STE stability.** Large logits saturate the softmax, reducing CE gradients to near-zero, cascading to near-zero STE gradients. Z-loss (`1e-4 × logsumexp(logits)².mean()`) keeps logits anchored near zero. From CiprianFlorin (PR #640); origin in PaLM/Gemma. + +**Triton kernel.** The kernel required three sequential fixes (int32 overflow, gc lifetime, traced conditional). Each bug hid behind the previous. + +--- + +## Architecture + +7 SSM-only blocks (no attention), d=576, D_STATE=64. Shared bigram/trigram token-boundary features. + +| Component | Precision | Reason | +|-----------|-----------|--------| +| B_proj, C_proj | Ternary ({-1,0,+1}, STE, per-row least-squares scales at serialization) | State is protected | +| dt_proj | fp16 | Discretization hazard | +| A_log | fp32 | Same | +| D (skip) | fp32 | Standard Mamba direct-path skip | +| Short conv (k=4) | fp32, trained | dt/B/C preprocessing | +| in_proj, out_proj | Ternary | Gating | + +Training: 10% bf16 warmstart → ternary QAT. Muon optimizer, MATRIX_LR=0.40. Z-loss 1e-4. B/C L2 normalization. Triton chunk-wise parallel scan with reversed-scan backward. Batch 32K tokens, EVAL_STRIDE=96. + +Triton scan kernel adapted from fla-org/flash-linear-attention HGRN (MIT license). + +
+Run command + +```bash +TTSM_BLOCKS=7 SSM_ONLY=1 D_STATE=64 A_LOG_INIT=diverse \ +SSM_SHORT_CONV=1 SSM_NORMALIZE_BC=1 TTSM_TRITON=1 \ +EVAL_STRIDE=96 MODEL_DIM=576 NUM_BLOCKS=7 MUON_EQ_R=1 \ +WARMSTART_FRAC=0.1 LAYER_TYPE_WARMDOWN=1 \ +BIGRAM_BUCKETS=3072 TRIGRAM_HASH=1 \ +MATRIX_LR=0.40 MUON_BACKEND_STEPS=5 \ +TRAIN_BATCH_TOKENS=32768 MAX_WALLCLOCK_SECONDS=600 \ +SEED=42 \ +torchrun --standalone --nproc_per_node=8 train_ternary.py +``` + +
+ +## Results + +| Seed | val_bpb | Steps | Artifact | +|------|---------|-------|----------| +| **42** | **2.0032** | 3889 | 12,039,626 B | + +8×H100 SXM, 154 ms/step, 600s wallclock. + +Additional seeds (same config, same hardware, separate runs): seed 314 → 2.0062, seed 999 → 2.0419. 3-seed mean: 2.0171 ± 0.018. + +## Compliance + +- [x] All seeds train in ≤600s +- [x] All artifacts ≤16,000,000 bytes (largest: 12,039,626) +- [x] Sliding window eval, EVAL_STRIDE=96, consistent across seeds +- [x] No test-time training on validation data +- [x] No network calls during evaluation + +--- + +## Attributions + +- **Fork base**: @thwu1 (int5/int6 mixed quantization) +- **Ternary QAT + Z-loss**: @CiprianFlorin (PR #640) +- **Mamba SSM**: Albert Gu, Tri Dao (2023) +- **Triton scan kernel**: fla-org/flash-linear-attention HGRN (MIT license, adapted) +- **BigramHash**: @Raahil Shah +- **MuonEq-R**: @clarkkev (PR #1394) +- **SSM baseline**: @mradassaad (PR #1644) +- **Dynamics protection convergence**: @mradassaad (PR #1890) +- **Muon optimizer**: Kosson et al., Jordan et al. +- **DeepSeek-V4**: CSA m=4 convergence with our short conv k=4; NS iteration calibration informed our NS=5 finding +- **Steps > depth insight**: @newjordan + +--- + +*First ternary SSM. The lane is open.* diff --git a/records/track_non_record_16mb/2026-04-30_TTSM_TernarySSM/submission.json b/records/track_non_record_16mb/2026-04-30_TTSM_TernarySSM/submission.json new file mode 100644 index 0000000000..2413a05d46 --- /dev/null +++ b/records/track_non_record_16mb/2026-04-30_TTSM_TernarySSM/submission.json @@ -0,0 +1,44 @@ +{ + "author": "Ambivalence", + "github_id": "dd_dent", + "name": "TTSM: Typical Ternary State-Space Model", + "date": "2026-04-30", + "track": "10min_16mb", + "val_loss": 3.38236311, + "val_bpb": 2.00323032, + "bytes_total": 12039626, + "seeds": [42], + "seed_results": { + "42": { + "val_loss": 3.38236311, + "val_bpb": 2.00323032, + "artifact_bytes": 12039626, + "steps": 3889, + "step_avg_ms": 154.30 + } + }, + "hardware": "8xH100 80GB SXM", + "pytorch_version": "2.9.1+cu128", + "cuda_version": "12.8", + "comparison_baseline_pr": 1644, + "technique_summary": "First ternary SSM submission. Mamba-1 selective SSM with B and C projections quantized to {-1,0,+1} via STE. Hidden state h_t remains fp16 — protected from quantization errors at both write gate (B) and readout selector (C). dt_proj kept fp16 (discretization hazard). Triton chunk-wise parallel scan adapted from fla-org HGRN kernel (MIT). Reversed-scan Triton backward reuses forward kernel on time-reversed inputs (26x speedup, 31s to 1.2s/step). Short conv kernel_size=4 on dt/B/C. B/C L2 normalization. 7 SSM-only blocks, d=576, D_STATE=64. 10% bf16 warmstart then ternary QAT. Muon optimizer MATRIX_LR=0.40, MuonEq-R. Z-loss 1e-4. 11M params (7.8M ternary at 1.6 bits/param, 3.3M fp16/fp32 dynamics). 12 MB artifact.", + "compliance": { + "train_under_600s": true, + "artifact_under_16mb": true, + "no_slot": true, + "no_pre_quant_ttt": true, + "no_etlb": true, + "no_ngram_cache": true + }, + "attribution": { + "mamba_architecture": "Albert Gu, Tri Dao (Mamba: Linear-Time Sequence Modeling, 2023)", + "triton_scan_kernel": "fla-org/flash-linear-attention HGRN kernel (MIT license, adapted)", + "ternary_qat_zloss": "@CiprianFlorin (PR #640) — ternary SOTA + Z-loss technique", + "fork_base": "@thwu1 — int5/int6 mixed quantization", + "muon_eq_r": "@clarkkev (PR #1394)", + "ssm_baseline": "@mradassaad (PR #1644, verified SSM SOTA 1.1473 bpb)", + "dynamics_protection": "@mradassaad (PR #1890)", + "muon_optimizer": "Kosson et al., Jordan et al.", + "steps_depth": "@newjordan" + } +} diff --git a/records/track_non_record_16mb/2026-04-30_TTSM_TernarySSM/ternary_linear.py b/records/track_non_record_16mb/2026-04-30_TTSM_TernarySSM/ternary_linear.py new file mode 100644 index 0000000000..090ac84b62 --- /dev/null +++ b/records/track_non_record_16mb/2026-04-30_TTSM_TernarySSM/ternary_linear.py @@ -0,0 +1,128 @@ +""" +TernaryLinear: BitLinear drop-in replacement for CastedLinear. + +Stores weights in fp32 (for optimizer quality), quantizes to ternary {-1, 0, +1} +in the forward pass via absmean scaling + STE (straight-through estimator). + +At inference/serialization time, weights are truly ternary — each weight is one of +three values, packed at 5 trits per byte for 1.6 bits/param density. + +During training, bf16 shadow weights are maintained (same as every other QAT approach). +The density advantage is purely at artifact time. +""" + +from __future__ import annotations + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor + + +class TernaryLinear(nn.Module): + """ + Linear layer with ternary quantization-aware training (QAT). + + Forward pass: quantize weight to {-1, 0, +1} * scale via absmean, + then do standard F.linear. Gradient flows through via STE. + + Supports per-row or per-tensor scaling: + - per_row=True: scale_i = mean(|W[i,:]|) for each output row + - per_row=False: scale = mean(|W|) for entire weight matrix + """ + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = False, + per_row: bool = True, + ): + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.per_row = per_row + self.weight = nn.Parameter(torch.empty(out_features, in_features)) + if bias: + self.bias = nn.Parameter(torch.zeros(out_features)) + else: + self.register_parameter("bias", None) + # Mark for zero-init detection (same convention as CastedLinear) + self._zero_init = False + + def _quantize_ternary(self, w: Tensor) -> tuple[Tensor, Tensor]: + """Quantize weight to ternary via absmean scaling. + + Returns (w_quantized_scaled, scale) where w_quantized_scaled = ternary * scale. + """ + if self.per_row: + # Per-row scale: each output channel gets its own magnitude + scale = w.abs().mean(dim=1, keepdim=True).clamp_min(1e-8) + else: + # Per-tensor scale + scale = w.abs().mean().clamp_min(1e-8) + # Quantize: round(w / scale) clamped to {-1, 0, +1} + w_ternary = torch.clamp(torch.round(w / scale), -1, 1) + return w_ternary * scale, scale + + def forward(self, x: Tensor) -> Tensor: + w = self.weight.to(x.dtype) + # Quantize to ternary + w_q, _ = self._quantize_ternary(w) + # STE: forward uses quantized weights, backward sees continuous weights + # w + (w_q - w).detach() has the value of w_q but the gradient of w + w_ste = w + (w_q - w).detach() + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w_ste, bias) + + def get_ternary_weights(self) -> tuple[Tensor, Tensor]: + """Extract the discrete ternary weights and scales (for serialization). + + Returns: + ternary: int8 tensor of {-1, 0, +1}, shape (out_features, in_features) + scale: fp16 tensor, shape (out_features, 1) if per_row else scalar + """ + with torch.no_grad(): + w = self.weight.float() + if self.per_row: + scale = w.abs().mean(dim=1, keepdim=True).clamp_min(1e-8) + else: + scale = w.abs().mean().clamp_min(1e-8) + ternary = torch.clamp(torch.round(w / scale), -1, 1).to(torch.int8) + return ternary, scale.to(torch.float16) + + def extra_repr(self) -> str: + return ( + f"in_features={self.in_features}, out_features={self.out_features}, " + f"bias={self.bias is not None}, per_row={self.per_row}" + ) + + +def replace_linear_with_ternary( + module: nn.Module, + per_row: bool = True, + skip_patterns: tuple[str, ...] = ("tok_emb", "lm_head", "bigram"), +) -> nn.Module: + """Replace CastedLinear/nn.Linear layers with TernaryLinear, except those + matching skip_patterns (embeddings, head, etc. stay fp16).""" + for name, child in list(module.named_children()): + full_name = name + if any(pat in full_name for pat in skip_patterns): + continue + if isinstance(child, nn.Linear): + ternary = TernaryLinear( + child.in_features, + child.out_features, + bias=child.bias is not None, + per_row=per_row, + ) + # Copy weights + ternary.weight.data.copy_(child.weight.data) + if child.bias is not None and ternary.bias is not None: + ternary.bias.data.copy_(child.bias.data) + # Preserve zero-init flag + ternary._zero_init = getattr(child, "_zero_init", False) + setattr(module, name, ternary) + else: + replace_linear_with_ternary(child, per_row=per_row, skip_patterns=skip_patterns) + return module diff --git a/records/track_non_record_16mb/2026-04-30_TTSM_TernarySSM/ternary_monitor.py b/records/track_non_record_16mb/2026-04-30_TTSM_TernarySSM/ternary_monitor.py new file mode 100644 index 0000000000..0ecdd132d9 --- /dev/null +++ b/records/track_non_record_16mb/2026-04-30_TTSM_TernarySSM/ternary_monitor.py @@ -0,0 +1,266 @@ +""" +TernaryMonitor: Adaptive LR controller for ternary QAT training. + +Ported from GeoCog's AdaptiveLRController (geocog/adaptive_lr.py). + +GeoCog watches coherence_H (directional consistency of latent movement) to +detect training states. We adapt this for ternary QAT by watching: + - flip_rate: fraction of weights that changed ternary state per step + - loss_velocity: smoothed slope of training loss + +State mapping: + GeoCog coherence_H → Ternary flip_rate + GROUPTHINK (high coh + low acc) → FROZEN (low flip + stagnant loss) + NOISY (low coh) → CHURNING (high flip + stagnant loss) + LEARNING (healthy coh range) → LEARNING (decreasing loss) + CRYSTALLIZING → PLATEAU (low flip + slow improvement) + CONVERGED → not mapped (we don't have accuracy for LM training) + +The key insight: flip_rate is to ternary weights what coherence_H is to +z-trajectory direction — both measure "is the model making meaningful +progress or stuck?" +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from enum import Enum +from typing import Optional + +import torch +from torch import Tensor + + +class TernaryState(Enum): + """Detected training state based on ternary monitoring signals.""" + FROZEN = "frozen" # Weights stuck, loss not improving → spike LR + CHURNING = "churning" # Weights thrashing, loss not improving → reduce LR + LEARNING = "learning" # Loss improving → maintain LR + PLATEAU = "plateau" # Slow/no improvement, moderate flip → nudge LR up + + +@dataclass +class TernaryMonitorConfig: + """Configuration for the ternary monitor.""" + + # Flip rate thresholds + frozen_threshold: float = 0.001 # Below this = frozen + churning_threshold: float = 0.05 # Above this = churning + + # Loss velocity threshold (negative = improving) + improving_threshold: float = -1e-4 # Below this = loss is improving + + # LR adjustment factors (applied to current LR factor) + frozen_factor: float = 2.0 # Spike LR to shake free (geocog: GROUPTHINK) + churning_factor: float = 0.5 # Halve LR to stabilize (geocog: NOISY) + learning_decay: float = 0.9 # Decay factor toward 1.0 when healthy + plateau_factor: float = 1.2 # Gentle nudge up + + # Smoothing + window: int = 50 # Steps to average over + hysteresis_steps: int = 2 # Steps state must be stable before acting + + # LR bounds + min_lr_factor: float = 0.1 # Don't go below base_lr * this + max_lr_factor: float = 5.0 # Don't go above base_lr * this + + # Gradient perturbation on frozen state + perturb_on_frozen: bool = True + perturb_magnitude: float = 0.3 # Noise fraction added to gradients + + # Logging + log_every: int = 100 # Log monitor state every N steps + + +class TernaryMonitor: + """ + Monitors ternary weight dynamics and adjusts learning rate accordingly. + + Usage: + monitor = TernaryMonitor(config, base_lr=0.02) + + for step in training: + loss = train_step() + + # Snapshot ternary weights periodically + if step % monitor_every == 0: + state, lr_factor = monitor.step(model, loss.item()) + # Apply lr_factor to optimizer groups + """ + + def __init__( + self, + config: Optional[TernaryMonitorConfig] = None, + base_lr: float = 0.02, + ): + self.config = config or TernaryMonitorConfig() + self.base_lr = base_lr + + # History buffers + self.loss_history: list[float] = [] + self.flip_rate_history: list[float] = [] + self.state_history: list[TernaryState] = [] + self.lr_factor_history: list[float] = [] + + # Weight snapshots + self._prev_weights: Optional[dict[str, Tensor]] = None + + # State tracking + self.current_state = TernaryState.LEARNING + self.steps_in_state = 0 + self.total_steps = 0 + self.current_lr_factor = 1.0 + + def _extract_ternary_weights(self, model: torch.nn.Module) -> dict[str, Tensor]: + """Extract quantized ternary weights from all TernaryLinear modules.""" + weights = {} + for name, module in model.named_modules(): + if hasattr(module, 'get_ternary_weights'): + ternary, _ = module.get_ternary_weights() + weights[name] = ternary + return weights + + def _compute_flip_rate( + self, prev: dict[str, Tensor], curr: dict[str, Tensor] + ) -> float: + """Compute fraction of weights that changed ternary state.""" + total = 0 + flipped = 0 + for name in prev: + if name in curr: + p, c = prev[name], curr[name] + total += p.numel() + flipped += (p != c).sum().item() + return flipped / max(total, 1) + + def _compute_loss_velocity(self) -> float: + """Compute smoothed loss velocity (negative = improving).""" + w = self.config.window + if len(self.loss_history) < w: + return -1.0 # Assume learning early on + return (self.loss_history[-1] - self.loss_history[-w]) / w + + def _detect_state(self, flip_rate: float, loss_vel: float) -> TernaryState: + """Detect training state from signals.""" + cfg = self.config + + # Average flip rate for stability + recent_flips = self.flip_rate_history[-cfg.window:] + avg_flip = sum(recent_flips) / max(len(recent_flips), 1) + + is_improving = loss_vel < cfg.improving_threshold + + if avg_flip < cfg.frozen_threshold and not is_improving: + return TernaryState.FROZEN + elif avg_flip > cfg.churning_threshold and not is_improving: + return TernaryState.CHURNING + elif is_improving: + return TernaryState.LEARNING + else: + return TernaryState.PLATEAU + + def step( + self, + model: torch.nn.Module, + loss_value: float, + ) -> tuple[TernaryState, float]: + """ + Update monitor with current model state and loss. + + Args: + model: the model (must contain TernaryLinear modules) + loss_value: current training loss + + Returns: + state: detected training state + lr_factor: multiplier to apply to learning rate + """ + # Extract current ternary weights + curr_weights = self._extract_ternary_weights(model) + + # Compute flip rate + if self._prev_weights is not None: + flip_rate = self._compute_flip_rate(self._prev_weights, curr_weights) + else: + flip_rate = 1.0 # First step: everything is "new" + self._prev_weights = {k: v.clone() for k, v in curr_weights.items()} + + # Update histories + self.loss_history.append(loss_value) + self.flip_rate_history.append(flip_rate) + + # Compute loss velocity + loss_vel = self._compute_loss_velocity() + + # Detect state + new_state = self._detect_state(flip_rate, loss_vel) + + # Track state transitions (hysteresis) + if new_state != self.current_state: + self.steps_in_state = 0 + else: + self.steps_in_state += 1 + self.current_state = new_state + + # Compute LR factor + cfg = self.config + if self.steps_in_state < cfg.hysteresis_steps: + # Not stable yet — hold current factor + lr_factor = self.current_lr_factor + elif new_state == TernaryState.LEARNING: + # Decay factor toward 1.0: factor = 1.0 + (factor - 1.0) * decay + # After a FROZEN spike (factor=2.0), this gradually returns to baseline + lr_factor = 1.0 + (self.current_lr_factor - 1.0) * cfg.learning_decay + else: + factor_map = { + TernaryState.FROZEN: cfg.frozen_factor, + TernaryState.CHURNING: cfg.churning_factor, + TernaryState.PLATEAU: cfg.plateau_factor, + } + raw_factor = self.current_lr_factor * factor_map[new_state] + lr_factor = max(cfg.min_lr_factor, min(cfg.max_lr_factor, raw_factor)) + + self.current_lr_factor = lr_factor + self.state_history.append(new_state) + self.lr_factor_history.append(lr_factor) + self.total_steps += 1 + + return new_state, lr_factor + + def should_perturb(self) -> bool: + """Check if we should add noise to gradients (frozen escape).""" + return ( + self.config.perturb_on_frozen + and self.current_state == TernaryState.FROZEN + and self.steps_in_state >= self.config.hysteresis_steps + ) + + def get_summary(self) -> dict: + """Get summary statistics for logging.""" + state_counts: dict[str, int] = {} + for s in self.state_history: + state_counts[s.value] = state_counts.get(s.value, 0) + 1 + + return { + "total_steps": self.total_steps, + "final_state": self.current_state.value, + "final_lr_factor": self.current_lr_factor, + "state_counts": state_counts, + "avg_flip_rate": ( + sum(self.flip_rate_history) / max(len(self.flip_rate_history), 1) + ), + "final_loss_velocity": self._compute_loss_velocity(), + "frozen_episodes": state_counts.get("frozen", 0), + "churning_episodes": state_counts.get("churning", 0), + } + + def format_status(self) -> str: + """Format current status for logging.""" + flip = self.flip_rate_history[-1] if self.flip_rate_history else 0.0 + loss_vel = self._compute_loss_velocity() + return ( + f"monitor:{self.current_state.value} " + f"flip_rate:{flip:.6f} loss_vel:{loss_vel:.6f} " + f"lr_factor:{self.current_lr_factor:.3f} " + f"steps_in_state:{self.steps_in_state}" + ) diff --git a/records/track_non_record_16mb/2026-04-30_TTSM_TernarySSM/train_seed42.log b/records/track_non_record_16mb/2026-04-30_TTSM_TernarySSM/train_seed42.log new file mode 100644 index 0000000000..3f13a0277c --- /dev/null +++ b/records/track_non_record_16mb/2026-04-30_TTSM_TernarySSM/train_seed42.log @@ -0,0 +1,210 @@ +W0429 16:02:03.643000 154 torch/distributed/run.py:803] +W0429 16:02:03.643000 154 torch/distributed/run.py:803] ***************************************** +W0429 16:02:03.643000 154 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0429 16:02:03.643000 154 torch/distributed/run.py:803] ***************************************** +logs/e735e273-9394-492b-b5fd-2c3c2d66b57a.txt +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=/workspace/parameter-golf/data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=/workspace/parameter-golf/data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model:TTSM num_blocks:7 num_loops:1 effective_depth:7 model_dim:576 +ttsm:blocks=7/7 kind=SSM-only d_state=64 a_log_init=diverse scan_chunk_size=64 short_conv=True normalize_bc=True +model_params:11035714 +world_size:8 grad_accum_steps:1 +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.03 matrix_lr:0.4 scalar_lr:0.02 muon_eq_r:True +train_batch_tokens:32768 train_seq_len:2048 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 +seed:42 +ternary:enabled=True per_row=True warmstart=0.1 monitor=log learned_quant_mix=False qm_penalty=0.001 layer_type_warmdown=True +ternary_monitor:initialized mode=log +ternary:warmstart bf16 for first 10% of training +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:1/20000 train_loss:6.9573 train_time:160ms step_avg:159.66ms +step:2/20000 train_loss:13.8095 train_time:277ms step_avg:138.48ms +step:3/20000 train_loss:7.7085 train_time:433ms step_avg:144.42ms +step:4/20000 train_loss:6.8335 train_time:589ms step_avg:147.32ms +step:5/20000 train_loss:6.3488 train_time:745ms step_avg:149.02ms +step:6/20000 train_loss:6.1646 train_time:902ms step_avg:150.41ms +step:7/20000 train_loss:5.9004 train_time:1058ms step_avg:151.20ms +step:8/20000 train_loss:5.4311 train_time:1214ms step_avg:151.78ms +step:9/20000 train_loss:5.5429 train_time:1373ms step_avg:152.50ms +step:10/20000 train_loss:5.4773 train_time:1529ms step_avg:152.87ms +step:100/20000 train_loss:3.9029 train_time:15223ms step_avg:152.23ms +step:200/20000 train_loss:3.6360 train_time:30450ms step_avg:152.25ms +step:300/20000 train_loss:3.4224 train_time:45669ms step_avg:152.23ms +ternary:activated layer_type_warmdown ramp_steps:354 step:394 +step:400/20000 train_loss:3.7694 train_time:60888ms step_avg:152.22ms +monitor:learning flip_rate:1.000000 loss_vel:-1.000000 lr_factor:1.000 steps_in_state:1 +step:500/20000 train_loss:3.4860 train_time:76115ms step_avg:152.23ms +monitor:learning flip_rate:0.251582 loss_vel:-1.000000 lr_factor:1.000 steps_in_state:2 +checkpoint:saved step:500 path:/workspace/ckpt_s000500_seed42.pt +checkpoint:ptz_saved step:500 path:/workspace/ckpt_s000500_seed42.ptz size:12063980 +step:600/20000 train_loss:3.4718 train_time:92288ms step_avg:153.81ms +monitor:learning flip_rate:0.236480 loss_vel:-1.000000 lr_factor:1.000 steps_in_state:3 +step:700/20000 train_loss:3.5342 train_time:107518ms step_avg:153.60ms +monitor:learning flip_rate:0.229603 loss_vel:-1.000000 lr_factor:1.000 steps_in_state:4 +step:800/20000 train_loss:3.7680 train_time:122768ms step_avg:153.46ms +monitor:learning flip_rate:0.228890 loss_vel:-1.000000 lr_factor:1.000 steps_in_state:5 +step:900/20000 train_loss:3.5070 train_time:137992ms step_avg:153.32ms +monitor:learning flip_rate:0.216324 loss_vel:-1.000000 lr_factor:1.000 steps_in_state:6 +step:1000/20000 train_loss:3.6167 train_time:153224ms step_avg:153.22ms +monitor:learning flip_rate:0.210186 loss_vel:-1.000000 lr_factor:1.000 steps_in_state:7 +checkpoint:saved step:1000 path:/workspace/ckpt_s001000_seed42.pt +checkpoint:ptz_saved step:1000 path:/workspace/ckpt_s001000_seed42.ptz size:12056058 +step:1100/20000 train_loss:3.7234 train_time:169378ms step_avg:153.98ms +monitor:learning flip_rate:0.194191 loss_vel:-1.000000 lr_factor:1.000 steps_in_state:8 +layer_type_warmdown:complete at step:1102 +step:1200/20000 train_loss:3.6442 train_time:184602ms step_avg:153.83ms +monitor:learning flip_rate:0.182293 loss_vel:-1.000000 lr_factor:1.000 steps_in_state:9 +step:1300/20000 train_loss:3.5068 train_time:199813ms step_avg:153.70ms +monitor:learning flip_rate:0.175489 loss_vel:-1.000000 lr_factor:1.000 steps_in_state:10 +step:1400/20000 train_loss:3.8614 train_time:215027ms step_avg:153.59ms +monitor:learning flip_rate:0.164171 loss_vel:-1.000000 lr_factor:1.000 steps_in_state:11 +step:1500/20000 train_loss:3.9501 train_time:230237ms step_avg:153.49ms +monitor:learning flip_rate:0.161048 loss_vel:-1.000000 lr_factor:1.000 steps_in_state:12 +checkpoint:saved step:1500 path:/workspace/ckpt_s001500_seed42.pt +checkpoint:ptz_saved step:1500 path:/workspace/ckpt_s001500_seed42.ptz size:12047082 +step:1600/20000 train_loss:3.4522 train_time:246363ms step_avg:153.98ms +monitor:learning flip_rate:0.146845 loss_vel:-1.000000 lr_factor:1.000 steps_in_state:13 +step:1700/20000 train_loss:3.6442 train_time:261578ms step_avg:153.87ms +monitor:learning flip_rate:0.133849 loss_vel:-1.000000 lr_factor:1.000 steps_in_state:14 +step:1800/20000 train_loss:3.4947 train_time:276788ms step_avg:153.77ms +monitor:learning flip_rate:0.119821 loss_vel:-1.000000 lr_factor:1.000 steps_in_state:15 +step:1900/20000 train_loss:3.7969 train_time:292034ms step_avg:153.70ms +monitor:learning flip_rate:0.109668 loss_vel:-1.000000 lr_factor:1.000 steps_in_state:16 +step:2000/20000 train_loss:3.8891 train_time:307259ms step_avg:153.63ms +monitor:learning flip_rate:0.102029 loss_vel:-1.000000 lr_factor:1.000 steps_in_state:17 +checkpoint:saved step:2000 path:/workspace/ckpt_s002000_seed42.pt +checkpoint:ptz_saved step:2000 path:/workspace/ckpt_s002000_seed42.ptz size:12044486 +step:2100/20000 train_loss:3.7660 train_time:323355ms step_avg:153.98ms +monitor:learning flip_rate:0.092557 loss_vel:-1.000000 lr_factor:1.000 steps_in_state:18 +step:2200/20000 train_loss:3.6453 train_time:338567ms step_avg:153.89ms +monitor:learning flip_rate:0.085794 loss_vel:-1.000000 lr_factor:1.000 steps_in_state:19 +step:2300/20000 train_loss:3.5297 train_time:353813ms step_avg:153.83ms +monitor:learning flip_rate:0.079167 loss_vel:-1.000000 lr_factor:1.000 steps_in_state:20 +step:2400/20000 train_loss:3.6347 train_time:369025ms step_avg:153.76ms +monitor:learning flip_rate:0.072580 loss_vel:-1.000000 lr_factor:1.000 steps_in_state:21 +step:2500/20000 train_loss:3.5448 train_time:384235ms step_avg:153.69ms +monitor:learning flip_rate:0.066831 loss_vel:-1.000000 lr_factor:1.000 steps_in_state:22 +checkpoint:saved step:2500 path:/workspace/ckpt_s002500_seed42.pt +checkpoint:ptz_saved step:2500 path:/workspace/ckpt_s002500_seed42.ptz size:12042034 +step:2600/20000 train_loss:3.3144 train_time:400454ms step_avg:154.02ms +monitor:learning flip_rate:0.061444 loss_vel:-1.000000 lr_factor:1.000 steps_in_state:23 +step:2700/20000 train_loss:3.7056 train_time:415680ms step_avg:153.96ms +monitor:learning flip_rate:0.056307 loss_vel:-1.000000 lr_factor:1.000 steps_in_state:24 +step:2800/20000 train_loss:3.7497 train_time:430892ms step_avg:153.89ms +monitor:learning flip_rate:0.051850 loss_vel:-1.000000 lr_factor:1.000 steps_in_state:25 +step:2900/20000 train_loss:3.4319 train_time:446217ms step_avg:153.87ms +monitor:learning flip_rate:0.046285 loss_vel:-1.000000 lr_factor:1.000 steps_in_state:26 +step:3000/20000 train_loss:3.4664 train_time:461427ms step_avg:153.81ms +monitor:learning flip_rate:0.042652 loss_vel:-1.000000 lr_factor:1.000 steps_in_state:27 +checkpoint:saved step:3000 path:/workspace/ckpt_s003000_seed42.pt +checkpoint:ptz_saved step:3000 path:/workspace/ckpt_s003000_seed42.ptz size:12041042 +step:3100/20000 train_loss:3.6180 train_time:478922ms step_avg:154.49ms +monitor:learning flip_rate:0.037634 loss_vel:-1.000000 lr_factor:1.000 steps_in_state:28 +step:3200/20000 train_loss:3.2874 train_time:494141ms step_avg:154.42ms +monitor:learning flip_rate:0.032860 loss_vel:-1.000000 lr_factor:1.000 steps_in_state:29 +step:3300/20000 train_loss:3.5751 train_time:509357ms step_avg:154.35ms +monitor:learning flip_rate:0.028174 loss_vel:-1.000000 lr_factor:1.000 steps_in_state:30 +step:3400/20000 train_loss:3.3550 train_time:524572ms step_avg:154.29ms +monitor:learning flip_rate:0.024406 loss_vel:-1.000000 lr_factor:1.000 steps_in_state:31 +step:3500/20000 train_loss:3.8713 train_time:539793ms step_avg:154.23ms +monitor:learning flip_rate:0.020268 loss_vel:-1.000000 lr_factor:1.000 steps_in_state:32 +checkpoint:saved step:3500 path:/workspace/ckpt_s003500_seed42.pt +checkpoint:ptz_saved step:3500 path:/workspace/ckpt_s003500_seed42.ptz size:12040144 +step:3600/20000 train_loss:3.3901 train_time:556057ms step_avg:154.46ms +monitor:learning flip_rate:0.015384 loss_vel:-1.000000 lr_factor:1.000 steps_in_state:33 +step:3700/20000 train_loss:3.5309 train_time:571273ms step_avg:154.40ms +monitor:learning flip_rate:0.011116 loss_vel:-1.000000 lr_factor:1.000 steps_in_state:34 +step:3800/20000 train_loss:3.3641 train_time:586482ms step_avg:154.34ms +monitor:learning flip_rate:0.006705 loss_vel:-1.000000 lr_factor:1.000 steps_in_state:35 +step:3889/20000 val_loss:3.3967 val_bpb:2.0117 train_time:600089ms step_avg:154.30ms +stopping_early: wallclock_cap train_time:600089ms step:3889/20000 +peak memory allocated: 38095 MiB reserved: 39522 MiB +ternary_monitor_summary: {'total_steps': 35, 'final_state': 'learning', 'final_lr_factor': 1.0, 'state_counts': {'learning': 35}, 'avg_flip_rate': 0.13412814097442138, 'final_loss_velocity': -1.0, 'frozen_episodes': 0, 'churning_episodes': 0} +ternary_distribution: total:7,483,392 -1:33.4% 0:33.1% +1:33.4% +ternary_serialize: ternary_params:7763968 passthrough_params:3271746 +Serialized model ternary+zlib: 12039626 bytes +Code size: 100880 bytes +Total submission size: 12140506 bytes +final_eval_mode:sliding_window stride:96 batch_seqs:32 + sliding_eval [ 0.0%] 32/80757 windows running_bpb=1.925180 + sliding_eval [ 2.0%] 1632/80757 windows running_bpb=2.021044 + sliding_eval [ 4.0%] 3232/80757 windows running_bpb=1.985360 + sliding_eval [ 6.0%] 4832/80757 windows running_bpb=2.009071 + sliding_eval [ 8.0%] 6432/80757 windows running_bpb=2.014345 + sliding_eval [ 9.9%] 8032/80757 windows running_bpb=2.009534 + sliding_eval [ 11.9%] 9632/80757 windows running_bpb=2.007353 + sliding_eval [ 13.9%] 11232/80757 windows running_bpb=2.012014 + sliding_eval [ 15.9%] 12832/80757 windows running_bpb=2.011026 + sliding_eval [ 17.9%] 14432/80757 windows running_bpb=2.012044 + sliding_eval [ 19.9%] 16032/80757 windows running_bpb=2.015362 + sliding_eval [ 21.8%] 17632/80757 windows running_bpb=2.017320 + sliding_eval [ 23.8%] 19232/80757 windows running_bpb=2.021735 + sliding_eval [ 25.8%] 20832/80757 windows running_bpb=2.020565 + sliding_eval [ 27.8%] 22432/80757 windows running_bpb=2.024036 + sliding_eval [ 29.8%] 24032/80757 windows running_bpb=2.023923 + sliding_eval [ 31.7%] 25632/80757 windows running_bpb=2.024551 + sliding_eval [ 33.7%] 27232/80757 windows running_bpb=2.030427 + sliding_eval [ 35.7%] 28832/80757 windows running_bpb=2.031924 + sliding_eval [ 37.7%] 30432/80757 windows running_bpb=2.032703 + sliding_eval [ 39.7%] 32032/80757 windows running_bpb=2.031216 + sliding_eval [ 41.6%] 33632/80757 windows running_bpb=2.032382 + sliding_eval [ 43.6%] 35232/80757 windows running_bpb=2.033270 + sliding_eval [ 45.6%] 36832/80757 windows running_bpb=2.033172 + sliding_eval [ 47.6%] 38432/80757 windows running_bpb=2.034914 + sliding_eval [ 49.6%] 40032/80757 windows running_bpb=2.034014 + sliding_eval [ 51.6%] 41632/80757 windows running_bpb=2.033366 + sliding_eval [ 53.5%] 43232/80757 windows running_bpb=2.033146 + sliding_eval [ 55.5%] 44832/80757 windows running_bpb=2.031532 + sliding_eval [ 57.5%] 46432/80757 windows running_bpb=2.030514 + sliding_eval [ 59.5%] 48032/80757 windows running_bpb=2.029468 + sliding_eval [ 61.5%] 49632/80757 windows running_bpb=2.029884 + sliding_eval [ 63.4%] 51232/80757 windows running_bpb=2.029239 + sliding_eval [ 65.4%] 52832/80757 windows running_bpb=2.029305 + sliding_eval [ 67.4%] 54432/80757 windows running_bpb=2.030022 + sliding_eval [ 69.4%] 56032/80757 windows running_bpb=2.032949 + sliding_eval [ 71.4%] 57632/80757 windows running_bpb=2.031508 + sliding_eval [ 73.3%] 59232/80757 windows running_bpb=2.032410 + sliding_eval [ 75.3%] 60832/80757 windows running_bpb=2.032854 + sliding_eval [ 77.3%] 62432/80757 windows running_bpb=2.033547 + sliding_eval [ 79.3%] 64032/80757 windows running_bpb=2.034310 + sliding_eval [ 81.3%] 65632/80757 windows running_bpb=2.037585 + sliding_eval [ 83.3%] 67232/80757 windows running_bpb=2.037453 + sliding_eval [ 85.2%] 68832/80757 windows running_bpb=2.037533 + sliding_eval [ 87.2%] 70432/80757 windows running_bpb=2.036644 + sliding_eval [ 89.2%] 72032/80757 windows running_bpb=2.036905 + sliding_eval [ 91.2%] 73632/80757 windows running_bpb=2.037254 + sliding_eval [ 93.2%] 75232/80757 windows running_bpb=2.037504 + sliding_eval [ 95.1%] 76832/80757 windows running_bpb=2.037974 + sliding_eval [ 97.1%] 78432/80757 windows running_bpb=2.038270 + sliding_eval [ 99.1%] 80032/80757 windows running_bpb=2.039144 +final_ternary_roundtrip val_loss:3.3824 val_bpb:2.0032 eval_time:1823473ms +final_ternary_roundtrip_exact val_loss:3.38236311 val_bpb:2.00323032 +[W429 16:44:38.398182101 AllocatorConfig.cpp:28] Warning: PYTORCH_CUDA_ALLOC_CONF is deprecated, use PYTORCH_ALLOC_CONF instead (function operator()) +[W429 16:44:38.500840353 AllocatorConfig.cpp:28] Warning: PYTORCH_CUDA_ALLOC_CONF is deprecated, use PYTORCH_ALLOC_CONF instead (function operator()) +[W429 16:44:38.731560272 AllocatorConfig.cpp:28] Warning: PYTORCH_CUDA_ALLOC_CONF is deprecated, use PYTORCH_ALLOC_CONF instead (function operator()) +[W429 16:44:38.862484357 AllocatorConfig.cpp:28] Warning: PYTORCH_CUDA_ALLOC_CONF is deprecated, use PYTORCH_ALLOC_CONF instead (function operator()) +[W429 16:44:38.911345223 AllocatorConfig.cpp:28] Warning: PYTORCH_CUDA_ALLOC_CONF is deprecated, use PYTORCH_ALLOC_CONF instead (function operator()) +[W429 16:44:38.943544393 AllocatorConfig.cpp:28] Warning: PYTORCH_CUDA_ALLOC_CONF is deprecated, use PYTORCH_ALLOC_CONF instead (function operator()) +[W429 16:44:38.030114791 AllocatorConfig.cpp:28] Warning: PYTORCH_CUDA_ALLOC_CONF is deprecated, use PYTORCH_ALLOC_CONF instead (function operator()) +[W429 16:44:38.033248966 AllocatorConfig.cpp:28] Warning: PYTORCH_CUDA_ALLOC_CONF is deprecated, use PYTORCH_ALLOC_CONF instead (function operator()) +[W429 16:44:41.139661567 AllocatorConfig.cpp:28] Warning: PYTORCH_CUDA_ALLOC_CONF is deprecated, use PYTORCH_ALLOC_CONF instead (function operator()) diff --git a/records/track_non_record_16mb/2026-04-30_TTSM_TernarySSM/train_ternary.py b/records/track_non_record_16mb/2026-04-30_TTSM_TernarySSM/train_ternary.py new file mode 100644 index 0000000000..2fdf30a8ea --- /dev/null +++ b/records/track_non_record_16mb/2026-04-30_TTSM_TernarySSM/train_ternary.py @@ -0,0 +1,2154 @@ +""" +Ternary Parameter Golf: train_ternary.py + +Ternary-weight ({-1, 0, +1}) language model training with monitoring-guided +adaptive LR control. Forked from SOTA 10L_Int5MLP submission. + +Key changes from baseline: +- TernaryLinear replaces CastedLinear for body weights (STE-based QAT) +- Trit packing serialization (5 trits/byte = 1.6 bits/param) +- Optional bf16 warmstart → ternary QAT transition +- 7 SSM blocks (density advantage funds more depth) +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +try: + import zstandard + _COMPRESSOR = "zstd" +except ImportError: + _COMPRESSOR = "zlib" + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +# Optional Triton kernel for TTSM scan (replaces Python outer loop, ~100x faster) +# Gate: TTSM_TRITON=1. Requires CUDA + Triton (available on H100 pods). +_TTSM_TRITON_FWD = None +if bool(int(os.environ.get("TTSM_TRITON", "0"))): + try: + # Insert the script's own directory so ttsm_triton_scan.py can be found + # regardless of how the script was invoked (torchrun, python, etc.) + _script_dir = str(Path(__file__).parent) + if _script_dir not in sys.path: + sys.path.insert(0, _script_dir) + from ttsm_triton_scan import selective_ssm_triton_forward as _TTSM_TRITON_FWD + except (ImportError, ModuleNotFoundError) as _e: + print(f"Warning: TTSM_TRITON=1 but import failed ({_e}). Falling back to Python scan.") + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- + +class Hyperparameters: + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 42)) + + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 500)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 100)) + checkpoint_every = int(os.environ.get("CHECKPOINT_EVERY", 0)) # 0=off; N=save raw state dict every N steps + + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3000)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 786_432)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + profile_steps = int(os.environ.get("PROFILE_STEPS", 0)) # 0=off; N=record N active steps after warmup + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_blocks = int(os.environ.get("NUM_BLOCKS", 5)) + num_loops = int(os.environ.get("NUM_LOOPS", 1)) + # num_layers kept for logging; effective depth = num_blocks * num_loops + num_layers = int(os.environ.get("NUM_LAYERS", 14)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = float(os.environ.get("MLP_MULT", 3.0)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.03)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.02)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.02)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_eq_r = bool(int(os.environ.get("MUON_EQ_R", "0"))) # MuonEq-R: row-normalize before NS (arXiv:2603.28254) + # Selective Muon for B/C projections: give SSM write-gate (B) and readout (C) a separate LR. + # 0.0 = disabled (all matrix params in one Muon group, current default). + # BC_LR > 0 = split B_proj/C_proj into their own Muon group at this LR. + # Tink probe: k-axis asymmetry — B (write gate) and C (readout) may benefit from different learning dynamics. + bc_lr = float(os.environ.get("BC_LR", 0.0)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) + weight_decay = float(os.environ.get("WEIGHT_DECAY", 0.0)) + + eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) + eval_batch_seqs = int(os.environ.get("EVAL_BATCH_SEQS", 32)) + + # BigramHash: BIGRAM_BUCKETS (alias BIGRAM_VOCAB_SIZE) buckets, BIGRAM_D_HEAD (alias BIGRAM_DIM) d_head + # Recommended: 3072 buckets, d_head=16 (session 8 sweep winner, +0.004 bpb local) + bigram_vocab_size = int(os.environ.get("BIGRAM_BUCKETS", os.environ.get("BIGRAM_VOCAB_SIZE", 0))) + bigram_dim = int(os.environ.get("BIGRAM_D_HEAD", os.environ.get("BIGRAM_DIM", 16))) + # TrigramHash: 2-head (bigram+trigram) embedding. Gate: TRIGRAM_HASH=1. + # Recommended: 8192 buckets, d_head=16 (+0.016 bpb local, +3.1% compute, +1.05MB) + trigram_hash = bool(int(os.environ.get("TRIGRAM_HASH", "0"))) + trigram_buckets = int(os.environ.get("TRIGRAM_BUCKETS", 8192)) + trigram_d_head = int(os.environ.get("TRIGRAM_D_HEAD", 16)) + + swa_enabled = bool(int(os.environ.get("SWA_ENABLED", "0"))) + swa_start_frac = float(os.environ.get("SWA_START_FRAC", 0.4)) + swa_every = int(os.environ.get("SWA_EVERY", 50)) + + # EMA (replaces SWA when enabled) + ema_enabled = bool(int(os.environ.get("EMA_ENABLED", "0"))) + ema_decay = float(os.environ.get("EMA_DECAY", 0.995)) + + # Ternary QAT settings + ternary_enabled = bool(int(os.environ.get("TERNARY_ENABLED", "1"))) + ternary_per_row = bool(int(os.environ.get("TERNARY_PER_ROW", "1"))) + # Warmstart fraction: train in bf16 for this fraction before switching to ternary QAT + # 0.0 = immediate ternary, 0.1 = 10% bf16 warmstart + warmstart_frac = float(os.environ.get("WARMSTART_FRAC", 0.0)) + # Monitor settings + # "log" = log signals only, "adapt" = adaptive LR, "off" = disabled + monitor_mode = os.environ.get("MONITOR_MODE", "log") + monitor_every = int(os.environ.get("MONITOR_EVERY", 100)) + # Committee density: per-block ternary non-zero fraction. 0 = disabled. + committee_log_interval = int(os.environ.get("COMMITTEE_LOG_INTERVAL", 0)) + # Per-block warmdown: late blocks (last 2) crystallize at 0.5x ramp rate. + # Tink: -0.013 bpb local (weak, correct direction). Gate: PER_BLOCK_WARMDOWN=1. + per_block_warmdown = bool(int(os.environ.get("PER_BLOCK_WARMDOWN", "0"))) + # Learned quant_mix: make quant_mix a learned scalar per TernaryLinear (vs fixed schedule). + # Tink session 8: -0.031 bpb local with penalty=0.001. Gate: LEARNED_QUANT_MIX=1. + learned_quant_mix = bool(int(os.environ.get("LEARNED_QUANT_MIX", "0"))) + qm_penalty = float(os.environ.get("QM_PENALTY", "0.001")) + # Layer-type warmdown: attn layers crystallize at 1x, MLP layers at 0.5x ramp rate. + # Tink session 8: -0.022 bpb + variance reduction. Gate: LAYER_TYPE_WARMDOWN=1. + layer_type_warmdown = bool(int(os.environ.get("LAYER_TYPE_WARMDOWN", "0"))) + + # ── TTSM: Selective SSM blocks ────────────────────────────────────────────── + # TTSM_BLOCKS=N replaces the first N transformer blocks with TTSMBlocks. + # 0 = pure transformer (default). N = num_blocks = pure TTSM. + # H100 target config: TTSM_BLOCKS=7, SSM_ONLY=1, D_STATE=64, A_LOG_INIT=diverse + ttsm_blocks = int(os.environ.get("TTSM_BLOCKS", "0")) + d_state = int(os.environ.get("D_STATE", "64")) + ssm_only = bool(int(os.environ.get("SSM_ONLY", "0"))) # drop MLP from SSM blocks + a_log_init = os.environ.get("A_LOG_INIT", "diverse") # "zero" or "diverse" + scan_chunk_size = int(os.environ.get("SCAN_CHUNK_SIZE", "64")) # inner scan loop chunk + # Short conv: causal depthwise conv (kernel=4) on dt/B/C before recurrence. + # GDN study (session 16): "crucial to performance". Enabled by SSM_SHORT_CONV=1. + ssm_short_conv = bool(int(os.environ.get("SSM_SHORT_CONV", "0"))) + # L2 normalize B and C before state update (stabilizes ternary QAT). + ssm_normalize_bc = bool(int(os.environ.get("SSM_NORMALIZE_BC", "0"))) + # Triton kernel for SSM scan: replaces Python outer loop with Triton chunk scan. + # ~100x speedup over Python loop (from 9s/step → ~100ms/step target). + # Requires CUDA + Triton. Enable with TTSM_TRITON=1. + ttsm_triton = bool(int(os.environ.get("TTSM_TRITON", "0"))) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Two-stage polynomial (DeepSeek V4 style): + # Phase 1 (convergence): (3.4445, -4.7750, 2.0315) — first min(steps, 8) iterations + # Phase 2 (refinement): (2.0, -1.5, 0.5) — remaining steps after 8 + # When steps <= 8: all steps use phase-1 (backward compatible with prior runs). + # When steps == 10: 8 phase-1 + 2 phase-2 = DeepSeek V4 hybrid NS. + # Tink validation: -0.028 loss at 150 steps vs steps=5, zero step-time overhead. + # + # Non-2D tensors (e.g. depthwise conv weights shape (d, 1, k)): reshape to 2D for NS. + # Muon packs updates into a flat buffer and uses p.numel(), so returning 2D is fine. + assert G.ndim >= 2, f"NS requires at least 2D, got shape {G.shape}" + if G.ndim > 2: + G = G.reshape(G.shape[0], -1) + a1, b1, c1 = (3.4445, -4.7750, 2.0315) + a2, b2, c2 = (2.0, -1.5, 0.5) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(min(steps, 8)): + A = X @ X.T + B = b1 * A + c1 * A @ A + X = a1 * X + B @ X + for _ in range(max(steps - 8, 0)): + A = X @ X.T + B = b2 * A + c2 * A @ A + X = a2 * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True, weight_decay: float = 0.0, muon_eq_r: bool = False): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov, weight_decay=weight_decay, muon_eq_r=muon_eq_r), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + muon_eq_r = group.get("muon_eq_r", False) + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + if muon_eq_r and g.dim() == 2: + # MuonEq-R (arXiv:2603.28254): row-normalize before NS + row_norms = g.norm(dim=1, keepdim=True).clamp(min=1e-8) + g = g / row_norms + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + wd = group.get("weight_decay", 0.0) + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + if wd > 0: + p.data.mul_(1.0 - lr * wd) + p.add_(g, alpha=-lr) + curr += p.numel() + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION +# ----------------------------- + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("\u2581"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + + +# ----------------------------- +# POST-TRAINING SERIALIZATION (TERNARY TRIT PACKING) +# ----------------------------- + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + # ssm_scale, mlp_scale, resid_mix: per-block scale/mix parameters + # dt_proj: fp16 CastedLinear in SSM — must NOT go to Muon or ternary serialize + # A_log: fp32 2-D SSM state matrix — must NOT go to Muon or ternary serialize + "attn_scale,attn_scales,mlp_scale,mlp_scales,ssm_scale,resid_mix,resid_mixes," + "q_gain,skip_weight,skip_weights,loop_pos,smear,bigram.scale,dt_proj,A_log", + ).split(",") + if pattern +) + +_TRITS_PER_BYTE = 5 +_POWERS = torch.tensor([1, 3, 9, 27, 81], dtype=torch.int32) + +def pack_trits(tensor: Tensor) -> tuple[Tensor, list[int], int]: + """Pack ternary {-1, 0, +1} tensor into bytes. 5 trits per byte.""" + shape = list(tensor.shape) + flat = (tensor.reshape(-1).to(torch.int32) + 1) + n = flat.numel() + pad = (_TRITS_PER_BYTE - n % _TRITS_PER_BYTE) % _TRITS_PER_BYTE + if pad > 0: + flat = torch.cat([flat, torch.ones(pad, dtype=torch.int32)]) + groups = flat.reshape(-1, _TRITS_PER_BYTE) + packed = (groups * _POWERS).sum(dim=1).to(torch.uint8) + return packed, shape, pad + +def unpack_trits(packed: Tensor, shape: list[int], pad: int) -> Tensor: + """Unpack bytes back to ternary {-1, 0, +1} tensor.""" + values = packed.to(torch.int32).unsqueeze(1).expand(-1, _TRITS_PER_BYTE) + divisors = torch.tensor([1, 3, 9, 27, 81], dtype=torch.int32, device=packed.device) + trits = ((values // divisors) % 3).reshape(-1) + total = 1 + for s in shape: + total *= s + trits = trits[:total] + return (trits - 1).to(torch.int8).reshape(shape) + +def _is_ternary_param(name: str, tensor: Tensor) -> bool: + """Determine if a parameter should be serialized as ternary.""" + if tensor.numel() <= 8192: + return False + if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): + return False + if any(p in name for p in ("tok_emb", "lm_head", "bigram")): + return False + if tensor.ndim != 2: + return False + return True + +def ternary_serialize(state_dict: dict[str, Tensor], per_row: bool = True) -> dict: + """Serialize model with ternary weights packed as trits.""" + obj: dict = { + "__format__": "ternary_packed_v1", + "ternary": {}, "scales": {}, "metadata": {}, "passthrough": {}, + } + for name, tensor in state_dict.items(): + t = tensor.detach().cpu().contiguous() + if _is_ternary_param(name, t): + t_float = t.float() + # First pass: absmean scale to determine ternary assignments + if per_row and t_float.ndim == 2: + scale_init = t_float.abs().mean(dim=1, keepdim=True).clamp_min(1e-8) + else: + scale_init = t_float.abs().mean().clamp_min(1e-8) + ternary = torch.clamp(torch.round(t_float / scale_init), -1, 1).to(torch.int8) + # Second pass: least-squares optimal scale per row + # s_i = (W_i · T_i) / (T_i · T_i), minimizes ||W - s*T||^2 + if per_row and t_float.ndim == 2: + t_f = ternary.float() + num = (t_float * t_f).sum(dim=1, keepdim=True) + den = (t_f * t_f).sum(dim=1, keepdim=True).clamp(min=1.0) + scale = (num / den).clamp_min(1e-8) + else: + scale = scale_init # fallback for per-tensor + packed, shape, pad = pack_trits(ternary) + obj["ternary"][name] = packed + obj["scales"][name] = scale.to(torch.float16) + obj["metadata"][name] = {"shape": shape, "pad": pad} + elif any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): + obj["passthrough"][name] = t.float() + else: + obj["passthrough"][name] = t.to(torch.float16) if t.is_floating_point() else t + return obj + +def ternary_deserialize(obj: dict, template_sd: dict[str, Tensor]) -> dict[str, Tensor]: + """Deserialize packed ternary model back to a state dict.""" + out: dict[str, Tensor] = {} + for name, orig in template_sd.items(): + if name in obj["ternary"]: + packed = obj["ternary"][name] + meta = obj["metadata"][name] + ternary = unpack_trits(packed, meta["shape"], meta["pad"]) + scale = obj["scales"][name].float() + if scale.ndim > 1: + out[name] = (ternary.float() * scale).to(orig.dtype) + else: + out[name] = (ternary.float() * scale.item()).to(orig.dtype) + elif name in obj["passthrough"]: + t = obj["passthrough"][name] + out[name] = t.to(orig.dtype) if t.is_floating_point() and t.dtype != orig.dtype else t + else: + raise KeyError(f"Parameter '{name}' not found in serialized model") + return out + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + def forward(self, x: Tensor) -> Tensor: + w = self.weight.to(x.dtype) + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w, bias) + + +class TernaryLinear(nn.Module): + """Linear layer with ternary QAT. STE-based: forward uses {-1,0,+1}*scale, + backward sees continuous weights.""" + + def __init__(self, in_features: int, out_features: int, bias: bool = False, + per_row: bool = True, learned_qm: bool = False): + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.per_row = per_row + self.learned_qm = learned_qm + self.weight = nn.Parameter(torch.empty(out_features, in_features)) + if bias: + self.bias = nn.Parameter(torch.zeros(out_features)) + else: + self.register_parameter("bias", None) + self._zero_init = False + if learned_qm: + # Scalar learned mix: sigmoid(-3.0) ≈ 0.05 — model starts nearly float + self.qm_logit = nn.Parameter(torch.tensor(-3.0)) + # Branchless quantization toggle: 0.0 = passthrough (bf16 warmstart), 1.0 = ternary STE + # With learned_qm: effective mix = _quant_mix * sigmoid(qm_logit) + # Using a buffer so torch.compile sees a tensor, not a python bool + self.register_buffer("_quant_mix", torch.tensor(1.0), persistent=False) + + @property + def quantize_enabled(self) -> bool: + return self._quant_mix.item() > 0.5 + + @quantize_enabled.setter + def quantize_enabled(self, val: bool) -> None: + self._quant_mix.fill_(1.0 if val else 0.0) + + def forward(self, x: Tensor) -> Tensor: + w = self.weight.to(x.dtype) + # Always compute ternary quantization (branchless for torch.compile) + if self.per_row: + scale = w.abs().mean(dim=1, keepdim=True).clamp_min(1e-8) + else: + scale = w.abs().mean().clamp_min(1e-8) + w_q = torch.clamp(torch.round(w / scale), -1, 1) * scale + # STE: blend between passthrough (mix=0) and quantized (mix=1) + diff = (w_q - w).detach() + if self.learned_qm: + effective_qm = self._quant_mix * torch.sigmoid(self.qm_logit) + else: + effective_qm = self._quant_mix + w = w + diff * effective_qm + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w, bias) + + def get_ternary_weights(self) -> tuple[Tensor, Tensor]: + """Extract discrete ternary weights and scales for serialization.""" + with torch.no_grad(): + w = self.weight.float() + if self.per_row: + scale = w.abs().mean(dim=1, keepdim=True).clamp_min(1e-8) + else: + scale = w.abs().mean().clamp_min(1e-8) + ternary = torch.clamp(torch.round(w / scale), -1, 1).to(torch.int8) + return ternary, scale.to(torch.float16) + + +def compute_committee_density(model: nn.Module) -> str: + """Log per-block ternary non-zero fraction (attn / mlp_fc / mlp_proj) + row_cv. + + Works on both TTRM and GPT models. The 'committee' is the set of active + (non-zero) ternary weights — sparser = more selective = better LR signal. + + row_cv (coefficient of variation of per-row non-zero fraction) is the + lock-formation metric from the SSD fork/lock framing: + - Low row_cv (~1.0) = forks open, rows uniformly active (early blocks) + - High row_cv (~1.5+) = locks formed, bimodal specialist/silent rows (late blocks) + Computed for mlp_fc only (the primary lottery-ticket layer). + """ + if not hasattr(model, "blocks"): + return "committee_density:no_blocks" + parts = [] + for i, block in enumerate(model.blocks): + attn_nz: list[float] = [] + for attr in ("c_q", "c_k", "c_v", "proj"): + layer = getattr(block.attn, attr, None) + if isinstance(layer, TernaryLinear): + t, _ = layer.get_ternary_weights() + attn_nz.append((t != 0).float().mean().item()) + attn_frac = sum(attn_nz) / len(attn_nz) if attn_nz else 0.0 + + fc_frac = 0.0 + proj_frac = 0.0 + fc_row_cv = 0.0 + if hasattr(block, "mlp"): + if isinstance(block.mlp.fc, TernaryLinear): + t, _ = block.mlp.fc.get_ternary_weights() + nz = (t != 0).float() + fc_frac = nz.mean().item() + # row_cv: lock-formation metric (high = specialist rows crystallized) + row_nz = nz.mean(dim=1) # per output-neuron non-zero fraction + row_mean = row_nz.mean().item() + row_std = row_nz.std().item() + fc_row_cv = row_std / max(row_mean, 1e-8) + if isinstance(block.mlp.proj, TernaryLinear): + t, _ = block.mlp.proj.get_ternary_weights() + proj_frac = (t != 0).float().mean().item() + parts.append( + f"b{i}:attn={attn_frac:.3f},fc={fc_frac:.3f},proj={proj_frac:.3f},fc_row_cv={fc_row_cv:.3f}" + ) + return "committee_density " + " ".join(parts) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int, num_kv_heads: int, rope_base: float, qk_gain_init: float, + learned_qm: bool = False): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = TernaryLinear(dim, dim, bias=False, learned_qm=learned_qm) + self.c_k = TernaryLinear(dim, kv_dim, bias=False, learned_qm=learned_qm) + self.c_v = TernaryLinear(dim, kv_dim, bias=False, learned_qm=learned_qm) + self.proj = TernaryLinear(dim, dim, bias=False, learned_qm=learned_qm) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, k, v, attn_mask=None, is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + def __init__(self, dim: int, mlp_mult: float, learned_qm: bool = False): + super().__init__() + hidden = int(mlp_mult * dim) + self.fc = TernaryLinear(dim, hidden, bias=False, learned_qm=learned_qm) + self.proj = TernaryLinear(hidden, dim, bias=False, learned_qm=learned_qm) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class SmearGate(nn.Module): + """Blend each token's embedding with the previous token's embedding.""" + def __init__(self, dim: int): + super().__init__() + self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + + def forward(self, x: Tensor) -> Tensor: + g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] + x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) + return (1 - g) * x + g * x_prev + + +class BigramHashEmbedding(nn.Module): + """Hash consecutive token pairs into a learned embedding table.""" + def __init__(self, bigram_vocab_size: int, bigram_dim: int, model_dim: int): + super().__init__() + self.bigram_vocab_size = bigram_vocab_size + self.embed = nn.Embedding(bigram_vocab_size, bigram_dim) + nn.init.zeros_(self.embed.weight) + self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) + + def bigram_hash(self, tokens: Tensor) -> Tensor: + t = tokens.to(torch.int32) + mod = self.bigram_vocab_size - 1 + out = torch.empty_like(t) + out[..., 0] = mod + out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod + return out.long() + + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(self.bigram_hash(token_ids)) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) + + +class TrigramHashEmbedding(nn.Module): + """EngramLite: lookup-table embeddings for token bigrams + trigrams. + + head 0: bigram = XOR_hash(prev, curr) + head 1: trigram = XOR_hash(prev2, prev, curr) + + Zero-init proj + learnable scale → neutral at init, learns gradually. + """ + def __init__(self, num_buckets: int, d_head: int, model_dim: int, num_heads: int = 2): + super().__init__() + self.num_buckets = num_buckets + self.num_heads = num_heads + self.tables = nn.ModuleList([ + nn.Embedding(num_buckets, d_head) for _ in range(num_heads) + ]) + for t in self.tables: + nn.init.zeros_(t.weight) + self.proj = CastedLinear(num_heads * d_head, model_dim, bias=False) + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.1, dtype=torch.float32)) + + def forward(self, token_ids: Tensor) -> Tensor: + t = token_ids.to(torch.int32) + nb = self.num_buckets + # head 0: bigram hash(prev, curr); position 0 gets index 0 (neutral bucket) + bigram_idx = torch.zeros_like(t, dtype=torch.int64) + bigram_idx[:, 1:] = torch.abs( + torch.bitwise_xor(36313 * t[:, 1:], 27191 * t[:, :-1]) + ) % nb + # head 1: trigram hash(prev2, prev, curr); positions 0-1 get index 0 + trigram_idx = torch.zeros_like(t, dtype=torch.int64) + trigram_idx[:, 2:] = torch.abs( + torch.bitwise_xor( + torch.bitwise_xor(36313 * t[:, 2:], 27191 * t[:, 1:-1]), + 12979 * t[:, :-2], + ) + ) % nb + hashes = [bigram_idx, trigram_idx][: self.num_heads] + embs = [self.tables[i](h) for i, h in enumerate(hashes)] + out = self.proj(torch.cat(embs, dim=-1)) # (B, S, model_dim) + return out * self.scale.to(dtype=out.dtype) + + +class Block(nn.Module): + def __init__(self, dim: int, num_heads: int, num_kv_heads: int, mlp_mult: float, rope_base: float, qk_gain_init: float, + learned_qm: bool = False): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init, learned_qm=learned_qm) + self.mlp = MLP(dim, mlp_mult, learned_qm=learned_qm) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +# ─── TTSM: Selective SSM + Block ───────────────────────────────────────────── + +# Module-level compiled scan chunk function. +# Compiling only the inner 64-step loop avoids dynamo trying to unroll the full +# 32×64=2048 sequence (which hangs compilation). The outer Python loop (32 iters) +# is cheap and calls this compiled function once per chunk. +# dynamic=True: handles variable batch sizes without recompilation. +@torch.compile(dynamic=True) +def _ssm_scan_chunk( + h: Tensor, # (batch, d_inner, d_state) — incoming state + dA_c: Tensor, # (batch, chunk_size, d_inner, d_state) + dBx_c: Tensor, # (batch, chunk_size, d_inner, d_state) + C_c: Tensor, # (batch, chunk_size, d_state) + chunk_size: int, +) -> tuple[Tensor, Tensor]: + """One chunk of the SSM scan — compiled as a single CUDA kernel.""" + y_chunk = torch.empty(h.shape[0], chunk_size, h.shape[1], device=h.device, dtype=h.dtype) + for i in range(chunk_size): + h = dA_c[:, i] * h + dBx_c[:, i] + y_chunk[:, i, :] = (h * C_c[:, i, None, :]).sum(-1) + return h, y_chunk + + +class SelectiveSSM(nn.Module): + """Mamba-1 selective SSM with chunk-wise parallel scan. + + Ternary treatment (confirmed Phase 0/1, session 9-10): + B_proj, C_proj → TernaryLinear (selectivity survives ternary quantization) + dt_proj → CastedLinear (fp16 — Δ gates recurrence, ternary = unstable) + A_log → fp32 Parameter (diagonal per (d_inner, d_state), kept precise) + D → fp32 Parameter (direct skip term, standard Mamba) + + Optional: short convolution on dt/B/C before recurrence (use_short_conv=True). + GDN study (session 16): "ShortConvolution is crucial to performance" — depthwise + causal conv kernel_size=4 + SiLU gives local context the SSM can't get from one token. + Conv weights are depthwise (groups=d), ndim=3 → go to AdamW scalar, not Muon. + + Optional: L2 normalize B and C before state update (normalize_bc=True). + Prevents state explosion under ternary QAT where STE can produce large gradients. + + Scan: chunk-wise hybrid. + - Projections (dt, B, C) run over the full sequence in one parallel pass. + - State carry is sequential within each chunk (SCAN_CHUNK_SIZE steps). + - torch.compile fuses the fixed-size inner loop into a single CUDA kernel. + - Outer loop: seq_len // chunk_size iterations (32 for seq=2048, chunk=64). + + A_log init: + "diverse" → A spans log(1..d_state) per channel (tiered cache hierarchy, + standard Mamba init, confirmed better in Tink session 13) + "zero" → A = -1 for all channels (single timescale, original default) + """ + + def __init__(self, d_inner: int, d_state: int, a_log_init: str = "diverse", + scan_chunk_size: int = 64, use_short_conv: bool = False, + normalize_bc: bool = False): + super().__init__() + self.d_inner = d_inner + self.d_state = d_state + self.scan_chunk_size = scan_chunk_size + self.use_short_conv = use_short_conv + self.normalize_bc = normalize_bc + + if a_log_init == "diverse": + # Standard Mamba init: tiered cache hierarchy across state channels + a_init = torch.log(torch.arange(1, d_state + 1, dtype=torch.float32)) + self.A_log = nn.Parameter(a_init.unsqueeze(0).expand(d_inner, -1).clone()) + else: + self.A_log = nn.Parameter(torch.zeros(d_inner, d_state)) + + self.B_proj = TernaryLinear(d_inner, d_state, bias=False) + self.C_proj = TernaryLinear(d_inner, d_state, bias=False) + self.dt_proj = CastedLinear(d_inner, d_inner, bias=False) + self.D = nn.Parameter(torch.ones(d_inner)) + + if use_short_conv: + # Depthwise causal conv, kernel_size=4, each channel independent. + # padding=3 on both sides; forward slices to seq_len for causality. + # ndim=3 weights → AdamW scalar automatically (not Muon, not ternary). + self.dt_conv = nn.Conv1d(d_inner, d_inner, kernel_size=4, groups=d_inner, padding=3, bias=True) + self.B_conv = nn.Conv1d(d_state, d_state, kernel_size=4, groups=d_state, padding=3, bias=True) + self.C_conv = nn.Conv1d(d_state, d_state, kernel_size=4, groups=d_state, padding=3, bias=True) + nn.init.zeros_(self.dt_conv.bias) + nn.init.zeros_(self.B_conv.bias) + nn.init.zeros_(self.C_conv.bias) + + @staticmethod + def _causal_conv1d(x: Tensor, conv: "nn.Conv1d") -> Tensor: + """Apply causal depthwise conv1d. + x: (batch, seq, d) → (batch, d, seq) → conv → crop to seq → (batch, seq, d) + padding=3 adds 3 zeros left + 3 right; slicing [:seq] removes right padding. + """ + seq = x.shape[1] + return conv(x.transpose(1, 2))[..., :seq].transpose(1, 2) + + @torch._dynamo.disable # prevent outer torch.compile(model) from tracing into this forward; + # _ssm_scan_chunk below is compiled separately when first called (64-step inner kernel) + def forward(self, x: Tensor) -> Tensor: + """x: (batch, seq, d_inner) → y: (batch, seq, d_inner) + + Two-level compilation strategy: + - This forward has @torch._dynamo.disable: outer model compile takes a graph break here. + Outer Python loop (32 iterations) runs eagerly — cheap. + - _ssm_scan_chunk is @torch.compile'd separately: compiles once (64-step loop, not 2048), + runs as a single fused CUDA kernel per chunk. 32× kernel launches vs 2048× Python ops. + """ + batch, seq, _ = x.shape + chunk_size = self.scan_chunk_size + assert seq % chunk_size == 0, ( + f"seq_len={seq} must be divisible by SCAN_CHUNK_SIZE={chunk_size}" + ) + + A = -torch.exp(self.A_log.float()) # (d_inner, d_state) + + # Full-sequence projections — all parallel (the expensive GEMMs) + dt_raw = self.dt_proj(x).float() # (batch, seq, d_inner) — pre-gate + B = self.B_proj(x).float() # (batch, seq, d_state) + C = self.C_proj(x).float() # (batch, seq, d_state) + x_f = x.float() # (batch, seq, d_inner) + + # Short convolution: local context before recurrence (GDN: "crucial to performance"). + # Apply conv BEFORE softplus for dt — softplus enforces dt > 0 for ZOH stability. + # silu after softplus would allow negative dt → positive exp(dt*A) → blow-up/NaN. + if self.use_short_conv: + dt_raw = self._causal_conv1d(dt_raw, self.dt_conv) + B = self._causal_conv1d(B, self.B_conv) + C = self._causal_conv1d(C, self.C_conv) + dt = F.softplus(dt_raw) # (batch, seq, d_inner) — always > 0 + + # L2 normalize B and C: prevents state explosion under ternary QAT + if self.normalize_bc: + B = F.normalize(B, dim=-1) + C = F.normalize(C, dim=-1) + + if _TTSM_TRITON_FWD is not None: + # ── Triton path: single kernel call, no Python outer loop ────────── + # Forward uses Triton (fast). Backward uses compiled PyTorch chunks (safe, no OOB). + # Works in both training mode and eval/inference mode. + y, _ = _TTSM_TRITON_FWD(x_f, self.A_log, dt, B, C, self.D) + return y.to(x.dtype) + + # ── Python path (fallback when TTSM_TRITON=0) ────────────────────────── + h = x.new_zeros(batch, self.d_inner, self.d_state, dtype=torch.float32) + y = torch.empty(batch, seq, self.d_inner, device=x.device, dtype=torch.float32) + + for chunk_start in range(0, seq, chunk_size): + # Reshape for broadcasting — all c=chunk_size positions computed in parallel + dt_c = dt[:, chunk_start:chunk_start + chunk_size, :, None] # (batch, c, d_inner, 1) + B_c = B[:, chunk_start:chunk_start + chunk_size, None, :] # (batch, c, 1, d_state) + x_c = x_f[:, chunk_start:chunk_start + chunk_size, :, None] # (batch, c, d_inner, 1) + C_c = C[:, chunk_start:chunk_start + chunk_size, :] # (batch, c, d_state) + + # ZOH discretization — all c positions in parallel + dA_c = torch.exp(dt_c * A[None, None, :, :]) # (batch, c, d_inner, d_state) + dBx_c = dt_c * B_c * x_c # (batch, c, d_inner, d_state) + + # Inner scan via compiled chunk function (fast CUDA kernel, no Python loop overhead) + h, y_chunk = _ssm_scan_chunk(h, dA_c, dBx_c, C_c, chunk_size) + y[:, chunk_start:chunk_start + chunk_size, :] = y_chunk + + return (y + self.D[None, None, :] * x_f).to(x.dtype) + + +class TTSMBlock(nn.Module): + """TTSM block: gated SSM sublayer + optional MLP sublayer. + + Drop-in for Block — same forward(x, x0) → x signature. + + Gated SSM structure (standard Mamba): + xz = in_proj(norm(x)) # (batch, seq, d_inner * 2) + x_in, z = split(xz, 2, dim=-1) + ssm_out = ssm(x_in) * silu(z) # gated SSM output + x += ssm_scale * out_proj(ssm_out) + + MLP sublayer (present unless SSM_ONLY=1): + x += mlp_scale * mlp(norm(x)) + + MLP cannibalization confirmed (Tink session 13): SSM-only often beats SSM+MLP + because the MLP is a training trap. SSM_ONLY=1 removes it entirely. + """ + + def __init__(self, dim: int, d_inner: int, d_state: int, mlp_mult: float, + ssm_only: bool = False, a_log_init: str = "diverse", + scan_chunk_size: int = 64, use_short_conv: bool = False, + normalize_bc: bool = False): + super().__init__() + self.d_inner = d_inner + self.ssm_only = ssm_only + + self.norm = RMSNorm() + self.in_proj = TernaryLinear(dim, d_inner * 2, bias=False) + self.ssm = SelectiveSSM(d_inner, d_state, a_log_init=a_log_init, + scan_chunk_size=scan_chunk_size, + use_short_conv=use_short_conv, + normalize_bc=normalize_bc) + self.out_proj = TernaryLinear(d_inner, dim, bias=False) + self.out_proj._zero_init = True + self.ssm_scale = nn.Parameter(torch.ones(dim)) + + if not ssm_only: + self.mlp_norm = RMSNorm() + self.mlp = MLP(dim, mlp_mult) + self.mlp_scale = nn.Parameter(torch.ones(dim)) + + self.resid_mix = nn.Parameter(torch.stack([torch.ones(dim), torch.zeros(dim)])) + + @torch._dynamo.disable # Prevent Dynamo from tracing into TTSMBlock when TTSM_TRITON=1: + # SelectiveSSM.forward calls the Triton kernel, which runs eagerly inside Dynamo traces + # and causes CUDA illegal memory access on H100 during eval-mode recompilation. + # With TTSM_TRITON=0 (Python path), this decorator is harmless (minor perf overhead). + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + + # SSM sublayer + xz = self.in_proj(self.norm(x)) + x_in, z = xz.chunk(2, dim=-1) + ssm_out = self.ssm(x_in) * F.silu(z) + x = x + self.ssm_scale.to(dtype=x.dtype)[None, None, :] * self.out_proj(ssm_out) + + # MLP sublayer (optional) + if not self.ssm_only: + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: float, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + bigram_vocab_size: int = 0, + bigram_dim: int = 128, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.bigram = BigramHashEmbedding(bigram_vocab_size, bigram_dim, model_dim) if bigram_vocab_size > 0 else None + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.smear = SmearGate(model_dim) + self.blocks = nn.ModuleList( + [ + Block(model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init) + for _ in range(num_layers) + ] + ) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + num_layers = len(self.blocks) + for name, module in self.named_modules(): + if isinstance(module, (nn.Linear, TernaryLinear)): + if getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: + nn.init.orthogonal_(module.weight, gain=1.0) + if ".proj." in name or name.endswith(".proj"): + with torch.no_grad(): + module.weight.mul_(1.0 / math.sqrt(2 * num_layers)) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + logits_f = logits.float() + lse = torch.logsumexp(logits_f, dim=-1) + return F.cross_entropy(logits_f, targets, reduction="mean") + 1e-4 * lse.pow(2).mean() + + def forward_logits(self, input_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + x = self.final_norm(x) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + logits_proj = self.lm_head(x) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + + +class TTRM(nn.Module): + """Typical Ternary Recursive Model. + + num_blocks unique shared blocks × num_loops iterations = effective depth. + Loop position embeddings distinguish iterations. + U-Net skip connections per loop. + No SmearGate (ternary-incompatible). + BigramHash injected before x0 is set (amplifies across loops). + """ + + def __init__( + self, + vocab_size: int, + num_blocks: int, + num_loops: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: float, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + bigram_vocab_size: int = 0, + bigram_dim: int = 16, + trigram_hash: bool = False, + trigram_buckets: int = 8192, + trigram_d_head: int = 16, + learned_qm: bool = False, + # TTSM: replace first ttsm_blocks transformer blocks with SelectiveSSM blocks + ttsm_blocks: int = 0, + d_state: int = 64, + ssm_only: bool = False, + a_log_init: str = "diverse", + scan_chunk_size: int = 64, + ssm_short_conv: bool = False, + ssm_normalize_bc: bool = False, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.num_blocks = num_blocks + self.num_loops = num_loops + self.ttsm_blocks = ttsm_blocks + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.bigram = BigramHashEmbedding(bigram_vocab_size, bigram_dim, model_dim) if bigram_vocab_size > 0 else None + self.trigram = TrigramHashEmbedding(trigram_buckets, trigram_d_head, model_dim) if trigram_hash else None + # Per-loop position embeddings (learned, passthrough — not ternary) + self.loop_pos = nn.Parameter(torch.randn(num_loops, model_dim) * 0.02) + # N unique blocks: first ttsm_blocks are TTSMBlock, remainder are transformer Block + blocks_list = [] + for i in range(num_blocks): + if i < ttsm_blocks: + blocks_list.append(TTSMBlock( + dim=model_dim, + d_inner=model_dim, + d_state=d_state, + mlp_mult=mlp_mult, + ssm_only=ssm_only, + a_log_init=a_log_init, + scan_chunk_size=scan_chunk_size, + use_short_conv=ssm_short_conv, + normalize_bc=ssm_normalize_bc, + )) + else: + blocks_list.append( + Block(model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init, + learned_qm=learned_qm) + ) + self.blocks = nn.ModuleList(blocks_list) + # U-Net skip connections: per-loop, per-skip-pair + num_enc = num_blocks // 2 + num_dec = num_blocks - num_enc + self.num_enc = num_enc + self.num_dec = num_dec + self.skip_weights = nn.Parameter( + torch.ones(num_loops, min(num_enc, num_dec), model_dim, dtype=torch.float32) + ) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + effective_depth = self.num_blocks * self.num_loops + for name, module in self.named_modules(): + if isinstance(module, (nn.Linear, TernaryLinear)): + if getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: + nn.init.orthogonal_(module.weight, gain=1.0) + if ".proj." in name or name.endswith(".proj"): + with torch.no_grad(): + module.weight.mul_(1.0 / math.sqrt(2 * effective_depth)) + + def _forward_body(self, x: Tensor, x0: Tensor) -> Tensor: + for loop in range(self.num_loops): + x = x + self.loop_pos[loop].to(dtype=x.dtype)[None, None, :] + skips: list[Tensor] = [] + for i in range(self.num_enc): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_dec): + if skips: + sw = self.skip_weights[loop, i].to(dtype=x.dtype)[None, None, :] + x = x + sw * skips.pop() + x = self.blocks[self.num_enc + i](x, x0) + return x + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + if self.trigram is not None: + x = x + self.trigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x # bigram/trigram context baked into x0 + x = self._forward_body(x, x0) + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + logits_f = logits.float() + lse = torch.logsumexp(logits_f, dim=-1) + return F.cross_entropy(logits_f, targets, reduction="mean") + 1e-4 * lse.pow(2).mean() + + def forward_logits(self, input_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + if self.trigram is not None: + x = x + self.trigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + x = self._forward_body(x, x0) + x = self.final_norm(x) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + logits_proj = self.lm_head(x) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + + +def eval_val_sliding( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int, + batch_seqs: int = 32, +) -> tuple[float, float]: + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= stride or ws == 0] + total_windows = len(window_starts) + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + + base_model.eval() + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = base_model.forward_logits(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), + reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt = y_batch[i, s:wlen] + prev = x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + if rank == 0 and (bi // batch_seqs) % 50 == 0: + done = min(bi + batch_seqs, len(my_windows)) + pct = done / len(my_windows) * 100 + running_bpb = 0.0 + if token_count.item() > 0: + rl = (loss_sum / token_count).item() + running_bpb = rl / math.log(2.0) * (token_count.item() / byte_count.item()) + print(f" sliding_eval [{pct:5.1f}%] {done}/{len(my_windows)} windows running_bpb={running_bpb:.6f}", flush=True) + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + + val_loss = (loss_sum / token_count).item() + bits_per_token = val_loss / math.log(2.0) + tokens_per_byte = token_count.item() / byte_count.item() + base_model.train() + return val_loss, bits_per_token * tokens_per_byte + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # MODEL + OPTIMIZER SETUP + base_model = TTRM( + vocab_size=args.vocab_size, + num_blocks=args.num_blocks, + num_loops=args.num_loops, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + bigram_vocab_size=args.bigram_vocab_size, + bigram_dim=args.bigram_dim, + trigram_hash=args.trigram_hash, + trigram_buckets=args.trigram_buckets, + trigram_d_head=args.trigram_d_head, + learned_qm=args.learned_quant_mix, + ttsm_blocks=args.ttsm_blocks, + d_state=args.d_state, + ssm_only=args.ssm_only, + a_log_init=args.a_log_init, + scan_chunk_size=args.scan_chunk_size, + ssm_short_conv=args.ssm_short_conv, + ssm_normalize_bc=args.ssm_normalize_bc, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, (CastedLinear, TernaryLinear)): + module.float() + restore_low_dim_params_to_fp32(base_model) + compile_fullgraph = bool(int(os.environ.get("COMPILE_FULLGRAPH", "0"))) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=compile_fullgraph) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + block_named_params = list(base_model.blocks.named_parameters()) + # When BC_LR > 0: B_proj/C_proj go to a separate Muon group with their own LR. + # When BC_LR == 0: standard — all matrix params in one group (no change in behavior). + _bc_names = ("ssm.B_proj", "ssm.C_proj") + matrix_params = [] + bc_params = [] # populated only when args.bc_lr > 0 + for _name, _p in block_named_params: + # Only strict ndim==2 params go to Muon. ndim==3 conv weights go to Adam (scalar_params). + # Note: ndim<2 (biases, scalars) and control tensors also go to Adam. + if _p.ndim != 2 or any(pattern in _name for pattern in CONTROL_TENSOR_NAME_PATTERNS): + pass # handled by scalar_params below + elif args.bc_lr > 0 and any(bc in _name for bc in _bc_names): + bc_params.append(_p) + else: + matrix_params.append(_p) + scalar_params = [ + p for name, p in block_named_params + if p.ndim != 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + scalar_params.append(base_model.loop_pos) + if base_model.bigram is not None: + scalar_params.append(base_model.bigram.scale) + if base_model.trigram is not None: + scalar_params.append(base_model.trigram.scale) + + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + tok_params = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] + if base_model.bigram is not None: + tok_params.append({"params": [base_model.bigram.embed.weight], "lr": token_lr, "base_lr": token_lr}) + if base_model.bigram.proj is not None: + matrix_params.append(base_model.bigram.proj.weight) + if base_model.trigram is not None: + for table in base_model.trigram.tables: + tok_params.append({"params": [table.weight], "lr": token_lr, "base_lr": token_lr}) + matrix_params.append(base_model.trigram.proj.weight) + + optimizer_tok = torch.optim.AdamW( + tok_params, + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.weight_decay, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + weight_decay=float(os.environ.get("MUON_WD", 0.0)), + muon_eq_r=args.muon_eq_r, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + # Selective Muon for B/C: separate optimizer group, participates in warmdown via base_lr. + if bc_params: + optimizer_bc = Muon( + bc_params, + lr=args.bc_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + weight_decay=float(os.environ.get("MUON_WD", 0.0)), + muon_eq_r=args.muon_eq_r, + ) + for group in optimizer_bc.param_groups: + group["base_lr"] = args.bc_lr + log0(f"selective_muon:bc_lr={args.bc_lr} bc_params={len(bc_params)}") + optimizer_scalar = torch.optim.AdamW( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.weight_decay, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if bc_params: + optimizers.append(optimizer_bc) # BC Muon: warmdown applied via base_lr in the training loop + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + _arch = "TTSM" if args.ttsm_blocks == args.num_blocks else ("TTRM+TTSM" if args.ttsm_blocks > 0 else "TTRM") + log0(f"model:{_arch} num_blocks:{args.num_blocks} num_loops:{args.num_loops} " + f"effective_depth:{args.num_blocks * args.num_loops} model_dim:{args.model_dim}") + if args.ttsm_blocks > 0: + _ssm_kind = "SSM-only" if args.ssm_only else "SSM+MLP" + log0(f"ttsm:blocks={args.ttsm_blocks}/{args.num_blocks} kind={_ssm_kind} " + f"d_state={args.d_state} a_log_init={args.a_log_init} " + f"scan_chunk_size={args.scan_chunk_size} " + f"short_conv={args.ssm_short_conv} normalize_bc={args.ssm_normalize_bc}") + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr} muon_eq_r:{args.muon_eq_r}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + log0(f"ternary:enabled={args.ternary_enabled} per_row={args.ternary_per_row} " + f"warmstart={args.warmstart_frac} monitor={args.monitor_mode} " + f"learned_quant_mix={args.learned_quant_mix} qm_penalty={args.qm_penalty} " + f"layer_type_warmdown={args.layer_type_warmdown}") + + # TERNARY MONITOR + ternary_monitor = None + if args.ternary_enabled and args.monitor_mode != "off": + from ternary_monitor import TernaryMonitor, TernaryMonitorConfig + monitor_cfg = TernaryMonitorConfig(log_every=args.monitor_every) + ternary_monitor = TernaryMonitor(config=monitor_cfg, base_lr=args.matrix_lr) + log0(f"ternary_monitor:initialized mode={args.monitor_mode}") + + # DATA LOADER & MODEL WARMUP + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + # Ternary QAT enable/disable for warmstart transition + def set_ternary_quantize(enable: bool) -> None: + for module in base_model.modules(): + if isinstance(module, TernaryLinear): + module.quantize_enabled = enable + + ternary_activated = not (args.ternary_enabled and args.warmstart_frac > 0) + ternary_activation_step: int = -1 # step when ternary first activated + ternary_ramp_steps: int = 0 # per-block warmdown ramp length (0 = complete/disabled) + if args.ternary_enabled and args.warmstart_frac > 0: + set_ternary_quantize(False) # Start in bf16 mode + log0(f"ternary:warmstart bf16 for first {args.warmstart_frac:.0%} of training") + elif args.ternary_enabled: + set_ternary_quantize(True) + ternary_activated = True + log0("ternary:enabled from step 0") + else: + set_ternary_quantize(False) + log0("ternary:disabled (CastedLinear-equivalent mode)") + + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # MAIN TRAINING LOOP + training_time_ms = 0.0 + stop_after_step: int | None = None + + # Profiler: PROFILE_STEPS=N records N steady-state steps (wait=2 for compile, warmup=1, active=N) + # Trace exported to /workspace/profile_trace.json; training stops after profiling completes. + _prof = None + if args.profile_steps > 0: + def _trace_handler(p): + path = '/workspace/profile_trace.json' + p.export_chrome_trace(path) + log0(f"profiler:trace_saved path:{path}") + _prof = torch.profiler.profile( + activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA], + schedule=torch.profiler.schedule(wait=2, warmup=1, active=args.profile_steps), + on_trace_ready=_trace_handler, + record_shapes=True, + profile_memory=False, + with_stack=False, + ) + _prof.__enter__() + stop_after_step = 2 + 1 + args.profile_steps # wait + warmup + active + log0(f"profiler:started profile_steps:{args.profile_steps} will_stop_at:{stop_after_step}") + swa_state: dict[str, Tensor] | None = None + swa_count = 0 + + # EMA disabled for ternary: EMA averages latent weights off the discrete ternary lattice + if args.ternary_enabled and args.ema_enabled: + log0("WARNING: EMA+ternary averages latent weights off the discrete ternary lattice. Disabling EMA.") + args.ema_enabled = False + + # EMA state (shadow copy of model params, updated every step) + ema_state: dict[str, Tensor] | None = None + if args.ema_enabled: + ema_state = {name: t.detach().cpu().clone() for name, t in base_model.state_dict().items()} + log0(f"ema:initialized decay={args.ema_decay}") + # No autotune warmup needed: BD=128 is hardcoded in _ttsm_scan_fwd_inner (no exploration). + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + # Intermediate checkpoint: raw state dict + .ptz (grokking-safe submission insurance) + # Saves every CHECKPOINT_EVERY steps. .ptz only saved post-ternary-activation. + if master_process and args.checkpoint_every > 0 and step > 0 and step % args.checkpoint_every == 0: + ckpt_path = f"/workspace/ckpt_s{step:06d}_seed{args.seed}.pt" + torch.save(base_model.state_dict(), ckpt_path) + log0(f"checkpoint:saved step:{step} path:{ckpt_path}") + # Also save .ptz after ternary is active (Tink: best val_bpb may be mid-run) + if ternary_activated and args.ternary_enabled: + _sd = {k: v.detach().cpu() for k, v in base_model.state_dict().items()} + _tern = ternary_serialize(_sd, per_row=args.ternary_per_row) + _buf = io.BytesIO() + torch.save(_tern, _buf) + _raw = _buf.getvalue() + _blob = zstandard.ZstdCompressor(level=22).compress(_raw) if _COMPRESSOR == "zstd" else zlib.compress(_raw, 9) + _ptz_path = f"/workspace/ckpt_s{step:06d}_seed{args.seed}.ptz" + with open(_ptz_path, "wb") as _f: + _f.write(_blob) + log0(f"checkpoint:ptz_saved step:{step} path:{_ptz_path} size:{len(_blob)}") + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + + # Warmstart → ternary QAT transition + # Branchless: _quant_mix buffer changes from 0→1, no recompile needed + if not ternary_activated and args.ternary_enabled: + should_activate = False + if max_wallclock_ms is not None: + should_activate = elapsed_ms >= max_wallclock_ms * args.warmstart_frac + else: + should_activate = step >= args.iterations * args.warmstart_frac + if should_activate: + if args.layer_type_warmdown and hasattr(base_model, "blocks"): + # Ramp attn at 1x, MLP at 0.5x; start from quant_mix=0 + set_ternary_quantize(False) # holds at 0.0 until ramp updates + if max_wallclock_ms is not None: + step_ms = elapsed_ms / max(step, 1) + remaining_ms = max_wallclock_ms - elapsed_ms + ternary_ramp_steps = max(1, int(remaining_ms * args.warmstart_frac / max(step_ms, 1e-9))) + else: + ternary_ramp_steps = max(1, int((args.iterations - step) * args.warmstart_frac)) + ternary_activation_step = step + log0(f"ternary:activated layer_type_warmdown ramp_steps:{ternary_ramp_steps} step:{step}") + elif args.per_block_warmdown and hasattr(base_model, "blocks"): + # Start ramp from quant_mix=0; each block ramps at own rate + set_ternary_quantize(False) # holds at 0.0 until ramp updates + if max_wallclock_ms is not None: + step_ms = elapsed_ms / max(step, 1) + remaining_ms = max_wallclock_ms - elapsed_ms + ternary_ramp_steps = max(1, int(remaining_ms * args.warmstart_frac / max(step_ms, 1e-9))) + else: + ternary_ramp_steps = max(1, int((args.iterations - step) * args.warmstart_frac)) + ternary_activation_step = step + log0(f"ternary:activated per_block_warmdown ramp_steps:{ternary_ramp_steps} step:{step}") + else: + set_ternary_quantize(True) + log0(f"ternary:activated at step:{step} elapsed:{elapsed_ms:.0f}ms") + ternary_activated = True + + # Per-block warmdown: ramp late blocks (last 2) at 0.5x crystallization rate + if (args.per_block_warmdown and ternary_activated + and ternary_activation_step >= 0 and ternary_ramp_steps > 0 + and hasattr(base_model, "blocks")): + steps_since = step - ternary_activation_step + n_blocks = len(base_model.blocks) + late_start = max(0, n_blocks - 2) + all_done = True + for i, block in enumerate(base_model.blocks): + rate = 0.5 if i >= late_start else 1.0 + qm = min(1.0, steps_since * rate / ternary_ramp_steps) + if qm < 1.0: + all_done = False + for m in block.modules(): + if isinstance(m, TernaryLinear): + m._quant_mix.fill_(qm) + if all_done: + ternary_ramp_steps = 0 # stop per-step updates once fully crystallized + log0(f"per_block_warmdown:complete at step:{step}") + + # Layer-type warmdown: attn/ssm layers at 1x ramp, MLP layers at 0.5x ramp + if (args.layer_type_warmdown and ternary_activated + and ternary_activation_step >= 0 and ternary_ramp_steps > 0 + and hasattr(base_model, "blocks")): + steps_since = step - ternary_activation_step + all_done = True + for block in base_model.blocks: + qm_attn = min(1.0, steps_since / ternary_ramp_steps) + qm_mlp = min(1.0, steps_since / (ternary_ramp_steps * 2)) + if qm_attn < 1.0 or qm_mlp < 1.0: + all_done = False + if isinstance(block, TTSMBlock): + # SSM projections at attn rate (1x); MLP (if present) at 0.5x + for m in (block.in_proj, block.out_proj, + block.ssm.B_proj, block.ssm.C_proj): + if isinstance(m, TernaryLinear): + m._quant_mix.fill_(qm_attn) + else: + for attr in ("c_q", "c_k", "c_v", "proj"): + m = getattr(block.attn, attr, None) + if isinstance(m, TernaryLinear): + m._quant_mix.fill_(qm_attn) + if hasattr(block, "mlp"): + for m in block.mlp.modules(): + if isinstance(m, TernaryLinear): + m._quant_mix.fill_(qm_mlp) + if all_done: + ternary_ramp_steps = 0 + log0(f"layer_type_warmdown:complete at step:{step}") + + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + if args.learned_quant_mix and ternary_activated and args.qm_penalty > 0: + qm_vals = torch.stack([ + torch.sigmoid(m.qm_logit) + for m in base_model.modules() + if isinstance(m, TernaryLinear) and hasattr(m, "qm_logit") + ]) + loss = loss + ((1.0 - qm_vals) ** 2).mean() * args.qm_penalty + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + if _prof is not None: + _prof.step() + if step >= stop_after_step: + _prof.__exit__(None, None, None) + _prof = None + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + + # EMA: update shadow weights every step + if ema_state is not None: + d = args.ema_decay + with torch.no_grad(): + for name, t in base_model.state_dict().items(): + ema_state[name].mul_(d).add_(t.detach().cpu(), alpha=1 - d) + + # SWA: collect checkpoints during warmdown (fallback if EMA disabled) + if args.swa_enabled and not args.ema_enabled and scale < args.swa_start_frac and step % args.swa_every == 0: + if swa_state is None: + swa_state = {name: t.detach().cpu().clone() for name, t in base_model.state_dict().items()} + swa_count = 1 + log0(f"swa:start step:{step}") + else: + for name, t in base_model.state_dict().items(): + swa_state[name] += t.detach().cpu() + swa_count += 1 + + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Ternary monitor: observe (and optionally act on) training dynamics + if (ternary_monitor is not None and ternary_activated + and step % args.monitor_every == 0): + mon_state, lr_factor = ternary_monitor.step(base_model, train_loss.item()) + + if should_log_train or step % ternary_monitor.config.log_every == 0: + log0(ternary_monitor.format_status()) + + # In "adapt" mode, apply the LR factor + if args.monitor_mode == "adapt": + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale * lr_factor + + # Gradient perturbation if frozen + if ternary_monitor.should_perturb(): + with torch.no_grad(): + for p in base_model.parameters(): + if p.grad is not None and p.ndim == 2: + noise = torch.randn_like(p.grad) * p.grad.abs().mean() + p.grad.add_(noise * ternary_monitor.config.perturb_magnitude) + + if (args.committee_log_interval > 0 and ternary_activated + and step % args.committee_log_interval == 0): + log0(compute_committee_density(base_model)) + + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # Monitor summary + if ternary_monitor is not None and master_process: + summary = ternary_monitor.get_summary() + log0(f"ternary_monitor_summary: {summary}") + + # Apply EMA or SWA + if ema_state is not None: + log0(f"ema:applying decay={args.ema_decay}") + current_state = base_model.state_dict() + ema_applied = { + name: tensor.to(dtype=current_state[name].dtype) + for name, tensor in ema_state.items() + } + base_model.load_state_dict(ema_applied, strict=True) + elif args.swa_enabled and swa_state is not None and swa_count > 1: + log0(f"swa:applying averaged {swa_count} checkpoints") + current_state = base_model.state_dict() + avg_state = { + name: (tensor / swa_count).to(dtype=current_state[name].dtype) + for name, tensor in swa_state.items() + } + base_model.load_state_dict(avg_state, strict=True) + + # Log ternary weight distribution + if master_process and args.ternary_enabled: + total_ternary = 0 + total_neg = 0 + total_zero = 0 + total_pos = 0 + for module in base_model.modules(): + if isinstance(module, TernaryLinear): + t, _ = module.get_ternary_weights() + n = t.numel() + total_ternary += n + total_neg += (t == -1).sum().item() + total_zero += (t == 0).sum().item() + total_pos += (t == 1).sum().item() + if total_ternary > 0: + log0(f"ternary_distribution: total:{total_ternary:,} " + f"-1:{total_neg/total_ternary:.1%} 0:{total_zero/total_ternary:.1%} " + f"+1:{total_pos/total_ternary:.1%}") + + # TERNARY SERIALIZATION + ROUNDTRIP VALIDATION + sd_cpu = {k: v.detach().cpu() for k, v in base_model.state_dict().items()} + + # Serialize: ternary body weights via trit packing, passthrough for embeddings/control + ternary_obj = ternary_serialize(sd_cpu, per_row=args.ternary_per_row) + + # Compress + quant_buf = io.BytesIO() + torch.save(ternary_obj, quant_buf) + quant_raw = quant_buf.getvalue() + if _COMPRESSOR == "zstd": + quant_blob = zstandard.ZstdCompressor(level=22).compress(quant_raw) + else: + quant_blob = zlib.compress(quant_raw, 9) + + artifact_path = f"final_model.ternary.seed{args.seed}.ptz" + if master_process: + with open(artifact_path, "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize(artifact_path) + code_bytes = len(code.encode("utf-8")) + n_ternary = sum(math.prod(meta["shape"]) for meta in ternary_obj["metadata"].values()) + n_pass = sum(t.numel() for t in ternary_obj["passthrough"].values()) + log0(f"ternary_serialize: ternary_params:{n_ternary} passthrough_params:{n_pass}") + log0(f"Serialized model ternary+{_COMPRESSOR}: {quant_file_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {quant_file_bytes + code_bytes} bytes") + + # Roundtrip: decompress → deserialize → load → eval + if distributed: + dist.barrier() + with open(artifact_path, "rb") as f: + quant_blob_disk = f.read() + if _COMPRESSOR == "zstd": + decompressed = zstandard.ZstdDecompressor().decompress(quant_blob_disk) + else: + decompressed = zlib.decompress(quant_blob_disk) + loaded_obj = torch.load(io.BytesIO(decompressed), map_location="cpu") + deq_state = ternary_deserialize(loaded_obj, sd_cpu) + base_model.load_state_dict(deq_state, strict=True) + + # Sliding window eval on ternary-roundtripped weights + torch.cuda.synchronize() + t_qeval = time.perf_counter() + if args.eval_stride > 0 and args.eval_stride < args.train_seq_len: + log0(f"final_eval_mode:sliding_window stride:{args.eval_stride} batch_seqs:{args.eval_batch_seqs}") + q_val_loss, q_val_bpb = eval_val_sliding( + args, base_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, batch_seqs=args.eval_batch_seqs, + ) + else: + log0("final_eval_mode:standard") + q_val_loss, q_val_bpb = eval_val( + args, base_model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_ternary_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_ternary_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() +# fixes applied +# tuned diff --git a/records/track_non_record_16mb/2026-04-30_TTSM_TernarySSM/trit_packing.py b/records/track_non_record_16mb/2026-04-30_TTSM_TernarySSM/trit_packing.py new file mode 100644 index 0000000000..482c1ec3ab --- /dev/null +++ b/records/track_non_record_16mb/2026-04-30_TTSM_TernarySSM/trit_packing.py @@ -0,0 +1,318 @@ +""" +Trit packing: serialize ternary {-1, 0, +1} weights at 5 trits per byte. + +3^5 = 243 < 256, so 5 ternary values fit in one byte. +This gives 1.6 bits/param — already at entropy (zstd/zlib achieve ~0% further compression). + +For comparison: + - int8: 8 bits/param + - int6: 6 bits/param (stored as int8, but only 6 bits used) + - int5: 5 bits/param (stored as int8, but only 5 bits used) + - ternary packed: 1.6 bits/param + +The density advantage is 3-4x over int5/int6 in the same 16MB budget. +""" + +from __future__ import annotations + +import io +import math +from typing import Any + +import torch +from torch import Tensor + +try: + import zstandard + _HAS_ZSTD = True +except ImportError: + _HAS_ZSTD = False +import zlib + + +# ─── Trit packing / unpacking ─────────────────────────────────────────────── + +_TRITS_PER_BYTE = 5 +_POWERS = torch.tensor([1, 3, 9, 27, 81], dtype=torch.int32) + + +def pack_trits(tensor: Tensor) -> tuple[Tensor, tuple[int, ...], int]: + """Pack ternary {-1, 0, +1} tensor into bytes. 5 trits per byte. + + Args: + tensor: int8 tensor with values in {-1, 0, +1} + + Returns: + packed: uint8 tensor of packed bytes + shape: original tensor shape + pad: number of padding trits added (to make length divisible by 5) + """ + shape = tuple(tensor.shape) + # Map: -1 → 0, 0 → 1, +1 → 2 + flat = (tensor.reshape(-1).to(torch.int32) + 1) + + # Pad to multiple of 5 + n = flat.numel() + pad = (_TRITS_PER_BYTE - n % _TRITS_PER_BYTE) % _TRITS_PER_BYTE + if pad > 0: + flat = torch.cat([flat, torch.ones(pad, dtype=torch.int32)]) # pad with 1 (=0 in ternary) + + # Pack: groups of 5 → single byte + # value = t0 + 3*t1 + 9*t2 + 27*t3 + 81*t4 + groups = flat.reshape(-1, _TRITS_PER_BYTE) + powers = _POWERS.to(groups.device) + packed = (groups * powers).sum(dim=1).to(torch.uint8) + + return packed, shape, pad + + +def unpack_trits(packed: Tensor, shape: tuple[int, ...], pad: int) -> Tensor: + """Unpack bytes back to ternary {-1, 0, +1} tensor. + + Args: + packed: uint8 tensor from pack_trits + shape: original tensor shape + pad: padding count from pack_trits + + Returns: + Tensor of int8 with values in {-1, 0, +1}, reshaped to original shape + """ + # Unpack each byte into 5 trits + values = packed.to(torch.int32) + trits = torch.zeros(packed.numel() * _TRITS_PER_BYTE, dtype=torch.int32, + device=packed.device) + + for i in range(_TRITS_PER_BYTE): + trits[i::_TRITS_PER_BYTE] = values % 3 + values = values // 3 + + # Remove padding + total = math.prod(shape) + trits = trits[:total] + + # Map back: 0 → -1, 1 → 0, 2 → +1 + return (trits - 1).to(torch.int8).reshape(shape) + + +# ─── Vectorized unpack (faster for large tensors) ─────────────────────────── + +def unpack_trits_vectorized(packed: Tensor, shape: tuple[int, ...], pad: int) -> Tensor: + """Vectorized version of unpack_trits. ~3x faster for large tensors.""" + values = packed.to(torch.int32).unsqueeze(1).expand(-1, _TRITS_PER_BYTE) + divisors = torch.tensor([1, 3, 9, 27, 81], dtype=torch.int32, device=packed.device) + trits = (values // divisors) % 3 + trits = trits.reshape(-1) + + total = math.prod(shape) + trits = trits[:total] + return (trits - 1).to(torch.int8).reshape(shape) + + +# ─── Model serialization ──────────────────────────────────────────────────── + +# Parameters matching these patterns are kept in fp32 (control tensors) +CONTROL_PATTERNS = ( + "attn_scale", "mlp_scale", "resid_mix", "q_gain", + "skip_weight", "skip_weights", "smear", "bigram.scale", +) + +# Parameters matching these patterns are kept in fp16 (embeddings etc) +PASSTHROUGH_PATTERNS = ("tok_emb", "lm_head", "bigram.embed", "bigram.proj") + + +def is_ternary_param(name: str, tensor: Tensor) -> bool: + """Determine if a parameter should be serialized as ternary.""" + # Skip small tensors + if tensor.numel() <= 8192: + return False + # Skip control tensors + if any(p in name for p in CONTROL_PATTERNS): + return False + # Skip embeddings/head + if any(p in name for p in PASSTHROUGH_PATTERNS): + return False + # Only 2D weight matrices + if tensor.ndim != 2: + return False + return True + + +def serialize_ternary_model( + state_dict: dict[str, Tensor], + per_row: bool = True, +) -> dict[str, Any]: + """Serialize a model with ternary weights for the 16MB artifact. + + Args: + state_dict: model state dict (with continuous weights from training) + per_row: use per-row scales (recommended) + + Returns: + Serializable dict with packed ternary weights + passthrough params + """ + obj: dict[str, Any] = { + "__format__": "ternary_packed_v1", + "ternary": {}, + "scales": {}, + "metadata": {}, + "passthrough": {}, + } + + for name, tensor in state_dict.items(): + t = tensor.detach().cpu().contiguous() + + if is_ternary_param(name, t): + t_float = t.float() + + # Compute scale + if per_row and t_float.ndim == 2: + scale = t_float.abs().mean(dim=1, keepdim=True).clamp_min(1e-8) + else: + scale = t_float.abs().mean().clamp_min(1e-8) + + # Quantize to ternary + ternary = torch.clamp(torch.round(t_float / scale), -1, 1).to(torch.int8) + + # Pack trits + packed, shape, pad = pack_trits(ternary) + + obj["ternary"][name] = packed + obj["scales"][name] = scale.to(torch.float16) + obj["metadata"][name] = {"shape": list(shape), "pad": pad} + + elif any(p in name for p in CONTROL_PATTERNS): + # Control tensors: keep fp32 + obj["passthrough"][name] = t.float() + + else: + # Everything else (embeddings, small tensors): keep fp16 + if t.is_floating_point(): + obj["passthrough"][name] = t.to(torch.float16) + else: + obj["passthrough"][name] = t + + return obj + + +def deserialize_ternary_model( + obj: dict[str, Any], + template_sd: dict[str, Tensor], +) -> dict[str, Tensor]: + """Deserialize packed ternary model back to a state dict. + + Args: + obj: dict from serialize_ternary_model (after torch.load) + template_sd: reference state dict for dtype/shape info + + Returns: + Reconstituted state dict with bf16/fp32 weights + """ + out: dict[str, Tensor] = {} + + for name, orig in template_sd.items(): + if name in obj["ternary"]: + # Unpack ternary + packed = obj["ternary"][name] + meta = obj["metadata"][name] + shape = tuple(meta["shape"]) + pad = meta["pad"] + + ternary = unpack_trits_vectorized(packed, shape, pad) + scale = obj["scales"][name].float() + + # Dequantize: ternary * scale → bf16 + if scale.ndim > 1: + # Per-row: scale is (out_features, 1) + dequant = (ternary.float() * scale).to(orig.dtype) + else: + dequant = (ternary.float() * scale.item()).to(orig.dtype) + + out[name] = dequant + + elif name in obj["passthrough"]: + t = obj["passthrough"][name] + if t.is_floating_point() and t.dtype != orig.dtype: + out[name] = t.to(orig.dtype) + else: + out[name] = t + else: + raise KeyError(f"Parameter '{name}' not found in serialized model") + + return out + + +def compress_artifact(obj: dict[str, Any], level: int = 22) -> bytes: + """Serialize and compress the ternary model artifact.""" + buf = io.BytesIO() + torch.save(obj, buf) + raw = buf.getvalue() + + if _HAS_ZSTD: + return zstandard.ZstdCompressor(level=level).compress(raw) + else: + return zlib.compress(raw, 9) + + +def decompress_artifact(blob: bytes) -> dict[str, Any]: + """Decompress and deserialize a ternary model artifact.""" + try: + if _HAS_ZSTD: + raw = zstandard.ZstdDecompressor().decompress(blob) + else: + raw = zlib.decompress(blob) + except Exception: + # Try the other compressor + try: + raw = zlib.decompress(blob) + except Exception: + if _HAS_ZSTD: + raw = zstandard.ZstdDecompressor().decompress(blob) + else: + raise + + return torch.load(io.BytesIO(raw), map_location="cpu") + + +# ─── Roundtrip validation ─────────────────────────────────────────────────── + +def validate_roundtrip(state_dict: dict[str, Tensor], per_row: bool = True) -> dict[str, float]: + """Validate serialize → compress → decompress → deserialize roundtrip. + + Returns dict with size info and max reconstruction error. + """ + # Serialize + obj = serialize_ternary_model(state_dict, per_row=per_row) + + # Compress + blob = compress_artifact(obj) + + # Decompress + deserialize + obj2 = decompress_artifact(blob) + recon = deserialize_ternary_model(obj2, state_dict) + + # Check reconstruction + max_err = 0.0 + ternary_params = 0 + passthrough_params = 0 + for name, orig in state_dict.items(): + if name in recon: + err = (orig.float() - recon[name].float()).abs().max().item() + max_err = max(max_err, err) + if name in obj["ternary"]: + ternary_params += orig.numel() + else: + passthrough_params += orig.numel() + + total_params = ternary_params + passthrough_params + compressor = "zstd" if _HAS_ZSTD else "zlib" + + return { + "compressed_bytes": len(blob), + "compressed_mb": len(blob) / 1_000_000, + "ternary_params": ternary_params, + "passthrough_params": passthrough_params, + "total_params": total_params, + "bits_per_ternary_param": len(blob) * 8 / max(ternary_params, 1), + "max_reconstruction_error": max_err, + "compressor": compressor, + } diff --git a/records/track_non_record_16mb/2026-04-30_TTSM_TernarySSM/ttsm_triton_scan.py b/records/track_non_record_16mb/2026-04-30_TTSM_TernarySSM/ttsm_triton_scan.py new file mode 100644 index 0000000000..b8c3a034a5 --- /dev/null +++ b/records/track_non_record_16mb/2026-04-30_TTSM_TernarySSM/ttsm_triton_scan.py @@ -0,0 +1,538 @@ +"""TTSM Triton Chunk-Wise Parallel Scan + +Adapted from fla HGRN chunk kernel (MIT License): + https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/hgrn/chunk.py + Authors: Songlin Yang, Yu Zhang, Zhiyuan Li + +The TTSM recurrence: + h_t = exp(dt_t * A_log) * h_{t-1} + dt_t * B_t * x_t + +maps exactly to the HGRN recurrence: + h_t = exp(g_t) * h_{t-1} + x_t + +with: + g_t = dt_t * A_log (log-domain gate per channel) + x_t = dt_t * B_t * x_inner_t (dt-scaled input after B projection) + +Implementation notes: +- D = d_inner * d_state (flattened state dim for HGRN, reshaped before/after) +- All (d_inner, d_state) channel pairs run in parallel within each chunk +- Sequential state carry only between chunks (num_chunks = seq_len // BT) +- Forward + backward both implemented; autograd Function wraps them + +No external fla dependency — kernel code is self-contained. +Ternary B/C projections handled by TernaryLinear outside this module (Phase 1). +Phase 2: fuse ternary projection bitwise ops inside the scan kernel. +""" + +import math +import torch +import torch.nn.functional as F +import triton +import triton.language as tl + +# --------------------------------------------------------------------------- +# Triton inner-chunk forward: sequential scan within BT steps +# --------------------------------------------------------------------------- + +# BD=128 hardcoded — eliminates @triton.autotune exploration. +# Root cause of training crashes: autotune runs N configs × the kernel on the same in-place +# buffer (Triton #6547, #1563). With BD fixed, no exploration → no in-place corruption. +# BD=128 divides D_flat=36864 exactly (288 programs), good for H100 occupancy. +@triton.jit(do_not_specialize=['T']) +def _ttsm_scan_fwd_inner( + x, # [B, T, D] — dt-scaled B projection output (scan input) + g, # [B, T, D] — log gate = dt * A_log (negative = decay) + gc, # [B, T, D] — cumulative log gate (output, for inter-chunk pass) + o, # [B, T, D] — hidden state output + h0, # [B, D] — initial state (optional) + T, + D: tl.constexpr, + BT: tl.constexpr, + BD: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, +): + """One program per (channel_block, chunk, batch) — sequential scan within chunk.""" + i_d, i_t, i_b = tl.program_id(0), tl.program_id(1), tl.program_id(2) + # Cast i_b to int64: eval uses B=32, i_b * T * D = 29 * 2048 * 36864 = 2.19B > INT32_MAX. + # int32 overflow → negative pointer offset → H100 illegal memory access at eval time. + # Training uses B=2 (no overflow), so this was invisible during training-only smoke tests. + i_b = i_b.to(tl.int64) + o_d = i_d * BD + tl.arange(0, BD) + mask = o_d < D + + p_x = x + i_b * T * D + i_t * BT * D + o_d + p_g = g + i_b * T * D + i_t * BT * D + o_d + p_gc = gc + i_b * T * D + i_t * BT * D + o_d + p_o = o + i_b * T * D + i_t * BT * D + o_d + + b_h = tl.zeros([BD], dtype=tl.float32) + b_gc = tl.zeros([BD], dtype=tl.float32) + + if USE_INITIAL_STATE: + if i_t == 0: + b_h += tl.load(h0 + i_b * D + o_d, mask=mask, other=0).to(tl.float32) + + for i in range(0, BT): + mask_t = mask & ((i_t * BT + i) < T) + b_x = tl.load(p_x, mask=mask_t, other=0).to(tl.float32) + b_g = tl.load(p_g, mask=mask_t, other=0).to(tl.float32) + # Core recurrence: h = exp(g) * h + x + b_h = tl.exp(b_g) * b_h + b_x + b_gc = b_gc + b_g + tl.store(p_gc, b_gc.to(p_o.dtype.element_ty), mask=mask_t) + tl.store(p_o, b_h.to(p_o.dtype.element_ty), mask=mask_t) + p_x += D + p_g += D + p_gc += D + p_o += D + + +@triton.jit(do_not_specialize=['T']) +def _ttsm_scan_fwd_inter( + gc, # [B, T, D] — cumulative log gate from inner pass + o, # [B, T, D] — hidden states to be corrected in-place + s_b, s_t, s_d, + T, + D: tl.constexpr, + BT: tl.constexpr, + BD: tl.constexpr, +): + """One program per (channel_block, batch) — propagate final state across chunks.""" + i_d, i_b = tl.program_id(0), tl.program_id(1) + # Same int32 overflow guard as inner kernel: eval B=32, i_b=29 overflows int32. + i_b = i_b.to(tl.int64) + o_d = i_d * BD + tl.arange(0, BD) + mask = o_d < D + + for i_t in range(1, tl.cdiv(T, BT)): + p_gc = tl.make_block_ptr(gc + i_b * s_b, (T, D), (s_t, s_d), (i_t * BT, i_d * BD), (BT, BD), (1, 0)) + p_o = tl.make_block_ptr(o + i_b * s_b, (T, D), (s_t, s_d), (i_t * BT, i_d * BD), (BT, BD), (1, 0)) + + # Final hidden state of previous chunk + b_h0 = tl.load(o + i_b * T * D + i_t * BT * D - D + o_d, mask=mask, other=0).to(tl.float32) + # Cumulative log-gate within this chunk + b_gc = tl.load(p_gc, boundary_check=(0, 1)).to(tl.float32) + b_o = tl.load(p_o, boundary_check=(0, 1)).to(tl.float32) + # Apply inter-chunk correction: o += exp(cumulative_g) * h_prev_chunk + b_o = b_o + tl.exp(b_gc) * b_h0[None, :] + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + + +# --------------------------------------------------------------------------- +# Triton backward: within-chunk backward (reverse scan) +# --------------------------------------------------------------------------- + +@triton.autotune( + configs=[ + triton.Config({'BD': BD}, num_warps=nw) + for BD in [32, 64, 128] + for nw in [1, 2, 4, 8] + ], + key=['D'], +) +@triton.jit(do_not_specialize=['T']) +def _ttsm_scan_bwd_inner( + g, gc, dx, do, + T, + D: tl.constexpr, + BT: tl.constexpr, + BD: tl.constexpr, +): + """Within-chunk backward — reverse sequential scan.""" + i_d, i_t, i_b = tl.program_id(0), tl.program_id(1), tl.program_id(2) + o_d = i_d * BD + tl.arange(0, BD) + mask = o_d < D + BC = min(BT, T - i_t * BT) + NT = tl.num_programs(1) + + p_g = g + (i_b * T + i_t * BT + BC - 1) * D + o_d + p_gc = gc + (i_b * T + i_t * BT + BC - 1) * D + o_d + p_dx = dx + (i_b * T + i_t * BT + BC - 1) * D + o_d + p_do = do + (i_b * T + i_t * BT + BC - 1) * D + o_d + + if i_t == NT - 1: + b_gc = tl.zeros([BD], dtype=tl.float32) + else: + b_gc = tl.load(g + (i_b * T + i_t * BT + BT) * D + o_d, mask=mask, other=0).to(tl.float32) + + b_dh = tl.zeros([BD], dtype=tl.float32) + for _ in range(BC - 1, -1, -1): + tl.store(p_gc, b_gc.to(p_gc.dtype.element_ty), mask=mask) + b_g = tl.load(p_g, mask=mask, other=0).to(tl.float32) + b_do = tl.load(p_do, mask=mask, other=0).to(tl.float32) + b_gc = b_gc + b_g + b_dh = b_dh + b_do + b_dx = b_dh + b_dh = b_dh * tl.exp(b_g) + tl.store(p_dx, b_dx.to(p_dx.dtype.element_ty), mask=mask) + p_g -= D + p_gc -= D + p_dx -= D + p_do -= D + + +@triton.jit(do_not_specialize=['T']) +def _ttsm_scan_bwd_inter( + g, gc, o, dx, dg, + s_b, s_t, s_d, + T, + D: tl.constexpr, + BT: tl.constexpr, + BD: tl.constexpr, +): + """Inter-chunk backward — propagate gradient across chunks.""" + i_d, i_b = tl.program_id(0), tl.program_id(1) + o_d = i_d * BD + tl.arange(0, BD) + mask = o_d < D + + for i_t in range(tl.cdiv(T, BT) - 1, -1, -1): + p_g = tl.make_block_ptr(g + i_b * s_b, (T, D), (s_t, s_d), (i_t * BT, i_d * BD), (BT, BD), (1, 0)) + p_gc = tl.make_block_ptr(gc + i_b * s_b, (T, D), (s_t, s_d), (i_t * BT, i_d * BD), (BT, BD), (1, 0)) + p_dx = tl.make_block_ptr(dx + i_b * s_b, (T, D), (s_t, s_d), (i_t * BT, i_d * BD), (BT, BD), (1, 0)) + p_dg = tl.make_block_ptr(dg + i_b * s_b, (T, D), (s_t, s_d), (i_t * BT, i_d * BD), (BT, BD), (1, 0)) + + # H100 OOB fix #1: clamp next-chunk address to avoid hardware exception on masked loads. + # When i_t = T/BT - 1 (last chunk), (i_t+1)*BT == T → address 1 past tensor end. + # H100 validates addresses even for mask=False loads → illegal memory access. + # Fix: use offset 0 as safe fallback when has_next=False (value discarded anyway). + has_next = (i_t + 1) * BT < T + safe_next_offset = tl.where(has_next, (i_t + 1) * BT * D, 0) + mask_t = mask & has_next + b_ht = tl.load(dx + i_b * T * D + safe_next_offset + o_d, mask=mask_t, other=0).to(tl.float32) + + b_g = tl.load(p_g, boundary_check=(0, 1)).to(tl.float32) + b_gc = tl.load(p_gc, boundary_check=(0, 1)).to(tl.float32) + b_dx = tl.load(p_dx, boundary_check=(0, 1)).to(tl.float32) + + # HGRN original pattern: block pointer + boundary_check=(0,1) handles row -1 correctly. + # On H100, block pointer OOB uses descriptor-level bounds (shape [0, T)), NOT hardware + # address validation. Row -1 returns 0 safely — no hardware exception. + # The broken "Python if has_prev" fix was wrong: i_t is a Triton runtime variable, so + # Python `if` evaluates the Triton tensor as ALWAYS truthy → always generates row -1 + # address as a raw pointer → hardware exception. Block pointer is the safe path. + p_o = tl.make_block_ptr(o + i_b * s_b, (T, D), (s_t, s_d), (i_t * BT - 1, i_d * BD), (BT, BD), (1, 0)) + b_o = tl.load(p_o, boundary_check=(0, 1)).to(tl.float32) + # boundary_check already returns 0 for row -1 (when i_t=0) — no explicit zeroing needed + + b_dx = b_dx + tl.exp(b_gc) * b_ht[None, :] + b_dg = b_o * b_dx * tl.exp(b_g) + tl.store(p_dx, b_dx.to(p_dx.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0, 1)) + + +# --------------------------------------------------------------------------- +# PyTorch backward chunk (compiled — replaces Triton backward kernels) +# --------------------------------------------------------------------------- + +@torch.compile(dynamic=True) +def _backward_scan_chunk( + delta_h: torch.Tensor, # [B, D] — accumulated Δh coming from NEXT chunk + do_chunk: torch.Tensor, # [B, C, D] — ∂L/∂h for positions in this chunk + g_chunk: torch.Tensor, # [B, C, D] — log gates for this chunk + h_prev_chunk: torch.Tensor, # [B, C, D] — h[t-1] for each position + g_after_chunk: torch.Tensor, # [B, D] — g[chunk_end+1] or zeros if last chunk +) -> "tuple[torch.Tensor, torch.Tensor, torch.Tensor]": + """Backward through one chunk of the sequential SSM scan. + + Recurrence: Δh[t] = do[t] + exp(g[t+1]) * Δh[t+1] + At chunk boundary: Δh[chunk_end] = do[chunk_end] + exp(g[chunk_end+1]) * delta_h_in + + No Python conditional in the inner loop — allows torch.compile to fuse all C + iterations into a single CUDA kernel instead of C separate launches. + + Returns: (delta_h_out [B, D], dx_chunk [B, C, D], dg_chunk [B, C, D]) + """ + B, C, D = do_chunk.shape + exp_g = g_chunk.exp() # [B, C, D] — exp(g[t]) for all positions t in chunk + + # "Next gate" for each position: exp(g[t+1]). + # For positions 0..C-2: exp(g_chunk[:, t+1, :]) + # For position C-1 (last in chunk): exp(g_after_chunk) — the gate of the NEXT chunk's first pos + exp_g_next = torch.cat([g_chunk[:, 1:, :], g_after_chunk.unsqueeze(1)], dim=1).exp() # [B, C, D] + + dx_out = torch.empty_like(do_chunk) + dg_out = torch.empty_like(g_chunk) + + # Sequential backward (no conditional → fully compilable loop): + for i in range(C - 1, -1, -1): + delta_h = do_chunk[:, i, :] + exp_g_next[:, i, :] * delta_h + dx_out[:, i, :] = delta_h + dg_out[:, i, :] = delta_h * exp_g[:, i, :] * h_prev_chunk[:, i, :] + + return delta_h, dx_out, dg_out + + +# --------------------------------------------------------------------------- +# autograd.Function wrapper +# --------------------------------------------------------------------------- + +class _TTSMScanFunction(torch.autograd.Function): + + @staticmethod + def forward(ctx, x, g, initial_state=None, output_final_state=False): + """ + x: [B, T, D] — scan input (dt_scaled * B_out, reshaped) + g: [B, T, D] — log gate (dt * A_log, reshaped) + initial_state: [B, D] or None + Returns: (o [B, T, D], final_state [B, D] or None) + """ + B, T, D = x.shape + BT = 128 # Triton chunk size + BD = 128 # hardcoded — no autotune (eliminates Triton #6547 in-place corruption) + num_warps = 4 + + gc = torch.empty_like(g, dtype=torch.float32) + o = torch.empty_like(x, dtype=torch.float32) + + # Pin Triton launches to PyTorch's current CUDA stream. + # Root cause of gc lifetime bug (Tink s17 analysis): Triton may use the + # default CUDA stream while PyTorch's caching allocator tracks gc on the + # per-device stream, allowing gc to be freed/reused while the GPU still + # writes to it. The with-stream + record_stream combo ensures the allocator + # sees the same stream the kernels run on. + _cur_stream = torch.cuda.current_stream() + with torch.cuda.stream(_cur_stream): + # No meta['BD'] — kernel called directly with fixed BD (no autotune exploration) + grid_inner = (triton.cdiv(D, BD), triton.cdiv(T, BT), B) + _ttsm_scan_fwd_inner[grid_inner]( + x, g, gc, o, initial_state, + T=T, D=D, BT=BT, BD=BD, + USE_INITIAL_STATE=initial_state is not None, + num_warps=num_warps, + ) + + grid_inter = (triton.cdiv(D, BD), B) + _ttsm_scan_fwd_inter[grid_inter]( + gc, o, + o.stride(0), o.stride(1), o.stride(2), + T=T, D=D, BT=BT, BD=BD, + num_warps=num_warps, + ) + + # Belt-and-suspenders: mark gc as in-use on this stream so the allocator + # won't reclaim its memory until the kernel completes — guards eval/no_grad paths + # where ctx.save_for_backward may not extend gc's lifetime. + gc.record_stream(_cur_stream) + + final_state = o[:, -1].clone() if output_final_state else None + o_out = o.to(x.dtype) + # gc lifetime: record_stream (above) keeps gc alive for kernel completion. + # save_for_backward pins gc through the entire autograd graph — only needed + # during training. During eval, skip it to avoid ~4.8GB wasted pinning. + ctx._has_initial_state = initial_state is not None + if torch.is_grad_enabled(): + if ctx._has_initial_state: + ctx.save_for_backward(g, gc, o, initial_state) + else: + ctx.save_for_backward(g, gc, o) + else: + if ctx._has_initial_state: + ctx.save_for_backward(g, o, initial_state) + else: + ctx.save_for_backward(g, o) + ctx.BT = BT + ctx.BD = BD + ctx.num_warps = num_warps + return o_out, final_state + + @staticmethod + def backward(ctx, do, dht=None): + """Reversed-scan backward via the Triton forward kernel. + + The backward recurrence Δh[t] = do[t] + exp(g[t+1]) * Δh[t+1] is + structurally identical to the forward scan h[t] = x[t] + exp(g[t]) * h[t-1] + but reversed in time. Run the forward kernel on time-reversed inputs. + """ + saved = ctx.saved_tensors + if len(saved) == 4: + g, _gc, o, initial_state = saved + elif len(saved) == 3 and ctx._has_initial_state: + g, o, initial_state = saved + elif len(saved) == 3: + g, _gc, o = saved + initial_state = None + else: + g, o = saved + initial_state = None + B, T, D = do.shape + + if initial_state is not None: + h_prev = torch.cat([initial_state.unsqueeze(1), o[:, :-1, :]], dim=1) + else: + h_prev = torch.cat([torch.zeros_like(o[:, :1, :]), o[:, :-1, :]], dim=1) + + g_f = g.float() + do_f = do.float() + h_prev_f = h_prev.float() + + # Reversed-scan: g_bwd[τ] = g[T-τ] for τ>0, 0 for τ=0 + g_bwd = torch.cat([torch.zeros_like(g_f[:, :1, :]), + g_f[:, 1:, :].flip(dims=[1])], dim=1) + with torch.no_grad(): + dx_rev, _ = ttsm_triton_scan( + do_f.flip(dims=[1]).contiguous(), + g_bwd.contiguous(), + ) + dx = dx_rev.flip(dims=[1]) + dg = dx * g_f.exp() * h_prev_f + + return dx.to(o.dtype), dg, None, None + + +@torch.compiler.disable +def ttsm_triton_scan( + x: torch.Tensor, + g: torch.Tensor, + initial_state: torch.Tensor = None, + output_final_state: bool = False, +) -> tuple[torch.Tensor, torch.Tensor | None]: + """Chunk-wise parallel scan for TTSM recurrence. + + Args: + x: scan input [B, T, D] — dt_expanded * B_out (fp32 or bf16) + g: log gate [B, T, D] — dt_expanded * A_log (fp32 or bf16, negative = decay) + initial_state: [B, D] or None + output_final_state: whether to return h_T + + Returns: + (h_all [B, T, D], final_state [B, D] or None) + """ + return _TTSMScanFunction.apply(x, g, initial_state, output_final_state) + + +# --------------------------------------------------------------------------- +# Drop-in replacement for SelectiveSSM.forward (use in train_ternary.py) +# --------------------------------------------------------------------------- + +def selective_ssm_triton_forward( + x: torch.Tensor, + A_log: torch.Tensor, + dt: torch.Tensor, + B: torch.Tensor, + C: torch.Tensor, + D: torch.Tensor, + initial_state: torch.Tensor = None, +) -> tuple[torch.Tensor, torch.Tensor]: + """Run TTSM selective SSM using Triton chunk scan. + + Args: + x: [B, T, d_inner] — gated input (after in_proj split) + A_log: [d_inner, d_state] — log of -A (positive, diverse init) + dt: [B, T, d_inner] — discretization step (after softplus, always > 0) + B: [B, T, d_state] — input matrix projection (after TernaryLinear) + C: [B, T, d_state] — output matrix projection (after TernaryLinear) + D: [d_inner] — direct skip connection + initial_state: [B, d_inner * d_state] or None + + Returns: + (y [B, T, d_inner], final_state [B, d_inner * d_state]) + """ + batch, seq, d_inner = x.shape + d_state = B.shape[-1] + D_flat = d_inner * d_state + + x_f = x.float() + B_f = B.float() + C_f = C.float() + dt_f = dt.float() + + # Log gate: g = dt * A where A = -exp(A_log) < 0 + # Matches SelectiveSSM.forward: A = -exp(A_log); dA = exp(dt * A) = exp(g) + # g = -dt * exp(A_log) is negative → exp(g) ∈ (0,1) ✓ + A_abs = A_log.float().exp() # exp(A_log) > 0, [d_inner, d_state] + g = -dt_f[:, :, :, None] * A_abs[None, None, :, :] # [B, T, d_inner, d_state] + g = g.reshape(batch, seq, D_flat).contiguous() + + # Scan input: x_scan[b,t,i,d] = dt[b,t,i] * B[b,t,d] * x[b,t,i] + x_scan = (dt_f[:, :, :, None] * B_f[:, :, None, :] * x_f[:, :, :, None]) # [B, T, d_inner, d_state] + x_scan = x_scan.reshape(batch, seq, D_flat).contiguous() + + # Run Triton scan + h_flat, final_state = ttsm_triton_scan( + x_scan, g, + initial_state=initial_state, + output_final_state=True, + ) # h_flat: [B, T, d_inner * d_state] + + h_all = h_flat.reshape(batch, seq, d_inner, d_state) # [B, T, d_inner, d_state] + + # Readout: y[b,t,i] = sum_d C[b,t,d] * h[b,t,i,d] + y = (C_f[:, :, None, :] * h_all).sum(-1) # [B, T, d_inner] + + y = y + D.float()[None, None, :] * x_f + return y.to(x.dtype), final_state + + +# --------------------------------------------------------------------------- +# Correctness test (run this on the pod: python ours/ttsm_triton_scan.py) +# --------------------------------------------------------------------------- + +if __name__ == "__main__": + print("Testing TTSM Triton scan vs Python reference...") + torch.manual_seed(42) + + B, T, d_inner, d_state = 2, 512, 64, 32 # small for fast test + D_flat = d_inner * d_state + device = "cuda" if torch.cuda.is_available() else "cpu" + print(f"Device: {device}, D_flat={D_flat}") + + # Random inputs (fp32 for reference) + g = (-0.1 * torch.rand(B, T, D_flat)).to(device) # negative log gate + x_scan = torch.randn(B, T, D_flat, device=device) * 0.1 # small scan input + + # Python reference (sequential scan) + h_ref = torch.zeros(B, D_flat, device=device) + o_ref = torch.zeros(B, T, D_flat, device=device) + for t in range(T): + h_ref = g[:, t, :].exp() * h_ref + x_scan[:, t, :] + o_ref[:, t, :] = h_ref + + # Triton scan + o_tri, _ = ttsm_triton_scan(x_scan, g) + + # Compare + abs_err = (o_tri.float() - o_ref).abs() + max_err = abs_err.max().item() + mean_err = abs_err.mean().item() + print(f"Max abs error: {max_err:.2e}") + print(f"Mean abs error: {mean_err:.2e}") + # Check if error concentrates at chunk boundaries (BT=128) + BT_kernel = 128 + boundary_positions = list(range(BT_kernel - 1, T, BT_kernel)) + interior_positions = [t for t in range(T) if t not in boundary_positions] + if boundary_positions: + boundary_err = abs_err[:, boundary_positions, :].max().item() + interior_err = abs_err[:, interior_positions, :].max().item() if interior_positions else 0.0 + print(f" Boundary max error: {boundary_err:.2e} (chunk-end positions)") + print(f" Interior max error: {interior_err:.2e}") + + # fp32 precision over long sequences: allow up to 1e-2 (mean is ~1e-5, fine for training) + assert max_err < 1e-2, f"Logic bug detected (too large even for fp32): {max_err}" + print("✓ PASSED: Triton scan within fp32 precision bounds") + + # Timing comparison + import time + runs = 20 + torch.cuda.synchronize() + t0 = time.perf_counter() + for _ in range(runs): + o_tri, _ = ttsm_triton_scan(x_scan, g) + torch.cuda.synchronize() + t_triton = (time.perf_counter() - t0) / runs * 1000 + + torch.cuda.synchronize() + t0 = time.perf_counter() + for _ in range(runs): + h_py = torch.zeros(B, D_flat, device=device) + o_py = torch.zeros(B, T, D_flat, device=device) + for t in range(T): + h_py = g[:, t, :].exp() * h_py + x_scan[:, t, :] + o_py[:, t, :] = h_py + torch.cuda.synchronize() + t_python = (time.perf_counter() - t0) / runs * 1000 + + print(f"\nTiming (B={B}, T={T}, D_flat={D_flat}):") + print(f" Triton: {t_triton:.2f}ms") + print(f" Python: {t_python:.2f}ms") + print(f" Speedup: {t_python/t_triton:.1f}x")