Non-record: Mamba-3 Hybrid SSM + SP8192 + Legal TTT — 1.1473 bpb#1644
Open
mradassaad wants to merge 1 commit intoopenai:mainfrom
Open
Non-record: Mamba-3 Hybrid SSM + SP8192 + Legal TTT — 1.1473 bpb#1644mradassaad wants to merge 1 commit intoopenai:mainfrom
mradassaad wants to merge 1 commit intoopenai:mainfrom
Conversation
mrbese
pushed a commit
to mrbese/parameter-golf
that referenced
this pull request
Apr 16, 2026
Major architecture overhaul informed by PR openai#1644 and PR openai#1355 findings: - ngroups: 16→1 (shared B/C, fixes kernel crash, saves ~6.88M params) - Attention: 1→2 layers at positions [2,5] (matches best Mamba-3 layout) - d_state: 64→128 (better state capacity, only +458K params with ngroups=1) - Depth recurrence: disabled (hurts SSMs by -69 mBPB per PR openai#1355) - Triton cache: per-rank TRITON_CACHE_DIR prevents multi-GPU JIT race - .contiguous() on all kernel inputs for safety - Replaced einops rearrange with .reshape() where possible Co-Authored-By: Claude Opus 4.6 (1M context) <[email protected]>
mrbese
pushed a commit
to mrbese/parameter-golf
that referenced
this pull request
Apr 16, 2026
Width > depth for SSMs (PR openai#1644 got best BPB with 7 layers). ~19.7M params, ~14.8 MB artifact (safe under 16 MB). v3 was 15.15M / 7.56 MB — lots of headroom to grow. Co-Authored-By: Claude Opus 4.6 (1M context) <[email protected]>
mradassaad
added a commit
to mradassaad/parameter-golf
that referenced
this pull request
Apr 18, 2026
The mixed-precision commit (82991d6) accidentally hardcoded scale.clamp_min to 1/127 (INT8 floor) for ALL rows — including INT6 rows that should floor at 1/31. Consequence: INT6 rows got scales ~4× smaller than intended, and values that should round to ±1 rounded to ±4-5 instead. This defeats the selective ±1 pruning pass: today's P1 run had only 1.02M ±1 candidates vs PR openai#1644's 3.32M, causing ~1.4 MiB of unnecessary compressed-size inflation. All runs since 82991d6 have been affected. Fix: per-row scale floor matching the per-row bit width (1/qmax for scalar qmax, 1/qmax[i] element-wise for per-row Tensor qmax). This bug explains why today's 18+ MiB runs couldn't fit — the architecture and quant choices weren't the problem. Expected effect: P1 shrinks from ~17.4 MB to ~15.9 MB (legal), with equivalent or slightly better post-quant quality (small values round to 0 rather than ±1).
3 tasks
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Non-record: Mamba-3 Hybrid SSM + SP8192 + Legal TTT — 1.1473 bpb
val_bpb: 1.1473 | 15.93MB | 5,224 steps | 113ms/step | 8×H100
A Mamba-3 state-space model submission. Improves on our prior SSM entry (PR #1355, 1.1526 bpb) by 5.3 mBPB via SP8192 tokenizer, INT8 embed GPTQ, and chunk score-first TTT.
Architecture
7-layer Mamba-3 SISO hybrid: 5 SSM blocks + 2 FlashAttention layers at positions 2 and 5, dim=512, d_state=64, expand=2, headdim=64, chunk_size=64, mlp_mult=3, 25.2M params. SP8192 BPE tokenizer (trained from scratch on FineWeb — see data section below).
Mamba-3's SSD (Structured State-Space Duality) kernels are pure Triton — no CUDA C++ dependencies, no custom extensions to install. SSD crosses over FlashAttention-2 in throughput at seq_len≈2K and is ~2× faster at our training length of 4096. The SSM state carries forward across eval windows (unlike attention KV), enabling stateful evaluation.
The hybrid design places attention at layers 2 and 5 (U-net encoder/decoder positions) because pure SSM at our 25M scale loses ~15 mBPB vs hybrid — attention is essential for short-range precision at this model size.
Mamba3Block refactoring:
_pre_ssd/_post_ssdWe factored the Mamba-3 forward pass into
_pre_ssd(in_proj, split, reshape, compute ADT/DT, norms) and_post_ssd(reshape, optional outproj_norm, out_proj). This separates the Python-level tensor manipulation from the core SSD kernel call, making the code more readable and enabling targeted profiling of each phase. The split also makes it easier to experiment with alternative SSD kernel implementations without touching the pre/post logic.Triton kernel analysis
Significant time went into understanding and profiling the Mamba-3 Triton kernels to find fusion opportunities or tuning headroom.
Per-kernel microbenchmark (1×H100, bsz=32, seq=4096, isolated Mamba3Layer):
Total fwd+bwd = 9.59 ms/iter. Triton kernels ≈ 3.57 ms; the remaining ~6 ms is in_proj/out_proj GEMMs, RMSNorm, SiLU, residuals.
Key insight: the bottleneck is SMEM-per-pipeline-stage, not register pressure. The dqkv backward kernel uses 107 KB × 2 stages = 214 KB, nearly saturating H100's 228 KB L1/SMEM budget.
regs/thread=255is the ptxas ceiling but not the binding constraint — spill lands in L1 (cheap). Any kernel fusion at these SMEM levels must be SMEM-neutral.dzdo→dqkv prologue fusion attempt: Implemented a fused variant that moves the dzdo epilogue (dO_scaled, dZ computation) into the dqkv kernel's prologue, eliminating one kernel launch and the HBM round-trip. The fusion was mathematically correct (rel_l2=0 on all 9 output gradients) but +1.56 ms slower (9.59 → 11.15 ms, −16%). Root cause: +8 KB SMEM for the extra z tile at stage=2 pipelining broke the autotuner's optimal schedule. Left env-gated (
MAMBA3_FUSED_BWD=1), off by default.Extended autotune grid: Expanded from 9 to 36 configs (adding maxnreg ∈ {None, 128, 192, 255} × num_warps ∈ {4, 8, 16}). Both fwd and dqkv picked identical winners to the stock 9-config grid. The upstream kernels are already at the Pareto front for our shape.
Mamba-3 tuning parameters (all tested, all negative)
Quantization: GPTQ with Embed Hessian
We ran 13 quantization configurations on a fixed BF16 checkpoint to isolate each component's contribution:
Findings:
The torch.compile + Triton allocator bug
We discovered that
torch.compile+triton.set_allocator(ContextVar.set)in Mamba-3's_Mamba3Functioncorrupts the CUDA driver state. After compiled training, any Triton kernel not cached by torch._inductor crashes with "illegal memory access" inload_binary.This forced us to use AR self-generated GPTQ instead of the optimal train-data approach — AR generation uses the compiled model (kernels cached), avoiding fresh autotuning. Cost: +5.5 mBPB quant gap (9.8 vs 4.3).
We confirmed the corruption persists across
torch._dynamo.reset(),triton.set_allocator(None), Triton cache deletion, different GPUs, and fresh Python subprocesses. It does NOT occur without priortorch.compile.Test-Time Training
Chunk score-first TTT: Score 310 chunks of 32×4096 tokens under
no_grad, then SGD adapt (lr=0.010, momentum=0.9, 1 epoch) on the same chunk. 76s eval time.TTT sweep (15 configs on a fixed checkpoint):
RoPE cache bug (fixed): Scoring under
torch.inference_mode()caches RoPE cos/sin as inference tensors. TTT adaptation forward hits the cache → crash. Fix: usetorch.no_grad()instead.Stateful-Overlap Evaluation
Instead of sliding-window eval (~500s), we use stateful-overlap with 1024-token overlap (~32s). SSM state carries forward across windows; attention gets 1024 tokens of prior context to re-establish KV cache. Matches sliding-window within 0.3 mBPB and frees 468s of eval budget for GPTQ + TTT.
An earlier diagnosis that stateful eval accumulates INT6 quant error in the SSM state was wrong — measurement showed quant delta is flat ~8.2 mBPB across 100-1892 windows. The real cause of pure-stateful BF16 regression was attention context loss at window boundaries, which the overlap fixes.
What didn't work
TBPTT (Truncated BPTT with Persistent SSM State)
The longest exploration dead-end: 10+ runs over several days. Maintain SSM state across training segments so the model learns to use long-range recurrent memory.
Why it failed: The model learns content-specific state from training document streams that can't be recovered at eval on unseen documents. Training loss (1.8895) vs eval loss (1.9667) gap = 77 mNats of unrecoverable state advantage. AR self-gen warmup (0-131K tokens) didn't help — the model's self-generated text doesn't produce the content-specific state. BF16 stateful-overlap: 1.1648 (+17.4 mBPB vs baseline). Quant gap 5× worse (13.1 vs 2.7 mBPB).
Residual-Stream SLOT
Optimize initial SSM state per eval window using gradients from already-scored overlap tokens. Result: +2–8 mBPB consistently. No consistent gradient direction in general-domain FineWeb.
Other negative results
SP8192 Data
The SP8192 tokenizer from
kevclark/parameter-golfHF repo uses a different tokenizer than what SOTA submissions use (tok/byte 0.2835 vs 0.2683), causing ~64 mBPB inflation. All top entrants regenerate locally. We train the SentencePiece BPE tokenizer from scratch ondocs_selected.jsonl. Our tokenizer: tok/byte = 0.2634.Eval Time Budget (1×H100, after DDP teardown)
Training
Setup & Run
Requirements
mamba-ssm2.3.1 (installed viasetup_mamba3.sh)Credits