From 3b2a129240c01296c8b3e3a0ed068fb482a3f152 Mon Sep 17 00:00:00 2001 From: benjamin-levin Date: Mon, 18 May 2026 20:11:32 -0400 Subject: [PATCH] generate: add prompt lookup decoding (PLD) with bit-exact rollback Adds a Prompt Lookup Decoding (PLD) generator that drafts from n-gram matches in the prompt and verifies in a single forward, avoiding the need for a separate draft model. Rollback on partial accept is exact (snapshot + restore + re-forward of the accepted prefix), so PLD output matches plain auto-regressive decoding under the same sampler. The cache module gains generic snap() / restore() hooks plus snapshot_prompt_cache / restore_prompt_cache helpers. Unlike trim_prompt_cache, these work for non-trimmable recurrent caches (ArraysCache used by Gated Delta Net / Mamba-style models), so PLD is bit-exact across every built-in cache type. Wiring: - mlx_lm.generate gains --prompt-lookup-num-tokens. - stream_generate accepts prompt_lookup_num_tokens=N (mutually exclusive with draft_model). Tests: - tests/test_models.py: snap/restore round-trip per cache class plus module-level helpers. - tests/test_generate.py: n-gram lookup helper, step-generator yield shape, AR-equivalence under greedy sampler, arg validation, and stream_generate wiring. --- mlx_lm/generate.py | 270 ++++++++++++++++++++++++++++++++++++++++- mlx_lm/models/cache.py | 124 +++++++++++++++++++ tests/test_generate.py | 150 +++++++++++++++++++++++ tests/test_models.py | 217 +++++++++++++++++++++++++++++++++ 4 files changed, 760 insertions(+), 1 deletion(-) diff --git a/mlx_lm/generate.py b/mlx_lm/generate.py index 3573b2640..44346d254 100644 --- a/mlx_lm/generate.py +++ b/mlx_lm/generate.py @@ -219,6 +219,14 @@ def setup_arg_parser(): help="Number of tokens to draft when using speculative decoding.", default=3, ) + parser.add_argument( + "--prompt-lookup-num-tokens", + type=int, + help="If set, use Prompt Lookup Decoding (PLD) drafting this many " + "tokens per cycle from prompt n-grams. No draft model required. " + "Mutually exclusive with --draft-model.", + default=None, + ) return parser @@ -654,12 +662,252 @@ def _draft_generate(y, num_draft): _rewind_cache(num_draft, n) +def _pld_find_draft( + generated: List[int], + prompt: List[int], + k_lookback: int, + k_lookahead: int, +) -> List[int]: + """ + Find a PLD draft: match the last ``k_lookback`` generated tokens against + the prompt right-to-left and return the next ``k_lookahead`` prompt + tokens after the most recent match. Empty list if no match (caller + falls back to plain auto-regressive decoding). + """ + if len(generated) < k_lookback: + return [] + suffix = generated[-k_lookback:] + for i in range(len(prompt) - k_lookback, -1, -1): + if prompt[i : i + k_lookback] == suffix: + return prompt[i + k_lookback : i + k_lookback + k_lookahead] + return [] + + +def prompt_lookup_generate_step( + prompt: mx.array, + model: nn.Module, + *, + prompt_lookup_num_tokens: int = 5, + prompt_lookup_min_match: int = 3, + max_tokens: int = 256, + sampler: Optional[Callable[[mx.array], mx.array]] = None, + logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None, + prompt_cache: Optional[Any] = None, + prefill_step_size: int = 512, + kv_bits: Optional[int] = None, + kv_group_size: int = 64, + quantized_kv_start: int = 0, +) -> Generator[Tuple[mx.array, mx.array, bool], None, None]: + """ + Prompt Lookup Decoding (PLD): a training-free speculative-decoding + variant where the "draft" is an n-gram lookup in the prompt rather than + a separate draft model. At each step the last few generated tokens are + matched against the prompt; the next ``prompt_lookup_num_tokens`` tokens + in the prompt after the match are verified in a single forward pass. + + Acceptance is highest when generation reproduces prompt content + (translation, code edit, summarisation, RAG answer); otherwise PLD + falls back to plain auto-regressive decoding for that step. + + Rollback on partial accept uses snapshot + restore + re-forward of the + accepted prefix. This is bit-exact for both KV caches and non-trimmable + recurrent caches (e.g. ``ArraysCache`` used by Gated Delta Net / + Mamba-style models), so PLD output matches plain AR decoding under the + same sampler. + + Args: + prompt (mx.array): The input prompt. + model (nn.Module): The model to use for generation. + prompt_lookup_num_tokens (int): Lookahead depth -- maximum number of + tokens drafted from the prompt per cycle. Default: ``5``. + prompt_lookup_min_match (int): N-gram length used to match the + suffix of the generated text against the prompt. Default: ``3``. + max_tokens (int): Maximum number of tokens to generate. Default: + ``256``. + sampler (Callable, optional): Sampler from log probabilities. + Default: argmax (greedy). + logits_processors (List[Callable], optional): Optional logits + processors. Default: ``None``. + prompt_cache (List[Any], optional): A pre-computed prompt cache. + Updated in place if provided. + prefill_step_size (int): Step size for processing the prompt. + kv_bits, kv_group_size, quantized_kv_start: KV-cache quantisation + options (see :func:`generate_step`). + + Yields: + Tuple[mx.array, mx.array, bool]: One token id, the vector of log + probabilities, and a bool that is ``True`` when the token was + drafted from the prompt and accepted (the analogue of + ``from_draft`` in :func:`speculative_generate_step`). + """ + if prompt_lookup_num_tokens < 1: + raise ValueError("prompt_lookup_num_tokens must be >= 1") + if prompt_lookup_min_match < 1: + raise ValueError("prompt_lookup_min_match must be >= 1") + + y = prompt.astype(mx.uint32) + prompt_ids = prompt.tolist() if hasattr(prompt, "tolist") else list(prompt) + generated: List[int] = [] + prev_tokens: Optional[mx.array] = None + + if prompt_cache is None: + model_cache = cache.make_prompt_cache(model) + else: + model_cache = prompt_cache + + sampler = sampler or (lambda x: mx.argmax(x, axis=-1)) + + quantize_cache_fn = functools.partial( + maybe_quantize_kv_cache, + quantized_kv_start=quantized_kv_start, + kv_group_size=kv_group_size, + kv_bits=kv_bits, + ) + + def _process_and_sample(tokens, logits): + if logits_processors: + for processor in logits_processors: + logits = processor(tokens, logits) + logprobs = logits - mx.logsumexp(logits, axis=-1, keepdims=True) + return sampler(logprobs), logprobs + + def _step(y_in, n_predict=1): + with mx.stream(generation_stream): + logits = model(y_in[None], cache=model_cache) + logits = logits[:, -n_predict:, :] + quantize_cache_fn(model_cache) + if logits_processors: + nonlocal prev_tokens + out_y, out_logprobs = [], [] + y_seen = y_in[: -(n_predict - 1)] if n_predict > 1 else y_in + for i in range(n_predict): + prev_tokens = ( + mx.concatenate([prev_tokens, y_seen]) + if prev_tokens is not None + else y_seen + ) + tok, lp = _process_and_sample(prev_tokens, logits[:, i, :]) + out_y.append(tok) + out_logprobs.append(lp) + return mx.concatenate(out_y, axis=0), mx.concatenate( + out_logprobs, axis=0 + ) + else: + return _process_and_sample(None, logits.squeeze(0)) + + def _prefill(y_in): + while y_in.size > 1: + n_to_process = min(prefill_step_size, y_in.size - 1) + model(y_in[:n_to_process][None], cache=model_cache) + quantize_cache_fn(model_cache) + mx.eval([c.state for c in model_cache]) + y_in = y_in[n_to_process:] + mx.clear_cache() + return y_in + + with mx.stream(generation_stream): + y = _prefill(y) + + ntoks = 0 + try: + while ntoks < max_tokens: + # Look up a draft in the prompt using the recent generation tail. + # If we have not generated enough tokens yet (or the recent tail + # has no n-gram match in the prompt), fall back to plain AR. + draft = _pld_find_draft( + generated, + prompt_ids, + k_lookback=prompt_lookup_min_match, + k_lookahead=min(prompt_lookup_num_tokens, max_tokens - ntoks - 1), + ) + + if not draft: + # AR fallback: single-token forward. + tok, lp = _step(y, n_predict=1) + mx.eval(tok) + tok_id = tok.item() + generated.append(tok_id) + ntoks += 1 + yield tok_id, lp, False + if ntoks == max_tokens: + break + y = tok + continue + + # Verify [y, draft_1, ..., draft_k] in one forward. Snapshot the + # pre-verify state of every cache so we can roll back on partial + # accept; the snapshot is by-reference and is dropped on full + # accept. + k = len(draft) + pre_snap = cache.snapshot_prompt_cache(model_cache) + verify_in = mx.concatenate([y, mx.array(draft, mx.uint32)]) + tokens, logprobs = _step(verify_in, n_predict=k + 1) + mx.eval(tokens) + tokens_list = tokens.tolist() + + # Count accepted draft tokens (longest matching prefix). + n_accept = 0 + for i in range(k): + if tokens_list[i] != draft[i]: + break + n_accept += 1 + + # Yield accepted drafts. + for i in range(n_accept): + if ntoks == max_tokens: + break + generated.append(tokens_list[i]) + ntoks += 1 + yield tokens_list[i], logprobs[i], True + + if ntoks < max_tokens: + # Yield the bonus / correction token from the verify forward. + bonus = tokens_list[n_accept] + generated.append(bonus) + ntoks += 1 + yield bonus, logprobs[n_accept], False + + if ntoks == max_tokens: + break + + if n_accept == k: + # Full accept: cache already at the correct post-verify + # state, drop the snapshot. + y = mx.array([tokens_list[k]], mx.uint32) + else: + # Partial accept: restore both caches to pre-verify state + # and re-forward the accepted prefix so the cache lands at + # exactly "after y + accepted drafts". This is bit-exact + # for both KV and recurrent caches and costs one extra + # forward of size (n_accept + 1) tokens. + cache.restore_prompt_cache(model_cache, pre_snap) + accepted = mx.array(draft[:n_accept], mx.uint32) + reforward = mx.concatenate([y, accepted]) + with mx.stream(generation_stream): + model(reforward[None], cache=model_cache) + quantize_cache_fn(model_cache) + # Trim any logits-processor history that overshot the + # accepted range (the verify pulled (k+1) tokens through + # the processor; we only accepted (n_accept + 1)). + if prev_tokens is not None and logits_processors: + overshoot = k - n_accept + if overshoot > 0: + prev_tokens = prev_tokens[:-overshoot] + y = mx.array([tokens_list[n_accept]], mx.uint32) + finally: + # Nothing to roll back on early exit: every yielded token's cache + # state was either left as-is (full accept) or already restored + + # re-forwarded (partial accept) before the next iteration. + pass + + def stream_generate( model: nn.Module, tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper], prompt: Union[str, mx.array, List[int]], max_tokens: int = 256, draft_model: Optional[nn.Module] = None, + prompt_lookup_num_tokens: Optional[int] = None, **kwargs, ) -> Generator[GenerationResponse, None, None]: """ @@ -675,6 +923,10 @@ def stream_generate( draft_model (Optional[nn.Module]): An optional draft model. If provided then speculative decoding is used. The draft model must use the same tokenizer as the main model. Default: ``None``. + prompt_lookup_num_tokens (Optional[int]): If set, use Prompt Lookup + Decoding (PLD) with this many drafted tokens per cycle. PLD drafts + from n-gram matches in the prompt -- no draft model is required. + Mutually exclusive with ``draft_model``. Default: ``None``. kwargs: The remaining options get passed to :func:`generate_step`. See :func:`generate_step` for more details. @@ -698,13 +950,28 @@ def stream_generate( kwargs["max_tokens"] = max_tokens - if draft_model is None: + if draft_model is not None and prompt_lookup_num_tokens is not None: + raise ValueError( + "draft_model and prompt_lookup_num_tokens are mutually exclusive." + ) + if draft_model is None and prompt_lookup_num_tokens is None: kwargs.pop("num_draft_tokens", None) token_generator = generate_step(prompt, model, **kwargs) # from_draft always false for non-speculative generation token_generator = ( (token, logprobs, False) for token, logprobs in token_generator ) + elif prompt_lookup_num_tokens is not None: + kwargs.pop("max_kv_size", None) + kwargs.pop("prompt_progress_callback", None) + kwargs.pop("num_draft_tokens", None) + kwargs.pop("input_embeddings", None) + token_generator = prompt_lookup_generate_step( + prompt, + model, + prompt_lookup_num_tokens=prompt_lookup_num_tokens, + **kwargs, + ) else: kwargs.pop("max_kv_size", None) kwargs.pop("prompt_progress_callback", None) @@ -2083,6 +2350,7 @@ def main(): quantized_kv_start=args.quantized_kv_start, draft_model=draft_model, num_draft_tokens=args.num_draft_tokens, + prompt_lookup_num_tokens=args.prompt_lookup_num_tokens, ) if not args.verbose: print(response) diff --git a/mlx_lm/models/cache.py b/mlx_lm/models/cache.py index b84c9d650..90b2390de 100644 --- a/mlx_lm/models/cache.py +++ b/mlx_lm/models/cache.py @@ -111,6 +111,45 @@ def trim_prompt_cache(cache: List[Any], num_tokens: int) -> List[Any]: return [c.trim(num_tokens) for c in cache][0] +def snapshot_prompt_cache(cache: List[Any]) -> List[Any]: + """ + Snapshot the current state of every cache for exact rollback. + + Unlike :func:`trim_prompt_cache`, this works for *all* cache types, + including non-trimmable recurrent caches (e.g. :class:`ArraysCache` + used by Gated Delta Net / Mamba-style models). Pair with + :func:`restore_prompt_cache` to revert the cache state after a + speculative-decoding verify forward. + + The snapshot is taken by reference where possible (no tensor copy) and + is therefore cheap to take and discard. It is safe as long as the + cache implementations only *append* or *replace by reference* during + intervening forwards -- which is true of every built-in cache here. + + Args: + cache (List[Any]): The model's cache. + + Returns: + List[Any]: An opaque list of per-cache snapshots, suitable for + passing to :func:`restore_prompt_cache`. + """ + return [c.snap() for c in cache] + + +def restore_prompt_cache(cache: List[Any], snapshot: List[Any]) -> None: + """ + Restore every cache to the state captured by + :func:`snapshot_prompt_cache`. + + Args: + cache (List[Any]): The model's cache. + snapshot (List[Any]): The snapshot returned by + :func:`snapshot_prompt_cache`. + """ + for c, s in zip(cache, snapshot): + c.restore(s) + + def create_attention_mask( N: int, offset: int, return_array: bool, window_size: Optional[int] ): @@ -146,6 +185,20 @@ def meta_state(self, v): def is_trimmable(self): return False + def snap(self): + """ + Snapshot the cache state for exact rollback. Override in + subclasses that mutate state during a forward. + """ + return None + + def restore(self, snapshot): + """ + Restore the cache to the state captured by :meth:`snap`. Default + is a no-op for caches with no mutable state. + """ + return + def size(self): """ Return the size (i.e. sequence length) of the cache. @@ -216,6 +269,15 @@ def trim(self, n): self.offset -= n return n + def snap(self): + # ConcatenateKVCache replaces keys/values by mx.concatenate, so the + # snapshot must hold the pre-update array references along with the + # offset; restoring all three reverts any intervening appends. + return (self.offset, self.keys, self.values) + + def restore(self, snapshot): + self.offset, self.keys, self.values = snapshot + def make_mask(self, *args, **kwargs): return create_attention_mask(*args, offset=self.offset, **kwargs) @@ -311,6 +373,14 @@ def trim(self, n): self.offset -= n return n + def snap(self): + # Writes happen in pre-allocated buffer at positions >= self.offset; + # restoring offset hides them (next write reuses the slots). + return self.offset + + def restore(self, snapshot): + self.offset = snapshot + def make_mask(self, *args, **kwargs): return create_attention_mask(*args, offset=self.offset, **kwargs) @@ -380,6 +450,14 @@ def trim(self, n): self.offset -= n return n + def snap(self): + # Writes happen at positions [prev_offset, offset); restoring the + # offset reverts the verify-forward writes without copying tensors. + return self.offset + + def restore(self, snapshot): + self.offset = snapshot + def to_quantized(self, group_size: int = 64, bits: int = 4) -> QuantizedKVCache: quant_cache = QuantizedKVCache(group_size=group_size, bits=bits) quant_cache.offset = self.offset @@ -548,6 +626,15 @@ def trim(self, n): self._idx -= n return n + def snap(self): + # Pre-rotation writes are in-place at [_idx, _idx + S); restoring + # offset and _idx reverts them. Post-rotation cases also restore + # by reverting _idx. + return (self.offset, self._idx) + + def restore(self, snapshot): + self.offset, self._idx = snapshot + def to_quantized(self, group_size: int = 64, bits: int = 4) -> QuantizedKVCache: raise NotImplementedError("RotatingKVCache Quantization NYI") @@ -688,6 +775,18 @@ def advance(self, N): if self.left_padding is not None: self.left_padding -= N + def snap(self): + # The model __call__ reassigns cache[i] by reference and advance() + # reassigns lengths/left_padding via -=, so a by-reference snapshot + # of the list and these two arrays is sufficient -- the originals + # are never mutated in place during a forward. + return (list(self.cache), self.lengths, self.left_padding) + + def restore(self, snapshot): + self.cache = list(snapshot[0]) + self.lengths = snapshot[1] + self.left_padding = snapshot[2] + def make_mask(self, N: int): if self.left_padding is not None: pos = mx.arange(N) @@ -793,6 +892,12 @@ def trim(self, n): self.offset -= n return n + def snap(self): + return (self.offset, self.start_position) + + def restore(self, snapshot): + self.offset, self.start_position = snapshot + @property def meta_state(self): return tuple(map(str, (self.chunk_size, self.start_position))) @@ -826,6 +931,13 @@ def trim(self, n): m = c.trim(n) return m + def snap(self): + return [c.snap() for c in self.caches] + + def restore(self, snapshot): + for c, s in zip(self.caches, snapshot): + c.restore(s) + @property def state(self): return [c.state for c in self.caches] @@ -1008,6 +1120,12 @@ def trim(self, n): self.offset -= n return n + def snap(self): + return (self._idx, self.offset) + + def restore(self, snapshot): + self._idx, self.offset = snapshot + def make_mask(self, N: int, return_array: bool = False, **kwargs): return create_causal_mask( N, offset=self._idx, left_padding=self.left_padding, **kwargs @@ -1324,6 +1442,12 @@ def trim(self, n): self.offset -= n return n + def snap(self): + return (self._offset, self._idx, self.offset) + + def restore(self, snapshot): + self._offset, self._idx, self.offset = snapshot + def to_quantized(self, group_size: int = 64, bits: int = 4) -> QuantizedKVCache: raise NotImplementedError("BatchRotatingKVCache Quantization NYI") diff --git a/tests/test_generate.py b/tests/test_generate.py index 4f5bb4c91..3ec9ca75f 100644 --- a/tests/test_generate.py +++ b/tests/test_generate.py @@ -807,5 +807,155 @@ def test_batch_max_kv_size_none_creates_regular_cache(self): self.assertIsInstance(cache, KVCache) +class TestPromptLookupDecoding(unittest.TestCase): + """Tests for Prompt Lookup Decoding (PLD) and its bit-exact rollback. + + Mirrors the structure of TestGenerate: loads a small chat model once, + then runs pure-Python tests for the n-gram lookup helper plus + end-to-end tests against the model. + """ + + @classmethod + def setUpClass(cls): + cls.HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit" + cls.model, cls.tokenizer = load(cls.HF_MODEL_PATH) + cls.model.set_dtype(mx.float32) + + def test_pld_find_draft_returns_match(self): + from mlx_lm.generate import _pld_find_draft + + prompt = [10, 20, 30, 40, 50, 60, 70] + generated = [99, 40, 50] + # last 2 generated == [40, 50], match in prompt at idx 3,4 -> next 3 are [60, 70] + out = _pld_find_draft(generated, prompt, k_lookback=2, k_lookahead=3) + self.assertEqual(out, [60, 70]) + + def test_pld_find_draft_no_match(self): + from mlx_lm.generate import _pld_find_draft + + prompt = [10, 20, 30] + generated = [100, 200, 300] + out = _pld_find_draft(generated, prompt, k_lookback=2, k_lookahead=3) + self.assertEqual(out, []) + + def test_pld_find_draft_picks_most_recent(self): + from mlx_lm.generate import _pld_find_draft + + # Two matches; loop goes right-to-left and returns the most recent. + prompt = [1, 2, 3, 99, 1, 2, 4, 5] + generated = [9, 1, 2] + out = _pld_find_draft(generated, prompt, k_lookback=2, k_lookahead=2) + self.assertEqual(out, [4, 5]) + + def test_pld_find_draft_insufficient_history(self): + from mlx_lm.generate import _pld_find_draft + + prompt = [1, 2, 3, 4] + generated = [1] # k_lookback=2 > len(generated) + out = _pld_find_draft(generated, prompt, k_lookback=2, k_lookahead=2) + self.assertEqual(out, []) + + def test_prompt_lookup_generate_step_yield_shape(self): + # Smoke test the step generator's output shape: each yield is + # (int_token, mx.array_logprobs, bool_from_draft). + from mlx_lm.generate import prompt_lookup_generate_step + + prompt = self.tokenizer.encode("hello world", return_tensors="mlx")[0] + n = 0 + for tok, lp, from_draft in prompt_lookup_generate_step( + prompt, + self.model, + prompt_lookup_num_tokens=3, + prompt_lookup_min_match=2, + max_tokens=4, + ): + self.assertIsInstance(tok, int) + self.assertIsInstance(lp, mx.array) + self.assertIsInstance(from_draft, bool) + n += 1 + self.assertEqual(n, 4) + + def test_prompt_lookup_generate_step_matches_ar(self): + # PLD with bit-exact rollback must yield the same tokens as plain + # auto-regressive decoding under the same (greedy) sampler. + from mlx_lm.generate import ( + generate_step, + prompt_lookup_generate_step, + ) + + # A prompt with enough internal repetition to give PLD something + # to draft from (otherwise it just falls back to AR every step). + prompt = self.tokenizer.encode( + "The cat sat on the mat. The cat sat on the mat. The cat sat on", + return_tensors="mlx", + )[0] + + ar_ids = [] + for i, (tok, _) in enumerate(generate_step(prompt, self.model)): + ar_ids.append(tok) + if i + 1 == 8: + break + + pld_ids = [] + for tok, _, _ in prompt_lookup_generate_step( + prompt, + self.model, + prompt_lookup_num_tokens=4, + prompt_lookup_min_match=2, + max_tokens=8, + ): + pld_ids.append(tok) + + self.assertEqual(pld_ids, ar_ids) + + def test_prompt_lookup_generate_step_rejects_bad_args(self): + from mlx_lm.generate import prompt_lookup_generate_step + + prompt = self.tokenizer.encode("hi", return_tensors="mlx")[0] + with self.assertRaises(ValueError): + next( + prompt_lookup_generate_step( + prompt, self.model, prompt_lookup_num_tokens=0, max_tokens=1 + ) + ) + with self.assertRaises(ValueError): + next( + prompt_lookup_generate_step( + prompt, self.model, prompt_lookup_min_match=0, max_tokens=1 + ) + ) + + def test_stream_generate_prompt_lookup(self): + # End-to-end wiring: stream_generate(prompt_lookup_num_tokens=...) + # routes through prompt_lookup_generate_step and surfaces tokens. + prompt = self.tokenizer.apply_chat_template( + [{"role": "user", "content": "hello"}], + add_generation_prompt=True, + ) + n = 0 + for resp in stream_generate( + self.model, + self.tokenizer, + prompt, + max_tokens=4, + prompt_lookup_num_tokens=3, + ): + n += 1 + self.assertEqual(n, 4) + + def test_stream_generate_prompt_lookup_conflicts_with_draft(self): + # draft_model and prompt_lookup_num_tokens are mutually exclusive. + with self.assertRaises(ValueError): + for _ in stream_generate( + self.model, + self.tokenizer, + "hello", + max_tokens=2, + draft_model=self.model, + prompt_lookup_num_tokens=2, + ): + pass + + if __name__ == "__main__": unittest.main() diff --git a/tests/test_models.py b/tests/test_models.py index 6e1fcd96e..98a2df40f 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -3221,5 +3221,222 @@ def test_gated_delta_masked(self): self.assertTrue(mx.allclose(st, st_gt, rtol=1e-4, atol=1e-3)) +class TestCacheSnapRestore(unittest.TestCase): + """Snap/restore round-trip tests for every built-in cache class. + + The contract: take a snapshot, perform additional updates, restore, + and confirm the cache state matches what would be produced by + stopping at the snapshot point. + """ + + def test_base_cache_snap_is_noop(self): + from mlx_lm.models.cache import _BaseCache + + c = _BaseCache() + snap = c.snap() + self.assertIsNone(snap) + # restore must accept and be a no-op. + c.restore(snap) + + def test_kv_cache_snap_restore(self): + from mlx_lm.models.cache import KVCache + + cache = KVCache() + x1 = mx.random.uniform(shape=(1, 2, 3, 8)) + cache.update_and_fetch(x1, x1) + + snap = cache.snap() + offset_at_snap = cache.offset + keys_at_snap = cache.keys[..., :offset_at_snap, :].astype(mx.float32) + + # Append more, then roll back. + x2 = mx.random.uniform(shape=(1, 2, 5, 8)) + cache.update_and_fetch(x2, x2) + self.assertEqual(cache.offset, 8) + + cache.restore(snap) + self.assertEqual(cache.offset, offset_at_snap) + self.assertTrue( + mx.array_equal( + cache.keys[..., :offset_at_snap, :].astype(mx.float32), keys_at_snap + ) + ) + + # Subsequent writes land at the rolled-back offset. + x3 = mx.random.uniform(shape=(1, 2, 2, 8)) + cache.update_and_fetch(x3, x3) + self.assertEqual(cache.offset, offset_at_snap + 2) + self.assertTrue( + mx.array_equal(cache.keys[..., offset_at_snap : offset_at_snap + 2, :], x3) + ) + + def test_quantized_kv_cache_snap_restore(self): + from mlx_lm.models.cache import QuantizedKVCache + + cache = QuantizedKVCache(group_size=64, bits=8) + x1 = mx.random.uniform(shape=(1, 2, 3, 64)) + cache.update_and_fetch(x1, x1) + snap = cache.snap() + offset_at_snap = cache.offset + + x2 = mx.random.uniform(shape=(1, 2, 4, 64)) + cache.update_and_fetch(x2, x2) + self.assertEqual(cache.offset, 7) + + cache.restore(snap) + self.assertEqual(cache.offset, offset_at_snap) + + # Next write reuses the slots starting at restored offset. + x3 = mx.random.uniform(shape=(1, 2, 2, 64)) + k_out, _ = cache.update_and_fetch(x3, x3) + self.assertEqual(cache.offset, offset_at_snap + 2) + self.assertEqual(k_out[0].shape[-2], offset_at_snap + 2) + + def test_rotating_kv_cache_snap_restore(self): + from mlx_lm.models.cache import RotatingKVCache + + cache = RotatingKVCache(max_size=8, keep=0) + x = mx.random.uniform(shape=(1, 2, 3, 4)) + cache.update_and_fetch(x, x) + snap = cache.snap() + offset_at_snap, idx_at_snap = cache.offset, cache._idx + + # Do a couple of single-step writes (the path snap/restore is built + # for: speculative verify is always shape[2] == 1 per step). + for _ in range(3): + y = mx.random.uniform(shape=(1, 2, 1, 4)) + cache.update_and_fetch(y, y) + self.assertEqual(cache.offset, 6) + + cache.restore(snap) + self.assertEqual(cache.offset, offset_at_snap) + self.assertEqual(cache._idx, idx_at_snap) + + def test_chunked_kv_cache_snap_restore(self): + from mlx_lm.models.cache import ChunkedKVCache + + cache = ChunkedKVCache(chunk_size=16) + x1 = mx.random.uniform(shape=(1, 2, 3, 8)) + cache.update_and_fetch(x1, x1) + snap = cache.snap() + offset_at_snap = cache.offset + start_at_snap = cache.start_position + + x2 = mx.random.uniform(shape=(1, 2, 4, 8)) + cache.update_and_fetch(x2, x2) + self.assertEqual(cache.offset, 7) + + cache.restore(snap) + self.assertEqual(cache.offset, offset_at_snap) + self.assertEqual(cache.start_position, start_at_snap) + + def test_concatenate_kv_cache_snap_restore(self): + from mlx_lm.models.cache import ConcatenateKVCache + + cache = ConcatenateKVCache() + x1 = mx.random.uniform(shape=(1, 2, 3, 8)) + cache.update_and_fetch(x1, x1) + snap = cache.snap() + offset_at_snap = cache.offset + keys_at_snap = cache.keys.astype(mx.float32) + values_at_snap = cache.values.astype(mx.float32) + + # Concatenate replaces keys/values by reference; snapshot must + # restore the pre-update references. + x2 = mx.random.uniform(shape=(1, 2, 2, 8)) + cache.update_and_fetch(x2, x2) + self.assertEqual(cache.offset, 5) + + cache.restore(snap) + self.assertEqual(cache.offset, offset_at_snap) + self.assertTrue(mx.array_equal(cache.keys.astype(mx.float32), keys_at_snap)) + self.assertTrue(mx.array_equal(cache.values.astype(mx.float32), values_at_snap)) + + def test_arrays_cache_snap_restore(self): + from mlx_lm.models.cache import ArraysCache + + cache = ArraysCache(size=2, left_padding=[1, 2]) + a0 = mx.random.uniform(shape=(2, 4, 8)) + a1 = mx.random.uniform(shape=(2, 4, 8)) + cache[0] = a0 + cache[1] = a1 + # The model __call__ also reassigns lengths via advance(); model it + # here by setting one explicitly. + cache.prepare(lengths=[3, 4]) + snap = cache.snap() + lengths_at_snap = cache.lengths + left_padding_at_snap = cache.left_padding + + # Mutate as a forward would: reassign list slots by reference and + # advance() consumes one token. + new_a0 = mx.random.uniform(shape=(2, 4, 8)) + cache[0] = new_a0 + cache.advance(1) + self.assertFalse(mx.array_equal(cache.lengths, lengths_at_snap)) + + cache.restore(snap) + # Restored entries are bit-exact references to the originals. + self.assertTrue(mx.array_equal(cache[0], a0)) + self.assertTrue(mx.array_equal(cache[1], a1)) + self.assertTrue(mx.array_equal(cache.lengths, lengths_at_snap)) + self.assertTrue(mx.array_equal(cache.left_padding, left_padding_at_snap)) + + def test_cache_list_snap_restore(self): + from mlx_lm.models.cache import CacheList, KVCache + + c0, c1 = KVCache(), KVCache() + cl = CacheList(c0, c1) + x = mx.random.uniform(shape=(1, 2, 3, 8)) + c0.update_and_fetch(x, x) + c1.update_and_fetch(x, x) + + snap = cl.snap() + # Mutate underlying caches. + y = mx.random.uniform(shape=(1, 2, 2, 8)) + c0.update_and_fetch(y, y) + c1.update_and_fetch(y, y) + self.assertEqual(c0.offset, 5) + self.assertEqual(c1.offset, 5) + + cl.restore(snap) + self.assertEqual(c0.offset, 3) + self.assertEqual(c1.offset, 3) + + def test_snapshot_restore_helpers(self): + # Module-level helpers iterate over the cache list and dispatch to + # each cache's snap/restore. + from mlx_lm.models.cache import ( + ArraysCache, + KVCache, + restore_prompt_cache, + snapshot_prompt_cache, + ) + + kv = KVCache() + x = mx.random.uniform(shape=(1, 2, 3, 8)) + kv.update_and_fetch(x, x) + + ac = ArraysCache(size=1) + ac_a0 = mx.random.uniform(shape=(2, 4, 8)) + ac[0] = ac_a0 + + cache = [kv, ac] + snap = snapshot_prompt_cache(cache) + self.assertEqual(len(snap), 2) + + # Mutate both caches. + y = mx.random.uniform(shape=(1, 2, 2, 8)) + kv.update_and_fetch(y, y) + new_ac0 = mx.random.uniform(shape=(2, 4, 8)) + ac[0] = new_ac0 + self.assertEqual(kv.offset, 5) + self.assertFalse(mx.array_equal(ac[0], ac_a0)) + + restore_prompt_cache(cache, snap) + self.assertEqual(kv.offset, 3) + # ArraysCache restored by-reference to the original entry. + self.assertTrue(mx.array_equal(ac[0], ac_a0)) + + if __name__ == "__main__": unittest.main()