Skip to content

Add opt-in bf16 GDN state for Qwen3-Next#2

Merged
benjamin-levin merged 1 commit into
mainfrom
gdn-bf16-state
May 19, 2026
Merged

Add opt-in bf16 GDN state for Qwen3-Next#2
benjamin-levin merged 1 commit into
mainfrom
gdn-bf16-state

Conversation

@benjamin-levin
Copy link
Copy Markdown
Owner

@benjamin-levin benjamin-levin commented May 18, 2026

Summary

Adds an opt-in MLX_LM_GDN_STATE_BF16=1 environment variable that allocates the Qwen3-Next gated-delta-network (GDN) recurrent state as bf16 instead of fp32. Default behavior is unchanged.

Motivation

GDN's per-layer recurrent state dominates decode-step bandwidth at long context. On M4 Max with Qwen3-Next-80B-A3B-4bit at 32k context, batch size 3:

  • 30 GDN linear-attention layers, each carrying [B=3, Hv=32, Dv=128, Dk=128] of state.
  • fp32 storage: ~2 MB per (sequence, layer); ~60 MB/sequence × 3 = 180 MB read + written every decode step.
  • bf16 storage: ~90 MB/step, a 2× reduction in state traffic.

The Metal kernel's recurrence accumulator is a local float (fp32), so only heap-resident storage changes. The existing StT kernel template parameter and the state_out allocation already propagate the chosen dtype through state_in → state_out, so no kernel-side change is required.

Quality validation

Teacher-forced 96 tokens on Qwen3-Next-80B-A3B-4bit:

  • KL ≤ 0.0068 (well under the 0.01 quality gate)
  • Top-1 match: 95-96/96 tokens

Recurrence stability rationale: the per-step update is state[i] = state[i] * decay + k[s] * delta where decay = exp(-exp(A_log) * softplus(a + dt_bias)). The kernel's internal fp32 accumulator performs the multiply before the bf16 round-trip on each step, so storage truncation contributes only at the bf16 quantization scale per step rather than compounding through the multiply.

API

Single environment variable, opt-in:

# default: fp32 state, unchanged from prior behavior
python -m mlx_lm.generate --model Qwen/Qwen3-Next-80B-A3B-4bit --prompt ...

# bf16 state, halves per-step state bandwidth
MLX_LM_GDN_STATE_BF16=1 python -m mlx_lm.generate --model Qwen/Qwen3-Next-80B-A3B-4bit --prompt ...

The env var is checked inside gated_delta_update, so it applies uniformly to all GDN consumers (qwen3_next, qwen3_5, kimi_linear) without per-model plumbing.

A config-field alternative was considered but rejected: a perf/bandwidth flag isn't a model-architectural choice, and threading a dtype through three model files' __init__ chains adds surface area for what is fundamentally a memory trade-off.

Backwards compatibility

  • Default (env var unset or ≠"1"): fp32 state, byte-identical to prior behavior.
  • A cache populated by a prior run (or seeded externally) with fp32 state is coerced to bf16 on first use when the env var is set, so existing prompt-cache files continue to work.

Implementation

Single file change: mlx_lm/models/gated_delta.py

  • New _state_dtype() helper returns bf16 iff MLX_LM_GDN_STATE_BF16=1, else float32.
  • The two mx.zeros(..., dtype=mx.float32) allocation sites (gated_delta_ops, gated_delta_update) now use _state_dtype().
  • gated_delta_update coerces a stale fp32 state on first step when the env var is set.

No callsite changes in qwen3_next.py — the existing kernel and the ops fallback both honor the supplied state dtype.

Tests

tests/test_models.py adds three small synthetic tests (no real model weights):

  • test_gated_delta_state_dtype_default_fp32 — env var unset → freshly allocated state is fp32 (historical behavior preserved).
  • test_gated_delta_state_dtype_bf16_opt_in — env var set → freshly allocated state is bf16; an externally seeded fp32 state is coerced to bf16 on first use.
  • test_gated_delta_bf16_state_matches_fp32_state — 16-step rollout with identical fp32 inputs against fp32 vs. bf16 recurrent state; asserts final-state and final-output cosine similarity ≥ 0.99.

Tests use unittest.mock.patch.dict(os.environ, ...) so each test cleanly restores environment state.

Default fp32 unchanged. Set ``MLX_LM_GDN_STATE_BF16=1`` to allocate
the gated-delta recurrent state as bf16, halving per-step state
traffic (~90 MB/step at Qwen3-Next 32k N=3). The Metal kernel's
recurrence accumulator is a local ``float``, so only heap storage
changes; ``StT`` is already wired through the kernel template and
``state_out`` allocation, so no kernel change is required.

Caches populated by an earlier run (or seeded externally) with fp32
state are coerced to bf16 on first use when the env var is set, so
existing prompt-cache files keep working.

Validation (Qwen3-Next-80B-A3B-4bit, teacher-forced 96 tokens):
KL <= 0.0068 (gate 0.01); top-1 match 95-96/96.

Tests:
- test_gated_delta_state_dtype_default_fp32 — fresh state stays fp32
- test_gated_delta_state_dtype_bf16_opt_in — env var switches to bf16
  and coerces fp32 seed state
- test_gated_delta_bf16_state_matches_fp32_state — 16-step rollout,
  cosine similarity of final state and final output >= 0.99
@benjamin-levin benjamin-levin marked this pull request as ready for review May 19, 2026 18:24
@benjamin-levin benjamin-levin merged commit c1e1958 into main May 19, 2026
1 of 2 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant