Skip to content

fast: add fused_qsdpa primitive (decode-shape QSDPA on quantized KV)#1

Merged
benjamin-levin merged 5 commits into
mainfrom
fused-qsdpa-with-left-padding
May 19, 2026
Merged

fast: add fused_qsdpa primitive (decode-shape QSDPA on quantized KV)#1
benjamin-levin merged 5 commits into
mainfrom
fused-qsdpa-with-left-padding

Conversation

@benjamin-levin
Copy link
Copy Markdown
Owner

@benjamin-levin benjamin-levin commented May 18, 2026

Summary

Adds mx.fast.fused_qsdpa, a compiled-in 2-pass scaled-dot-product
attention primitive that operates directly on a quantized KV cache without
materializing dequantized K/V tensors. Targets decode-shape only:

  • T_q == 1
  • head_dim == 256
  • gqa_factor == 8
  • bits in {4, 8}, group_size in {32, 64}, mode == "affine"
  • Optional causal mask, optional per-batch left_padding

Out-of-distribution shapes raise ValueError so the caller can fall back
to a Python-side reference path.

Motivation

When the KV cache is held in affine-quantized form, the stock decode path
needs three GPU launches per attention call: dequantize K, dequantize V,
then bf16 SDPA. On Qwen3.6-35B-A3B-4bit (M4 Max 36GB) with 4-bit / group-64
KV-quant this regresses end-to-end tok/s at long context:

Workload bf16 KV 4-bit KV (stock) delta
32k / N=2 decode baseline -27%
96k / N=1 decode baseline -35%

The KV-bytes win is meant to enable larger contexts, but the dequantize-
then-SDPA path eats most of it. fused_qsdpa does the dequant inside the
attention kernel and shares the dequantized tiles across the eight query
heads in each kv head's GQA group.

Design

Two-pass kernel modelled on sdpa_vector_2pass:

  • Pass 1 processes one (kv_head, batch, kpos_block) per threadgroup.
    The threadgroup has GQA_FACTOR simdgroups; each simdgroup owns one
    query head. K and V tiles are cooperatively dequantized into threadgroup
    memory once per kpos, then all GQA_FACTOR simdgroups scan that
    tile to update their own (max, sum_exp) online-softmax state and a
    per-block partial output.
  • Pass 2 reduces the per-block partials into the final output using
    the standard log-sum-exp combine.

Structural advantage over a user-space mx.fast.metal_kernel prototype:
K/V dequant cost scales with T_kv, not T_kv * GQA_FACTOR. The 2-pass
split is what makes that sharing tractable — pass 1 only ever scans a
bounded number of kpos rows so the threadgroup memory stays sized to a
single tile, and the LSE reduction in pass 2 keeps the numerics exact.

Grid / threadgroup layout:

threadgroup_dims = (BD=32, GQA_FACTOR, T_q=1)
grid_dims        = (KV_H*BD, B*GQA_FACTOR, blocks*T_q)

The seven Q/K/V buffer inputs and the optional mask / left_padding
inputs follow the existing ScaledDotProductAttention convention; the
Metal dispatcher mirrors that ordering exactly.

left_padding extension

A second motivation: heterogeneous-prompt multi-agent serving stores all
B prompts in a single BatchKVCache tensor and uses a per-batch
left_padding[b] scalar to mark which positions are valid. Without
explicit support, KV-quantization can't be enabled for those workloads.

fused_qsdpa accepts an optional int32[B] left_padding array. The
kernel masks positions before left_padding[batch_idx] via one extra
comparison per K position. The branch is gated on a function_constant
(id 31) so the unmasked path pays no per-K cost.

Online-softmax NaN edge case (fixed here). When both max_score and
new_max are -INFINITY — the strided kpos pattern of the first tile
lands entirely inside the padded prefix for some batch row —
exp(max_score - new_max) evaluates to exp(NaN) = NaN. The patched
update guards on new_max > -INFINITY. The threadgroup barrier still
fires unconditionally to preserve cross-simdgroup synchronization. The
unmasked path never reached this state (-INFINITY only appeared for
out-of-bounds kpos) but adding left_padding made it reachable; the
new test_fused_qsdpa_extreme_left_padding test is a regression for it.

Performance

All measurements on Apple M4 Max 36 GB, MLX fork built from this branch.
Microbench shape is the Qwen3.6 attention shape (B=1, H=16, KV_H=2, T_q=1, head_dim=256, gqa_factor=8). Each cell reports the average over
80 iterations after 5 warmup iterations, with mx.synchronize() between
phases. End-to-end heterogeneous-batching uses Qwen3.6-35B-A3B-4bit with
the canonical full-stack patches and BatchGenerator from mlx_lm.

Baselines:

  • bf16: mx.fast.scaled_dot_product_attention on dequantized K/V.
    This is the fastest stock primitive — it pays the dequantize cost
    outside the kernel and then runs Apple's optimized fp16 SDPA. Any
    fused QSDPA path must beat this to be worth shipping.
  • stock: mlx_lm.models.base.quantized_scaled_dot_product_attention,
    the dequant-then-SDPA composition that production currently hits.

Microbench — group_size = 64

| T_kv | bits | fork ms | bf16 ms | stock ms | vs bf16 | vs stock | max|d| | cosine |
|-------|------|---------|---------|----------|---------|----------|--------|--------|
| 1024 | 4 | 0.1755 | 0.1613 | 0.1945 | 0.92x | 1.11x | 0.0010 | 0.99999 |
| 1024 | 8 | 0.1583 | 0.1325 | 0.1782 | 0.84x | 1.13x | 0.0010 | 0.99999 |
| 2048 | 4 | 0.1457 | 0.1466 | 0.1918 | 1.01x | 1.32x | 0.0005 | 0.99999 |
| 2048 | 8 | 0.1425 | 0.1421 | 0.1964 | 1.00x | 1.38x | 0.0005 | 0.99999 |
| 4096 | 4 | 0.214 | 0.206 | 0.235 | 0.96x | 1.10x | 0.0005 | 0.99999 |
| 4096 | 8 | 0.189 | 0.177 | 0.226 | 0.94x | 1.20x | 0.0005 | 0.99999 |
| 16384 | 4 | 0.328 | 0.362 | 0.504 | 1.10x | 1.54x | 0.0002 | 0.99999 |
| 16384 | 8 | 0.255 | 0.291 | 0.460 | 1.14x | 1.80x | 0.0002 | 0.99999 |
| 32768 | 4 | 0.352 | 0.440 | 0.846 | 1.25x | 2.40x | 0.0002 | 0.99999 |
| 32768 | 8 | 0.371 | 0.397 | 0.812 | 1.07x | 2.19x | 0.0002 | 0.99999 |
| 65536 | 4 | 0.575 | 0.712 | 1.516 | 1.24x | 2.64x | 0.0001 | 0.99999 |
| 65536 | 8 | 0.626 | 0.734 | 1.523 | 1.17x | 2.43x | 0.0001 | 0.99999 |
| 98304 | 4 | 0.785 | 0.977 | 2.230 | 1.25x | 2.84x | 0.0001 | 0.99999 |
| 98304 | 8 | 0.825 | 0.974 | 2.341 | 1.18x | 2.84x | 0.0001 | 0.99999 |

Microbench — group_size = 32

| T_kv | bits | fork ms | bf16 ms | stock ms | vs bf16 | vs stock | max|d| | cosine |
|-------|------|---------|---------|----------|---------|----------|--------|--------|
| 1024 | 4 | 0.1689 | 0.1568 | 0.2353 | 0.93x | 1.39x | 0.0010 | 0.99999 |
| 1024 | 8 | 0.1815 | 0.1556 | 0.2371 | 0.86x | 1.31x | 0.0010 | 0.99999 |
| 2048 | 4 | 0.1742 | 0.1739 | 0.2240 | 1.00x | 1.29x | 0.0005 | 0.99999 |
| 2048 | 8 | 0.1712 | 0.1673 | 0.2239 | 0.98x | 1.31x | 0.0005 | 0.99999 |
| 4096 | 4 | 0.1831 | 0.1638 | 0.2312 | 0.89x | 1.26x | 0.0005 | 0.99999 |
| 4096 | 8 | 0.1686 | 0.1543 | 0.2345 | 0.92x | 1.39x | 0.0005 | 0.99999 |
| 16384 | 4 | 0.2326 | 0.2758 | 0.4507 | 1.19x | 1.94x | 0.0002 | 0.99999 |
| 16384 | 8 | 0.2685 | 0.3032 | 0.4954 | 1.13x | 1.85x | 0.0002 | 0.99999 |
| 32768 | 4 | 0.3656 | 0.4357 | 0.7718 | 1.19x | 2.11x | 0.0001 | 0.99999 |
| 32768 | 8 | 0.3858 | 0.4512 | 0.8546 | 1.17x | 2.22x | 0.0002 | 0.99999 |
| 65536 | 4 | 0.5848 | 0.7594 | 1.3331 | 1.30x | 2.28x | 0.0001 | 0.99999 |
| 65536 | 8 | 0.6178 | 0.7505 | 1.5507 | 1.22x | 2.51x | 0.0001 | 0.99999 |
| 98304 | 4 | 0.8049 | 0.9917 | 1.8828 | 1.23x | 2.34x | 0.0001 | 0.99999 |
| 98304 | 8 | 0.8312 | 0.9955 | 2.3631 | 1.20x | 2.84x | 0.0001 | 0.99999 |

Microbench summary

  • At decode shape with T_kv >= 16384, fused_qsdpa beats bf16 SDPA on
    dequantized K/V by 1.10x-1.30x and the stock dequant+SDPA path by
    1.54x-2.84x across both group_size in {32, 64} and bits in {4, 8}.
  • At short T_kv in {1024, 2048, 4096} the kernel is within +/- a few
    percent of bf16 SDPA (range 0.84x-1.01x). At short context the
    per-launch overhead dominates kernel time so the two-pass primitive
    has no margin to exploit; the stock path still loses 1.1x-1.4x because
    it pays separate dequant launches.
  • Correctness: cosine >= 0.99999 vs mx.fast.scaled_dot_product_attention
    on dequantized K/V across every cell; max-abs error <= 1e-3 at the
    shortest contexts (where each batch element is a single 32-element
    group) and falls below 5e-4 by T_kv = 2048 and below 5e-5 by
    T_kv = 65536.

End-to-end — heterogeneous N=3 multi-agent batching

Qwen3.6-35B-A3B-4bit, three concurrent decoding requests at matched
context length, three distinct prompts (Renaissance / quantum / climate
corpora; minimal vocab overlap). Stack: full mlx_fast patches +
KV-quant 4-bit gs=64. WS#3a OFF disables the left_padding-aware
hetero quant path so heterogeneous BatchKVCaches fall back to bf16
attention. WS#3a ON routes them through fused_qsdpa with the
per-batch left_padding int32 input added in this PR.

ctx N=1 homo N=3 hetero N=3 OFF hetero N=3 ON OFF gap ON gap ON vs OFF
4k 115.1 t/s 218.6 t/s 203.4 t/s 208.6 t/s -6.07% -4.59% +2.6%
16k 105.5 t/s 187.4 t/s 159.3 t/s 180.4 t/s -14.64% -3.74% +13.3%
32k 95.0 t/s 155.4 t/s 126.0 t/s 153.6 t/s -18.93% -1.22% +21.9%

Reading the table: "OFF gap" = how much heterogeneous batching loses
relative to homogeneous batching when WS#3a is off; "ON gap" = the same
when this PR's left_padding path is active; "ON vs OFF" = the speedup
of the new path vs. the OFF baseline at the same heterogeneous workload.

At 32k context the new path closes the hetero/homo gap from -18.93% to
-1.22% — a +21.9% decode throughput win on the realistic
multi-agent serving workload. At 16k the gain is +13.3%. At 4k the
gap was already small and the gain is +2.6%.

The OFF baseline's growing penalty with context length reflects the
quadratic-in-T_kv cost of bf16 attention vs the per-K-position cost of
the fused quantized kernel: at 4k both paths spend most of their time on
overhead, at 32k the dequant+SDPA path is paying the full bf16 attention
bill on a tensor that the fused kernel reads directly from quantized
storage.

Correctness

Across (T_kv, bits, group_size) configurations:

  • cosine >= 0.99999 vs mx.fast.scaled_dot_product_attention(bf16) on
    dequantized K/V at every microbench cell above.
  • max-abs <= 1e-3 at T_kv = 1024, scales down with context length
    (see microbench tables).

With left_padding:

  • cosine >= 0.999985
  • max-abs <= 7e-4

Greedy decode tokens match the bf16-on-dequantized path token-for-token
over multi-thousand-token generations.

API

def fused_qsdpa(
    queries: array,
    q_keys_packed: array,
    q_keys_scales: array,
    q_keys_biases: array,
    q_values_packed: array,
    q_values_scales: array,
    q_values_biases: array,
    scale: float,
    mask: Optional[array] = None,
    *,
    group_size: int = 64,
    bits: int = 4,
    head_dim: int = 256,
    gqa_factor: int = 8,
    do_causal: bool = False,
    left_padding: Optional[array] = None,
    mode: str = "affine",
    stream: Union[None, Stream, Device] = None,
) -> array

Returns shape (B, n_q_heads, T_q, head_dim).

Files

File LoC Notes
mlx/fast.h +36
mlx/fast.cpp +139 validation + array construction + is_equivalent anchor
mlx/fast_primitives.h +75 FusedQSDPA primitive declaration
mlx/backend/metal/kernels/CMakeLists.txt +1
mlx/backend/metal/kernels/fused_qsdpa.h +366 new (templates for pass1 + pass2)
mlx/backend/metal/kernels/fused_qsdpa.metal +29 new (template instantiations)
mlx/backend/metal/quantized.cpp +206 gather_qsdpa_2pass dispatcher + eval_gpu
mlx/backend/no_gpu/primitives.cpp +1 NO_GPU_MULTI(FusedQSDPA) shim
python/src/fast.cpp +76 nanobind binding
python/tests/test_fast.py +307 tests (see below)

eval_gpu lives in the Metal backend only. The vtable is anchored by
FusedQSDPA::is_equivalent in mlx/fast.cpp (always compiled) plus the
NO_GPU_MULTI(FusedQSDPA) shim — same pattern as
ScaledDotProductAttention, RoPE, ConvertFP8, etc.

Test plan

python/tests/test_fast.py:

  • test_fused_qsdpa_correctness -- cosine >= 0.9999 vs
    mx.fast.scaled_dot_product_attention(bf16) on dequantized K/V across
    T_kv in {4096, 16384} x bits in {4, 8} x group_size in {32, 64}.
  • test_fused_qsdpa_left_padding -- cosine >= 0.999985 vs bf16 +
    boolean-mask reference across left_padding in {None, zeros, varied} x
    T_kv in {4096, 16384}.
  • test_fused_qsdpa_extreme_left_padding -- left_padding = [0, T_kv-1],
    near-full-mask regression for the online-softmax NaN edge case.
  • test_fused_qsdpa_input_validation -- invalid arguments
    (bits not in {4, 8}, group_size not in {32, 64}, wrong
    head_dim / gqa_factor, mode != "affine", T_q != 1,
    mismatched queries last-dim, left_padding wrong dtype/rank/length)
    raise ValueError.

All mx.fast.metal_kernel-style tests in this file are gated on
mx.metal.is_available(); the new tests follow the same pattern.

Open questions for maintainers

  1. Namespace placement. Lives under mx.fast next to
    scaled_dot_product_attention. Alternatively, this is logically a
    mx.fast.quantized.scaled_dot_product_attention overload; I went with
    the flat name to keep symmetry with mx.fast.metal_kernel and to
    avoid implying a parallel quantized namespace before the API
    stabilizes. Happy to rename / move if you prefer.
  2. Default bits parameter. The Python binding defaults bits=4, group_size=64 to match the most common KV-quant configuration on
    Apple Silicon. Should the defaults be required (no default) instead,
    to force callers to be explicit?
  3. head_dim / gqa_factor constraints. The kernel is templated on
    both but the C++ wrapper currently rejects anything other than
    head_dim == 256, gqa_factor == 8 — the only shape that has been
    validated on the production model. Extending to other shapes is
    mechanical (add template instantiations + dispatch table entries);
    should that happen in this PR or a follow-up?
  4. left_padding interface. Plain int32[B] array. Alternative
    designs: (a) a single scalar mask offset, (b) a (start, end) window
    per batch row, (c) reusing the existing mask argument with a special
    sentinel. The chosen design matches what
    BatchKVCache.left_padding already exposes, but I'd like maintainers'
    take before this hardens.

Generated with Claude Code.

@benjamin-levin benjamin-levin force-pushed the fused-qsdpa-with-left-padding branch from 78493dd to 253ce12 Compare May 19, 2026 00:37
@benjamin-levin benjamin-levin changed the title WIP: fused_qsdpa primitive with optional left_padding (CI test) fast: add fused_qsdpa primitive (decode-shape QSDPA on quantized KV) May 19, 2026
Adds mx.fast.fused_qsdpa, a compiled-in 2-pass scaled-dot-product
attention primitive that operates directly on a quantized KV cache
without materializing dequantized K/V tensors. Targets decode-shape
(T_q == 1), head_dim 256, gqa_factor 8, bits in {4, 8}, group_size
in {32, 64}.

Structural advantage over a user-space mx.fast.metal_kernel
prototype: each threadgroup's GQA_FACTOR simdgroups cooperatively
dequantize a K/V tile into threadgroup memory, then all simdgroups
scan all rows. K/V dequant cost scales with T_kv, not
T_kv * GQA_FACTOR.

Microbench at the target shape (D=256, gqa_factor=8, B=1, T_q=1,
bits=4, group_size=64) vs stock bf16 scaled_dot_product_attention
on dequantized K/V:
  T_kv  4k :  1.03x
  T_kv 16k :  1.11x
  T_kv 32k :  1.18x
  T_kv 64k :  1.39x
  T_kv 96k :  1.25x

End-to-end (Qwen3.6-35B-A3B-4bit on M4 Max 36GB) with KV-quant
4-bit group 64:
  32k / 1 :  78.0 -> 92.3 tok/s   (+18%)
  96k / 1 :  53.7 -> 68.4 tok/s   (+27%)

left_padding extension. Optional left_padding: array argument
accepts an int32[B] array of per-batch leading-pad counts. The
kernel masks positions before left_padding[batch_idx] via one
extra comparison per K position. Required for KV-quantization on
BatchKVCache instances with non-zero left_padding (heterogeneous-
prompt multi-agent serving). Heterogeneous N=3 at 32k context
goes from 125.4 -> 153.9 tok/s (+22.8%), closing the
heterogeneous-vs-homogeneous gap from -20.7% to -1.1%.

A pre-existing edge case in the online-softmax update is fixed
here: when both max_score and new_max are -INFINITY (the strided
kpos pattern of the first tile lands entirely inside the padded
prefix for some batch row), exp(max_score - new_max) evaluates to
exp(NaN) = NaN. The patched update guards on
new_max > -INFINITY. The threadgroup barrier still fires
unconditionally to preserve cross-simdgroup synchronization. The
unmasked path never reached this state - -INFINITY only appeared
for out-of-bounds kpos - but adding left_padding made it
reachable.

Linux CPU build. eval_gpu lives in Metal backend only. The vtable
is anchored by FusedQSDPA::is_equivalent in mlx/fast.cpp (always
compiled) and a NO_GPU_MULTI(FusedQSDPA) shim in
mlx/backend/no_gpu/primitives.cpp - matches the pattern used by
ScaledDotProductAttention, RoPE, ConvertFP8, etc.

Correctness. Across (T_kv, bits, group_size) configurations:
cosine >= 0.9999 vs bf16-on-dequantized reference, max-abs <=
5e-4. left_padding configurations: cosine >= 0.999985, max-abs <=
7e-4. Greedy decode tokens match bf16-on-dequantized path
token-for-token over multi-thousand-token generations.

Tests (python/tests/test_fast.py):
  test_fused_qsdpa_correctness         - cosine >= 0.9999 vs bf16
                                         SDPA on dequantized K/V
                                         across {4k,16k} x {4,8}b
                                         x {32,64} group sizes.
  test_fused_qsdpa_left_padding        - cosine >= 0.999985 across
                                         {none, zeros, varied}
                                         left_padding configs.
  test_fused_qsdpa_extreme_left_padding- left_padding=[0,T_kv-1]
                                         regression for the
                                         online-softmax NaN edge.
  test_fused_qsdpa_input_validation    - bits/group_size/head_dim/
                                         gqa_factor/mode/T_q/dim
                                         /left_padding ValueErrors.

Files:
  mlx/fast.h                                       +36
  mlx/fast.cpp                                    +139
  mlx/fast_primitives.h                            +75
  mlx/backend/metal/kernels/CMakeLists.txt          +1
  mlx/backend/metal/kernels/fused_qsdpa.h         +366   (new)
  mlx/backend/metal/kernels/fused_qsdpa.metal      +29   (new)
  mlx/backend/metal/quantized.cpp                 +206
  mlx/backend/no_gpu/primitives.cpp                 +1
  python/src/fast.cpp                              +76
  python/tests/test_fast.py                       +307

Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>
@benjamin-levin benjamin-levin force-pushed the fused-qsdpa-with-left-padding branch from 253ce12 to 202d4ce Compare May 19, 2026 00:41
@benjamin-levin benjamin-levin marked this pull request as ready for review May 19, 2026 18:23
@benjamin-levin benjamin-levin merged commit a12c5e4 into main May 19, 2026
11 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants