fast: add fused_qsdpa primitive (decode-shape QSDPA on quantized KV)#1
Merged
Conversation
Co-authored-by: Cheng <[email protected]>
78493dd to
253ce12
Compare
Adds mx.fast.fused_qsdpa, a compiled-in 2-pass scaled-dot-product
attention primitive that operates directly on a quantized KV cache
without materializing dequantized K/V tensors. Targets decode-shape
(T_q == 1), head_dim 256, gqa_factor 8, bits in {4, 8}, group_size
in {32, 64}.
Structural advantage over a user-space mx.fast.metal_kernel
prototype: each threadgroup's GQA_FACTOR simdgroups cooperatively
dequantize a K/V tile into threadgroup memory, then all simdgroups
scan all rows. K/V dequant cost scales with T_kv, not
T_kv * GQA_FACTOR.
Microbench at the target shape (D=256, gqa_factor=8, B=1, T_q=1,
bits=4, group_size=64) vs stock bf16 scaled_dot_product_attention
on dequantized K/V:
T_kv 4k : 1.03x
T_kv 16k : 1.11x
T_kv 32k : 1.18x
T_kv 64k : 1.39x
T_kv 96k : 1.25x
End-to-end (Qwen3.6-35B-A3B-4bit on M4 Max 36GB) with KV-quant
4-bit group 64:
32k / 1 : 78.0 -> 92.3 tok/s (+18%)
96k / 1 : 53.7 -> 68.4 tok/s (+27%)
left_padding extension. Optional left_padding: array argument
accepts an int32[B] array of per-batch leading-pad counts. The
kernel masks positions before left_padding[batch_idx] via one
extra comparison per K position. Required for KV-quantization on
BatchKVCache instances with non-zero left_padding (heterogeneous-
prompt multi-agent serving). Heterogeneous N=3 at 32k context
goes from 125.4 -> 153.9 tok/s (+22.8%), closing the
heterogeneous-vs-homogeneous gap from -20.7% to -1.1%.
A pre-existing edge case in the online-softmax update is fixed
here: when both max_score and new_max are -INFINITY (the strided
kpos pattern of the first tile lands entirely inside the padded
prefix for some batch row), exp(max_score - new_max) evaluates to
exp(NaN) = NaN. The patched update guards on
new_max > -INFINITY. The threadgroup barrier still fires
unconditionally to preserve cross-simdgroup synchronization. The
unmasked path never reached this state - -INFINITY only appeared
for out-of-bounds kpos - but adding left_padding made it
reachable.
Linux CPU build. eval_gpu lives in Metal backend only. The vtable
is anchored by FusedQSDPA::is_equivalent in mlx/fast.cpp (always
compiled) and a NO_GPU_MULTI(FusedQSDPA) shim in
mlx/backend/no_gpu/primitives.cpp - matches the pattern used by
ScaledDotProductAttention, RoPE, ConvertFP8, etc.
Correctness. Across (T_kv, bits, group_size) configurations:
cosine >= 0.9999 vs bf16-on-dequantized reference, max-abs <=
5e-4. left_padding configurations: cosine >= 0.999985, max-abs <=
7e-4. Greedy decode tokens match bf16-on-dequantized path
token-for-token over multi-thousand-token generations.
Tests (python/tests/test_fast.py):
test_fused_qsdpa_correctness - cosine >= 0.9999 vs bf16
SDPA on dequantized K/V
across {4k,16k} x {4,8}b
x {32,64} group sizes.
test_fused_qsdpa_left_padding - cosine >= 0.999985 across
{none, zeros, varied}
left_padding configs.
test_fused_qsdpa_extreme_left_padding- left_padding=[0,T_kv-1]
regression for the
online-softmax NaN edge.
test_fused_qsdpa_input_validation - bits/group_size/head_dim/
gqa_factor/mode/T_q/dim
/left_padding ValueErrors.
Files:
mlx/fast.h +36
mlx/fast.cpp +139
mlx/fast_primitives.h +75
mlx/backend/metal/kernels/CMakeLists.txt +1
mlx/backend/metal/kernels/fused_qsdpa.h +366 (new)
mlx/backend/metal/kernels/fused_qsdpa.metal +29 (new)
mlx/backend/metal/quantized.cpp +206
mlx/backend/no_gpu/primitives.cpp +1
python/src/fast.cpp +76
python/tests/test_fast.py +307
Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>
253ce12 to
202d4ce
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Adds
mx.fast.fused_qsdpa, a compiled-in 2-pass scaled-dot-productattention primitive that operates directly on a quantized KV cache without
materializing dequantized K/V tensors. Targets decode-shape only:
T_q == 1head_dim == 256gqa_factor == 8bits in {4, 8},group_size in {32, 64},mode == "affine"left_paddingOut-of-distribution shapes raise
ValueErrorso the caller can fall backto a Python-side reference path.
Motivation
When the KV cache is held in affine-quantized form, the stock decode path
needs three GPU launches per attention call: dequantize K, dequantize V,
then bf16 SDPA. On Qwen3.6-35B-A3B-4bit (M4 Max 36GB) with 4-bit / group-64
KV-quant this regresses end-to-end tok/s at long context:
The KV-bytes win is meant to enable larger contexts, but the dequantize-
then-SDPA path eats most of it.
fused_qsdpadoes the dequant inside theattention kernel and shares the dequantized tiles across the eight query
heads in each kv head's GQA group.
Design
Two-pass kernel modelled on
sdpa_vector_2pass:(kv_head, batch, kpos_block)per threadgroup.The threadgroup has
GQA_FACTORsimdgroups; each simdgroup owns onequery head. K and V tiles are cooperatively dequantized into threadgroup
memory once per kpos, then all
GQA_FACTORsimdgroups scan thattile to update their own
(max, sum_exp)online-softmax state and aper-block partial output.
the standard log-sum-exp combine.
Structural advantage over a user-space
mx.fast.metal_kernelprototype:K/V dequant cost scales with
T_kv, notT_kv * GQA_FACTOR. The 2-passsplit is what makes that sharing tractable — pass 1 only ever scans a
bounded number of kpos rows so the threadgroup memory stays sized to a
single tile, and the LSE reduction in pass 2 keeps the numerics exact.
Grid / threadgroup layout:
The seven Q/K/V buffer inputs and the optional
mask/left_paddinginputs follow the existing
ScaledDotProductAttentionconvention; theMetal dispatcher mirrors that ordering exactly.
left_paddingextensionA second motivation: heterogeneous-prompt multi-agent serving stores all
B prompts in a single
BatchKVCachetensor and uses a per-batchleft_padding[b]scalar to mark which positions are valid. Withoutexplicit support, KV-quantization can't be enabled for those workloads.
fused_qsdpaaccepts an optionalint32[B]left_paddingarray. Thekernel masks positions before
left_padding[batch_idx]via one extracomparison per K position. The branch is gated on a
function_constant(id 31) so the unmasked path pays no per-K cost.
Online-softmax NaN edge case (fixed here). When both
max_scoreandnew_maxare-INFINITY— the strided kpos pattern of the first tilelands entirely inside the padded prefix for some batch row —
exp(max_score - new_max)evaluates toexp(NaN) = NaN. The patchedupdate guards on
new_max > -INFINITY. The threadgroup barrier stillfires unconditionally to preserve cross-simdgroup synchronization. The
unmasked path never reached this state (
-INFINITYonly appeared forout-of-bounds
kpos) but addingleft_paddingmade it reachable; thenew
test_fused_qsdpa_extreme_left_paddingtest is a regression for it.Performance
All measurements on Apple M4 Max 36 GB, MLX fork built from this branch.
Microbench shape is the Qwen3.6 attention shape (
B=1, H=16, KV_H=2, T_q=1, head_dim=256, gqa_factor=8). Each cell reports the average over80 iterations after 5 warmup iterations, with
mx.synchronize()betweenphases. End-to-end heterogeneous-batching uses Qwen3.6-35B-A3B-4bit with
the canonical full-stack patches and
BatchGeneratorfrommlx_lm.Baselines:
mx.fast.scaled_dot_product_attentionon dequantized K/V.This is the fastest stock primitive — it pays the dequantize cost
outside the kernel and then runs Apple's optimized fp16 SDPA. Any
fused QSDPA path must beat this to be worth shipping.
mlx_lm.models.base.quantized_scaled_dot_product_attention,the dequant-then-SDPA composition that production currently hits.
Microbench — group_size = 64
| T_kv | bits | fork ms | bf16 ms | stock ms | vs bf16 | vs stock | max|d| | cosine |
|-------|------|---------|---------|----------|---------|----------|--------|--------|
| 1024 | 4 | 0.1755 | 0.1613 | 0.1945 | 0.92x | 1.11x | 0.0010 | 0.99999 |
| 1024 | 8 | 0.1583 | 0.1325 | 0.1782 | 0.84x | 1.13x | 0.0010 | 0.99999 |
| 2048 | 4 | 0.1457 | 0.1466 | 0.1918 | 1.01x | 1.32x | 0.0005 | 0.99999 |
| 2048 | 8 | 0.1425 | 0.1421 | 0.1964 | 1.00x | 1.38x | 0.0005 | 0.99999 |
| 4096 | 4 | 0.214 | 0.206 | 0.235 | 0.96x | 1.10x | 0.0005 | 0.99999 |
| 4096 | 8 | 0.189 | 0.177 | 0.226 | 0.94x | 1.20x | 0.0005 | 0.99999 |
| 16384 | 4 | 0.328 | 0.362 | 0.504 | 1.10x | 1.54x | 0.0002 | 0.99999 |
| 16384 | 8 | 0.255 | 0.291 | 0.460 | 1.14x | 1.80x | 0.0002 | 0.99999 |
| 32768 | 4 | 0.352 | 0.440 | 0.846 | 1.25x | 2.40x | 0.0002 | 0.99999 |
| 32768 | 8 | 0.371 | 0.397 | 0.812 | 1.07x | 2.19x | 0.0002 | 0.99999 |
| 65536 | 4 | 0.575 | 0.712 | 1.516 | 1.24x | 2.64x | 0.0001 | 0.99999 |
| 65536 | 8 | 0.626 | 0.734 | 1.523 | 1.17x | 2.43x | 0.0001 | 0.99999 |
| 98304 | 4 | 0.785 | 0.977 | 2.230 | 1.25x | 2.84x | 0.0001 | 0.99999 |
| 98304 | 8 | 0.825 | 0.974 | 2.341 | 1.18x | 2.84x | 0.0001 | 0.99999 |
Microbench — group_size = 32
| T_kv | bits | fork ms | bf16 ms | stock ms | vs bf16 | vs stock | max|d| | cosine |
|-------|------|---------|---------|----------|---------|----------|--------|--------|
| 1024 | 4 | 0.1689 | 0.1568 | 0.2353 | 0.93x | 1.39x | 0.0010 | 0.99999 |
| 1024 | 8 | 0.1815 | 0.1556 | 0.2371 | 0.86x | 1.31x | 0.0010 | 0.99999 |
| 2048 | 4 | 0.1742 | 0.1739 | 0.2240 | 1.00x | 1.29x | 0.0005 | 0.99999 |
| 2048 | 8 | 0.1712 | 0.1673 | 0.2239 | 0.98x | 1.31x | 0.0005 | 0.99999 |
| 4096 | 4 | 0.1831 | 0.1638 | 0.2312 | 0.89x | 1.26x | 0.0005 | 0.99999 |
| 4096 | 8 | 0.1686 | 0.1543 | 0.2345 | 0.92x | 1.39x | 0.0005 | 0.99999 |
| 16384 | 4 | 0.2326 | 0.2758 | 0.4507 | 1.19x | 1.94x | 0.0002 | 0.99999 |
| 16384 | 8 | 0.2685 | 0.3032 | 0.4954 | 1.13x | 1.85x | 0.0002 | 0.99999 |
| 32768 | 4 | 0.3656 | 0.4357 | 0.7718 | 1.19x | 2.11x | 0.0001 | 0.99999 |
| 32768 | 8 | 0.3858 | 0.4512 | 0.8546 | 1.17x | 2.22x | 0.0002 | 0.99999 |
| 65536 | 4 | 0.5848 | 0.7594 | 1.3331 | 1.30x | 2.28x | 0.0001 | 0.99999 |
| 65536 | 8 | 0.6178 | 0.7505 | 1.5507 | 1.22x | 2.51x | 0.0001 | 0.99999 |
| 98304 | 4 | 0.8049 | 0.9917 | 1.8828 | 1.23x | 2.34x | 0.0001 | 0.99999 |
| 98304 | 8 | 0.8312 | 0.9955 | 2.3631 | 1.20x | 2.84x | 0.0001 | 0.99999 |
Microbench summary
T_kv >= 16384,fused_qsdpabeats bf16 SDPA ondequantized K/V by 1.10x-1.30x and the stock dequant+SDPA path by
1.54x-2.84x across both
group_size in {32, 64}andbits in {4, 8}.T_kv in {1024, 2048, 4096}the kernel is within +/- a fewpercent of bf16 SDPA (range 0.84x-1.01x). At short context the
per-launch overhead dominates kernel time so the two-pass primitive
has no margin to exploit; the stock path still loses 1.1x-1.4x because
it pays separate dequant launches.
mx.fast.scaled_dot_product_attentionon dequantized K/V across every cell; max-abs error <= 1e-3 at the
shortest contexts (where each batch element is a single 32-element
group) and falls below 5e-4 by
T_kv = 2048and below 5e-5 byT_kv = 65536.End-to-end — heterogeneous N=3 multi-agent batching
Qwen3.6-35B-A3B-4bit, three concurrent decoding requests at matched
context length, three distinct prompts (Renaissance / quantum / climate
corpora; minimal vocab overlap). Stack: full mlx_fast patches +
KV-quant 4-bit gs=64. WS#3a OFF disables the
left_padding-awarehetero quant path so heterogeneous BatchKVCaches fall back to bf16
attention. WS#3a ON routes them through
fused_qsdpawith theper-batch
left_paddingint32 input added in this PR.Reading the table: "OFF gap" = how much heterogeneous batching loses
relative to homogeneous batching when WS#3a is off; "ON gap" = the same
when this PR's
left_paddingpath is active; "ON vs OFF" = the speedupof the new path vs. the OFF baseline at the same heterogeneous workload.
At 32k context the new path closes the hetero/homo gap from -18.93% to
-1.22% — a +21.9% decode throughput win on the realistic
multi-agent serving workload. At 16k the gain is +13.3%. At 4k the
gap was already small and the gain is +2.6%.
The OFF baseline's growing penalty with context length reflects the
quadratic-in-T_kv cost of bf16 attention vs the per-K-position cost of
the fused quantized kernel: at 4k both paths spend most of their time on
overhead, at 32k the dequant+SDPA path is paying the full bf16 attention
bill on a tensor that the fused kernel reads directly from quantized
storage.
Correctness
Across
(T_kv, bits, group_size)configurations:mx.fast.scaled_dot_product_attention(bf16)ondequantized K/V at every microbench cell above.
T_kv = 1024, scales down with context length(see microbench tables).
With
left_padding:Greedy decode tokens match the bf16-on-dequantized path token-for-token
over multi-thousand-token generations.
API
Returns shape
(B, n_q_heads, T_q, head_dim).Files
mlx/fast.hmlx/fast.cppis_equivalentanchormlx/fast_primitives.hFusedQSDPAprimitive declarationmlx/backend/metal/kernels/CMakeLists.txtmlx/backend/metal/kernels/fused_qsdpa.hmlx/backend/metal/kernels/fused_qsdpa.metalmlx/backend/metal/quantized.cppgather_qsdpa_2passdispatcher +eval_gpumlx/backend/no_gpu/primitives.cppNO_GPU_MULTI(FusedQSDPA)shimpython/src/fast.cpppython/tests/test_fast.pyeval_gpulives in the Metal backend only. The vtable is anchored byFusedQSDPA::is_equivalentinmlx/fast.cpp(always compiled) plus theNO_GPU_MULTI(FusedQSDPA)shim — same pattern asScaledDotProductAttention,RoPE,ConvertFP8, etc.Test plan
python/tests/test_fast.py:test_fused_qsdpa_correctness-- cosine >= 0.9999 vsmx.fast.scaled_dot_product_attention(bf16)on dequantized K/V acrossT_kv in {4096, 16384}xbits in {4, 8}xgroup_size in {32, 64}.test_fused_qsdpa_left_padding-- cosine >= 0.999985 vs bf16 +boolean-mask reference across
left_padding in {None, zeros, varied}xT_kv in {4096, 16384}.test_fused_qsdpa_extreme_left_padding--left_padding = [0, T_kv-1],near-full-mask regression for the online-softmax NaN edge case.
test_fused_qsdpa_input_validation-- invalid arguments(
bitsnot in{4, 8},group_sizenot in{32, 64}, wronghead_dim/gqa_factor,mode != "affine",T_q != 1,mismatched queries last-dim,
left_paddingwrong dtype/rank/length)raise
ValueError.All
mx.fast.metal_kernel-style tests in this file are gated onmx.metal.is_available(); the new tests follow the same pattern.Open questions for maintainers
mx.fastnext toscaled_dot_product_attention. Alternatively, this is logically amx.fast.quantized.scaled_dot_product_attentionoverload; I went withthe flat name to keep symmetry with
mx.fast.metal_kerneland toavoid implying a parallel quantized namespace before the API
stabilizes. Happy to rename / move if you prefer.
bitsparameter. The Python binding defaultsbits=4, group_size=64to match the most common KV-quant configuration onApple Silicon. Should the defaults be required (no default) instead,
to force callers to be explicit?
head_dim/gqa_factorconstraints. The kernel is templated onboth but the C++ wrapper currently rejects anything other than
head_dim == 256, gqa_factor == 8— the only shape that has beenvalidated on the production model. Extending to other shapes is
mechanical (add template instantiations + dispatch table entries);
should that happen in this PR or a follow-up?
left_paddinginterface. Plainint32[B]array. Alternativedesigns: (a) a single scalar mask offset, (b) a
(start, end)windowper batch row, (c) reusing the existing
maskargument with a specialsentinel. The chosen design matches what
BatchKVCache.left_paddingalready exposes, but I'd like maintainers'take before this hardens.
Generated with Claude Code.