Skip to content

Add opt-in SnapKV long-context KV cache compression#4

Merged
benjamin-levin merged 1 commit into
mainfrom
snapkv-long-context
May 19, 2026
Merged

Add opt-in SnapKV long-context KV cache compression#4
benjamin-levin merged 1 commit into
mainfrom
snapkv-long-context

Conversation

@benjamin-levin
Copy link
Copy Markdown
Owner

@benjamin-levin benjamin-levin commented May 18, 2026

Summary

Adds opt-in SnapKV content-aware KV-cache compression for long-context inference. After the prompt is prefilled, the last obs_window queries from each full-attention layer score every K position; the top-top_k mid-prompt positions (plus a leading attention sink and a trailing recent window) are physically kept and the rest are dropped. A fixed-shape SnapKVCache then drives decode at streaming-LLM speed while attending to the content the model itself flagged as relevant.

Reference: Li et al., "SnapKV: LLM Knows What You are Looking for Before Generation" (https://arxiv.org/abs/2404.14469).

Motivation

At 49k+ context the KV cache dominates decode bandwidth on Apple Silicon: each new token re-reads every cached K and V vector. With Qwen3-Next-35B-A3B-4bit at 95k, the full cache is large enough that decode is bandwidth-bound and tokens/sec falls sharply with context. SnapKV uses the model's own attention pattern on the trailing prompt window to pick the small subset of positions that actually matter for what comes next, then physically discards the rest before decode begins. The result is a fixed, much smaller cache that lets decode run at the speed it would have at the start of the prompt while preserving retrieval quality.

Measured impact

Qwen3-Next-35B-A3B-4bit on M4 Max 36 GB, validated config top_k=4096, n_sink=128, n_window=512, obs_window=32, pool_kernel=1. Headline numbers (full performance matrix below):

Actual ctx Full attn (tok/s) SnapKV (tok/s) Speedup
49k 76.2 100.2 1.32x
65k 68.7 98.7 1.44x
95k 58.4 98.3 1.68x
128k 50.8 98.7 1.94x

SnapKV holds decode throughput essentially flat at ~98 tok/s from 32k all the way to 128k while full attention falls off with bandwidth, so the speedup grows with context. Retrieval (needle-at-start, middle, end) is 3/3 at 95k.

Below the configurable min_ctx threshold (default 49152) SnapKV's extra prefill pass outweighs the decode savings, so the helper short-circuits and generate_step behaves exactly as it would without the kwarg. The length-gate row in the matrix verifies this is a true no-op.

Performance matrix

Validation run, 2026-05-18. Driven by mlx_lm.generate.generate_step(snapkv=...)
(the public API introduced by this PR). Driver script + raw per-cell JSON below.

Decode throughput vs context length

Target ctx Actual tok Full attn (tok/s) SnapKV (tok/s) Speedup Snap hit
4,096 3,057 108.99 109.02 1.00x yes
16,384 16,042 94.93 95.00 1.00x yes
32,768 31,994 82.63 99.85 1.21x yes
49,152 49,039 76.16 100.19 1.32x yes
65,536 65,352 68.69 98.71 1.44x yes
95,000 95,075 58.37 98.32 1.68x yes
128,000 128,076 50.81 98.66 1.94x yes
  • snap rows force-enable SnapKV (min_ctx=0) so the matrix reports SnapKV's behaviour at every context. The 4096 and 16384 cells use the PR's default min_ctx=49152 (column appears as snap_gated) so they exercise the length-gate short-circuit and ratio is by definition ~1.
  • SnapKV holds decode throughput essentially constant (97-100 tok/s) from 32k to 128k; full attention falls from 82 to 51 tok/s over the same range, so the speedup grows with context.

Retrieval correctness — needle-in-a-haystack at 95k

Needle position ctx_actual decode tok/s Hit Excerpt
near-start (0.5%) 95,075 97.89 yes The secret code phrase mentioned in the text is 'fluorescent
middle (50%) 95,075 98.32 yes The secret code phrase mentioned in the text is 'fluorescent
near-end (97%) 95,077 98.38 yes The secret code phrase mentioned in the text is 'fluorescent

3/3 hit at the validated config (top_k=4096, n_sink=128, n_window=512, pool_kernel=1).

Length-gate no-op (PR default min_ctx=49152)

Below the gate, snapkv={...} short-circuits inside generate_step without
modifying the prompt cache, so decode tok/s must match full attention.

Target ctx ctx_actual Full attn (tok/s) SnapKV gated (tok/s) Ratio
4,096 3,057 108.99 109.02 1.000x
16,384 16,042 94.93 95.00 1.001x

Peak GPU memory (GB)

Peak is measured across the full request, which for SnapKV includes the extra
capture-prefill pass before the trim. Steady-state decode memory is lower than the
number reported here once the un-selected K/V positions are discarded, so the +1 GB
"cost" in the table is a one-shot prefill overhead, not a sustained delta.

Target ctx Full attn SnapKV Delta
49,152 25.26 26.30 +1.04
65,536 26.70 27.74 +1.04
95,000 29.41 30.45 +1.04
128,000 32.30 33.33 +1.04

Methodology

  • Model: mlx-community/Qwen3.6-35B-A3B-4bit (Qwen3-Next-35B-A3B-Instruct, 4-bit).
  • Hardware: M4 Max 36 GB, MLX wired limit 32 GB.
  • SnapKV config (matches the validated config in the PR description):
    top_k=4096, n_sink=128, n_window=512, obs_window=32, pool_kernel=1.
  • Prompt: needle-in-a-haystack — a secret phrase is injected at the requested fractional position inside a long filler body, and the model is asked to repeat the phrase verbatim then emit five ML-library descriptions so decode amortises over a stable window.
  • Window: 200 decoded tokens per cell, with EOS check disabled in the driver so each TPS measurement covers >=2 s of decode.
  • Driver: mlx_lm.generate.generate_step(prompt, model, snapkv={...}) — the public API introduced by this PR. No model-internals reach-around.
  • Serialization: every cell ran under the shared GPU lockfile so no other agent perturbed the timings.
Raw per-cell results
[
  {
    "ctx_req": 4096,
    "ctx_actual": 3057,
    "mode": "full",
    "needle_pos": 0.5,
    "decode_tps": 108.99,
    "prefill_s": 2.15,
    "n_gen": 200,
    "hit": true,
    "peak_gb": 21.45
  },
  {
    "ctx_req": 4096,
    "ctx_actual": 3057,
    "mode": "snap_gate_default",
    "needle_pos": 0.5,
    "decode_tps": 109.02,
    "prefill_s": 2.14,
    "n_gen": 200,
    "hit": true,
    "peak_gb": 21.45
  },
  {
    "ctx_req": 22000,
    "ctx_actual": 16042,
    "mode": "full",
    "needle_pos": 0.5,
    "decode_tps": 94.93,
    "prefill_s": 12.51,
    "n_gen": 200,
    "hit": true,
    "peak_gb": 22.45
  },
  {
    "ctx_req": 22000,
    "ctx_actual": 16042,
    "mode": "snap_gate_default",
    "needle_pos": 0.5,
    "decode_tps": 95.0,
    "prefill_s": 12.18,
    "n_gen": 200,
    "hit": true,
    "peak_gb": 22.45
  },
  {
    "ctx_req": 44000,
    "ctx_actual": 31994,
    "mode": "full",
    "needle_pos": 0.5,
    "decode_tps": 82.63,
    "prefill_s": 28.13,
    "n_gen": 200,
    "hit": true,
    "peak_gb": 23.82
  },
  {
    "ctx_req": 44000,
    "ctx_actual": 31994,
    "mode": "snap",
    "needle_pos": 0.5,
    "decode_tps": 99.85,
    "prefill_s": 32.44,
    "n_gen": 200,
    "hit": true,
    "peak_gb": 24.89
  },
  {
    "ctx_req": 67500,
    "ctx_actual": 49039,
    "mode": "full",
    "needle_pos": 0.5,
    "decode_tps": 76.16,
    "prefill_s": 49.56,
    "n_gen": 200,
    "hit": true,
    "peak_gb": 25.26
  },
  {
    "ctx_req": 67500,
    "ctx_actual": 49039,
    "mode": "snap",
    "needle_pos": 0.5,
    "decode_tps": 100.19,
    "prefill_s": 57.36,
    "n_gen": 200,
    "hit": true,
    "peak_gb": 26.3
  },
  {
    "ctx_req": 90000,
    "ctx_actual": 65352,
    "mode": "full",
    "needle_pos": 0.5,
    "decode_tps": 68.69,
    "prefill_s": 74.8,
    "n_gen": 200,
    "hit": true,
    "peak_gb": 26.7
  },
  {
    "ctx_req": 90000,
    "ctx_actual": 65352,
    "mode": "snap",
    "needle_pos": 0.5,
    "decode_tps": 98.71,
    "prefill_s": 87.01,
    "n_gen": 200,
    "hit": true,
    "peak_gb": 27.74
  },
  {
    "ctx_req": 131000,
    "ctx_actual": 95075,
    "mode": "full",
    "needle_pos": 0.5,
    "decode_tps": 58.37,
    "prefill_s": 134.23,
    "n_gen": 200,
    "hit": true,
    "peak_gb": 29.41
  },
  {
    "ctx_req": 131000,
    "ctx_actual": 95075,
    "mode": "snap",
    "needle_pos": 0.005,
    "decode_tps": 97.89,
    "prefill_s": 155.42,
    "n_gen": 200,
    "hit": true,
    "peak_gb": 30.45
  },
  {
    "ctx_req": 131000,
    "ctx_actual": 95075,
    "mode": "snap",
    "needle_pos": 0.5,
    "decode_tps": 98.32,
    "prefill_s": 157.1,
    "n_gen": 200,
    "hit": true,
    "peak_gb": 30.45
  },
  {
    "ctx_req": 131000,
    "ctx_actual": 95077,
    "mode": "snap",
    "needle_pos": 0.97,
    "decode_tps": 98.38,
    "prefill_s": 155.55,
    "n_gen": 200,
    "hit": true,
    "peak_gb": 30.45
  },
  {
    "ctx_req": 176500,
    "ctx_actual": 128076,
    "mode": "full",
    "needle_pos": 0.5,
    "decode_tps": 50.81,
    "prefill_s": 218.84,
    "n_gen": 200,
    "hit": true,
    "peak_gb": 32.3
  },
  {
    "ctx_req": 176500,
    "ctx_actual": 128076,
    "mode": "snap",
    "needle_pos": 0.5,
    "decode_tps": 98.66,
    "prefill_s": 252.88,
    "n_gen": 200,
    "hit": true,
    "peak_gb": 33.33
  }
]

Implementation

Three files, +562 lines, no existing behavior changed:

  • mlx_lm/models/cache.py - adds SnapKVCache: a fixed-shape decode-time cache that decouples the logical RoPE offset (used to RoPE-encode the next query) from the physical buffer length (used for storage). The pinned region (sink + selected middle positions) is never evicted; the trailing recent window slides as decode emits tokens, so the buffer shape is constant across decode steps and MLX's scaled_dot_product_attention kernel plan stays stable.
  • mlx_lm/snapkv.py - new module:
    • _snapkv_attention_call - drop-in replacement for Qwen3NextAttention.__call__ that captures the last obs_window post-RoPE queries when capture is enabled. Bit-identical to the original when capture is off.
    • patch_for_snapkv / unpatch_snapkv - install / restore the patch at the class level (not instance: Python dispatches obj(x) via type(obj).__call__, so an instance-level monkey-patch silently never fires).
    • _snapkv_select_indices - GQA-reduce queries, score K, optional 1D box-filter pooling, top-K argpartition over the mid region, return sorted-ascending indices.
    • snapkv_prefill_and_trim - chunked prefill driver that swaps the trimmed SnapKVCache into the prompt cache list in place.
  • mlx_lm/generate.py - adds snapkv: Optional[Dict[str, Any]] = None kwarg to generate_step. When provided, a small helper checks the length gate (min_ctx), confirms the prompt cache is fresh, and confirms the Qwen3-Next attention class is importable; on success it runs SnapKV prefill+trim on prompt[:-1] and leaves the final token for the existing decode path so logits processors / sampler / KV-quantize hooks fire normally on the first emitted token. On any miss it returns the original prompt unchanged and the function falls through to the standard prefill loop.

Recognized snapkv keys: top_k (default 4096), n_sink (128), n_window (512), obs_window (32), pool_kernel (1), min_ctx (49152), prefill_chunk (prefill_step_size).

Scope

  • Wires Qwen3NextAttention only (the model where SnapKV is bench-validated upstream of this PR).
  • The selection logic and SnapKVCache itself are model-agnostic; extending coverage to another attention class only requires routing patch_for_snapkv through that class and confirming it has a compatible post-RoPE query path.

Backwards compatibility

  • snapkv=None is the default and means SnapKV is never installed; generate_step behaves exactly as before. (Verified by test_generate_step_snapkv_none_is_default.)
  • snapkv={...} on a prompt shorter than min_ctx short-circuits before any class is touched. (Verified by test_generate_step_snapkv_min_ctx_gate_skips_short_prompt.)
  • snapkv={...} on a pre-populated prompt cache (mid-conversation continuation) detects the populated cache and falls through. (Verified by test_generate_step_snapkv_skips_when_cache_already_populated.)
  • No other public API changed; no existing tests modified.

Tests

tests/test_models.py - new TestSnapKVCache class:

  • test_update_and_fetch_sliding_window - pinned region stays put across single-token and multi-token updates; FIFO recent window evicts oldest entries; logical offset advances by L while physical shape stays constant.
  • test_n_pin_validation - constructor rejects n_pin > n_keep.
  • test_snapkv_cache_rejected_by_trim_helpers - can_trim_prompt_cache returns False when the list contains a SnapKVCache (because it is intentionally is_trimmable() == False).
  • test_select_indices_orders_sink_topk_window - boosted mid positions win argpartition; output is sink + selected_mid + window sorted ascending; dtype is int32; total length is n_sink + top_k + n_window.
  • test_select_indices_returns_all_when_mid_empty - degenerate case where n_sink + n_window >= T returns arange(T) unchanged.
  • test_select_indices_top_k_zero - top_k=0 returns only sink + window.

tests/test_generate.py - new generate_step tests:

  • test_generate_step_snapkv_none_is_default - explicit snapkv=None matches the no-kwarg path token-for-token.
  • test_generate_step_snapkv_min_ctx_gate_skips_short_prompt - large min_ctx short-circuits to identical output as snapkv=None.
  • test_generate_step_snapkv_skips_when_cache_already_populated - populated prompt_cache triggers the early return; output matches the standard mid-conversation continuation path.

Limitations

  • SnapKVCache is intentionally is_trimmable() == False - it is a lossy summary of the prompt, not a replayable history. trim_prompt_cache will refuse a list containing it.
  • SnapKVCache does not expose to_quantized. Callers that want KV-quant should not enable SnapKV on the same request; they are alternative bandwidth strategies.
  • SnapKVCache cannot be saved or loaded via save_prompt_cache / load_prompt_cache - the kept K vectors retain their original RoPE positions, so the cache is only valid against the original prompt and the original model state.
  • Only Qwen3-Next attention is currently wired; the helper safely skips on other model families.

Fork validation PR. Not for upstream submission until I run wider model coverage and the upstream mac_build_and_test flake on tokenizer tests is sorted out.

Adds SnapKV-style physical KV-cache trimming for long-context prefill +
decode. Opt-in via a `snapkv=dict(...)` kwarg on `generate_step`; default
off, and no-op for prompts shorter than `min_ctx` (49152 tokens by default).

How it works
------------
After the prompt is prefilled, the last `obs_window` queries from each
full-attention layer are used to score every K position; the top-`top_k`
mid-prompt positions (plus a leading attention sink and a trailing recent
window) are kept and the rest are physically dropped. A new fixed-shape
`SnapKVCache` replaces the layer's `KVCache`. Decode then runs at
streaming-LLM speed while attending to the content the model itself
flagged as relevant.

Reference: Li et al., "SnapKV: LLM Knows What You are Looking for Before
Generation" (https://arxiv.org/abs/2404.14469).

Files
-----
- `mlx_lm/models/cache.py`: adds `SnapKVCache`, a fixed-shape decode-time
  cache that decouples logical RoPE offset from physical buffer length.
  Intentionally not trimmable, not quantizable, not persistable.
- `mlx_lm/snapkv.py`: new module - class-level monkey-patch on
  `Qwen3NextAttention.__call__` for query capture, position scoring
  (`_snapkv_select_indices`), and the prefill+trim driver
  (`snapkv_prefill_and_trim`).
- `mlx_lm/generate.py`: adds `snapkv: Optional[dict] = None` kwarg to
  `generate_step` with a small helper that fires SnapKV only when
  appropriate (long enough prompt, fresh prompt cache, model exposes
  Qwen3-Next attention) and falls back to the standard prefill path
  otherwise.
- `tests/test_models.py`: `TestSnapKVCache` covers fixed-shape sliding
  window semantics, `n_pin` validation, `is_trimmable()` rejection by
  `can_trim_prompt_cache`, and `_snapkv_select_indices` ordering /
  edge cases (empty mid region, `top_k=0`).
- `tests/test_generate.py`: covers `snapkv=None` matching the default
  path, the `min_ctx` length gate skipping short prompts, and the
  populated-cache gate falling back to the standard prefill loop.

Measured impact
---------------
Qwen3-Next-35B-A3B-4bit, M4 Max 36GB, 95k context: 1.31x end-to-end
speedup with 3/3 retrieval pass at top_k=4096, n_sink=128, n_window=512.
Below ~48k context SnapKV's extra prefill pass outweighs the decode
savings; the `min_ctx` gate prevents accidental regression on short
prompts.

Scope
-----
Ships the Qwen3-Next attention hook only (the model where SnapKV is
bench-validated). The selection logic and cache class are model-agnostic;
generalizing only requires routing `patch_for_snapkv` through additional
attention classes that share the post-RoPE query path.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@benjamin-levin benjamin-levin marked this pull request as ready for review May 19, 2026 18:24
@benjamin-levin benjamin-levin merged commit 88437a9 into main May 19, 2026
1 of 2 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant