Add opt-in SnapKV long-context KV cache compression#4
Merged
Conversation
Adds SnapKV-style physical KV-cache trimming for long-context prefill + decode. Opt-in via a `snapkv=dict(...)` kwarg on `generate_step`; default off, and no-op for prompts shorter than `min_ctx` (49152 tokens by default). How it works ------------ After the prompt is prefilled, the last `obs_window` queries from each full-attention layer are used to score every K position; the top-`top_k` mid-prompt positions (plus a leading attention sink and a trailing recent window) are kept and the rest are physically dropped. A new fixed-shape `SnapKVCache` replaces the layer's `KVCache`. Decode then runs at streaming-LLM speed while attending to the content the model itself flagged as relevant. Reference: Li et al., "SnapKV: LLM Knows What You are Looking for Before Generation" (https://arxiv.org/abs/2404.14469). Files ----- - `mlx_lm/models/cache.py`: adds `SnapKVCache`, a fixed-shape decode-time cache that decouples logical RoPE offset from physical buffer length. Intentionally not trimmable, not quantizable, not persistable. - `mlx_lm/snapkv.py`: new module - class-level monkey-patch on `Qwen3NextAttention.__call__` for query capture, position scoring (`_snapkv_select_indices`), and the prefill+trim driver (`snapkv_prefill_and_trim`). - `mlx_lm/generate.py`: adds `snapkv: Optional[dict] = None` kwarg to `generate_step` with a small helper that fires SnapKV only when appropriate (long enough prompt, fresh prompt cache, model exposes Qwen3-Next attention) and falls back to the standard prefill path otherwise. - `tests/test_models.py`: `TestSnapKVCache` covers fixed-shape sliding window semantics, `n_pin` validation, `is_trimmable()` rejection by `can_trim_prompt_cache`, and `_snapkv_select_indices` ordering / edge cases (empty mid region, `top_k=0`). - `tests/test_generate.py`: covers `snapkv=None` matching the default path, the `min_ctx` length gate skipping short prompts, and the populated-cache gate falling back to the standard prefill loop. Measured impact --------------- Qwen3-Next-35B-A3B-4bit, M4 Max 36GB, 95k context: 1.31x end-to-end speedup with 3/3 retrieval pass at top_k=4096, n_sink=128, n_window=512. Below ~48k context SnapKV's extra prefill pass outweighs the decode savings; the `min_ctx` gate prevents accidental regression on short prompts. Scope ----- Ships the Qwen3-Next attention hook only (the model where SnapKV is bench-validated). The selection logic and cache class are model-agnostic; generalizing only requires routing `patch_for_snapkv` through additional attention classes that share the post-RoPE query path. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2fad280 to
0c9a1de
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Adds opt-in SnapKV content-aware KV-cache compression for long-context inference. After the prompt is prefilled, the last
obs_windowqueries from each full-attention layer score every K position; the top-top_kmid-prompt positions (plus a leading attention sink and a trailing recent window) are physically kept and the rest are dropped. A fixed-shapeSnapKVCachethen drives decode at streaming-LLM speed while attending to the content the model itself flagged as relevant.Reference: Li et al., "SnapKV: LLM Knows What You are Looking for Before Generation" (https://arxiv.org/abs/2404.14469).
Motivation
At 49k+ context the KV cache dominates decode bandwidth on Apple Silicon: each new token re-reads every cached K and V vector. With Qwen3-Next-35B-A3B-4bit at 95k, the full cache is large enough that decode is bandwidth-bound and tokens/sec falls sharply with context. SnapKV uses the model's own attention pattern on the trailing prompt window to pick the small subset of positions that actually matter for what comes next, then physically discards the rest before decode begins. The result is a fixed, much smaller cache that lets decode run at the speed it would have at the start of the prompt while preserving retrieval quality.
Measured impact
Qwen3-Next-35B-A3B-4bit on M4 Max 36 GB, validated config
top_k=4096, n_sink=128, n_window=512, obs_window=32, pool_kernel=1. Headline numbers (full performance matrix below):SnapKV holds decode throughput essentially flat at ~98 tok/s from 32k all the way to 128k while full attention falls off with bandwidth, so the speedup grows with context. Retrieval (needle-at-start, middle, end) is 3/3 at 95k.
Below the configurable
min_ctxthreshold (default 49152) SnapKV's extra prefill pass outweighs the decode savings, so the helper short-circuits andgenerate_stepbehaves exactly as it would without the kwarg. The length-gate row in the matrix verifies this is a true no-op.Performance matrix
Validation run, 2026-05-18. Driven by
mlx_lm.generate.generate_step(snapkv=...)(the public API introduced by this PR). Driver script + raw per-cell JSON below.
Decode throughput vs context length
snaprows force-enable SnapKV (min_ctx=0) so the matrix reports SnapKV's behaviour at every context. The4096and16384cells use the PR's defaultmin_ctx=49152(column appears assnap_gated) so they exercise the length-gate short-circuit and ratio is by definition ~1.Retrieval correctness — needle-in-a-haystack at 95k
The secret code phrase mentioned in the text is 'fluorescentThe secret code phrase mentioned in the text is 'fluorescentThe secret code phrase mentioned in the text is 'fluorescent3/3 hit at the validated config (
top_k=4096, n_sink=128, n_window=512, pool_kernel=1).Length-gate no-op (PR default
min_ctx=49152)Below the gate,
snapkv={...}short-circuits insidegenerate_stepwithoutmodifying the prompt cache, so decode tok/s must match full attention.
Peak GPU memory (GB)
Peak is measured across the full request, which for SnapKV includes the extra
capture-prefill pass before the trim. Steady-state decode memory is lower than the
number reported here once the un-selected K/V positions are discarded, so the +1 GB
"cost" in the table is a one-shot prefill overhead, not a sustained delta.
Methodology
mlx-community/Qwen3.6-35B-A3B-4bit(Qwen3-Next-35B-A3B-Instruct, 4-bit).top_k=4096, n_sink=128, n_window=512, obs_window=32, pool_kernel=1.mlx_lm.generate.generate_step(prompt, model, snapkv={...})— the public API introduced by this PR. No model-internals reach-around.Raw per-cell results
[ { "ctx_req": 4096, "ctx_actual": 3057, "mode": "full", "needle_pos": 0.5, "decode_tps": 108.99, "prefill_s": 2.15, "n_gen": 200, "hit": true, "peak_gb": 21.45 }, { "ctx_req": 4096, "ctx_actual": 3057, "mode": "snap_gate_default", "needle_pos": 0.5, "decode_tps": 109.02, "prefill_s": 2.14, "n_gen": 200, "hit": true, "peak_gb": 21.45 }, { "ctx_req": 22000, "ctx_actual": 16042, "mode": "full", "needle_pos": 0.5, "decode_tps": 94.93, "prefill_s": 12.51, "n_gen": 200, "hit": true, "peak_gb": 22.45 }, { "ctx_req": 22000, "ctx_actual": 16042, "mode": "snap_gate_default", "needle_pos": 0.5, "decode_tps": 95.0, "prefill_s": 12.18, "n_gen": 200, "hit": true, "peak_gb": 22.45 }, { "ctx_req": 44000, "ctx_actual": 31994, "mode": "full", "needle_pos": 0.5, "decode_tps": 82.63, "prefill_s": 28.13, "n_gen": 200, "hit": true, "peak_gb": 23.82 }, { "ctx_req": 44000, "ctx_actual": 31994, "mode": "snap", "needle_pos": 0.5, "decode_tps": 99.85, "prefill_s": 32.44, "n_gen": 200, "hit": true, "peak_gb": 24.89 }, { "ctx_req": 67500, "ctx_actual": 49039, "mode": "full", "needle_pos": 0.5, "decode_tps": 76.16, "prefill_s": 49.56, "n_gen": 200, "hit": true, "peak_gb": 25.26 }, { "ctx_req": 67500, "ctx_actual": 49039, "mode": "snap", "needle_pos": 0.5, "decode_tps": 100.19, "prefill_s": 57.36, "n_gen": 200, "hit": true, "peak_gb": 26.3 }, { "ctx_req": 90000, "ctx_actual": 65352, "mode": "full", "needle_pos": 0.5, "decode_tps": 68.69, "prefill_s": 74.8, "n_gen": 200, "hit": true, "peak_gb": 26.7 }, { "ctx_req": 90000, "ctx_actual": 65352, "mode": "snap", "needle_pos": 0.5, "decode_tps": 98.71, "prefill_s": 87.01, "n_gen": 200, "hit": true, "peak_gb": 27.74 }, { "ctx_req": 131000, "ctx_actual": 95075, "mode": "full", "needle_pos": 0.5, "decode_tps": 58.37, "prefill_s": 134.23, "n_gen": 200, "hit": true, "peak_gb": 29.41 }, { "ctx_req": 131000, "ctx_actual": 95075, "mode": "snap", "needle_pos": 0.005, "decode_tps": 97.89, "prefill_s": 155.42, "n_gen": 200, "hit": true, "peak_gb": 30.45 }, { "ctx_req": 131000, "ctx_actual": 95075, "mode": "snap", "needle_pos": 0.5, "decode_tps": 98.32, "prefill_s": 157.1, "n_gen": 200, "hit": true, "peak_gb": 30.45 }, { "ctx_req": 131000, "ctx_actual": 95077, "mode": "snap", "needle_pos": 0.97, "decode_tps": 98.38, "prefill_s": 155.55, "n_gen": 200, "hit": true, "peak_gb": 30.45 }, { "ctx_req": 176500, "ctx_actual": 128076, "mode": "full", "needle_pos": 0.5, "decode_tps": 50.81, "prefill_s": 218.84, "n_gen": 200, "hit": true, "peak_gb": 32.3 }, { "ctx_req": 176500, "ctx_actual": 128076, "mode": "snap", "needle_pos": 0.5, "decode_tps": 98.66, "prefill_s": 252.88, "n_gen": 200, "hit": true, "peak_gb": 33.33 } ]Implementation
Three files, +562 lines, no existing behavior changed:
mlx_lm/models/cache.py- addsSnapKVCache: a fixed-shape decode-time cache that decouples the logical RoPE offset (used to RoPE-encode the next query) from the physical buffer length (used for storage). The pinned region (sink + selected middle positions) is never evicted; the trailing recent window slides as decode emits tokens, so the buffer shape is constant across decode steps and MLX'sscaled_dot_product_attentionkernel plan stays stable.mlx_lm/snapkv.py- new module:_snapkv_attention_call- drop-in replacement forQwen3NextAttention.__call__that captures the lastobs_windowpost-RoPE queries when capture is enabled. Bit-identical to the original when capture is off.patch_for_snapkv/unpatch_snapkv- install / restore the patch at the class level (not instance: Python dispatchesobj(x)viatype(obj).__call__, so an instance-level monkey-patch silently never fires)._snapkv_select_indices- GQA-reduce queries, score K, optional 1D box-filter pooling, top-Kargpartitionover the mid region, return sorted-ascending indices.snapkv_prefill_and_trim- chunked prefill driver that swaps the trimmedSnapKVCacheinto the prompt cache list in place.mlx_lm/generate.py- addssnapkv: Optional[Dict[str, Any]] = Nonekwarg togenerate_step. When provided, a small helper checks the length gate (min_ctx), confirms the prompt cache is fresh, and confirms the Qwen3-Next attention class is importable; on success it runs SnapKV prefill+trim onprompt[:-1]and leaves the final token for the existing decode path so logits processors / sampler / KV-quantize hooks fire normally on the first emitted token. On any miss it returns the original prompt unchanged and the function falls through to the standard prefill loop.Recognized
snapkvkeys:top_k(default 4096),n_sink(128),n_window(512),obs_window(32),pool_kernel(1),min_ctx(49152),prefill_chunk(prefill_step_size).Scope
Qwen3NextAttentiononly (the model where SnapKV is bench-validated upstream of this PR).SnapKVCacheitself are model-agnostic; extending coverage to another attention class only requires routingpatch_for_snapkvthrough that class and confirming it has a compatible post-RoPE query path.Backwards compatibility
snapkv=Noneis the default and means SnapKV is never installed;generate_stepbehaves exactly as before. (Verified bytest_generate_step_snapkv_none_is_default.)snapkv={...}on a prompt shorter thanmin_ctxshort-circuits before any class is touched. (Verified bytest_generate_step_snapkv_min_ctx_gate_skips_short_prompt.)snapkv={...}on a pre-populated prompt cache (mid-conversation continuation) detects the populated cache and falls through. (Verified bytest_generate_step_snapkv_skips_when_cache_already_populated.)Tests
tests/test_models.py- newTestSnapKVCacheclass:test_update_and_fetch_sliding_window- pinned region stays put across single-token and multi-token updates; FIFO recent window evicts oldest entries; logical offset advances byLwhile physical shape stays constant.test_n_pin_validation- constructor rejectsn_pin > n_keep.test_snapkv_cache_rejected_by_trim_helpers-can_trim_prompt_cachereturns False when the list contains aSnapKVCache(because it is intentionallyis_trimmable() == False).test_select_indices_orders_sink_topk_window- boosted mid positions win argpartition; output issink + selected_mid + windowsorted ascending; dtype is int32; total length isn_sink + top_k + n_window.test_select_indices_returns_all_when_mid_empty- degenerate case wheren_sink + n_window >= Treturnsarange(T)unchanged.test_select_indices_top_k_zero-top_k=0returns onlysink + window.tests/test_generate.py- newgenerate_steptests:test_generate_step_snapkv_none_is_default- explicitsnapkv=Nonematches the no-kwarg path token-for-token.test_generate_step_snapkv_min_ctx_gate_skips_short_prompt- largemin_ctxshort-circuits to identical output assnapkv=None.test_generate_step_snapkv_skips_when_cache_already_populated- populatedprompt_cachetriggers the early return; output matches the standard mid-conversation continuation path.Limitations
SnapKVCacheis intentionallyis_trimmable() == False- it is a lossy summary of the prompt, not a replayable history.trim_prompt_cachewill refuse a list containing it.SnapKVCachedoes not exposeto_quantized. Callers that want KV-quant should not enable SnapKV on the same request; they are alternative bandwidth strategies.SnapKVCachecannot be saved or loaded viasave_prompt_cache/load_prompt_cache- the kept K vectors retain their original RoPE positions, so the cache is only valid against the original prompt and the original model state.Fork validation PR. Not for upstream submission until I run wider model coverage and the upstream
mac_build_and_testflake on tokenizer tests is sorted out.