Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 76 additions & 0 deletions mlx_lm/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from typing import (
Any,
Callable,
Dict,
Generator,
List,
Optional,
Expand Down Expand Up @@ -304,6 +305,58 @@ def maybe_quantize_kv_cache(prompt_cache, quantized_kv_start, kv_group_size, kv_
prompt_cache[e] = c.to_quantized(group_size=kv_group_size, bits=kv_bits)


def _maybe_snapkv_prefill(
*,
prompt: mx.array,
model: nn.Module,
prompt_cache,
opts: Dict[str, Any],
prefill_step_size: int,
) -> mx.array:
"""Run SnapKV prefill+trim in place on ``prompt_cache``.

Returns the portion of the prompt still to be consumed by the caller --
SnapKV consumes ``prompt[:-1]`` and leaves the final token for the
regular ``generate_step`` decode path so logits_processors / sampler /
quantize hooks fire normally on the first emitted token.

Returns the original ``prompt`` unchanged if SnapKV is skipped (prompt
shorter than ``min_ctx``, ``prompt_cache`` already populated, model is
not Qwen3-Next, etc.).
"""
prompt_len = int(prompt.shape[0])
min_ctx = int(opts.get("min_ctx", 49152))
if prompt_len < max(2, min_ctx):
return prompt
if prompt_cache and getattr(prompt_cache[0], "offset", 0) > 0:
return prompt
try:
from .snapkv import patch_for_snapkv, snapkv_prefill_and_trim
except ImportError:
return prompt

obs_window = int(opts.get("obs_window", 32))
try:
patch_for_snapkv(model, obs_window=obs_window)
except ImportError:
return prompt

head = prompt[:-1]
tail = prompt[-1:]
new_cache, _ = snapkv_prefill_and_trim(
model,
head,
top_k=int(opts.get("top_k", 4096)),
n_sink=int(opts.get("n_sink", 128)),
n_window=int(opts.get("n_window", 512)),
obs_window=obs_window,
pool_kernel=int(opts.get("pool_kernel", 1)),
prefill_chunk=int(opts.get("prefill_chunk", prefill_step_size)),
)
prompt_cache[:] = new_cache
return tail


def generate_step(
prompt: mx.array,
model: nn.Module,
Expand All @@ -319,6 +372,7 @@ def generate_step(
quantized_kv_start: int = 0,
prompt_progress_callback: Optional[Callable[[int, int], None]] = None,
input_embeddings: Optional[mx.array] = None,
snapkv: Optional[Dict[str, Any]] = None,
) -> Generator[Tuple[mx.array, mx.array], None, None]:
"""
A generator producing token ids based on the given prompt from the model.
Expand Down Expand Up @@ -347,6 +401,14 @@ def generate_step(
prompt tokens processed so far and the total number of prompt tokens.
input_embeddings (mx.array, optional): Input embeddings to use instead of or in
conjunction with prompt tokens. Default: ``None``.
snapkv (dict, optional): If provided, enable SnapKV content-aware
KV-cache compression for long-context prompts. Recognized keys:
``top_k`` (default 4096), ``n_sink`` (default 128), ``n_window``
(default 512), ``obs_window`` (default 32), ``pool_kernel``
(default 1), ``min_ctx`` (default 49152). SnapKV is skipped for
prompts shorter than ``min_ctx``. Only takes effect on models
exposing the Qwen3-Next attention class. See ``mlx_lm.snapkv``.
Default: ``None``.

Yields:
Tuple[mx.array, mx.array]: One token and a vector of log probabilities.
Expand Down Expand Up @@ -374,6 +436,20 @@ def generate_step(
max_kv_size=max_kv_size,
)

# Optional: SnapKV content-aware KV compression for long-context prompts.
# Performs an extra prefill pass that captures attention scores from the
# final ``obs_window`` queries, scores each cached K position, and
# physically drops the un-selected entries. Only fires when the prompt is
# long enough (``min_ctx``) and the model exposes Qwen3-Next attention.
if snapkv is not None and input_embeddings is None and len(prompt) > 0:
prompt = _maybe_snapkv_prefill(
prompt=prompt,
model=model,
prompt_cache=prompt_cache,
opts=snapkv,
prefill_step_size=prefill_step_size,
)

prompt_progress_callback = prompt_progress_callback or (lambda *_: None)

quantize_cache_fn = functools.partial(
Expand Down
108 changes: 108 additions & 0 deletions mlx_lm/models/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,114 @@ def nbytes(self):
return self.keys.nbytes + self.values.nbytes


class SnapKVCache(_BaseCache):
"""Fixed-shape KV cache for SnapKV-style content-aware compression.

Constructed by :func:`mlx_lm.snapkv.snapkv_prefill_and_trim` after a
one-shot prefill of the prompt: K/V positions are scored using the last
``obs_window`` queries, top-K positions are kept (plus an attention
sink and a recent-tokens window), and the un-selected K/V entries are
physically dropped.

The cache layout (axis=2 is sequence):

``[ 0 : n_pin ) pinned (sink + selected middle positions)``
``[ n_pin : n_keep ) sliding recent-tokens window``

During decode, new K/V enters at the tail of the recent window and the
oldest recent entry is evicted, so the buffer shape is **constant**
across decode steps. This keeps MLX's
``mx.fast.scaled_dot_product_attention`` kernel plan stable across
iterations.

The kept K vectors retain their original RoPE positions; they are
**not** re-encoded after trim. The ``offset`` property therefore
returns the *logical* prompt length (so the next query's RoPE offset
is correct) rather than the smaller physical cache length.

This cache is **not trimmable** (``trim_prompt_cache`` will refuse it),
not quantizable, and not persistable: it represents a lossy summary
of the prompt, not the full prompt.
"""

def __init__(
self,
keys: mx.array,
values: mx.array,
logical_offset: int,
*,
n_pin: int,
):
B, KV_H, n_keep, _ = keys.shape
if n_pin > n_keep:
raise ValueError(f"n_pin={n_pin} > n_keep={n_keep}")
self._n_keep = int(n_keep)
self._n_pin = int(n_pin)
self.keys = keys
self.values = values
self._logical = int(logical_offset)

@property
def offset(self) -> int:
return self._logical

@offset.setter
def offset(self, value: int) -> None:
self._logical = int(value)

def update_and_fetch(self, keys: mx.array, values: mx.array):
# Slide the recent-tokens window: pinned | recent[L:] | new.
L = keys.shape[-2]
n_pin = self._n_pin
n_keep = self._n_keep
self.keys = mx.concatenate(
[
self.keys[..., :n_pin, :],
self.keys[..., n_pin + L : n_keep, :],
keys,
],
axis=-2,
)
self.values = mx.concatenate(
[
self.values[..., :n_pin, :],
self.values[..., n_pin + L : n_keep, :],
values,
],
axis=-2,
)
self._logical += L
return self.keys, self.values

def size(self) -> int:
return self._n_keep

@property
def state(self):
return self.keys, self.values

@state.setter
def state(self, v):
self.keys, self.values = v
self._n_keep = self.keys.shape[-2]

def is_trimmable(self) -> bool:
return False

@property
def nbytes(self) -> int:
if self.keys is None:
return 0
return self.keys.nbytes + self.values.nbytes

def empty(self) -> bool:
return self.keys is None

def make_mask(self, *args, **kwargs):
# Buffer is fully populated; no masking required for SDPA.
return None


class RotatingKVCache(_BaseCache):
step = 256

Expand Down
Loading
Loading