Skip to content

Non-record: Mamba-3 Hybrid SSM + SP8192 + Legal TTT — 1.1473 bpb#1644

Open
mradassaad wants to merge 1 commit intoopenai:mainfrom
mradassaad:mamba3-sp8192-ttt-pr
Open

Non-record: Mamba-3 Hybrid SSM + SP8192 + Legal TTT — 1.1473 bpb#1644
mradassaad wants to merge 1 commit intoopenai:mainfrom
mradassaad:mamba3-sp8192-ttt-pr

Conversation

@mradassaad
Copy link
Copy Markdown

Non-record: Mamba-3 Hybrid SSM + SP8192 + Legal TTT — 1.1473 bpb

val_bpb: 1.1473 | 15.93MB | 5,224 steps | 113ms/step | 8×H100

A Mamba-3 state-space model submission. Improves on our prior SSM entry (PR #1355, 1.1526 bpb) by 5.3 mBPB via SP8192 tokenizer, INT8 embed GPTQ, and chunk score-first TTT.

Stage val_bpb Delta Size
BF16 (5,224 steps) 1.1390 88MB
+ AR GPTQ (INT8 embed / INT6 matrix) ~1.1490 +10.0 mBPB 15.93MB
+ Chunk TTT (SGD lr=0.010, 310 chunks) 1.1473 −1.7 mBPB 15.93MB

Architecture

7-layer Mamba-3 SISO hybrid: 5 SSM blocks + 2 FlashAttention layers at positions 2 and 5, dim=512, d_state=64, expand=2, headdim=64, chunk_size=64, mlp_mult=3, 25.2M params. SP8192 BPE tokenizer (trained from scratch on FineWeb — see data section below).

Mamba-3's SSD (Structured State-Space Duality) kernels are pure Triton — no CUDA C++ dependencies, no custom extensions to install. SSD crosses over FlashAttention-2 in throughput at seq_len≈2K and is ~2× faster at our training length of 4096. The SSM state carries forward across eval windows (unlike attention KV), enabling stateful evaluation.

The hybrid design places attention at layers 2 and 5 (U-net encoder/decoder positions) because pure SSM at our 25M scale loses ~15 mBPB vs hybrid — attention is essential for short-range precision at this model size.

Mamba3Block refactoring: _pre_ssd / _post_ssd

We factored the Mamba-3 forward pass into _pre_ssd (in_proj, split, reshape, compute ADT/DT, norms) and _post_ssd (reshape, optional outproj_norm, out_proj). This separates the Python-level tensor manipulation from the core SSD kernel call, making the code more readable and enabling targeted profiling of each phase. The split also makes it easier to experiment with alternative SSD kernel implementations without touching the pre/post logic.

Triton kernel analysis

Significant time went into understanding and profiling the Mamba-3 Triton kernels to find fusion opportunities or tuning headroom.

Per-kernel microbenchmark (1×H100, bsz=32, seq=4096, isolated Mamba3Layer):

Kernel Time regs/thread SMEM Notes
mamba3_siso_fwd_kernel 1322 µs 255 82,496 B stages=3
mamba3_siso_bwd_kernel_dqkv 1190 µs 255 107,280 B stages=2
mamba3_siso_bwd_kernel_rotary_bias_angles 588 µs 255 4,096 B atomic-add
mamba3_siso_bwd_kernel_dzdo 455 µs 32 0 B produces dO_scaled + dZ
mamba3_siso_bwd_kernel_ddt_dtrap_dinput_states 18 µs 30 0 B negligible

Total fwd+bwd = 9.59 ms/iter. Triton kernels ≈ 3.57 ms; the remaining ~6 ms is in_proj/out_proj GEMMs, RMSNorm, SiLU, residuals.

Key insight: the bottleneck is SMEM-per-pipeline-stage, not register pressure. The dqkv backward kernel uses 107 KB × 2 stages = 214 KB, nearly saturating H100's 228 KB L1/SMEM budget. regs/thread=255 is the ptxas ceiling but not the binding constraint — spill lands in L1 (cheap). Any kernel fusion at these SMEM levels must be SMEM-neutral.

dzdo→dqkv prologue fusion attempt: Implemented a fused variant that moves the dzdo epilogue (dO_scaled, dZ computation) into the dqkv kernel's prologue, eliminating one kernel launch and the HBM round-trip. The fusion was mathematically correct (rel_l2=0 on all 9 output gradients) but +1.56 ms slower (9.59 → 11.15 ms, −16%). Root cause: +8 KB SMEM for the extra z tile at stage=2 pipelining broke the autotuner's optimal schedule. Left env-gated (MAMBA3_FUSED_BWD=1), off by default.

Extended autotune grid: Expanded from 9 to 36 configs (adding maxnreg ∈ {None, 128, 192, 255} × num_warps ∈ {4, 8, 16}). Both fwd and dqkv picked identical winners to the stock 9-config grid. The upstream kernels are already at the Pareto front for our shape.

Mamba-3 tuning parameters (all tested, all negative)
Parameter Result Why
rope_fraction=1.0 No improvement, 1.6% slower
headdim=128 145ms/step (vs 115ms baseline) Worse tensor core utilization at our config
headdim=32 135ms/step Too many kernel launches
chunk_size=128 Slower than default 64
ngroups=2 at expand=2 +25.2 mBPB post-quant Adds ~500KB, forces destructive pruning
d_state=128 25%+ slower at any seq_len
Pure Mamba 8K seq −10.7 mBPB vs hybrid 4K But no attention = worse overall
Pure Mamba 16K seq Only +2.5 mBPB over 8K Sub-linear scaling, not free
MIMO at our scale Negligible per Mamba-3 paper Table 3 Only meaningful at 1.5B+

Quantization: GPTQ with Embed Hessian

We ran 13 quantization configurations on a fixed BF16 checkpoint to isolate each component's contribution:

Config Gap (mBPB) Size
INT6 all (no GPTQ) +90.8 14.17MB
INT8 embed (no GPTQ) +10.1 15.18MB
AR GPTQ + INT8 embed +8.9 15.15MB
AR GPTQ + INT8 embed + embed Hessian +9.8 15.21MB
Train-data GPTQ + INT8 embed + embed Hessian +4.3 15.18MB
SDClip k=12.85 +153.4 13.10MB

Findings:

  1. INT8 embeddings = 90% of the fix (90.8 → 10.1 mBPB). The 8192×512 embedding is 16% of the model but extremely sensitive to INT6.
  2. GPTQ on matrices is negligible (0.3 mBPB difference).
  3. Embed Hessian via final_norm output hook closes the remaining gap for train-data GPTQ. The logit projection activation Hessian guides embedding quantization.
  4. SDClip is catastrophic for our architecture at any k value. SOTA's SDClip sigmas are tuned for different weight distributions.

The torch.compile + Triton allocator bug

We discovered that torch.compile + triton.set_allocator(ContextVar.set) in Mamba-3's _Mamba3Function corrupts the CUDA driver state. After compiled training, any Triton kernel not cached by torch._inductor crashes with "illegal memory access" in load_binary.

This forced us to use AR self-generated GPTQ instead of the optimal train-data approach — AR generation uses the compiled model (kernels cached), avoiding fresh autotuning. Cost: +5.5 mBPB quant gap (9.8 vs 4.3).

We confirmed the corruption persists across torch._dynamo.reset(), triton.set_allocator(None), Triton cache deletion, different GPUs, and fresh Python subprocesses. It does NOT occur without prior torch.compile.

Test-Time Training

Chunk score-first TTT: Score 310 chunks of 32×4096 tokens under no_grad, then SGD adapt (lr=0.010, momentum=0.9, 1 epoch) on the same chunk. 76s eval time.

TTT sweep (15 configs on a fixed checkpoint):

  • lr=0.010 const = optimal (−6.7 mBPB on prior checkpoint)
  • Warmup-cosine: +0.07-0.12 mBPB gain only — not worth complexity
  • Window TTT (per-window on overlap tokens): −0.1 mBPB in 573s — dead. Gradient signal from 1×1024 tokens per window too weak vs chunk TTT's 32×4096.

RoPE cache bug (fixed): Scoring under torch.inference_mode() caches RoPE cos/sin as inference tensors. TTT adaptation forward hits the cache → crash. Fix: use torch.no_grad() instead.

Stateful-Overlap Evaluation

Instead of sliding-window eval (~500s), we use stateful-overlap with 1024-token overlap (~32s). SSM state carries forward across windows; attention gets 1024 tokens of prior context to re-establish KV cache. Matches sliding-window within 0.3 mBPB and frees 468s of eval budget for GPTQ + TTT.

An earlier diagnosis that stateful eval accumulates INT6 quant error in the SSM state was wrong — measurement showed quant delta is flat ~8.2 mBPB across 100-1892 windows. The real cause of pure-stateful BF16 regression was attention context loss at window boundaries, which the overlap fixes.

What didn't work

TBPTT (Truncated BPTT with Persistent SSM State)

The longest exploration dead-end: 10+ runs over several days. Maintain SSM state across training segments so the model learns to use long-range recurrent memory.

Why it failed: The model learns content-specific state from training document streams that can't be recovered at eval on unseen documents. Training loss (1.8895) vs eval loss (1.9667) gap = 77 mNats of unrecoverable state advantage. AR self-gen warmup (0-131K tokens) didn't help — the model's self-generated text doesn't produce the content-specific state. BF16 stateful-overlap: 1.1648 (+17.4 mBPB vs baseline). Quant gap 5× worse (13.1 vs 2.7 mBPB).

Residual-Stream SLOT

Optimize initial SSM state per eval window using gradients from already-scored overlap tokens. Result: +2–8 mBPB consistently. No consistent gradient direction in general-domain FineWeb.

Other negative results

Tried Effect Why
WD=0.085 (matching SOTA) Plateaus at train_loss=6.5 Our 26M/7L can't absorb that much regularization
10 layers at expand=1.5 135-150ms/step Too slow
Depth recurrence (loop layers) BF16 works (+21 mBPB) Post-quant +60 mBPB, no pruning headroom
FP16 in_proj rows Only 3 mBPB, costs 400KB Removed
INT8 embed + Brotli-11 +1MB vs INT6 + LZMA Net larger

SP8192 Data

The SP8192 tokenizer from kevclark/parameter-golf HF repo uses a different tokenizer than what SOTA submissions use (tok/byte 0.2835 vs 0.2683), causing ~64 mBPB inflation. All top entrants regenerate locally. We train the SentencePiece BPE tokenizer from scratch on docs_selected.jsonl. Our tokenizer: tok/byte = 0.2634.

Eval Time Budget (1×H100, after DDP teardown)

Phase Time
AR GPTQ (32 seqs × 4096 tokens + Hessians) 223s
Chunk TTT (310 chunks, SGD) 76s
Stateful-overlap eval 32s
Total 331s of 600s

Training

  • Optimizer: Muon with MuonEq-R. WD=0.04, momentum=0.99, matrix_lr=0.025.
  • Schedule: Linear warmdown starting at iter 2600 (50% of training at full LR). Late QAT activates at lr_mul < 0.15.
  • Batch: 1M tokens per step, seq_len=4096.
  • EMA: decay=0.997, applied before GPTQ.

Setup & Run

# Install Mamba-3 (pure Triton, no CUDA build)
bash setup_mamba3.sh

# Generate SP8192 data from scratch (~35 min)
cd data && python3 download_hf_docs_and_tokenize.py \
  --output-root . --tokenizer-config tokenizer_specs_8192.json --skip-byte

# Run
VOCAB_SIZE=8192 NUM_LAYERS=7 NUM_ATTN_LAYERS=2 USE_BIGRAM_HASH=0 TRAIN_SEQ_LEN=4096 \
WARMDOWN_ITERS=2600 WARMDOWN_SHAPE=linear MUON_EQ_R=1 \
LATE_QAT_THRESHOLD=0.15 USE_GPTQ=1 QUANT_BITS=6 QUANT_BITS_EMBED=8 GPTQ_NUM_SEQS=32 \
EVAL_OVERLAP=1024 USE_LZMA=1 EVAL_TEMP=0.9 \
WEIGHT_DECAY=0.04 MUON_MOMENTUM=0.99 MATRIX_LR=0.025 \
torchrun --nproc_per_node=8 train_mamba3_hybrid.py

Requirements

  • PyTorch 2.9.1+cu128
  • Triton 3.5.1
  • mamba-ssm 2.3.1 (installed via setup_mamba3.sh)
  • SentencePiece, NumPy, LZMA (stdlib)
  • 8×H100 80GB SXM

Credits

Component Origin Author
Mamba-3 SISO kernels state-spaces/mamba @tridao, @albertgu
SP8192 vocabulary + GPTQ on embeddings PR #1394 @clarkkev
MuonEq-R PR #1217 @bigbig
U-Net skip connections PR #289, PR #1089 @integrate-your-mind, @mikeapedia
Muon optimizer PR #399 @abaybektursun
Full Hessian GPTQ PR #535, PR #1060 @raahilshah, @dexhunter
Sliding window eval PR #122 @mtybadger
Warmdown schedule PR #364 @shikhar1729
EMA PR #315, PR #401 @jfprincz, @newjordan
LeakyReLU² activation PR #185 @dttdrv
Logit softcap PR #315 @jfprincz
Late QAT Competition folk knowledge Multiple contributors
Score-first TTT PR #549, PR #756 @abaybektursun

mrbese pushed a commit to mrbese/parameter-golf that referenced this pull request Apr 16, 2026
Major architecture overhaul informed by PR openai#1644 and PR openai#1355 findings:
- ngroups: 16→1 (shared B/C, fixes kernel crash, saves ~6.88M params)
- Attention: 1→2 layers at positions [2,5] (matches best Mamba-3 layout)
- d_state: 64→128 (better state capacity, only +458K params with ngroups=1)
- Depth recurrence: disabled (hurts SSMs by -69 mBPB per PR openai#1355)
- Triton cache: per-rank TRITON_CACHE_DIR prevents multi-GPU JIT race
- .contiguous() on all kernel inputs for safety
- Replaced einops rearrange with .reshape() where possible

Co-Authored-By: Claude Opus 4.6 (1M context) <[email protected]>
mrbese pushed a commit to mrbese/parameter-golf that referenced this pull request Apr 16, 2026
Width > depth for SSMs (PR openai#1644 got best BPB with 7 layers).
~19.7M params, ~14.8 MB artifact (safe under 16 MB).
v3 was 15.15M / 7.56 MB — lots of headroom to grow.

Co-Authored-By: Claude Opus 4.6 (1M context) <[email protected]>
mradassaad added a commit to mradassaad/parameter-golf that referenced this pull request Apr 18, 2026
The mixed-precision commit (82991d6) accidentally hardcoded scale.clamp_min
to 1/127 (INT8 floor) for ALL rows — including INT6 rows that should floor
at 1/31. Consequence: INT6 rows got scales ~4× smaller than intended, and
values that should round to ±1 rounded to ±4-5 instead. This defeats the
selective ±1 pruning pass: today's P1 run had only 1.02M ±1 candidates vs
PR openai#1644's 3.32M, causing ~1.4 MiB of unnecessary compressed-size inflation.

All runs since 82991d6 have been affected. Fix: per-row scale floor matching
the per-row bit width (1/qmax for scalar qmax, 1/qmax[i] element-wise for
per-row Tensor qmax).

This bug explains why today's 18+ MiB runs couldn't fit — the architecture
and quant choices weren't the problem. Expected effect: P1 shrinks from
~17.4 MB to ~15.9 MB (legal), with equivalent or slightly better post-quant
quality (small values round to 0 rather than ±1).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant