Add opt-in bf16 GDN state for Qwen3-Next#2
Merged
Conversation
Default fp32 unchanged. Set ``MLX_LM_GDN_STATE_BF16=1`` to allocate the gated-delta recurrent state as bf16, halving per-step state traffic (~90 MB/step at Qwen3-Next 32k N=3). The Metal kernel's recurrence accumulator is a local ``float``, so only heap storage changes; ``StT`` is already wired through the kernel template and ``state_out`` allocation, so no kernel change is required. Caches populated by an earlier run (or seeded externally) with fp32 state are coerced to bf16 on first use when the env var is set, so existing prompt-cache files keep working. Validation (Qwen3-Next-80B-A3B-4bit, teacher-forced 96 tokens): KL <= 0.0068 (gate 0.01); top-1 match 95-96/96. Tests: - test_gated_delta_state_dtype_default_fp32 — fresh state stays fp32 - test_gated_delta_state_dtype_bf16_opt_in — env var switches to bf16 and coerces fp32 seed state - test_gated_delta_bf16_state_matches_fp32_state — 16-step rollout, cosine similarity of final state and final output >= 0.99
cfbe2dc to
54c811b
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Adds an opt-in
MLX_LM_GDN_STATE_BF16=1environment variable that allocates the Qwen3-Next gated-delta-network (GDN) recurrent state as bf16 instead of fp32. Default behavior is unchanged.Motivation
GDN's per-layer recurrent state dominates decode-step bandwidth at long context. On M4 Max with Qwen3-Next-80B-A3B-4bit at 32k context, batch size 3:
[B=3, Hv=32, Dv=128, Dk=128]of state.The Metal kernel's recurrence accumulator is a local
float(fp32), so only heap-resident storage changes. The existingStTkernel template parameter and thestate_outallocation already propagate the chosen dtype throughstate_in → state_out, so no kernel-side change is required.Quality validation
Teacher-forced 96 tokens on Qwen3-Next-80B-A3B-4bit:
Recurrence stability rationale: the per-step update is
state[i] = state[i] * decay + k[s] * deltawheredecay = exp(-exp(A_log) * softplus(a + dt_bias)). The kernel's internal fp32 accumulator performs the multiply before the bf16 round-trip on each step, so storage truncation contributes only at the bf16 quantization scale per step rather than compounding through the multiply.API
Single environment variable, opt-in:
The env var is checked inside
gated_delta_update, so it applies uniformly to all GDN consumers (qwen3_next,qwen3_5,kimi_linear) without per-model plumbing.A config-field alternative was considered but rejected: a perf/bandwidth flag isn't a model-architectural choice, and threading a dtype through three model files'
__init__chains adds surface area for what is fundamentally a memory trade-off.Backwards compatibility
≠"1"): fp32 state, byte-identical to prior behavior.Implementation
Single file change:
mlx_lm/models/gated_delta.py_state_dtype()helper returnsbf16iffMLX_LM_GDN_STATE_BF16=1, elsefloat32.mx.zeros(..., dtype=mx.float32)allocation sites (gated_delta_ops,gated_delta_update) now use_state_dtype().gated_delta_updatecoerces a stale fp32 state on first step when the env var is set.No callsite changes in
qwen3_next.py— the existing kernel and the ops fallback both honor the supplied state dtype.Tests
tests/test_models.pyadds three small synthetic tests (no real model weights):test_gated_delta_state_dtype_default_fp32— env var unset → freshly allocated state is fp32 (historical behavior preserved).test_gated_delta_state_dtype_bf16_opt_in— env var set → freshly allocated state is bf16; an externally seeded fp32 state is coerced to bf16 on first use.test_gated_delta_bf16_state_matches_fp32_state— 16-step rollout with identical fp32 inputs against fp32 vs. bf16 recurrent state; asserts final-state and final-output cosine similarity ≥ 0.99.Tests use
unittest.mock.patch.dict(os.environ, ...)so each test cleanly restores environment state.